From 3e796ac1625b68e6c439b3214e15aef0a4a58344 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 23 Oct 2024 15:16:58 +0800 Subject: [PATCH 01/39] =?UTF-8?q?feat=20fix=20=20=20=20=E5=8F=96=E6=B6=88?= =?UTF-8?q?=E5=BD=92=E4=B8=80=E5=8C=96=E9=A2=84=E5=A4=84=E7=90=86=EF=BC=8C?= =?UTF-8?q?=E5=9B=A0=E4=B8=BA=E8=AE=AD=E7=BB=83=E6=97=B6=E6=B2=A1=E6=9C=89?= =?UTF-8?q?=E5=81=9A=E8=AF=A5=E6=93=8D=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/utils/design_ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index f4f6a34..267ea00 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -85,7 +85,7 @@ def seg_preprocess(img_path): if ori_shape != (img_scale_w, img_scale_h): # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) - img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape From db30823cf319d5296ad8971d30ce693b6ecc99b0 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 25 Oct 2024 10:19:29 +0800 Subject: [PATCH 02/39] =?UTF-8?q?feat=20=20=20=E6=96=B0=E5=A2=9Esketch=20?= =?UTF-8?q?=E8=83=8C=E5=90=8E=E8=A7=86=E8=A7=92=E5=9B=BE=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/design_generate.py | 1 + app/service/design_fast/item.py | 4 +- app/service/design_fast/pipeline/__init__.py | 2 + .../design_fast/pipeline/back_perspective.py | 79 +++++++++++++++++++ app/service/design_fast/utils/organize.py | 5 +- app/service/utils/new_oss_client.py | 2 +- 6 files changed, 89 insertions(+), 4 deletions(-) create mode 100644 app/service/design_fast/pipeline/back_perspective.py diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index ac1f79c..244e09c 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -81,6 +81,7 @@ def design_generate(request_data): 'mask_url': lay['mask_url'], 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, + 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None, }) items_response['synthesis_url'] = synthesis(layers, new_size, basic) else: diff --git a/app/service/design_fast/item.py b/app/service/design_fast/item.py index e10320d..c1bd467 100644 --- a/app/service/design_fast/item.py +++ b/app/service/design_fast/item.py @@ -1,4 +1,4 @@ -from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection +from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection, BackPerspective class BaseItem: @@ -16,6 +16,7 @@ class TopItem(BaseItem): LoadImage(minio_client), KeyPoint(), Segmentation(minio_client), + BackPerspective(minio_client), Color(minio_client), PrintPainting(minio_client), Scaling(), @@ -36,6 +37,7 @@ class BottomItem(BaseItem): KeyPoint(), ContourDetection(), # Segmentation(), + BackPerspective(minio_client), Color(minio_client), PrintPainting(minio_client), Scaling(), diff --git a/app/service/design_fast/pipeline/__init__.py b/app/service/design_fast/pipeline/__init__.py index ec55933..f265bbe 100644 --- a/app/service/design_fast/pipeline/__init__.py +++ b/app/service/design_fast/pipeline/__init__.py @@ -1,3 +1,4 @@ +from .back_perspective import BackPerspective from .color import Color from .contour_detection import ContourDetection from .keypoint import KeyPoint @@ -13,6 +14,7 @@ __all__ = [ 'KeyPoint', 'ContourDetection', 'Segmentation', + 'BackPerspective', 'Color', 'PrintPainting', 'Scaling', diff --git a/app/service/design_fast/pipeline/back_perspective.py b/app/service/design_fast/pipeline/back_perspective.py new file mode 100644 index 0000000..5ddd37c --- /dev/null +++ b/app/service/design_fast/pipeline/back_perspective.py @@ -0,0 +1,79 @@ +import cv2 +import numpy as np + +from app.service.design_fast.utils.design_ensemble import get_seg_result +from app.service.utils.new_oss_client import oss_upload_image + + +class BackPerspective: + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + + # 如果sketch为系统图 查看是否有对应的 背后视角图 + if result['path'].split('/')[0] == 'aida-sys-image': + file_path = result['path'].replace("images", 'images_back', 1) + if self.is_file_exists(bucket_name='aida-sys-image', file_name=file_path[file_path.find('/') + 1:]): + result['back_perspective_url'] = file_path + return result + else: + seg_result = get_seg_result("1", result['image'])[0] + elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']: + seg_result = result['seg_result'] + else: + seg_result = get_seg_result("1", result['image'])[0] + + m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0)) + back_sketch = result['image'].copy() + back_sketch[m > 100] = 255 + # 上传背后视角图 + _, img_encoded = cv2.imencode(".jpg", back_sketch) + + resp = oss_upload_image(self.minio_client, bucket='test', object_name=result['path'], image_bytes=img_encoded.tobytes()) + result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}" + return result + + def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)): + mask = mask.astype(np.uint8) * 255 + # 查找轮廓 + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # 创建一个彩色副本用于绘制轮廓 + mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + + def thicken_contour_inward(contour, thick): + # 创建一个空白的黑色图像与原始掩码大小相同 + blank = np.zeros_like(mask) + # 在空白图像上绘制白色的轮廓 + cv2.drawContours(blank, [contour], -1, 255, thickness=thick) + # 找到轮廓的中心(可以用重心等方法近似) + M = cv2.moments(contour) + cx = int(M['m10'] / M['m00']) + cy = int(M['m01'] / M['m00']) + # 进行距离变换,离中心越近的值越小 + dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5) + # 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留 + result = np.zeros_like(mask) + for i in range(dist_transform.shape[0]): + for j in range(dist_transform.shape[1]): + if dist_transform[i, j] < thick: + result[i, j] = 255 + return result + + for contour in contours: + thickened_contour = thicken_contour_inward(contour, thickness) + mask_color[thickened_contour > 0] = color + + _, binary_result = cv2.threshold(mask_color, 127, 255, cv2.THRESH_BINARY) + + # 转换为掩码形式 + mask_result = cv2.cvtColor(binary_result, cv2.COLOR_BGR2GRAY) + return mask_result + + def is_file_exists(self, bucket_name, file_name): + try: + self.minio_client.stat_object(bucket_name, file_name) + return True + except Exception: + return False diff --git a/app/service/design_fast/utils/organize.py b/app/service/design_fast/utils/organize.py index 8190de0..ad3cff3 100644 --- a/app/service/design_fast/utils/organize.py +++ b/app/service/design_fast/utils/organize.py @@ -33,8 +33,8 @@ def organize_clothing(layer): mask=cv2.resize(layer['mask'], layer["front_image"].size), gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", pattern_image_url=layer['pattern_image_url'], - pattern_image=layer['pattern_image'] - + pattern_image=layer['pattern_image'], + back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" ) # 后片数据 back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), @@ -50,6 +50,7 @@ def organize_clothing(layer): mask=cv2.resize(layer['mask'], layer["front_image"].size), gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", pattern_image_url=layer['pattern_image_url'], + back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" ) return front_layer, back_layer diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 95a0fbf..178caae 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/31/sketchboard/female/dress/6edcbf92-7da9-4809-a0a8-a4b4f06dec1e0628000041.jpg" + url = "aida-sys-image/images_back/female/trousers/0825000630.jpg" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "cv2" if read_type == "cv2": From b9d2b510a363fa496827e0be4f3fd2c18be1857b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 25 Oct 2024 10:31:04 +0800 Subject: [PATCH 03/39] =?UTF-8?q?feat=20=20=20design=20=E6=97=A0=E8=89=B2?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/color.py | 7 +++++++ app/service/utils/new_oss_client.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/app/service/design_fast/pipeline/color.py b/app/service/design_fast/pipeline/color.py index 546c671..3033bb5 100644 --- a/app/service/design_fast/pipeline/color.py +++ b/app/service/design_fast/pipeline/color.py @@ -14,11 +14,18 @@ class Color: def __call__(self, result): dim_image_h, dim_image_w = result['image'].shape[0:2] + # 渐变色 if "gradient" in result.keys() and result['gradient'] != "": bucket_name = result['gradient'].split('/')[0] object_name = result['gradient'][result['gradient'].find('/') + 1:] pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name) resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + # 无色 + elif "color" not in result.keys() or result['color'] == "": + result['final_image'] = result['pattern_image'] = result['single_image'] = result['image'] + result['alpha'] = 100 / 255.0 + return result + # 正常颜色 else: pattern = self.get_pattern(result['color']) resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 178caae..6d644a5 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-sys-image/images_back/female/trousers/0825000630.jpg" + url = "aida-results/result_e961eed6-9278-11ef-a957-0826ae3ad6b3.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "cv2" if read_type == "cv2": From aca90159d3a89de1ebb4844f37f9f0e4edbfdb6e Mon Sep 17 00:00:00 2001 From: xupei Date: Tue, 29 Oct 2024 16:50:46 +0800 Subject: [PATCH 04/39] =?UTF-8?q?=E4=BB=8E=E5=90=91=E9=87=8F=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E4=B8=AD=E6=A3=80=E7=B4=A2=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E5=B9=B6=E9=9B=86=E6=88=90=E5=88=B0chat-robot?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_query_image.py | 36 ++++++++ app/api/api_route.py | 3 +- app/schemas/query_image.py | 6 ++ app/service/chat_robot/script/main.py | 2 +- app/service/chat_robot/script/prompt.py | 49 ++++++---- .../chat_robot/script/service/CallQWen.py | 57 ++++++++++-- app/service/search_image_with_text/service.py | 89 +++++++++++++++++++ 7 files changed, 217 insertions(+), 25 deletions(-) create mode 100644 app/api/api_query_image.py create mode 100644 app/schemas/query_image.py create mode 100644 app/service/search_image_with_text/service.py diff --git a/app/api/api_query_image.py b/app/api/api_query_image.py new file mode 100644 index 0000000..d27c67b --- /dev/null +++ b/app/api/api_query_image.py @@ -0,0 +1,36 @@ +import json +import logging +from http.client import HTTPException + +from fastapi import APIRouter + +from app.schemas.query_image import QueryImageModel +from app.schemas.response_template import ResponseModel +from app.service.search_image_with_text.service import query + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/query_image") +def query_image(request_data: QueryImageModel): + """ + 对话机器人 + 创建一个具有以下参数的请求体: + - **gender**: 性别 + - **content**: 用户输入的内容 + + 示例参数: + { + "gender": "male", + "content": "give me a long sleeve blouse", + } + """ + try: + logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}") + data = query(request_data.gender, request_data.content) + logger.info(f"query_image response @@@@@@:{json.dumps(data)}") + except Exception as e: + logger.warning(f"query_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_route.py b/app/api/api_route.py index 7ee774d..0da3a66 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from app.api import api_attribute_retrieve +from app.api import api_attribute_retrieve, api_query_image from app.api import api_brighten from app.api import api_chat_robot from app.api import api_design @@ -23,3 +23,4 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'], router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") +router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") \ No newline at end of file diff --git a/app/schemas/query_image.py b/app/schemas/query_image.py new file mode 100644 index 0000000..147603f --- /dev/null +++ b/app/schemas/query_image.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class QueryImageModel(BaseModel): + gender: str + content: str diff --git a/app/service/chat_robot/script/main.py b/app/service/chat_robot/script/main.py index cabe372..3342a5c 100644 --- a/app/service/chat_robot/script/main.py +++ b/app/service/chat_robot/script/main.py @@ -100,7 +100,7 @@ def chat(post_data): # session_key=f"buffer:{user_id}:{session_id}", # ) - final_outputs = CallQWen.call_with_messages(input_message) + final_outputs = CallQWen.call_with_messages(input_message, gender) # api_response = { # 'user_id': user_id, # 'session_id': session_id, diff --git a/app/service/chat_robot/script/prompt.py b/app/service/chat_robot/script/prompt.py index a88044d..ad6ac9e 100644 --- a/app/service/chat_robot/script/prompt.py +++ b/app/service/chat_robot/script/prompt.py @@ -1,16 +1,31 @@ +# FASHION_CHAT_BOT_PREFIX = """ +# You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. +# The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. +# Remember your answer should be very precise and the final output answer should not exceed 20 words. +# +# You may encounter the following types of questions: +# 1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. +# Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. +# Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. +# Never query for all the columns from a specific table, only ask for the relevant columns given the question. +# You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. +# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. +# If the question does not seem related to the database, just return "I don't know" as the answer. +# +# 2) If the query related to current events, you should use internet_search to seek help from the internet. +# +# 3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant. +# +# Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential. +# """ + FASHION_CHAT_BOT_PREFIX = """ You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. Remember your answer should be very precise and the final output answer should not exceed 20 words. You may encounter the following types of questions: -1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. -Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. -Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. -Never query for all the columns from a specific table, only ask for the relevant columns given the question. -You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. -If the question does not seem related to the database, just return "I don't know" as the answer. +1) If you need to query information related to clothing retrieval, please use the get_image_from_vector_db tool. 2) If the query related to current events, you should use internet_search to seek help from the internet. @@ -37,15 +52,19 @@ ANSWER_FORMAT_SUFFIX = """ My final answer are limited to 20 words and be as much precise as possible. """ +# TOOLS_FUNCTIONS_SUFFIX = ( +# "If the input involves clothing queries," +# "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." +# "All SQL statements must use 'ORDER BY RAND()', for example:" +# "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" +# "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" +# "If the input does not involve clothing queries, " +# "I should engage in conversation as an assistant or search from internet with internet_search tool." +# "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" +# "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " +# ) TOOLS_FUNCTIONS_SUFFIX = ( - "If the input involves clothing queries," - "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." - "All SQL statements must use 'ORDER BY RAND()', for example:" - "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" - "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" - "If the input does not involve clothing queries, " - "I should engage in conversation as an assistant or search from internet with internet_search tool." - "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" + "If the input involves clothing queries,please use the get_image_from_vector_db tool." "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " ) diff --git a/app/service/chat_robot/script/service/CallQWen.py b/app/service/chat_robot/script/service/CallQWen.py index d2e2c06..33dcd04 100644 --- a/app/service/chat_robot/script/service/CallQWen.py +++ b/app/service/chat_robot/script/service/CallQWen.py @@ -8,6 +8,7 @@ from app.core.config import * from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.database import CustomDatabase from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN +from app.service.search_image_with_text.service import query get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." @@ -32,6 +33,12 @@ query_database_description = ( "order by rand() LIMIT 2'" ) +query_vector_db_description = ( + "Use this tool to find the clothing images that users need. " + "If the user's input includes clothing types such as blouse, skirt, dress, outerwear, pants, or trousers, please use this tool. " + "The input for the tool is the string provided by the user." +) + tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials." "Input is an empty string") @@ -105,15 +112,37 @@ tools = [ "function": { "name": "tutorial_tool", "description": tutorial_description, + # "parameters": { + # "type": "object", + # "properties": { + # "sql_string": { + # "type": "string", + # "description": "由模型生成的sql语句" + # } + # } + # }, + } + }, + { + "type": "function", + "function": { + "name": "get_image_from_vector_db", + "description": query_vector_db_description, "parameters": { - "type": "object", - "properties": { - "sql_string": { - "type": "string", - "description": "由模型生成的sql语句" + "parameters": { + "type": "object", + "properties": { + "gender": { + "type": "string", + "description": "性别" + }, + "content": { + "type": "string", + "description": "用户描述" + } } - } - }, + }, + } } } ] @@ -150,6 +179,10 @@ def query_database(sql_string): return CustomDatabase.run(db, sql_string) +def get_image_from_vector_db(gender, content): + return query(gender, content) + + @retry(exceptions=NewConnectionError, tries=3, delay=1) def get_response(messages): response = Generation.call( @@ -164,7 +197,8 @@ def get_response(messages): return response -def call_with_messages(message): +def call_with_messages(message, gender): + user_input = message print('\n') # messages = [ # { @@ -235,6 +269,12 @@ def call_with_messages(message): tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()} flag = False result_content = tool_info['content'] + elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db': + tool_info = {"name": "get_image_from_vector_db", "role": "tool", + 'content': get_image_from_vector_db(gender, user_input)} + flag = False + result_content = tool_info['content'] + response_type = "image" print(f"工具输出信息:{tool_info['content']}\n") messages.append(tool_info) @@ -257,5 +297,6 @@ def call_with_messages(message): def tutorial_tool(): return TUTORIAL_TOOL_RETURN + if __name__ == '__main__': call_with_messages() diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py new file mode 100644 index 0000000..98f6ac4 --- /dev/null +++ b/app/service/search_image_with_text/service.py @@ -0,0 +1,89 @@ +import chromadb +import hashlib + +import pandas as pd +from chromadb.config import Settings +from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction +from tqdm import tqdm + +# 读取 csv 文件 +csv_file_path = r'D:/Files/csv/output/output.csv' +image_path = r'D:/images-clean' + +df = pd.read_csv(csv_file_path, encoding='Windows-1252') + +# 创建 Chroma 客户端 +client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) +# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) +# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) +# 创建集合 +embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") + + +def create_collection(): + collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) + + # 存储数据,包括自定义属性 + images_description = [] + images_metadata = [] + ids = [] + batch_size = 41666 # 最大批量大小 + for index, row in tqdm(df.iterrows()): + # 将图片的md5作为id + with open(image_path + row['path'], 'rb') as f: + image_data = f.read() + md5_value = hashlib.md5(image_data).hexdigest() + ids.append(md5_value) + images_description.append(row['description']) + images_metadata.append({ + "gender": row['gender'], + "path": row['path'] + }) + + # 将数据添加到集合 + # 每达到 batch_size 就执行一次 upsert + if len(ids) >= batch_size: + collection.upsert( + ids=list(ids), + documents=images_description, + metadatas=images_metadata # 添加自定义属性 + ) + # 清空列表以准备下一批数据 + ids.clear() + images_description.clear() + images_metadata.clear() + + if ids: + collection.upsert( + ids=list(ids), + documents=images_description, + metadatas=images_metadata # 添加自定义属性 + ) + + print("Data successfully stored in the vector database.") + + +def query(gender, content): + collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn) + # 6. 查询相似内容 + user_gender = gender # 用户输入的性别 + user_content = content # 用户输入的内容 + + results = collection.query( + query_texts=user_content, + n_results=5, # 返回前 5 个结果 + where={"gender": user_gender} # 根据性别过滤 + ) + + # 输出结果 + resp = [] + for document, result in zip(results['documents'][0], results['metadatas'][0]): + # print("Path:", result['path']) + # print("Content:", document) + resp.append(result['path']) + return resp + + +if __name__ == '__main__': + # create_collection() + query("female", "I need a long sleeve dress") From ba529edfaa6be141233ee75764c22a62ebf0f169 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 29 Oct 2024 16:58:21 +0800 Subject: [PATCH 05/39] =?UTF-8?q?feat=20=20=20dockerfile=20=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .dockerignore | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..0b6bf22 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,2 @@ +seg_cache +test \ No newline at end of file From 11329bac3e6edfc1c76c2acc5d37c14aad8dad07 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 29 Oct 2024 16:58:37 +0800 Subject: [PATCH 06/39] =?UTF-8?q?feat=20=20=20dockerfile=20=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1828 -> 1860 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6c9e38f1ded86de71e2126d5c357903ea0d08a05..73507145f0a1adf6986bae737597a22a911f640e 100644 GIT binary patch delta 44 ycmZ3&cZ6@lELQnsh75)xhJ1!xhD3%Gh9rhM23rOL20aE-AU0$$-8_@En-Ku{m Date: Tue, 29 Oct 2024 17:17:30 +0800 Subject: [PATCH 07/39] =?UTF-8?q?feat=20=20=20dockerfile=20=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/search_image_with_text/service.py | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 98f6ac4..47a9dde 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -7,10 +7,10 @@ from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaE from tqdm import tqdm # 读取 csv 文件 -csv_file_path = r'D:/Files/csv/output/output.csv' -image_path = r'D:/images-clean' +# csv_file_path = r'D:/Files/csv/output/output.csv' +# image_path = r'D:/images-clean' -df = pd.read_csv(csv_file_path, encoding='Windows-1252') +# df = pd.read_csv(csv_file_path, encoding='Windows-1252') # 创建 Chroma 客户端 client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) @@ -20,47 +20,47 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -def create_collection(): - collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) - - # 存储数据,包括自定义属性 - images_description = [] - images_metadata = [] - ids = [] - batch_size = 41666 # 最大批量大小 - for index, row in tqdm(df.iterrows()): - # 将图片的md5作为id - with open(image_path + row['path'], 'rb') as f: - image_data = f.read() - md5_value = hashlib.md5(image_data).hexdigest() - ids.append(md5_value) - images_description.append(row['description']) - images_metadata.append({ - "gender": row['gender'], - "path": row['path'] - }) - - # 将数据添加到集合 - # 每达到 batch_size 就执行一次 upsert - if len(ids) >= batch_size: - collection.upsert( - ids=list(ids), - documents=images_description, - metadatas=images_metadata # 添加自定义属性 - ) - # 清空列表以准备下一批数据 - ids.clear() - images_description.clear() - images_metadata.clear() - - if ids: - collection.upsert( - ids=list(ids), - documents=images_description, - metadatas=images_metadata # 添加自定义属性 - ) - - print("Data successfully stored in the vector database.") +# def create_collection(): +# collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) +# +# # 存储数据,包括自定义属性 +# images_description = [] +# images_metadata = [] +# ids = [] +# batch_size = 41666 # 最大批量大小 +# for index, row in tqdm(df.iterrows()): +# # 将图片的md5作为id +# with open(image_path + row['path'], 'rb') as f: +# image_data = f.read() +# md5_value = hashlib.md5(image_data).hexdigest() +# ids.append(md5_value) +# images_description.append(row['description']) +# images_metadata.append({ +# "gender": row['gender'], +# "path": row['path'] +# }) +# +# # 将数据添加到集合 +# # 每达到 batch_size 就执行一次 upsert +# if len(ids) >= batch_size: +# collection.upsert( +# ids=list(ids), +# documents=images_description, +# metadatas=images_metadata # 添加自定义属性 +# ) +# # 清空列表以准备下一批数据 +# ids.clear() +# images_description.clear() +# images_metadata.clear() +# +# if ids: +# collection.upsert( +# ids=list(ids), +# documents=images_description, +# metadatas=images_metadata # 添加自定义属性 +# ) +# +# print("Data successfully stored in the vector database.") def query(gender, content): From 76a5e97ab8575a51c04d3269b1c67f032e1cbca9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 29 Oct 2024 17:27:05 +0800 Subject: [PATCH 08/39] feat fix 1 --- app/api/api_query_image.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/api/api_query_image.py b/app/api/api_query_image.py index d27c67b..ca0dbe6 100644 --- a/app/api/api_query_image.py +++ b/app/api/api_query_image.py @@ -1,8 +1,7 @@ import json import logging -from http.client import HTTPException -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.schemas.query_image import QueryImageModel from app.schemas.response_template import ResponseModel From 93b284721f9483224604293a161cc869d92203ef Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 29 Oct 2024 17:44:10 +0800 Subject: [PATCH 09/39] feat fix 1 --- app/service/search_image_with_text/service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 47a9dde..36a86a8 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -17,7 +17,8 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 -embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") # def create_collection(): From 6f940736c076cc92d33bb04469451c6b60a2fd43 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 30 Oct 2024 10:34:53 +0800 Subject: [PATCH 10/39] =?UTF-8?q?feat=20fix=20=20=20=20=20=E6=9A=82?= =?UTF-8?q?=E6=97=B6=E5=8F=96=E6=B6=88design=E8=A2=AB=E5=90=8E=E8=A7=86?= =?UTF-8?q?=E8=A7=92=E5=9B=BE=20=E5=BE=85=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/design_generate.py | 2 +- app/service/design_fast/item.py | 4 ++-- app/service/design_fast/utils/organize.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index 244e09c..582de4c 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -81,7 +81,7 @@ def design_generate(request_data): 'mask_url': lay['mask_url'], 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, - 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None, + # 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None, }) items_response['synthesis_url'] = synthesis(layers, new_size, basic) else: diff --git a/app/service/design_fast/item.py b/app/service/design_fast/item.py index c1bd467..f7af700 100644 --- a/app/service/design_fast/item.py +++ b/app/service/design_fast/item.py @@ -16,7 +16,7 @@ class TopItem(BaseItem): LoadImage(minio_client), KeyPoint(), Segmentation(minio_client), - BackPerspective(minio_client), + # BackPerspective(minio_client), Color(minio_client), PrintPainting(minio_client), Scaling(), @@ -37,7 +37,7 @@ class BottomItem(BaseItem): KeyPoint(), ContourDetection(), # Segmentation(), - BackPerspective(minio_client), + # BackPerspective(minio_client), Color(minio_client), PrintPainting(minio_client), Scaling(), diff --git a/app/service/design_fast/utils/organize.py b/app/service/design_fast/utils/organize.py index ad3cff3..92be044 100644 --- a/app/service/design_fast/utils/organize.py +++ b/app/service/design_fast/utils/organize.py @@ -34,7 +34,7 @@ def organize_clothing(layer): gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", pattern_image_url=layer['pattern_image_url'], pattern_image=layer['pattern_image'], - back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" + # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" ) # 后片数据 back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), @@ -50,7 +50,7 @@ def organize_clothing(layer): mask=cv2.resize(layer['mask'], layer["front_image"].size), gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", pattern_image_url=layer['pattern_image_url'], - back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" + # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" ) return front_layer, back_layer From c508ddea48206cdb54f3573f748b38a998a04e2a Mon Sep 17 00:00:00 2001 From: xupei Date: Wed, 30 Oct 2024 15:00:30 +0800 Subject: [PATCH 11/39] =?UTF-8?q?=E4=BB=8E=E5=90=91=E9=87=8F=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E4=B8=AD=E6=A3=80=E7=B4=A2=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E5=B9=B6=E9=9B=86=E6=88=90=E5=88=B0chat-robot?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/search_image_with_text/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 36a86a8..2274c51 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -67,7 +67,7 @@ embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddin def query(gender, content): collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn) # 6. 查询相似内容 - user_gender = gender # 用户输入的性别 + user_gender = gender.lower() # 用户输入的性别 user_content = content # 用户输入的内容 results = collection.query( From d92c59383b772e8e7c4ffd11721aa5bb24003b22 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 8 Nov 2024 14:05:09 +0800 Subject: [PATCH 12/39] =?UTF-8?q?feat=20=20=20=20design=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=BF=81=E7=A7=BB4090=E6=B5=8B=E8=AF=95=20fi?= =?UTF-8?q?x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 6 +----- app/service/attribute/service_att_recognition.py | 2 +- app/service/attribute/service_category_recognition.py | 2 +- app/service/generate_image/utils/image_processing.py | 4 ++-- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 35c12b7..5909a3a 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -93,9 +93,6 @@ OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613", "gpt-4-0613", "gpt-4-32k-0613", } -# attribute service config -ATT_TRITON_URL = "10.1.1.240:10000" - # SR service config SR_MODEL_NAME = "super_resolution" SR_TRITON_URL = "10.1.1.240:10031" @@ -132,7 +129,6 @@ GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config -SEG_MODEL_URL = '10.1.1.240:10000' SEGMENTATION = { "new_model_name": "seg_knet", "name": "seg_ocrnet_hr18", @@ -141,7 +137,7 @@ SEGMENTATION = { } # DESIGN config -DESIGN_MODEL_URL = '10.1.1.240:10000' +DESIGN_MODEL_URL = '10.1.1.243:10000' AIDA_CLOTHING = "aida-clothing" KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') diff --git a/app/service/attribute/service_att_recognition.py b/app/service/attribute/service_att_recognition.py index 1251891..f93146e 100644 --- a/app/service/attribute/service_att_recognition.py +++ b/app/service/attribute/service_att_recognition.py @@ -28,7 +28,7 @@ class AttributeRecognition: } ) self.const = const - self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}") + self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") def get_result(self): for sketch in self.request_data: diff --git a/app/service/attribute/service_category_recognition.py b/app/service/attribute/service_category_recognition.py index f917af2..7c277c9 100644 --- a/app/service/attribute/service_category_recognition.py +++ b/app/service/attribute/service_category_recognition.py @@ -26,7 +26,7 @@ class CategoryRecognition: self.attr_type = pd.read_csv(CATEGORY_PATH) # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] - self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) + self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) for sketch in request_data: self.request_data.append( { diff --git a/app/service/generate_image/utils/image_processing.py b/app/service/generate_image/utils/image_processing.py index af36188..02d8bee 100644 --- a/app/service/generate_image/utils/image_processing.py +++ b/app/service/generate_image/utils/image_processing.py @@ -81,7 +81,7 @@ def get_contours(image): def seg_infer_image(image_obj): image, ori_shape = seg_preprocess(image_obj) - client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}") + client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") transformed_img = image.astype(np.float32) # 输入集 inputs = [ @@ -250,7 +250,7 @@ def generate_category_recognition(image, gender): return preprocessed_img preprocessed_img = preprocess(image) - triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) + triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) inputs = [ httpclient.InferInput("input__0", preprocessed_img.shape, datatype="FP32") From 82c23717b0301bb4cccd79bb742457b8d9dc51ad Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 8 Nov 2024 14:34:32 +0800 Subject: [PATCH 13/39] =?UTF-8?q?feat=20=20=20=20design=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=BF=81=E7=A7=BB4090=E6=B5=8B=E8=AF=95=20fi?= =?UTF-8?q?x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/search_image_with_text/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 2274c51..35c5955 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -18,7 +18,7 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 # embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.243:11434/api/embeddings", model_name="mxbai-embed-large") # def create_collection(): From 696daea7750bfc21b38f461da6f53ade3054612c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 8 Nov 2024 14:35:23 +0800 Subject: [PATCH 14/39] =?UTF-8?q?feat=20=20=20=20design=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=BF=81=E7=A7=BB4090=E6=B5=8B=E8=AF=95=20fi?= =?UTF-8?q?x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 3 ++- app/service/search_image_with_text/service.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 5909a3a..37592c3 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -135,7 +135,8 @@ SEGMENTATION = { "input": "seg_input__0", "output": "seg_output__0", } - +# ollama config +OLLAMA_URL = "http://10.1.1.243:11434/api/embeddings" # DESIGN config DESIGN_MODEL_URL = '10.1.1.243:10000' AIDA_CLOTHING = "aida-clothing" diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 35c5955..edd4d93 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -6,6 +6,8 @@ from chromadb.config import Settings from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction from tqdm import tqdm +from app.core.config import OLLAMA_URL + # 读取 csv 文件 # csv_file_path = r'D:/Files/csv/output/output.csv' # image_path = r'D:/images-clean' @@ -18,7 +20,7 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 # embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.243:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url=OLLAMA_URL, model_name="mxbai-embed-large") # def create_collection(): From 4cc993cf27e116e7b034a65bd6077c448b8bb759 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 19 Nov 2024 10:14:52 +0800 Subject: [PATCH 15/39] =?UTF-8?q?feat=20=20=20=20design=20=E9=80=8F?= =?UTF-8?q?=E6=98=8E=E5=92=8C=E9=80=89=E5=8F=96=E9=80=8F=E6=98=8E=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 3 --- app/core/config.py | 4 +-- app/service/design_fast/pipeline/split.py | 21 +++++++++++++++- app/service/design_fast/utils/transparent.py | 26 ++++++++++++++++++++ app/service/utils/new_oss_client.py | 7 +++--- 5 files changed, 52 insertions(+), 9 deletions(-) create mode 100644 app/service/design_fast/utils/transparent.py diff --git a/app/api/api_design.py b/app/api/api_design.py index aa9fe43..af79f05 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -67,7 +67,6 @@ def design(request_data: DesignModel): 0 ], "path": "aida-sys-image/images/female/trousers/0825000630.jpg", - "seg_mask_url": "test/result.png", "print": { "element": { "element_angle_list": [], @@ -104,7 +103,6 @@ def design(request_data: DesignModel): 0 ], "path": "aida-sys-image/images/female/blouse/0902003811.jpg", - "seg_mask_url": "test/result.png", "print": { "element": { "element_angle_list": [], @@ -141,7 +139,6 @@ def design(request_data: DesignModel): 0 ], "path": "aida-sys-image/images/female/outwear/0825000410.jpg", - "seg_mask_url": "test/result.png", "print": { "element": { "element_angle_list": [], diff --git a/app/core/config.py b/app/core/config.py index 37592c3..63f1bda 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -20,7 +20,7 @@ class Settings(BaseSettings): OSS = "minio" -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" @@ -110,7 +110,7 @@ GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") # Generate Single Logo service config -GSL_MODEL_URL = '10.1.1.240:10041' +GSL_MODEL_URL = '10.1.1.243:10041' GSL_MINIO_BUCKET = "aida-users" GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") diff --git a/app/service/design_fast/pipeline/split.py b/app/service/design_fast/pipeline/split.py index 737b50e..2f67027 100644 --- a/app/service/design_fast/pipeline/split.py +++ b/app/service/design_fast/pipeline/split.py @@ -8,9 +8,10 @@ from cv2 import cvtColor, COLOR_BGR2RGBA from app.core.config import AIDA_CLOTHING from app.service.design_fast.utils.conversion_image import rgb_to_rgba +from app.service.design_fast.utils.transparent import sketch_to_transparent from app.service.design_fast.utils.upload_image import upload_png_mask from app.service.utils.generate_uuid import generate_uuid -from app.service.utils.new_oss_client import oss_upload_image +from app.service.utils.new_oss_client import oss_upload_image, oss_get_image class Split(object): @@ -30,6 +31,24 @@ class Split(object): front_mask = cv2.resize(front_mask, new_size) result_front_image[front_mask != 0] = rgba_image[front_mask != 0] result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) + if 'transparent' in result.keys(): + # 用户自选区域transparent + transparent = result['transparent'] + if transparent['mask_url'] is not None and transparent['mask_url'] != "": + # 预处理用户自选区mask + seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2") + seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST) + # 转换颜色空间为 RGB(OpenCV 默认是 BGR) + image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB) + + r, g, b = cv2.split(image_rgb) + blue_mask = b > r + + # 创建红色和绿色掩码 + transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255 + result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"]) + else: + result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"]) result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None) height, width = front_mask.shape diff --git a/app/service/design_fast/utils/transparent.py b/app/service/design_fast/utils/transparent.py new file mode 100644 index 0000000..3f73807 --- /dev/null +++ b/app/service/design_fast/utils/transparent.py @@ -0,0 +1,26 @@ +from PIL import Image + + +def sketch_to_transparent(image, mask, transparency): + # 打开原始图片 + image = image.convert("RGBA") + # 打开mask图片,假设mask图片是灰度图,白色区域为要处理的区域,黑色区域为保留的区域 + mask = Image.fromarray(mask) + + # 根据透明度调整因子,将透明度转换为0-255之间的值 + alpha_value = int((1 - transparency) * 255.0) + + # 获取图片的像素数据 + image_pixels = image.load() + mask_pixels = mask.load() + + width, height = image.size + + for y in range(height): + for x in range(width): + # 如果mask区域对应的像素为白色(值大于128,这里假设白色为要处理的区域,可根据实际情况调整) + if mask_pixels[x, y] > 128: + r, g, b, a = image_pixels[x, y] + image_pixels[x, y] = (r, g, b, alpha_value) + + return image diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 6d644a5..23e0f8a 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,13 +82,14 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-results/result_e961eed6-9278-11ef-a957-0826ae3ad6b3.png" + url = "aida-results/result_94d3fc82-a560-11ef-b2c1-0826ae3ad6b3.png" + # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" - read_type = "cv2" + read_type = "2" if read_type == "cv2": img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) cv2.imshow("", img) cv2.waitKey(0) else: img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) - img.show() + img.save("原图.png") From 26be5279df62907779dfd106c79c269318909879 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 19 Nov 2024 10:20:25 +0800 Subject: [PATCH 16/39] =?UTF-8?q?feat=20=20=20=20design=20=E9=80=8F?= =?UTF-8?q?=E6=98=8E=E5=92=8C=E9=80=89=E5=8F=96=E9=80=8F=E6=98=8E=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 4 ++++ app/core/config.py | 2 +- app/service/utils/new_oss_client.py | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index af79f05..b0231a2 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -164,6 +164,10 @@ def design(request_data: DesignModel): 1.0, 1.0 ], + "transparent":{ + "mask_url":"test/transparent_test/transparent_mask.png", + "scale":0.1 + }, "type": "Outwear" }, { diff --git a/app/core/config.py b/app/core/config.py index 63f1bda..5701e92 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -20,7 +20,7 @@ class Settings(BaseSettings): OSS = "minio" -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 23e0f8a..8dbf2fc 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-results/result_94d3fc82-a560-11ef-b2c1-0826ae3ad6b3.png" + url ="aida-results/result_91559b60-a61c-11ef-af8e-0242ac150002.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" @@ -92,4 +92,4 @@ if __name__ == '__main__': cv2.waitKey(0) else: img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) - img.save("原图.png") + img.show() From 754c6eff87f9d144c4f9d757e4f3bf8007700370 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 19 Nov 2024 17:35:10 +0800 Subject: [PATCH 17/39] =?UTF-8?q?feat=20=20=204090=20triton=20server=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 6 +- .../service_generate_product_image.py | 59 ++++++++++--------- app/service/utils/new_oss_client.py | 2 +- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 5701e92..48af4a1 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -117,10 +117,10 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Product service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") -GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' -GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' +GPI_MODEL_NAME_OVERALL = 'stable_diffusion_xl_cnet' +GPI_MODEL_NAME_SINGLE = 'stable_diffusion_xl_cnet' -GPI_MODEL_URL = '10.1.1.240:10041' +GPI_MODEL_URL = '10.1.1.243:10051' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 5ea6f83..da4bb4b 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -15,7 +15,7 @@ import cv2 import numpy as np import redis import tritonclient.grpc as grpcclient -from PIL import Image, ImageOps +from PIL import Image from tritonclient.utils import np_to_triton_dtype from app.core.config import * @@ -41,7 +41,7 @@ class GenerateProductImage: self.batch_size = 1 self.product_type = request_data.product_type self.prompt = request_data.prompt - self.image, self.image_size = pre_processing_image(request_data.image_url) + self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} @@ -58,9 +58,10 @@ class GenerateProductImage: if self.product_type == "single": image = result.as_numpy("generated_cnet_image") else: - image = result.as_numpy("generated_inpaint_image") + image = result.as_numpy("generated_cnet_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + cropped_image = post_processing_image(image_result, self.left, self.top) + image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -74,7 +75,7 @@ class GenerateProductImage: try: prompts = [self.prompt] * self.batch_size self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) - self.image = cv2.resize(self.image, (512, 768)) + self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size if self.product_type == "single": @@ -82,9 +83,9 @@ class GenerateProductImage: image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) else: - text_obj = np.array(prompts, dtype="object").reshape(1) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1)) # 假设 prompts、images 和 self.image_strength 已经定义 @@ -136,22 +137,13 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + # resize 原图至1024*1024 + image = image.resize((int(1024 / image.height * image.width), 1024)) + # 原始图片的尺寸 width, height = image.size - # 计算长宽比为 3:2 的新尺寸 - desired_ratio = 2 / 3 - current_ratio = width / height - - if current_ratio > desired_ratio: - # 原始图片更宽,需要在上下添加 padding - new_width = width - new_height = int(width / desired_ratio) - else: - # 原始图片更高或者长宽比已经为 3:2 - new_height = height - new_width = int(height * desired_ratio) - + new_height, new_width = 1024, 1024 # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) @@ -160,9 +152,9 @@ def pre_processing_image(image_url): top = (new_height - height) // 2 pad_image.paste(image, (left, top)) - # 将画布 resize 成宽度 500,长度 750 - resized_image = pad_image.resize((500, 750)) - image_size = (512, 768) + # 将画布 resize 成宽度 1024,长度 1024 + resized_image = pad_image.resize((1024, 1024)) + image_size = (1024, 1024) if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): # 创建白色背景 @@ -171,15 +163,28 @@ def pre_processing_image(image_url): background.paste(resized_image, mask=resized_image.split()[3]) image = np.array(background) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image, image_size + return image, image_size, left, top + + +def post_processing_image(image, left, top): + width, height = image.size + # 计算裁剪后的宽度和坐标 + new_width = width - 2 * left + right = left + new_width + + # 进行裁剪操作 + cropped_image = image.crop((left, 0, right, height)) + + # 保存裁剪后的图像,将此处的 'cropped_image.jpg' 替换为你想要保存的文件名 + return cropped_image if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", # prompt="", - image_strength=0.9, - prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + image_strength=0.65, + prompt="The best quality, masterpiece, real image. A handsome man wearing blouse, outwear, trousers, 8K realistic, HUD", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", product_type="overall" ) diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 8dbf2fc..f402a14 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url ="aida-results/result_91559b60-a61c-11ef-af8e-0242ac150002.png" + url ="aida-results/result_27915298-a656-11ef-b4f3-0242ac150002.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From 7eb9b18f8f8cbb00505da5be453d28622bcf058c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 26 Nov 2024 16:08:10 +0800 Subject: [PATCH 18/39] =?UTF-8?q?feat=20=20=201.design=20=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E8=BF=9B=E5=BA=A6=E6=A8=A1=E5=BC=8F=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=202.=E7=BB=93=E6=9E=9C=E4=BB=A5stream=E5=8F=91=E9=80=81?= =?UTF-8?q?=E5=88=B0java=203.=E6=96=B0=E5=A2=9E=E9=85=8D=E9=A5=B0=E7=B1=BB?= =?UTF-8?q?=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 182 +++++++++++++++++- app/service/design_fast/design_generate.py | 147 +++++++++++++- app/service/design_fast/item.py | 23 ++- app/service/design_fast/pipeline/loading.py | 2 + app/service/design_fast/pipeline/scale.py | 4 +- app/service/design_fast/pipeline/split.py | 2 +- app/service/design_fast/utils/organize.py | 39 ++++ .../design_fast/utils/synthesis_item.py | 15 ++ app/service/utils/new_oss_client.py | 2 +- 9 files changed, 405 insertions(+), 11 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index b0231a2..665d544 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -2,13 +2,13 @@ import json import logging import os -from fastapi import APIRouter, HTTPException, UploadFile, File, Form +from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel from app.schemas.response_template import ResponseModel from app.service.design.model_process_service import model_transpose from app.service.design_batch.service import start_design_batch_generate -from app.service.design_fast.design_generate import design_generate +from app.service.design_fast.design_generate import design_generate, design_generate_v2 from app.service.design_fast.utils.redis_utils import Redis router = APIRouter() @@ -16,7 +16,7 @@ logger = logging.getLogger() @router.post("/design") -def design(request_data: DesignModel): +def design(request_data: DesignModel, background_tasks: BackgroundTasks): """ 创建一个具有以下参数的请求体: 示例参数: @@ -196,6 +196,182 @@ def design(request_data: DesignModel): return ResponseModel(data=data) +@router.post("/design_v2") +async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + 示例参数: + { + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 200, + 241 + ], + "hand_point_right": [ + 223, + 297 + ], + "waistband_left": [ + 112, + 241 + ], + "hand_point_left": [ + 92, + 305 + ], + "shoulder_left": [ + 99, + 116 + ], + "shoulder_right": [ + 215, + 116 + ] + }, + "layer_order": true, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 270372, + "color": "30 28 28", + "image_id": 69780, + "offset": [ + 0, + 0 + ], + "path": "aida-sys-image/images/female/trousers/0825000630.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 10, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "businessId": 270373, + "color": "30 28 28", + "image_id": 98243, + "offset": [ + 0, + 0 + ], + "path": "aida-sys-image/images/female/blouse/0902003811.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 11, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "businessId": 270374, + "color": "172 68 68", + "image_id": 98244, + "offset": [ + 0, + 0 + ], + "path": "aida-sys-image/images/female/outwear/0825000410.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 12, + "resize_scale": [ + 1.0, + 1.0 + ], + "transparent":{ + "mask_url":"test/transparent_test/transparent_mask.png", + "scale":0.1 + }, + "type": "Outwear" + }, + { + "body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png", + "image_id": 96090, + "type": "Body" + } + ] + } + ], + "process_id": "83" + } + """ + try: + # 异步 + logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict())}") + background_tasks.add_task(design_generate_v2, request_data) + except Exception as e: + logger.warning(f"design Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + @router.post('/get_progress') def get_progress(request_data: DesignProgressModel): """ diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index 582de4c..80edd96 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -5,8 +5,8 @@ import time from minio import Minio from app.core.config import * -from app.service.design_fast.item import BodyItem, TopItem, BottomItem -from app.service.design_fast.utils.organize import organize_body, organize_clothing +from app.service.design_fast.item import BodyItem, TopItem, BottomItem, AccessoriesItem +from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_accessories from app.service.design_fast.utils.progress import final_progress, update_progress from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority from app.service.utils.decorator import RunTime @@ -26,9 +26,14 @@ def process_item(item, basic): elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']: top_server = TopItem(data=item, basic=basic, minio_client=minio_client) item_data = top_server.process() - else: + elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']: bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client) item_data = bottom_server.process() + elif item['type'].lower() in ['accessories']: + bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client) + item_data = bottom_server.process() + else: + raise NotImplementedError(f"Item type {item['type']} not implemented") return item_data @@ -38,6 +43,10 @@ def process_layer(item, layers): body_layer = organize_body(item) layers.append(body_layer) return item['body_image'].size + elif item['name'] == 'accessories': + front_layer, back_layer = organize_accessories(item) + layers.append(front_layer) + layers.append(back_layer) else: front_layer, back_layer = organize_clothing(item) layers.append(front_layer) @@ -57,7 +66,7 @@ def design_generate(request_data): def process_object(step, object): nonlocal active_threads basic = object['basic'] - items_response = {'layers': []} + items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""} if basic['single_overall'] == "overall": item_results = [] for item in object['items']: @@ -126,6 +135,136 @@ def design_generate(request_data): return object_response +@RunTime +def design_generate_v2(request_data): + objects_data = request_data.dict()['objects'] + # process_id = request_data.dict()['process_id'] + # object_response = {} + threads = [] + active_threads = 0 + lock = threading.Lock() + + # total = len(objects_data) + + def process_object(step, object): + nonlocal active_threads + basic = object['basic'] + items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""} + if basic['single_overall'] == "overall": + item_results = [] + for item in object['items']: + item_results.append(process_item(item, basic)) + layers = [] + body_size = None + for item in item_results: + body_size = process_layer(item, layers) + layers = sorted(layers, key=lambda s: s.get("priority", float('inf'))) + + layers, new_size = update_base_size_priority(layers, body_size) + + for lay in layers: + items_response['layers'].append({ + 'image_category': "body" if lay['name'] == 'mannequin' else lay['name'], + 'position': lay['position'], + 'priority': lay.get("priority", None), + 'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None, + 'image_size': lay['image'] if lay['image'] is None else lay['image'].size, + 'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "", + 'mask_url': lay['mask_url'], + 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, + 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, + # 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None, + }) + items_response['synthesis_url'] = synthesis(layers, new_size, basic) + else: + item_result = process_item(object['items'][0], basic) + items_response['layers'].append({ + 'image_category': f"{item_result['name']}_front", + 'image_size': item_result['back_image'].size if item_result['back_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item_result['front_image_url'], + 'mask_url': item_result['mask_url'], + "gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "", + 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, + }) + items_response['layers'].append({ + 'image_category': f"{item_result['name']}_back", + 'image_size': item_result['front_image'].size if item_result['front_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item_result['back_image_url'], + 'mask_url': item_result['mask_url'], + "gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "", + 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, + }) + items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image']) + + # 发送结果给java端 + url = "https://3998-117-143-125-51.ngrok-free.app/api/third/party/receiveDesignResults" + headers = { + 'Accept': "*/*", + 'Accept-Encoding': "gzip, deflate, br", + 'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0", + 'Connection': "keep-alive", + 'Content-Type': "application/json" + } + response = post_request(url, json_data=items_response, headers=headers) + if response: + # 打印结果 + logger.info(response.text) + logger.info(items_response) + + # update_progress(process_id, total) + + # with lock: + # object_response[step] = items_response + # active_threads -= 1 + + for step, object in enumerate(objects_data): + t = threading.Thread(target=process_object, args=(step, object)) + threads.append(t) + t.start() + # with lock: + # active_threads += 1 + + # for t in threads: + # t.join() + # final_progress(process_id) + # return object_response + + +import requests + + +def post_request(url, data=None, json_data=None, headers=None, auth=None, timeout=5): + """ + 发送POST请求的封装函数 + + :param url: 接口的URL地址 + :param data: 要发送的数据(字典形式,用于表单数据等,会自动编码) + :param json_data: 要发送的JSON数据(字典形式,会自动转换为JSON字符串) + :param headers: 请求头字典 + :param auth: 认证信息(如 ('username', 'password') 形式用于基本认证) + :param timeout: 超时时间,单位为秒 + :return: 返回接口的响应对象 + """ + try: + response = requests.post( + url, + data=data, + json=json_data, + headers=headers, + auth=auth, + timeout=timeout + ) + response.raise_for_status() # 如果请求失败,抛出异常 + return response + except requests.RequestException as e: + print(f"POST请求出错: {e}") + return None + + if __name__ == '__main__': object_data = { "objects": [ diff --git a/app/service/design_fast/item.py b/app/service/design_fast/item.py index f7af700..ec18b17 100644 --- a/app/service/design_fast/item.py +++ b/app/service/design_fast/item.py @@ -1,4 +1,4 @@ -from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection, BackPerspective +from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection class BaseItem: @@ -9,6 +9,27 @@ class BaseItem: self.result.update(basic) +class AccessoriesItem(BaseItem): + def __init__(self, data, basic, minio_client): + super().__init__(data, basic) + self.Accessories_pipeline = [ + LoadImage(minio_client), + # KeyPoint(), + ContourDetection(), + # Segmentation(minio_client), + # BackPerspective(minio_client), + Color(minio_client), + PrintPainting(minio_client), + Scaling(), + Split(minio_client) + ] + + def process(self): + for item in self.Accessories_pipeline: + self.result = item(self.result) + return self.result + + class TopItem(BaseItem): def __init__(self, data, basic, minio_client): super().__init__(data, basic) diff --git a/app/service/design_fast/pipeline/loading.py b/app/service/design_fast/pipeline/loading.py index 0ce0dfa..5a55d9d 100644 --- a/app/service/design_fast/pipeline/loading.py +++ b/app/service/design_fast/pipeline/loading.py @@ -74,6 +74,8 @@ class LoadImage: keypoint = 'head_point' elif name == 'earring': keypoint = 'ear_point' + elif name == 'accessories': + keypoint = "accessories" else: raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " f"bag, shoes, hairstyle, earring.") diff --git a/app/service/design_fast/pipeline/scale.py b/app/service/design_fast/pipeline/scale.py index 732fcd8..c901aa7 100644 --- a/app/service/design_fast/pipeline/scale.py +++ b/app/service/design_fast/pipeline/scale.py @@ -18,7 +18,7 @@ class Scaling: - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1 ) - + if distance_clo == 0: result['scale'] = 1 else: @@ -46,4 +46,6 @@ class Scaling: result['scale'] = result['scale_bag'] elif result['keypoint'] == 'ear_point': result['scale'] = result['scale_earrings'] + else: + result['scale'] = 1 return result diff --git a/app/service/design_fast/pipeline/split.py b/app/service/design_fast/pipeline/split.py index 2f67027..344c5c5 100644 --- a/app/service/design_fast/pipeline/split.py +++ b/app/service/design_fast/pipeline/split.py @@ -21,7 +21,7 @@ class Split(object): def __call__(self, result): try: - if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms','accessories'): front_mask = result['front_mask'] back_mask = result['back_mask'] rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) diff --git a/app/service/design_fast/utils/organize.py b/app/service/design_fast/utils/organize.py index 92be044..33edc4f 100644 --- a/app/service/design_fast/utils/organize.py +++ b/app/service/design_fast/utils/organize.py @@ -55,6 +55,45 @@ def organize_clothing(layer): return front_layer, back_layer +def organize_accessories(layer): + # 起始坐标 + start_point = (0, 0) + # 前片数据 + front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None), + name=f'{layer["name"].lower()}_front', + image=layer["front_image"], + # mask_image=layer['front_mask_image'], + image_url=layer['front_image_url'], + mask_url=layer['mask_url'], + sacle=layer['scale'], + clothes_keypoint=(0, 0), + position=start_point, + resize_scale=layer["resize_scale"], + mask=cv2.resize(layer['mask'], layer["front_image"].size), + gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", + pattern_image_url=layer['pattern_image_url'], + pattern_image=layer['pattern_image'], + # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" + ) + # 后片数据 + back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), + name=f'{layer["name"].lower()}_back', + image=layer["back_image"], + # mask_image=layer['back_mask_image'], + image_url=layer['back_image_url'], + mask_url=layer['mask_url'], + sacle=layer['scale'], + clothes_keypoint=(0, 0), + position=start_point, + resize_scale=layer["resize_scale"], + mask=cv2.resize(layer['mask'], layer["front_image"].size), + gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", + pattern_image_url=layer['pattern_image_url'], + # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" + ) + return front_layer, back_layer + + def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale): """ Align left diff --git a/app/service/design_fast/utils/synthesis_item.py b/app/service/design_fast/utils/synthesis_item.py index f5d505f..d7711f3 100644 --- a/app/service/design_fast/utils/synthesis_item.py +++ b/app/service/design_fast/utils/synthesis_item.py @@ -79,9 +79,11 @@ def synthesis(data, size, basic_info): _, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY) top_outer_mask = np.array(binary_body_mask) bottom_outer_mask = np.array(binary_body_mask) + accessories_outer_mask = np.array(binary_body_mask) top = True bottom = True + accessories = True i = len(data) while i: i -= 1 @@ -109,10 +111,23 @@ def synthesis(data, size, basic_info): background = np.zeros_like(top_outer_mask) background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] bottom_outer_mask = background + bottom_outer_mask + elif accessories and data[i]['name'] in ['accessories_front']: + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['adaptive_position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) + background = np.zeros_like(top_outer_mask) + background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] + accessories_outer_mask = background + accessories_outer_mask + pass elif bottom is False and top is False: break all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask) + all_mask = cv2.bitwise_or(all_mask, accessories_outer_mask) for layer in data: if layer['image'] is not None: diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index f402a14..a338adb 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url ="aida-results/result_27915298-a656-11ef-b4f3-0242ac150002.png" + url ="aida-results/result_461110a5-aba2-11ef-83e7-0826ae3ad6b3.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From ba39cf4cbd5a39dd416124edd604ab51795239f8 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 27 Nov 2024 11:15:48 +0800 Subject: [PATCH 19/39] =?UTF-8?q?feat=20=20=201.design=20=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E8=BF=9B=E5=BA=A6=E6=A8=A1=E5=BC=8F=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=202.=E7=BB=93=E6=9E=9C=E4=BB=A5stream=E5=8F=91=E9=80=81?= =?UTF-8?q?=E5=88=B0java=203.=E6=96=B0=E5=A2=9E=E9=85=8D=E9=A5=B0=E7=B1=BB?= =?UTF-8?q?=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/scale.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/app/service/design_fast/pipeline/scale.py b/app/service/design_fast/pipeline/scale.py index c901aa7..d1c7a36 100644 --- a/app/service/design_fast/pipeline/scale.py +++ b/app/service/design_fast/pipeline/scale.py @@ -46,6 +46,16 @@ class Scaling: result['scale'] = result['scale_bag'] elif result['keypoint'] == 'ear_point': result['scale'] = result['scale_earrings'] + elif result['keypoint'] == 'accessories': + # 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width) + # 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width + distance_clo = result['img_shape'][1] + distance_bdy = 320 / 2 + + if distance_clo == 0: + result['scale'] = 1 + else: + result['scale'] = distance_bdy / distance_clo else: result['scale'] = 1 return result From 80a2a8c0760bf0fcdf94a2ed8b50cffdb59a6cb7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 27 Nov 2024 16:30:20 +0800 Subject: [PATCH 20/39] =?UTF-8?q?feat=20=20design=20stream=20=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E6=96=B0=E5=A2=9ErequestId=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix --- app/service/design_fast/design_generate.py | 6 +++++- app/service/utils/new_oss_client.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index 80edd96..be015ac 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -149,7 +149,11 @@ def design_generate_v2(request_data): def process_object(step, object): nonlocal active_threads basic = object['basic'] - items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""} + items_response = { + 'layers': [], + 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "", + 'requestId': object['requestId'] if 'requestId' in object.keys() else "" + } if basic['single_overall'] == "overall": item_results = [] for item in object['items']: diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index a338adb..0ead375 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url ="aida-results/result_461110a5-aba2-11ef-83e7-0826ae3ad6b3.png" + url ="aida-results/result_68756122-ac6b-11ef-8bf8-0826ae3ad6b3.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From 2b9ab7fe7a8a88a550358ad3e1f2ac635192f7f9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 28 Nov 2024 11:32:43 +0800 Subject: [PATCH 21/39] =?UTF-8?q?feat=20=20generate=20product=20img=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 4 +--- .../service_generate_product_image.py | 24 ++++++------------- app/service/utils/new_oss_client.py | 2 +- 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 48af4a1..ff6f0a1 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -117,9 +117,7 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Product service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") -GPI_MODEL_NAME_OVERALL = 'stable_diffusion_xl_cnet' -GPI_MODEL_NAME_SINGLE = 'stable_diffusion_xl_cnet' - +GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all' GPI_MODEL_URL = '10.1.1.243:10051' # Generate Single Logo service config diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index da4bb4b..16b814b 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -55,10 +55,7 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - if self.product_type == "single": - image = result.as_numpy("generated_cnet_image") - else: - image = result.as_numpy("generated_cnet_image") + image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) cropped_image = post_processing_image(image_result, self.left, self.top) image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -78,14 +75,9 @@ class GenerateProductImage: self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size - if self.product_type == "single": - text_obj = np.array(prompts, dtype="object").reshape(-1, 1) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) - else: - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1)) + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) # 假设 prompts、images 和 self.image_strength 已经定义 @@ -95,13 +87,11 @@ class GenerateProductImage: input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) - inputs = [input_text, input_image, input_image_strength] input_image_strength.set_data_from_numpy(image_strength_obj) - if self.product_type == "single": - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) - else: - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + inputs = [input_text, input_image, input_image_strength] + + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 0ead375..5067d15 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url ="aida-results/result_68756122-ac6b-11ef-8bf8-0826ae3ad6b3.png" + url ="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From e1f19f62b37ad7542c1b055dcf7960439761e0a7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 14:24:48 +0800 Subject: [PATCH 22/39] =?UTF-8?q?feat=20=20generate=20img=20fast=20version?= =?UTF-8?q?=20=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 1 + app/schemas/generate_image.py | 1 + .../generate_image/service_generate_image.py | 9 ++++-- .../service_generate_product_image.py | 32 +++++++++++++------ app/service/utils/new_oss_client.py | 2 +- 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index ff6f0a1..97e014d 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -101,6 +101,7 @@ SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ # GenerateImage service config GI_MODEL_NAME = 'stable_diffusion_xl' +FAST_GI_MODEL_URL = '10.1.1.243:10011' GI_MODEL_URL = '10.1.1.240:10041' GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 3dd7cf8..11e295f 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -8,6 +8,7 @@ class GenerateImageModel(BaseModel): mode: str category: str gender: str + version: str class GenerateSingleLogoImageModel(BaseModel): diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index dac211c..7d2937b 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -35,7 +35,11 @@ class GenerateImage: # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + if request_data.version == "fast": + self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL) + else: + self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": # cv2 读图片是BGR PIL读图片是RGB @@ -185,7 +189,8 @@ if __name__ == '__main__': image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', category="test", - gender="male" + gender="male", + version="fast" ) server = GenerateImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 16b814b..ebffb05 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -55,7 +55,10 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") + if self.product_type == "single": + image = result.as_numpy("generated_inpaint_image") + else: + image = result.as_numpy("generated_cnet_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) cropped_image = post_processing_image(image_result, self.left, self.top) image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -75,9 +78,14 @@ class GenerateProductImage: self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) + if self.product_type == "single": + text_obj = np.array(prompts, dtype="object").reshape(-1, 1) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) # 假设 prompts、images 和 self.image_strength 已经定义 @@ -91,7 +99,11 @@ class GenerateProductImage: inputs = [input_text, input_image, input_image_strength] - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + + if self.product_type == "single": + ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: @@ -173,10 +185,10 @@ if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", # prompt="", - image_strength=0.65, - prompt="The best quality, masterpiece, real image. A handsome man wearing blouse, outwear, trousers, 8K realistic, HUD", - image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", - product_type="overall" + image_strength=0.7, + prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", + image_url="aida-results/result_836dce70-ad59-11ef-86ab-0242ac130002.png", + product_type="single" ) server = GenerateProductImage(rd) - print(server.get_result()) + print(server.get_result()) \ No newline at end of file diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 5067d15..6dd22bd 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url ="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png" + url ="aida-users/89/test/123-89.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From 5920c78a6dfc95c7b25637d504f01b0643490b05 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 15:30:32 +0800 Subject: [PATCH 23/39] =?UTF-8?q?feat=20=20flux=20=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 7 +++++-- app/service/generate_image/service_generate_image.py | 8 ++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 97e014d..dd56258 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -100,9 +100,12 @@ SR_MINIO_BUCKET = "aida-users" SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ENV}") # GenerateImage service config -GI_MODEL_NAME = 'stable_diffusion_xl' FAST_GI_MODEL_URL = '10.1.1.243:10011' -GI_MODEL_URL = '10.1.1.240:10041' +FAST_GI_MODEL_NAME = 'stable_diffusion_xl' + +GI_MODEL_URL = '10.1.1.240:10061' +GI_MODEL_NAME = 'flux' + GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 7d2937b..8cf7cf9 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -138,8 +138,8 @@ class GenerateImage: image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") - input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) + input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) @@ -185,12 +185,12 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", - prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', + prompt='a fabric print, flower, yellow, 4k, hud', image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', category="test", gender="male", - version="fast" + version="high" ) server = GenerateImage(rd) print(server.get_result()) From 216ca4587d7d65b164fb4fbc0338958f82c61a8f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 15:35:09 +0800 Subject: [PATCH 24/39] =?UTF-8?q?feat=20=20flux=20=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_image.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 8cf7cf9..d34db5e 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -35,6 +35,7 @@ class GenerateImage: # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.version = request_data.version if request_data.version == "fast": self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL) else: @@ -146,7 +147,10 @@ class GenerateImage: input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] - ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) + if self.version == "fast": + ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 generate_data = None while time_out > 0: From a9003cc81f24cf519b4efea1a538115ff601541e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 15:36:24 +0800 Subject: [PATCH 25/39] =?UTF-8?q?feat=20=20flux=20=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_image.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index d34db5e..b49433e 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -148,9 +148,10 @@ class GenerateImage: inputs = [input_text, input_image, input_mode] if self.version == "fast": - ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) - else: ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) + time_out = 600 generate_data = None while time_out > 0: From b52c87bd15cc5d4c53f8837ee4f4790ea45499b4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 17:39:22 +0800 Subject: [PATCH 26/39] =?UTF-8?q?feat=20=20product=20=E4=BF=AE=E5=A4=8D=20?= =?UTF-8?q?fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_image/service_generate_product_image.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index ebffb05..60b19cb 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -55,10 +55,7 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - if self.product_type == "single": - image = result.as_numpy("generated_inpaint_image") - else: - image = result.as_numpy("generated_cnet_image") + image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) cropped_image = post_processing_image(image_result, self.left, self.top) image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -99,7 +96,6 @@ class GenerateProductImage: inputs = [input_text, input_image, input_image_strength] - if self.product_type == "single": ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) else: @@ -187,8 +183,8 @@ if __name__ == '__main__': # prompt="", image_strength=0.7, prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", - image_url="aida-results/result_836dce70-ad59-11ef-86ab-0242ac130002.png", - product_type="single" + image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", + product_type="overall" ) server = GenerateProductImage(rd) - print(server.get_result()) \ No newline at end of file + print(server.get_result()) From f362c795e4e17f76ed0b22673a2f3abe6a9c0fe8 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 17:54:42 +0800 Subject: [PATCH 27/39] =?UTF-8?q?feat=20=20design=20triton=20=E6=9B=B4?= =?UTF-8?q?=E6=8D=A2=E4=B8=BAA6000=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index dd56258..d369ff2 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -140,7 +140,7 @@ SEGMENTATION = { # ollama config OLLAMA_URL = "http://10.1.1.243:11434/api/embeddings" # DESIGN config -DESIGN_MODEL_URL = '10.1.1.243:10000' +DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing" KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') From 96a444ca831367c42a4b473aff52151ba90c8601 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 18:46:53 +0800 Subject: [PATCH 28/39] =?UTF-8?q?feat=20=20product=20=E5=90=8E=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=9B=BE=E7=89=87size=E6=94=B9=E4=B8=BA320*700=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 60b19cb..606af59 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -165,15 +165,15 @@ def pre_processing_image(image_url): def post_processing_image(image, left, top): - width, height = image.size - # 计算裁剪后的宽度和坐标 - new_width = width - 2 * left - right = left + new_width + resized_image = image.resize((int(image.width * (700 / image.height)), 700)) + # 计算裁剪的坐标 + left = (resized_image.width - 320) // 2 + upper = 0 + right = left + 320 + lower = 700 - # 进行裁剪操作 - cropped_image = image.crop((left, 0, right, height)) - - # 保存裁剪后的图像,将此处的 'cropped_image.jpg' 替换为你想要保存的文件名 + # 进行裁剪 + cropped_image = resized_image.crop((left, upper, right, lower)) return cropped_image From b5b52f5fbe5409cb829860f807ece2c39a36f714 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 18:51:53 +0800 Subject: [PATCH 29/39] =?UTF-8?q?feat=20=20product=20=E5=90=8E=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=9B=BE=E7=89=87size=E6=94=B9=E4=B8=BA320*700=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_image/service_generate_product_image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 606af59..1c20a13 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -165,12 +165,12 @@ def pre_processing_image(image_url): def post_processing_image(image, left, top): - resized_image = image.resize((int(image.width * (700 / image.height)), 700)) + resized_image = image.resize((int(image.width * (768 / image.height)), 768)) # 计算裁剪的坐标 - left = (resized_image.width - 320) // 2 + left = (resized_image.width - 512) // 2 upper = 0 - right = left + 320 - lower = 700 + right = left + 512 + lower = 768 # 进行裁剪 cropped_image = resized_image.crop((left, upper, right, lower)) From 6a0c581e5df5d0a63c766b6d2c7527a7bc014c9c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 19:03:55 +0800 Subject: [PATCH 30/39] =?UTF-8?q?feat=20=20=E5=8F=96=E6=B6=88flux=20?= =?UTF-8?q?=E7=9A=84=E6=B1=A1=E7=82=B9=E5=88=A4=E6=96=AD=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index b49433e..daad5c6 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -91,7 +91,7 @@ class GenerateImage: image = result.as_numpy("generated_image") image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) is_smudge = True - if self.category == "sketch": + if self.category == "sketch" and self.version == "fast": # 色阶调整 cutoff = 1 levels_img = autoLevels(image_result, cutoff) From 87d38480888e585f2bc0e75d37dc2752c10a7aa4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 19:49:27 +0800 Subject: [PATCH 31/39] =?UTF-8?q?feat=20=20flux=20=E5=8F=96=E6=B6=88?= =?UTF-8?q?=E6=B1=A1=E7=82=B9=E6=A3=80=E6=B5=8B=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E7=B1=BB=E5=88=AB=E5=88=A4=E6=96=AD=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_image/service_generate_image.py | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index daad5c6..2ba7862 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -92,23 +92,28 @@ class GenerateImage: image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) is_smudge = True if self.category == "sketch" and self.version == "fast": - # 色阶调整 - cutoff = 1 - levels_img = autoLevels(image_result, cutoff) - # 亮度调整 - luminance = luminance_adjust(0.3, levels_img) - # 去背景 - remove_bg_image = remove_background(luminance) - # 人脸检测 - # if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: - # is_smudge = False - # else: - # 污点/ - is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) - # 类型识别 - category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) - self.generate_data['category'] = str(category) - image_result = not_smudge_image + if self.version == "fast": + # 色阶调整 + cutoff = 1 + levels_img = autoLevels(image_result, cutoff) + # 亮度调整 + luminance = luminance_adjust(0.3, levels_img) + # 去背景 + remove_bg_image = remove_background(luminance) + # 人脸检测 + # if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: + # is_smudge = False + # else: + # 污点/ + is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) + # 类型识别 + category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) + self.generate_data['category'] = str(category) + image_result = not_smudge_image + else: + category, scores, not_smudge_image = generate_category_recognition(image=image_result, gender=self.gender) + self.generate_data['category'] = str(category) + image_result = not_smudge_image if is_smudge: # 无污点 # image_result = adjust_contrast(image_result) image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -190,7 +195,7 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", - prompt='a fabric print, flower, yellow, 4k, hud', + prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background', image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', category="test", From 0f17d8b6544288b3f8a3269ca75f96ea04b4e5bc Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 19:52:33 +0800 Subject: [PATCH 32/39] =?UTF-8?q?feat=20=20flux=20=E5=8F=96=E6=B6=88?= =?UTF-8?q?=E6=B1=A1=E7=82=B9=E6=A3=80=E6=B5=8B=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E7=B1=BB=E5=88=AB=E5=88=A4=E6=96=AD=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 2ba7862..86912f8 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -91,7 +91,7 @@ class GenerateImage: image = result.as_numpy("generated_image") image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) is_smudge = True - if self.category == "sketch" and self.version == "fast": + if self.category == "sketch": if self.version == "fast": # 色阶调整 cutoff = 1 From b7736553ffcc1a0a0a1c2c33d4dc426b041d261f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 20:17:39 +0800 Subject: [PATCH 33/39] =?UTF-8?q?feat=20design=20stream=20=E9=83=A8?= =?UTF-8?q?=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/design_generate.py | 24 +--------------------- app/service/utils/new_oss_client.py | 2 +- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index be015ac..f4012cf 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -2,6 +2,7 @@ import logging import threading import time +import requests from minio import Minio from app.core.config import * @@ -138,16 +139,9 @@ def design_generate(request_data): @RunTime def design_generate_v2(request_data): objects_data = request_data.dict()['objects'] - # process_id = request_data.dict()['process_id'] - # object_response = {} threads = [] - active_threads = 0 - lock = threading.Lock() - - # total = len(objects_data) def process_object(step, object): - nonlocal active_threads basic = object['basic'] items_response = { 'layers': [], @@ -219,26 +213,10 @@ def design_generate_v2(request_data): logger.info(response.text) logger.info(items_response) - # update_progress(process_id, total) - - # with lock: - # object_response[step] = items_response - # active_threads -= 1 - for step, object in enumerate(objects_data): t = threading.Thread(target=process_object, args=(step, object)) threads.append(t) t.start() - # with lock: - # active_threads += 1 - - # for t in threads: - # t.join() - # final_progress(process_id) - # return object_response - - -import requests def post_request(url, data=None, json_data=None, headers=None, auth=None, timeout=5): diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 6dd22bd..4b3cbb1 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url ="aida-users/89/test/123-89.png" + url = "aida-users/89/test/123-89.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From 20fd1e5d62e74e2c293980074c2215eabe33360c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 20:44:19 +0800 Subject: [PATCH 34/39] =?UTF-8?q?feat=20=20generate=20img=20api=20?= =?UTF-8?q?=E6=B3=A8=E9=87=8A=E4=BF=AE=E6=94=B9=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 3dee667..b3ea61c 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -26,6 +26,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun - **mode**: 生成模式,img2img或者txt2img - **category**: 生成图片的类别,sketch print 等等 - **gender**: 生成sketch专用,服装类别 + - **version**: 使用模型版本 fast 或者 high 示例参数: { From 532e56af6303a845405d82672e521b7660da9231 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 20:44:45 +0800 Subject: [PATCH 35/39] =?UTF-8?q?feat=20=20generate=20img=20api=20?= =?UTF-8?q?=E6=B3=A8=E9=87=8A=E4=BF=AE=E6=94=B9=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index b3ea61c..53790a3 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -35,7 +35,8 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", "mode": "img2img", "category": "sketch", - "gender": "male" + "gender": "male", + "version": "fast" } """ try: From bc9aa034458be245ad7cc8d7be6258da14e83e60 Mon Sep 17 00:00:00 2001 From: xupei Date: Mon, 2 Dec 2024 18:20:45 +0800 Subject: [PATCH 36/39] =?UTF-8?q?=E7=BF=BB=E8=AF=91=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E4=BD=BF=E7=94=A8llama3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatgpt_for_translation.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index 193bcfc..05d85fb 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -1,5 +1,7 @@ -import logging +import json + +import requests from dashscope import Generation from requests import RequestException from retry import retry @@ -15,6 +17,15 @@ from app.core.config import QWEN_API_KEY # openai_api_key=OPENAI_API_KEY, # temperature=0) +prefix_for_llama = ( + """ + Translate everything within the brackets [] into English. + Never translate or modify any English input. + The input must be fully translated into coherent English sentences. + Please only output the translated result.\n + """ + ) + def translate_to_en(text): template = ( @@ -52,6 +63,12 @@ def translate_to_en(text): print("input : {}, translate result : {}".format(text, assistant_output.content)) return assistant_output.content + # llama3专用 + # data = get_translation_from_llama3(text) + # translation = data + # # print("Response from llama3 : " + translation) + # return translation + @retry(exceptions=RequestException, tries=3, delay=1) def get_response(messages): @@ -65,6 +82,36 @@ def get_response(messages): ) return response + +def get_translation_from_llama3(text): + url = "http://localhost:11434/api/generate" + # url = "http://10.1.1.240:1143/api/generate" + + prompt = f"System: {prefix_for_llama}\nUser:[{text}]" + + # 创建请求的负载 + payload = { + "model": "llama3.2", + "prompt": prompt, + "stream": False + } + + # 将负载转换为 JSON 格式 + headers = {'Content-Type': 'application/json'} + response = requests.post(url, data=json.dumps(payload), headers=headers) + + # 处理响应 + if response.status_code == 200: + # print("Response from server:") + # print(response.json()) + resp = json.loads(response.content).get("response") + print("input : {}, translate result : {}".format(text, resp)) + return resp + else: + print(f"Request failed with status code {response.status_code}") + print(response.text) + + def main(): """Main function""" text = translate_to_en("fire") From ea3e6667a051c2003da5055032ea486f0b338b6a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 20:13:33 +0800 Subject: [PATCH 37/39] =?UTF-8?q?feat=20=20translator=20=E5=88=87=E6=8D=A2?= =?UTF-8?q?ollama=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatgpt_for_translation.py | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index 05d85fb..5d720b9 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -1,6 +1,5 @@ import json - import requests from dashscope import Generation from requests import RequestException @@ -17,35 +16,35 @@ from app.core.config import QWEN_API_KEY # openai_api_key=OPENAI_API_KEY, # temperature=0) -prefix_for_llama = ( - """ - Translate everything within the brackets [] into English. - Never translate or modify any English input. - The input must be fully translated into coherent English sentences. - Please only output the translated result.\n - """ - ) +# prefix_for_llama = ( +# """ +# Translate everything within the brackets [] into English. +# Never translate or modify any English input. +# The input must be fully translated into coherent English sentences. +# Please only output the translated result.\n +# """ +# ) def translate_to_en(text): - template = ( - """You are a translation expert, proficient in various languages. - And can translate various languages into English. - Please translate to grammatically correct English regardless of the input language. - If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", - output the input text exactly as it is without any modifications or additions. - If there are grammatical errors, correct them and then output the sentence.""" - ) - - prefix = ( - """ - Translate everything within the brackets [] into English. - Never translate or modify any English input. - The input must be fully translated into coherent English sentences. - Never present the translation results in the format - "The translation of \"Material suave\" into English would be \"Smooth material.\"". Instead, directly output "Smooth material". - """ - ) + # template = ( + # """You are a translation expert, proficient in various languages. + # And can translate various languages into English. + # Please translate to grammatically correct English regardless of the input language. + # If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", + # output the input text exactly as it is without any modifications or additions. + # If there are grammatical errors, correct them and then output the sentence.""" + # ) + # + # prefix = ( + # """ + # Translate everything within the brackets [] into English. + # Never translate or modify any English input. + # The input must be fully translated into coherent English sentences. + # Never present the translation results in the format + # "The translation of \"Material suave\" into English would be \"Smooth material.\"". Instead, directly output "Smooth material". + # """ + # ) messages = [ # { # Translate the entire text and ensure the output is a complete and coherent sentence in English. @@ -54,7 +53,7 @@ def translate_to_en(text): # }, { # "content": input('请输入:'), # 用户message - "content": prefix + text, # 用户message + "content": text, # 用户message "role": "user" } ] @@ -74,7 +73,7 @@ def translate_to_en(text): def get_response(messages): response = Generation.call( model='qwen-turbo', - api_key= QWEN_API_KEY, + api_key=QWEN_API_KEY, messages=messages, # seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234 result_format='message', # 将输出设置为message形式 @@ -84,15 +83,15 @@ def get_response(messages): def get_translation_from_llama3(text): - url = "http://localhost:11434/api/generate" + url = "http://10.1.1.240:11434/api/generate" # url = "http://10.1.1.240:1143/api/generate" - prompt = f"System: {prefix_for_llama}\nUser:[{text}]" + # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" # 创建请求的负载 payload = { - "model": "llama3.2", - "prompt": prompt, + "model": "translator", + "prompt": f"[{text}]", "stream": False } @@ -114,7 +113,7 @@ def get_translation_from_llama3(text): def main(): """Main function""" - text = translate_to_en("fire") + text = get_translation_from_llama3("[火焰]") print(text) From 5491c54bda681448fe93632e4898de6af82c58d4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 20:31:46 +0800 Subject: [PATCH 38/39] =?UTF-8?q?feat=20=20translator=20=E5=88=87=E6=8D=A2?= =?UTF-8?q?ollama=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_prompt_generation.py | 4 ++-- app/service/prompt_generation/chatgpt_for_translation.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index 59e5779..11733e8 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, HTTPException from app.schemas.prompt_generation import PromptGenerationImageModel from app.schemas.response_template import ResponseModel -from app.service.prompt_generation.chatgpt_for_translation import translate_to_en +from app.service.prompt_generation.chatgpt_for_translation import translate_to_en, get_translation_from_llama3 router = APIRouter() logger = logging.getLogger() @@ -26,7 +26,7 @@ def prompt_generation(request_data: PromptGenerationImageModel): """ try: logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") - data = translate_to_en("[" + request_data.text + "]") + data = get_translation_from_llama3("[" + request_data.text + "]") logger.info(f"prompt_generation response @@@@@@:{data}") except Exception as e: logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index 5d720b9..e541781 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -1,4 +1,6 @@ import json +import logging +import time import requests from dashscope import Generation @@ -7,6 +9,8 @@ from retry import retry from app.core.config import QWEN_API_KEY +logger = logging.getLogger(__name__) + # os.environ["http_proxy"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] = "http://127.0.0.1:7890" @@ -83,6 +87,7 @@ def get_response(messages): def get_translation_from_llama3(text): + start_time = time.time() url = "http://10.1.1.240:11434/api/generate" # url = "http://10.1.1.240:1143/api/generate" @@ -98,15 +103,16 @@ def get_translation_from_llama3(text): # 将负载转换为 JSON 格式 headers = {'Content-Type': 'application/json'} response = requests.post(url, data=json.dumps(payload), headers=headers) - # 处理响应 if response.status_code == 200: # print("Response from server:") # print(response.json()) resp = json.loads(response.content).get("response") + logger.info(f"translation server runtime is {time.time() - start_time} , response is {resp}") print("input : {}, translate result : {}".format(text, resp)) return resp else: + logger.info(f"translation server runtime is {time.time() - start_time} , response is {response.content}") print(f"Request failed with status code {response.status_code}") print(response.text) From c0abc8ffe945a3b9f11e0f45ffcee0500c4be9f9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 22:42:03 +0800 Subject: [PATCH 39/39] =?UTF-8?q?feat=20=20OLLAMA=5FURL=20=E5=88=87?= =?UTF-8?q?=E6=8D=A2=E5=88=B0A6000=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index d369ff2..7629429 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -138,7 +138,7 @@ SEGMENTATION = { "output": "seg_output__0", } # ollama config -OLLAMA_URL = "http://10.1.1.243:11434/api/embeddings" +OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings" # DESIGN config DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing"