From 482137ef8c4ba6708cdaf6b6a56e1dc2245b41d0 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 9 Dec 2024 13:34:08 +0800 Subject: [PATCH 01/37] =?UTF-8?q?design=20stream=E6=B5=81=20java=E5=9C=B0?= =?UTF-8?q?=E5=9D=80=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/design_generate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index f4012cf..04270b9 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -197,9 +197,8 @@ def design_generate_v2(request_data): 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, }) items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image']) - # 发送结果给java端 - url = "https://3998-117-143-125-51.ngrok-free.app/api/third/party/receiveDesignResults" + url = "https://develop.api.aida.com.hk/api/third/party/receiveDesignResults" headers = { 'Accept': "*/*", 'Accept-Encoding': "gzip, deflate, br", From 6e67248e5121b4a1da17c53a98235b53afef951a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 9 Dec 2024 13:37:21 +0800 Subject: [PATCH 02/37] =?UTF-8?q?design=20stream=E6=B5=81=20java=E5=9C=B0?= =?UTF-8?q?=E5=9D=80=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_test.py | 3 ++- app/core/config.py | 3 +++ app/service/design_fast/design_generate.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/app/api/api_test.py b/app/api/api_test.py index 1271f95..a273b11 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -4,7 +4,7 @@ import logging from fastapi import APIRouter from fastapi import HTTPException -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL from app.schemas.response_template import ResponseModel logger = logging.getLogger() @@ -18,6 +18,7 @@ def test(id: int): "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, + "JAVA_STREAM_API_URL": JAVA_STREAM_API_URL, "local_oss_server": OSS } logger.info(json.dumps(data)) diff --git a/app/core/config.py b/app/core/config.py index 7629429..1d11a0b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -34,6 +34,9 @@ else: RABBITMQ_ENV = "-dev" # 开发环境 # RABBITMQ_ENV = "-local" # 本地测试环境 +JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults") + + settings = Settings() # minio 配置 diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index 04270b9..7876239 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -198,7 +198,7 @@ def design_generate_v2(request_data): }) items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image']) # 发送结果给java端 - url = "https://develop.api.aida.com.hk/api/third/party/receiveDesignResults" + url = JAVA_STREAM_API_URL headers = { 'Accept': "*/*", 'Accept-Encoding': "gzip, deflate, br", From bcc46d16b34fd26fa6c09227c58281ce226c2a48 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 9 Dec 2024 14:28:42 +0800 Subject: [PATCH 03/37] =?UTF-8?q?design=20stream=E6=B5=81=20java=E5=9C=B0?= =?UTF-8?q?=E5=9D=80=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/api/api_design.py b/app/api/api_design.py index 665d544..fa4ad29 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -364,6 +364,7 @@ async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks """ try: # 异步 + logger.info(f"generate_image request item is : @@@@@@:{request_data}") logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict())}") background_tasks.add_task(design_generate_v2, request_data) except Exception as e: From 0eef1d82880ebb8d1df2538871f6bf6764243567 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 9 Dec 2024 14:34:16 +0800 Subject: [PATCH 04/37] =?UTF-8?q?design=20stream=E6=B5=81=20java=E5=9C=B0?= =?UTF-8?q?=E5=9D=80=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 1 - app/schemas/design.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index fa4ad29..665d544 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -364,7 +364,6 @@ async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks """ try: # 异步 - logger.info(f"generate_image request item is : @@@@@@:{request_data}") logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict())}") background_tasks.add_task(design_generate_v2, request_data) except Exception as e: diff --git a/app/schemas/design.py b/app/schemas/design.py index 7ebd8e6..dab80d2 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -4,6 +4,7 @@ from pydantic import BaseModel class DesignModel(BaseModel): objects: list[dict] process_id: str + requestId: str class DesignProgressModel(BaseModel): From 2c0a7729b85e354d7278d6c3501e4a41b2136a64 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 9 Dec 2024 14:48:54 +0800 Subject: [PATCH 05/37] =?UTF-8?q?design=20stream=E6=B5=81=20java=E5=9C=B0?= =?UTF-8?q?=E5=9D=80=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/design_generate.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index 7876239..4fc0726 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -139,6 +139,7 @@ def design_generate(request_data): @RunTime def design_generate_v2(request_data): objects_data = request_data.dict()['objects'] + request_id = request_data.requestId threads = [] def process_object(step, object): @@ -146,7 +147,7 @@ def design_generate_v2(request_data): items_response = { 'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "", - 'requestId': object['requestId'] if 'requestId' in object.keys() else "" + 'requestId': request_id } if basic['single_overall'] == "overall": item_results = [] @@ -206,11 +207,11 @@ def design_generate_v2(request_data): 'Connection': "keep-alive", 'Content-Type': "application/json" } + logger.info(items_response) response = post_request(url, json_data=items_response, headers=headers) if response: # 打印结果 logger.info(response.text) - logger.info(items_response) for step, object in enumerate(objects_data): t = threading.Thread(target=process_object, args=(step, object)) From e6f0ee7f3a89644a615a04509e2a299f59ebad57 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 11:06:20 +0800 Subject: [PATCH 06/37] design design batch --- app/core/config.py | 2 +- .../design_batch/design_batch_celery.py | 5 +- app/service/design_batch/item.py | 25 +++++- app/service/design_batch/pipeline/__init__.py | 2 + .../design_batch/pipeline/back_perspective.py | 79 +++++++++++++++++++ app/service/design_batch/pipeline/color.py | 7 ++ app/service/design_batch/pipeline/keypoint.py | 10 ++- app/service/design_batch/pipeline/loading.py | 5 ++ app/service/design_batch/pipeline/scale.py | 12 +++ .../design_batch/pipeline/segmentation.py | 29 +++++-- app/service/design_batch/pipeline/split.py | 27 ++++++- app/service/design_batch/service.py | 2 +- app/service/design_batch/utils/MQ.py | 9 ++- .../design_batch/utils/design_ensemble.py | 2 +- app/service/design_batch/utils/organize.py | 44 ++++++++++- .../design_batch/utils/synthesis_item.py | 21 ++++- app/service/design_batch/utils/transparent.py | 26 ++++++ 17 files changed, 281 insertions(+), 26 deletions(-) create mode 100644 app/service/design_batch/pipeline/back_perspective.py create mode 100644 app/service/design_batch/utils/transparent.py diff --git a/app/core/config.py b/app/core/config.py index 1d11a0b..9f11f75 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -20,7 +20,7 @@ class Settings(BaseSettings): OSS = "minio" -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" diff --git a/app/service/design_batch/design_batch_celery.py b/app/service/design_batch/design_batch_celery.py index 3f12862..9cca005 100644 --- a/app/service/design_batch/design_batch_celery.py +++ b/app/service/design_batch/design_batch_celery.py @@ -12,7 +12,7 @@ from app.service.design_batch.utils.save_json import oss_upload_json from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single id_lock = threading.Lock() -celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://') +celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.190:5672//', backend='rpc://') celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' celery_app.conf.worker_hijack_root_logger = False logging.getLogger('pika').setLevel(logging.WARNING) @@ -46,7 +46,7 @@ def process_layer(item, layers): layers.append(back_layer) -@celery_app.task +# @celery_app.task def batch_design(objects_data, tasks_id, json_name): object_response = [] threads = [] @@ -108,6 +108,7 @@ def batch_design(objects_data, tasks_id, json_name): with lock: object_response.append(items_response) + logger.info(items_response) publish_status(tasks_id, step + 1, items_response) active_threads -= 1 diff --git a/app/service/design_batch/item.py b/app/service/design_batch/item.py index cad1488..ec18b17 100644 --- a/app/service/design_batch/item.py +++ b/app/service/design_batch/item.py @@ -1,4 +1,4 @@ -from app.service.design_batch.pipeline import * +from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection class BaseItem: @@ -9,6 +9,27 @@ class BaseItem: self.result.update(basic) +class AccessoriesItem(BaseItem): + def __init__(self, data, basic, minio_client): + super().__init__(data, basic) + self.Accessories_pipeline = [ + LoadImage(minio_client), + # KeyPoint(), + ContourDetection(), + # Segmentation(minio_client), + # BackPerspective(minio_client), + Color(minio_client), + PrintPainting(minio_client), + Scaling(), + Split(minio_client) + ] + + def process(self): + for item in self.Accessories_pipeline: + self.result = item(self.result) + return self.result + + class TopItem(BaseItem): def __init__(self, data, basic, minio_client): super().__init__(data, basic) @@ -16,6 +37,7 @@ class TopItem(BaseItem): LoadImage(minio_client), KeyPoint(), Segmentation(minio_client), + # BackPerspective(minio_client), Color(minio_client), PrintPainting(minio_client), Scaling(), @@ -36,6 +58,7 @@ class BottomItem(BaseItem): KeyPoint(), ContourDetection(), # Segmentation(), + # BackPerspective(minio_client), Color(minio_client), PrintPainting(minio_client), Scaling(), diff --git a/app/service/design_batch/pipeline/__init__.py b/app/service/design_batch/pipeline/__init__.py index ec55933..f265bbe 100644 --- a/app/service/design_batch/pipeline/__init__.py +++ b/app/service/design_batch/pipeline/__init__.py @@ -1,3 +1,4 @@ +from .back_perspective import BackPerspective from .color import Color from .contour_detection import ContourDetection from .keypoint import KeyPoint @@ -13,6 +14,7 @@ __all__ = [ 'KeyPoint', 'ContourDetection', 'Segmentation', + 'BackPerspective', 'Color', 'PrintPainting', 'Scaling', diff --git a/app/service/design_batch/pipeline/back_perspective.py b/app/service/design_batch/pipeline/back_perspective.py new file mode 100644 index 0000000..5ddd37c --- /dev/null +++ b/app/service/design_batch/pipeline/back_perspective.py @@ -0,0 +1,79 @@ +import cv2 +import numpy as np + +from app.service.design_fast.utils.design_ensemble import get_seg_result +from app.service.utils.new_oss_client import oss_upload_image + + +class BackPerspective: + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + + # 如果sketch为系统图 查看是否有对应的 背后视角图 + if result['path'].split('/')[0] == 'aida-sys-image': + file_path = result['path'].replace("images", 'images_back', 1) + if self.is_file_exists(bucket_name='aida-sys-image', file_name=file_path[file_path.find('/') + 1:]): + result['back_perspective_url'] = file_path + return result + else: + seg_result = get_seg_result("1", result['image'])[0] + elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']: + seg_result = result['seg_result'] + else: + seg_result = get_seg_result("1", result['image'])[0] + + m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0)) + back_sketch = result['image'].copy() + back_sketch[m > 100] = 255 + # 上传背后视角图 + _, img_encoded = cv2.imencode(".jpg", back_sketch) + + resp = oss_upload_image(self.minio_client, bucket='test', object_name=result['path'], image_bytes=img_encoded.tobytes()) + result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}" + return result + + def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)): + mask = mask.astype(np.uint8) * 255 + # 查找轮廓 + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # 创建一个彩色副本用于绘制轮廓 + mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + + def thicken_contour_inward(contour, thick): + # 创建一个空白的黑色图像与原始掩码大小相同 + blank = np.zeros_like(mask) + # 在空白图像上绘制白色的轮廓 + cv2.drawContours(blank, [contour], -1, 255, thickness=thick) + # 找到轮廓的中心(可以用重心等方法近似) + M = cv2.moments(contour) + cx = int(M['m10'] / M['m00']) + cy = int(M['m01'] / M['m00']) + # 进行距离变换,离中心越近的值越小 + dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5) + # 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留 + result = np.zeros_like(mask) + for i in range(dist_transform.shape[0]): + for j in range(dist_transform.shape[1]): + if dist_transform[i, j] < thick: + result[i, j] = 255 + return result + + for contour in contours: + thickened_contour = thicken_contour_inward(contour, thickness) + mask_color[thickened_contour > 0] = color + + _, binary_result = cv2.threshold(mask_color, 127, 255, cv2.THRESH_BINARY) + + # 转换为掩码形式 + mask_result = cv2.cvtColor(binary_result, cv2.COLOR_BGR2GRAY) + return mask_result + + def is_file_exists(self, bucket_name, file_name): + try: + self.minio_client.stat_object(bucket_name, file_name) + return True + except Exception: + return False diff --git a/app/service/design_batch/pipeline/color.py b/app/service/design_batch/pipeline/color.py index 546c671..3033bb5 100644 --- a/app/service/design_batch/pipeline/color.py +++ b/app/service/design_batch/pipeline/color.py @@ -14,11 +14,18 @@ class Color: def __call__(self, result): dim_image_h, dim_image_w = result['image'].shape[0:2] + # 渐变色 if "gradient" in result.keys() and result['gradient'] != "": bucket_name = result['gradient'].split('/')[0] object_name = result['gradient'][result['gradient'].find('/') + 1:] pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name) resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + # 无色 + elif "color" not in result.keys() or result['color'] == "": + result['final_image'] = result['pattern_image'] = result['single_image'] = result['image'] + result['alpha'] = 100 / 255.0 + return result + # 正常颜色 else: pattern = self.get_pattern(result['color']) resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) diff --git a/app/service/design_batch/pipeline/keypoint.py b/app/service/design_batch/pipeline/keypoint.py index 313a613..73d7586 100644 --- a/app/service/design_batch/pipeline/keypoint.py +++ b/app/service/design_batch/pipeline/keypoint.py @@ -4,7 +4,8 @@ import numpy as np from pymilvus import MilvusClient from app.core.config import * -from app.service.design_batch.utils.design_ensemble import get_keypoint_result +from app.service.design_fast.utils.design_ensemble import get_keypoint_result +from app.service.utils.decorator import ClassCallRunTime, RunTime logger = logging.getLogger(__name__) @@ -16,14 +17,15 @@ class KeyPoint: def get_name(cls): return cls.name + @ClassCallRunTime def __call__(self, result): if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 # result['clothes_keypoint'] = self.infer_keypoint_result(result) site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) - keypoint_cache = self.keypoint_cache(result, site) + # keypoint_cache = self.keypoint_cache(result, site) + keypoint_cache = False # 取消向量查询 直接过模型推理 - # keypoint_cache = False if keypoint_cache is False: keypoint_infer_result, site = self.infer_keypoint_result(result) result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site) @@ -87,7 +89,7 @@ class KeyPoint: logger.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - # @ RunTime + @RunTime def keypoint_cache(self, result, site): try: client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) diff --git a/app/service/design_batch/pipeline/loading.py b/app/service/design_batch/pipeline/loading.py index 8f02378..5a55d9d 100644 --- a/app/service/design_batch/pipeline/loading.py +++ b/app/service/design_batch/pipeline/loading.py @@ -1,6 +1,9 @@ +import io import logging import cv2 +import numpy as np +from PIL import Image from app.service.utils.new_oss_client import oss_get_image @@ -71,6 +74,8 @@ class LoadImage: keypoint = 'head_point' elif name == 'earring': keypoint = 'ear_point' + elif name == 'accessories': + keypoint = "accessories" else: raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " f"bag, shoes, hairstyle, earring.") diff --git a/app/service/design_batch/pipeline/scale.py b/app/service/design_batch/pipeline/scale.py index 1908a9c..d1c7a36 100644 --- a/app/service/design_batch/pipeline/scale.py +++ b/app/service/design_batch/pipeline/scale.py @@ -46,4 +46,16 @@ class Scaling: result['scale'] = result['scale_bag'] elif result['keypoint'] == 'ear_point': result['scale'] = result['scale_earrings'] + elif result['keypoint'] == 'accessories': + # 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width) + # 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width + distance_clo = result['img_shape'][1] + distance_bdy = 320 / 2 + + if distance_clo == 0: + result['scale'] = 1 + else: + result['scale'] = distance_bdy / distance_clo + else: + result['scale'] = 1 return result diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index cba3446..ebf02b4 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -5,7 +5,8 @@ import cv2 import numpy as np from app.core.config import SEG_CACHE_PATH -from app.service.design_batch.utils.design_ensemble import get_seg_result +from app.service.design_fast.utils.design_ensemble import get_seg_result +from app.service.utils.decorator import ClassCallRunTime from app.service.utils.new_oss_client import oss_get_image logger = logging.getLogger() @@ -15,6 +16,7 @@ class Segmentation: def __init__(self, minio_client): self.minio_client = minio_client + @ClassCallRunTime def __call__(self, result): if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "": seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2") @@ -31,13 +33,26 @@ class Segmentation: result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255 result['mask'] = result['front_mask'] + result['back_mask'] else: - # 本地查询seg 缓存是否存在 - _, seg_result = self.load_seg_result(result["image_id"]) - result['seg_result'] = seg_result - if not _: + # preview 过模型 不缓存 + if "preview_submit" in result.keys() and result['preview_submit'] == "preview": + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + # submit 过模型 缓存 + elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": # 推理获得seg 结果 seg_result = get_seg_result(result["image_id"], result['image'])[0] self.save_seg_result(seg_result, result['image_id']) + # null 正常流程 加载本地缓存 无缓存则过模型 + else: + # 本地查询seg 缓存是否存在 + _, seg_result = self.load_seg_result(result["image_id"]) + # 判断缓存和实际图片size是否相同 + if not _ or result["image"].shape[:2] != seg_result.shape: + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + self.save_seg_result(seg_result, result['image_id']) + result['seg_result'] = seg_result + # 处理前片后片 temp_front = seg_result == 1.0 result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) @@ -48,7 +63,7 @@ class Segmentation: @staticmethod def save_seg_result(seg_result, image_id): - file_path = f"seg_cache/{image_id}.npy" + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) logger.info(f"保存成功 :{os.path.abspath(file_path)}") @@ -57,7 +72,7 @@ class Segmentation: @staticmethod def load_seg_result(image_id): - file_path = f"seg_cache/{image_id}.npy" + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) diff --git a/app/service/design_batch/pipeline/split.py b/app/service/design_batch/pipeline/split.py index 5dbcef5..344c5c5 100644 --- a/app/service/design_batch/pipeline/split.py +++ b/app/service/design_batch/pipeline/split.py @@ -7,10 +7,11 @@ from PIL import Image from cv2 import cvtColor, COLOR_BGR2RGBA from app.core.config import AIDA_CLOTHING -from app.service.design_batch.utils.conversion_image import rgb_to_rgba -from app.service.design_batch.utils.upload_image import upload_png_mask +from app.service.design_fast.utils.conversion_image import rgb_to_rgba +from app.service.design_fast.utils.transparent import sketch_to_transparent +from app.service.design_fast.utils.upload_image import upload_png_mask from app.service.utils.generate_uuid import generate_uuid -from app.service.utils.new_oss_client import oss_upload_image +from app.service.utils.new_oss_client import oss_upload_image, oss_get_image class Split(object): @@ -20,7 +21,7 @@ class Split(object): def __call__(self, result): try: - if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms','accessories'): front_mask = result['front_mask'] back_mask = result['back_mask'] rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) @@ -30,6 +31,24 @@ class Split(object): front_mask = cv2.resize(front_mask, new_size) result_front_image[front_mask != 0] = rgba_image[front_mask != 0] result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) + if 'transparent' in result.keys(): + # 用户自选区域transparent + transparent = result['transparent'] + if transparent['mask_url'] is not None and transparent['mask_url'] != "": + # 预处理用户自选区mask + seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2") + seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST) + # 转换颜色空间为 RGB(OpenCV 默认是 BGR) + image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB) + + r, g, b = cv2.split(image_rgb) + blue_mask = b > r + + # 创建红色和绿色掩码 + transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255 + result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"]) + else: + result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"]) result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None) height, width = front_mask.shape diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index ca6908e..e2a9b23 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id) + generate_clothes_task = batch_design(json.loads(file.decode())['objects'], data.total, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} diff --git a/app/service/design_batch/utils/MQ.py b/app/service/design_batch/utils/MQ.py index 50e98c2..d787bcb 100644 --- a/app/service/design_batch/utils/MQ.py +++ b/app/service/design_batch/utils/MQ.py @@ -2,9 +2,12 @@ import json import pika +from app.core.config import RABBITMQ_PARAMS + def publish_status(task_id, progress, result): - connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213')) + connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.190')) channel = connection.channel() channel.queue_declare(queue='DesignBatch', durable=True) message = {'task_id': task_id, 'progress': progress, "result": result} @@ -15,3 +18,7 @@ def publish_status(task_id, progress, result): delivery_mode=2, )) connection.close() + + +if __name__ == '__main__': + publish_status("1", "1", "1") diff --git a/app/service/design_batch/utils/design_ensemble.py b/app/service/design_batch/utils/design_ensemble.py index f4f6a34..267ea00 100644 --- a/app/service/design_batch/utils/design_ensemble.py +++ b/app/service/design_batch/utils/design_ensemble.py @@ -85,7 +85,7 @@ def seg_preprocess(img_path): if ori_shape != (img_scale_w, img_scale_h): # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) - img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape diff --git a/app/service/design_batch/utils/organize.py b/app/service/design_batch/utils/organize.py index 8190de0..33edc4f 100644 --- a/app/service/design_batch/utils/organize.py +++ b/app/service/design_batch/utils/organize.py @@ -33,8 +33,8 @@ def organize_clothing(layer): mask=cv2.resize(layer['mask'], layer["front_image"].size), gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", pattern_image_url=layer['pattern_image_url'], - pattern_image=layer['pattern_image'] - + pattern_image=layer['pattern_image'], + # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" ) # 后片数据 back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), @@ -50,6 +50,46 @@ def organize_clothing(layer): mask=cv2.resize(layer['mask'], layer["front_image"].size), gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", pattern_image_url=layer['pattern_image_url'], + # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" + ) + return front_layer, back_layer + + +def organize_accessories(layer): + # 起始坐标 + start_point = (0, 0) + # 前片数据 + front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None), + name=f'{layer["name"].lower()}_front', + image=layer["front_image"], + # mask_image=layer['front_mask_image'], + image_url=layer['front_image_url'], + mask_url=layer['mask_url'], + sacle=layer['scale'], + clothes_keypoint=(0, 0), + position=start_point, + resize_scale=layer["resize_scale"], + mask=cv2.resize(layer['mask'], layer["front_image"].size), + gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", + pattern_image_url=layer['pattern_image_url'], + pattern_image=layer['pattern_image'], + # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" + ) + # 后片数据 + back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), + name=f'{layer["name"].lower()}_back', + image=layer["back_image"], + # mask_image=layer['back_mask_image'], + image_url=layer['back_image_url'], + mask_url=layer['mask_url'], + sacle=layer['scale'], + clothes_keypoint=(0, 0), + position=start_point, + resize_scale=layer["resize_scale"], + mask=cv2.resize(layer['mask'], layer["front_image"].size), + gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", + pattern_image_url=layer['pattern_image_url'], + # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" ) return front_layer, back_layer diff --git a/app/service/design_batch/utils/synthesis_item.py b/app/service/design_batch/utils/synthesis_item.py index 272ab23..d7711f3 100644 --- a/app/service/design_batch/utils/synthesis_item.py +++ b/app/service/design_batch/utils/synthesis_item.py @@ -79,9 +79,11 @@ def synthesis(data, size, basic_info): _, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY) top_outer_mask = np.array(binary_body_mask) bottom_outer_mask = np.array(binary_body_mask) + accessories_outer_mask = np.array(binary_body_mask) top = True bottom = True + accessories = True i = len(data) while i: i -= 1 @@ -98,7 +100,7 @@ def synthesis(data, size, basic_info): background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] top_outer_mask = background + top_outer_mask elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]: - bottom = False + # bottom = False mask_shape = data[i]['mask'].shape y_offset, x_offset = data[i]['adaptive_position'] # 初始化叠加区域的起始和结束位置 @@ -109,10 +111,23 @@ def synthesis(data, size, basic_info): background = np.zeros_like(top_outer_mask) background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] bottom_outer_mask = background + bottom_outer_mask + elif accessories and data[i]['name'] in ['accessories_front']: + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['adaptive_position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) + background = np.zeros_like(top_outer_mask) + background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] + accessories_outer_mask = background + accessories_outer_mask + pass elif bottom is False and top is False: break all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask) + all_mask = cv2.bitwise_or(all_mask, accessories_outer_mask) for layer in data: if layer['image'] is not None: @@ -185,12 +200,14 @@ def update_base_size_priority(layers, size): # 计算透明背景图片的宽度 min_x = min(info['position'][1] for info in layers) x_list = [] + new_height = 700 for info in layers: if info['image'] is not None: x_list.append(info['position'][1] + info['image'].width) + if info['name'] == 'mannequin': + new_height = info['image'].height max_x = max(x_list) new_width = max_x - min_x - new_height = 700 # 更新坐标 for info in layers: info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x) diff --git a/app/service/design_batch/utils/transparent.py b/app/service/design_batch/utils/transparent.py new file mode 100644 index 0000000..3f73807 --- /dev/null +++ b/app/service/design_batch/utils/transparent.py @@ -0,0 +1,26 @@ +from PIL import Image + + +def sketch_to_transparent(image, mask, transparency): + # 打开原始图片 + image = image.convert("RGBA") + # 打开mask图片,假设mask图片是灰度图,白色区域为要处理的区域,黑色区域为保留的区域 + mask = Image.fromarray(mask) + + # 根据透明度调整因子,将透明度转换为0-255之间的值 + alpha_value = int((1 - transparency) * 255.0) + + # 获取图片的像素数据 + image_pixels = image.load() + mask_pixels = mask.load() + + width, height = image.size + + for y in range(height): + for x in range(width): + # 如果mask区域对应的像素为白色(值大于128,这里假设白色为要处理的区域,可根据实际情况调整) + if mask_pixels[x, y] > 128: + r, g, b, a = image_pixels[x, y] + image_pixels[x, y] = (r, g, b, alpha_value) + + return image From 88d4af499a7b5fee74bc1aeab01b961cf3b6fca8 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 11:06:37 +0800 Subject: [PATCH 07/37] design design batch --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 9f11f75..1d11a0b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -20,7 +20,7 @@ class Settings(BaseSettings): OSS = "minio" -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From 350c50888bac0a31aeff6c838b67cac23186708d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 11:16:41 +0800 Subject: [PATCH 08/37] design design batch --- app/api/api_design.py | 2 +- app/service/design_batch/utils/save_json.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index 665d544..f12e170 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -445,7 +445,7 @@ async def design(file: UploadFile = File(...), async def save_request_file(contents, file_name): # 创建保存文件的目录(如果不存在) - save_dir = os.path.join(os.getcwd(), "design_batch", "request_data") + save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data") if not os.path.exists(save_dir): os.makedirs(save_dir) # 处理文件 diff --git a/app/service/design_batch/utils/save_json.py b/app/service/design_batch/utils/save_json.py index 9acd916..f8f2925 100644 --- a/app/service/design_batch/utils/save_json.py +++ b/app/service/design_batch/utils/save_json.py @@ -1,13 +1,19 @@ import json import logging +import os logger = logging.getLogger() def oss_upload_json(oss_client, json_data, object_name): try: - with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file: + save_dir = os.path.join(os.getcwd(), "service/design_batch", "response_data") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # 处理文件 + file_path = os.path.join(save_dir, object_name) + with open(file_path, 'w') as file: json.dump(json_data, file, indent=4) - oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}") + oss_client.fput_object("test", object_name, file_path) except Exception as e: logger.warning(str(e)) From 7d54d398c549ad4383fc61e0e854a06634868223 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 11:21:03 +0800 Subject: [PATCH 09/37] design design batch --- app/service/design_batch/design_batch_celery.py | 2 +- app/service/design_batch/service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/service/design_batch/design_batch_celery.py b/app/service/design_batch/design_batch_celery.py index 9cca005..d1b4240 100644 --- a/app/service/design_batch/design_batch_celery.py +++ b/app/service/design_batch/design_batch_celery.py @@ -46,7 +46,7 @@ def process_layer(item, layers): layers.append(back_layer) -# @celery_app.task +@celery_app.task def batch_design(objects_data, tasks_id, json_name): object_response = [] threads = [] diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index e2a9b23..fcd278c 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design(json.loads(file.decode())['objects'], data.total, data.file_name) + generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} From 2a06825446a835c0cfaefdc1fb68fe811d86c7b6 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 11:35:41 +0800 Subject: [PATCH 10/37] design design batch --- app/service/design_batch/design_batch_celery.py | 2 +- app/service/design_batch/service.py | 2 +- app/service/design_batch/utils/save_json.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/app/service/design_batch/design_batch_celery.py b/app/service/design_batch/design_batch_celery.py index d1b4240..e8e4b9d 100644 --- a/app/service/design_batch/design_batch_celery.py +++ b/app/service/design_batch/design_batch_celery.py @@ -108,7 +108,7 @@ def batch_design(objects_data, tasks_id, json_name): with lock: object_response.append(items_response) - logger.info(items_response) + # logger.info(items_response) publish_status(tasks_id, step + 1, items_response) active_threads -= 1 diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index fcd278c..e2a9b23 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.file_name) + generate_clothes_task = batch_design(json.loads(file.decode())['objects'], data.total, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} diff --git a/app/service/design_batch/utils/save_json.py b/app/service/design_batch/utils/save_json.py index f8f2925..915df69 100644 --- a/app/service/design_batch/utils/save_json.py +++ b/app/service/design_batch/utils/save_json.py @@ -1,3 +1,4 @@ +import io import json import logging import os @@ -14,6 +15,7 @@ def oss_upload_json(oss_client, json_data, object_name): file_path = os.path.join(save_dir, object_name) with open(file_path, 'w') as file: json.dump(json_data, file, indent=4) - oss_client.fput_object("test", object_name, file_path) + json_bytes = json.dumps(json_data).encode('utf-8') + oss_client.put_object("test", object_name, io.BytesIO(json_bytes), length=len(json_bytes), content_type="application/json") except Exception as e: logger.warning(str(e)) From 91a21fe336581ee119a01af11e12be1aa153c462 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 11:36:52 +0800 Subject: [PATCH 11/37] design design batch --- app/service/design_batch/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index e2a9b23..fcd278c 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design(json.loads(file.decode())['objects'], data.total, data.file_name) + generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} From f7a6711434bdb73b1ec6001e6376a1af50110fee Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 11:38:46 +0800 Subject: [PATCH 12/37] design design batch --- app/service/design_batch/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index fcd278c..e2a9b23 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.file_name) + generate_clothes_task = batch_design(json.loads(file.decode())['objects'], data.total, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} From ea71a0a9a97c26c86cddbc529ad9a5e2abc949cc Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 13:50:40 +0800 Subject: [PATCH 13/37] Revert "design design batch" This reverts commit f7a6711434bdb73b1ec6001e6376a1af50110fee. --- app/service/design_batch/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index e2a9b23..fcd278c 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design(json.loads(file.decode())['objects'], data.total, data.file_name) + generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} From 731f07d252bea5da62dc3ff6bb62417a636a35e4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 13:50:42 +0800 Subject: [PATCH 14/37] Revert "design design batch" This reverts commit 91a21fe336581ee119a01af11e12be1aa153c462. --- app/service/design_batch/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index fcd278c..e2a9b23 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.file_name) + generate_clothes_task = batch_design(json.loads(file.decode())['objects'], data.total, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} From 84fe2663f490ddf95fc1adae7fb69bd79645e628 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 13:51:22 +0800 Subject: [PATCH 15/37] Revert "design design batch" This reverts commit e6f0ee7f --- .../design_batch/design_batch_celery.py | 3 +- app/service/design_batch/item.py | 25 +----- app/service/design_batch/pipeline/__init__.py | 2 - .../design_batch/pipeline/back_perspective.py | 79 ------------------- app/service/design_batch/pipeline/color.py | 7 -- app/service/design_batch/pipeline/keypoint.py | 10 +-- app/service/design_batch/pipeline/loading.py | 5 -- app/service/design_batch/pipeline/scale.py | 12 --- .../design_batch/pipeline/segmentation.py | 29 ++----- app/service/design_batch/pipeline/split.py | 27 +------ app/service/design_batch/service.py | 2 +- app/service/design_batch/utils/MQ.py | 9 +-- .../design_batch/utils/design_ensemble.py | 2 +- app/service/design_batch/utils/organize.py | 44 +---------- .../design_batch/utils/synthesis_item.py | 21 +---- app/service/design_batch/utils/transparent.py | 26 ------ 16 files changed, 24 insertions(+), 279 deletions(-) delete mode 100644 app/service/design_batch/pipeline/back_perspective.py delete mode 100644 app/service/design_batch/utils/transparent.py diff --git a/app/service/design_batch/design_batch_celery.py b/app/service/design_batch/design_batch_celery.py index e8e4b9d..3f12862 100644 --- a/app/service/design_batch/design_batch_celery.py +++ b/app/service/design_batch/design_batch_celery.py @@ -12,7 +12,7 @@ from app.service.design_batch.utils.save_json import oss_upload_json from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single id_lock = threading.Lock() -celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.190:5672//', backend='rpc://') +celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://') celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' celery_app.conf.worker_hijack_root_logger = False logging.getLogger('pika').setLevel(logging.WARNING) @@ -108,7 +108,6 @@ def batch_design(objects_data, tasks_id, json_name): with lock: object_response.append(items_response) - # logger.info(items_response) publish_status(tasks_id, step + 1, items_response) active_threads -= 1 diff --git a/app/service/design_batch/item.py b/app/service/design_batch/item.py index ec18b17..cad1488 100644 --- a/app/service/design_batch/item.py +++ b/app/service/design_batch/item.py @@ -1,4 +1,4 @@ -from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection +from app.service.design_batch.pipeline import * class BaseItem: @@ -9,27 +9,6 @@ class BaseItem: self.result.update(basic) -class AccessoriesItem(BaseItem): - def __init__(self, data, basic, minio_client): - super().__init__(data, basic) - self.Accessories_pipeline = [ - LoadImage(minio_client), - # KeyPoint(), - ContourDetection(), - # Segmentation(minio_client), - # BackPerspective(minio_client), - Color(minio_client), - PrintPainting(minio_client), - Scaling(), - Split(minio_client) - ] - - def process(self): - for item in self.Accessories_pipeline: - self.result = item(self.result) - return self.result - - class TopItem(BaseItem): def __init__(self, data, basic, minio_client): super().__init__(data, basic) @@ -37,7 +16,6 @@ class TopItem(BaseItem): LoadImage(minio_client), KeyPoint(), Segmentation(minio_client), - # BackPerspective(minio_client), Color(minio_client), PrintPainting(minio_client), Scaling(), @@ -58,7 +36,6 @@ class BottomItem(BaseItem): KeyPoint(), ContourDetection(), # Segmentation(), - # BackPerspective(minio_client), Color(minio_client), PrintPainting(minio_client), Scaling(), diff --git a/app/service/design_batch/pipeline/__init__.py b/app/service/design_batch/pipeline/__init__.py index f265bbe..ec55933 100644 --- a/app/service/design_batch/pipeline/__init__.py +++ b/app/service/design_batch/pipeline/__init__.py @@ -1,4 +1,3 @@ -from .back_perspective import BackPerspective from .color import Color from .contour_detection import ContourDetection from .keypoint import KeyPoint @@ -14,7 +13,6 @@ __all__ = [ 'KeyPoint', 'ContourDetection', 'Segmentation', - 'BackPerspective', 'Color', 'PrintPainting', 'Scaling', diff --git a/app/service/design_batch/pipeline/back_perspective.py b/app/service/design_batch/pipeline/back_perspective.py deleted file mode 100644 index 5ddd37c..0000000 --- a/app/service/design_batch/pipeline/back_perspective.py +++ /dev/null @@ -1,79 +0,0 @@ -import cv2 -import numpy as np - -from app.service.design_fast.utils.design_ensemble import get_seg_result -from app.service.utils.new_oss_client import oss_upload_image - - -class BackPerspective: - def __init__(self, minio_client): - self.minio_client = minio_client - - def __call__(self, result): - - # 如果sketch为系统图 查看是否有对应的 背后视角图 - if result['path'].split('/')[0] == 'aida-sys-image': - file_path = result['path'].replace("images", 'images_back', 1) - if self.is_file_exists(bucket_name='aida-sys-image', file_name=file_path[file_path.find('/') + 1:]): - result['back_perspective_url'] = file_path - return result - else: - seg_result = get_seg_result("1", result['image'])[0] - elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']: - seg_result = result['seg_result'] - else: - seg_result = get_seg_result("1", result['image'])[0] - - m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0)) - back_sketch = result['image'].copy() - back_sketch[m > 100] = 255 - # 上传背后视角图 - _, img_encoded = cv2.imencode(".jpg", back_sketch) - - resp = oss_upload_image(self.minio_client, bucket='test', object_name=result['path'], image_bytes=img_encoded.tobytes()) - result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}" - return result - - def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)): - mask = mask.astype(np.uint8) * 255 - # 查找轮廓 - contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - - # 创建一个彩色副本用于绘制轮廓 - mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) - - def thicken_contour_inward(contour, thick): - # 创建一个空白的黑色图像与原始掩码大小相同 - blank = np.zeros_like(mask) - # 在空白图像上绘制白色的轮廓 - cv2.drawContours(blank, [contour], -1, 255, thickness=thick) - # 找到轮廓的中心(可以用重心等方法近似) - M = cv2.moments(contour) - cx = int(M['m10'] / M['m00']) - cy = int(M['m01'] / M['m00']) - # 进行距离变换,离中心越近的值越小 - dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5) - # 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留 - result = np.zeros_like(mask) - for i in range(dist_transform.shape[0]): - for j in range(dist_transform.shape[1]): - if dist_transform[i, j] < thick: - result[i, j] = 255 - return result - - for contour in contours: - thickened_contour = thicken_contour_inward(contour, thickness) - mask_color[thickened_contour > 0] = color - - _, binary_result = cv2.threshold(mask_color, 127, 255, cv2.THRESH_BINARY) - - # 转换为掩码形式 - mask_result = cv2.cvtColor(binary_result, cv2.COLOR_BGR2GRAY) - return mask_result - - def is_file_exists(self, bucket_name, file_name): - try: - self.minio_client.stat_object(bucket_name, file_name) - return True - except Exception: - return False diff --git a/app/service/design_batch/pipeline/color.py b/app/service/design_batch/pipeline/color.py index 3033bb5..546c671 100644 --- a/app/service/design_batch/pipeline/color.py +++ b/app/service/design_batch/pipeline/color.py @@ -14,18 +14,11 @@ class Color: def __call__(self, result): dim_image_h, dim_image_w = result['image'].shape[0:2] - # 渐变色 if "gradient" in result.keys() and result['gradient'] != "": bucket_name = result['gradient'].split('/')[0] object_name = result['gradient'][result['gradient'].find('/') + 1:] pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name) resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) - # 无色 - elif "color" not in result.keys() or result['color'] == "": - result['final_image'] = result['pattern_image'] = result['single_image'] = result['image'] - result['alpha'] = 100 / 255.0 - return result - # 正常颜色 else: pattern = self.get_pattern(result['color']) resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) diff --git a/app/service/design_batch/pipeline/keypoint.py b/app/service/design_batch/pipeline/keypoint.py index 73d7586..313a613 100644 --- a/app/service/design_batch/pipeline/keypoint.py +++ b/app/service/design_batch/pipeline/keypoint.py @@ -4,8 +4,7 @@ import numpy as np from pymilvus import MilvusClient from app.core.config import * -from app.service.design_fast.utils.design_ensemble import get_keypoint_result -from app.service.utils.decorator import ClassCallRunTime, RunTime +from app.service.design_batch.utils.design_ensemble import get_keypoint_result logger = logging.getLogger(__name__) @@ -17,15 +16,14 @@ class KeyPoint: def get_name(cls): return cls.name - @ClassCallRunTime def __call__(self, result): if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 # result['clothes_keypoint'] = self.infer_keypoint_result(result) site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) - # keypoint_cache = self.keypoint_cache(result, site) - keypoint_cache = False + keypoint_cache = self.keypoint_cache(result, site) # 取消向量查询 直接过模型推理 + # keypoint_cache = False if keypoint_cache is False: keypoint_infer_result, site = self.infer_keypoint_result(result) result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site) @@ -89,7 +87,7 @@ class KeyPoint: logger.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - @RunTime + # @ RunTime def keypoint_cache(self, result, site): try: client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) diff --git a/app/service/design_batch/pipeline/loading.py b/app/service/design_batch/pipeline/loading.py index 5a55d9d..8f02378 100644 --- a/app/service/design_batch/pipeline/loading.py +++ b/app/service/design_batch/pipeline/loading.py @@ -1,9 +1,6 @@ -import io import logging import cv2 -import numpy as np -from PIL import Image from app.service.utils.new_oss_client import oss_get_image @@ -74,8 +71,6 @@ class LoadImage: keypoint = 'head_point' elif name == 'earring': keypoint = 'ear_point' - elif name == 'accessories': - keypoint = "accessories" else: raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " f"bag, shoes, hairstyle, earring.") diff --git a/app/service/design_batch/pipeline/scale.py b/app/service/design_batch/pipeline/scale.py index d1c7a36..1908a9c 100644 --- a/app/service/design_batch/pipeline/scale.py +++ b/app/service/design_batch/pipeline/scale.py @@ -46,16 +46,4 @@ class Scaling: result['scale'] = result['scale_bag'] elif result['keypoint'] == 'ear_point': result['scale'] = result['scale_earrings'] - elif result['keypoint'] == 'accessories': - # 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width) - # 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width - distance_clo = result['img_shape'][1] - distance_bdy = 320 / 2 - - if distance_clo == 0: - result['scale'] = 1 - else: - result['scale'] = distance_bdy / distance_clo - else: - result['scale'] = 1 return result diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index ebf02b4..cba3446 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -5,8 +5,7 @@ import cv2 import numpy as np from app.core.config import SEG_CACHE_PATH -from app.service.design_fast.utils.design_ensemble import get_seg_result -from app.service.utils.decorator import ClassCallRunTime +from app.service.design_batch.utils.design_ensemble import get_seg_result from app.service.utils.new_oss_client import oss_get_image logger = logging.getLogger() @@ -16,7 +15,6 @@ class Segmentation: def __init__(self, minio_client): self.minio_client = minio_client - @ClassCallRunTime def __call__(self, result): if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "": seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2") @@ -33,26 +31,13 @@ class Segmentation: result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255 result['mask'] = result['front_mask'] + result['back_mask'] else: - # preview 过模型 不缓存 - if "preview_submit" in result.keys() and result['preview_submit'] == "preview": - # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] - # submit 过模型 缓存 - elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": + # 本地查询seg 缓存是否存在 + _, seg_result = self.load_seg_result(result["image_id"]) + result['seg_result'] = seg_result + if not _: # 推理获得seg 结果 seg_result = get_seg_result(result["image_id"], result['image'])[0] self.save_seg_result(seg_result, result['image_id']) - # null 正常流程 加载本地缓存 无缓存则过模型 - else: - # 本地查询seg 缓存是否存在 - _, seg_result = self.load_seg_result(result["image_id"]) - # 判断缓存和实际图片size是否相同 - if not _ or result["image"].shape[:2] != seg_result.shape: - # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] - self.save_seg_result(seg_result, result['image_id']) - result['seg_result'] = seg_result - # 处理前片后片 temp_front = seg_result == 1.0 result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) @@ -63,7 +48,7 @@ class Segmentation: @staticmethod def save_seg_result(seg_result, image_id): - file_path = f"{SEG_CACHE_PATH}{image_id}.npy" + file_path = f"seg_cache/{image_id}.npy" try: np.save(file_path, seg_result) logger.info(f"保存成功 :{os.path.abspath(file_path)}") @@ -72,7 +57,7 @@ class Segmentation: @staticmethod def load_seg_result(image_id): - file_path = f"{SEG_CACHE_PATH}{image_id}.npy" + file_path = f"seg_cache/{image_id}.npy" logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) diff --git a/app/service/design_batch/pipeline/split.py b/app/service/design_batch/pipeline/split.py index 344c5c5..5dbcef5 100644 --- a/app/service/design_batch/pipeline/split.py +++ b/app/service/design_batch/pipeline/split.py @@ -7,11 +7,10 @@ from PIL import Image from cv2 import cvtColor, COLOR_BGR2RGBA from app.core.config import AIDA_CLOTHING -from app.service.design_fast.utils.conversion_image import rgb_to_rgba -from app.service.design_fast.utils.transparent import sketch_to_transparent -from app.service.design_fast.utils.upload_image import upload_png_mask +from app.service.design_batch.utils.conversion_image import rgb_to_rgba +from app.service.design_batch.utils.upload_image import upload_png_mask from app.service.utils.generate_uuid import generate_uuid -from app.service.utils.new_oss_client import oss_upload_image, oss_get_image +from app.service.utils.new_oss_client import oss_upload_image class Split(object): @@ -21,7 +20,7 @@ class Split(object): def __call__(self, result): try: - if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms','accessories'): + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): front_mask = result['front_mask'] back_mask = result['back_mask'] rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) @@ -31,24 +30,6 @@ class Split(object): front_mask = cv2.resize(front_mask, new_size) result_front_image[front_mask != 0] = rgba_image[front_mask != 0] result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) - if 'transparent' in result.keys(): - # 用户自选区域transparent - transparent = result['transparent'] - if transparent['mask_url'] is not None and transparent['mask_url'] != "": - # 预处理用户自选区mask - seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2") - seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST) - # 转换颜色空间为 RGB(OpenCV 默认是 BGR) - image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB) - - r, g, b = cv2.split(image_rgb) - blue_mask = b > r - - # 创建红色和绿色掩码 - transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255 - result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"]) - else: - result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"]) result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None) height, width = front_mask.shape diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index e2a9b23..ca6908e 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design(json.loads(file.decode())['objects'], data.total, data.file_name) + generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} diff --git a/app/service/design_batch/utils/MQ.py b/app/service/design_batch/utils/MQ.py index d787bcb..50e98c2 100644 --- a/app/service/design_batch/utils/MQ.py +++ b/app/service/design_batch/utils/MQ.py @@ -2,12 +2,9 @@ import json import pika -from app.core.config import RABBITMQ_PARAMS - def publish_status(task_id, progress, result): - connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - # connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.190')) + connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213')) channel = connection.channel() channel.queue_declare(queue='DesignBatch', durable=True) message = {'task_id': task_id, 'progress': progress, "result": result} @@ -18,7 +15,3 @@ def publish_status(task_id, progress, result): delivery_mode=2, )) connection.close() - - -if __name__ == '__main__': - publish_status("1", "1", "1") diff --git a/app/service/design_batch/utils/design_ensemble.py b/app/service/design_batch/utils/design_ensemble.py index 267ea00..f4f6a34 100644 --- a/app/service/design_batch/utils/design_ensemble.py +++ b/app/service/design_batch/utils/design_ensemble.py @@ -85,7 +85,7 @@ def seg_preprocess(img_path): if ori_shape != (img_scale_w, img_scale_h): # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) - # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape diff --git a/app/service/design_batch/utils/organize.py b/app/service/design_batch/utils/organize.py index 33edc4f..8190de0 100644 --- a/app/service/design_batch/utils/organize.py +++ b/app/service/design_batch/utils/organize.py @@ -33,8 +33,8 @@ def organize_clothing(layer): mask=cv2.resize(layer['mask'], layer["front_image"].size), gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", pattern_image_url=layer['pattern_image_url'], - pattern_image=layer['pattern_image'], - # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" + pattern_image=layer['pattern_image'] + ) # 后片数据 back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), @@ -50,46 +50,6 @@ def organize_clothing(layer): mask=cv2.resize(layer['mask'], layer["front_image"].size), gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", pattern_image_url=layer['pattern_image_url'], - # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" - ) - return front_layer, back_layer - - -def organize_accessories(layer): - # 起始坐标 - start_point = (0, 0) - # 前片数据 - front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None), - name=f'{layer["name"].lower()}_front', - image=layer["front_image"], - # mask_image=layer['front_mask_image'], - image_url=layer['front_image_url'], - mask_url=layer['mask_url'], - sacle=layer['scale'], - clothes_keypoint=(0, 0), - position=start_point, - resize_scale=layer["resize_scale"], - mask=cv2.resize(layer['mask'], layer["front_image"].size), - gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", - pattern_image_url=layer['pattern_image_url'], - pattern_image=layer['pattern_image'], - # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" - ) - # 后片数据 - back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), - name=f'{layer["name"].lower()}_back', - image=layer["back_image"], - # mask_image=layer['back_mask_image'], - image_url=layer['back_image_url'], - mask_url=layer['mask_url'], - sacle=layer['scale'], - clothes_keypoint=(0, 0), - position=start_point, - resize_scale=layer["resize_scale"], - mask=cv2.resize(layer['mask'], layer["front_image"].size), - gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", - pattern_image_url=layer['pattern_image_url'], - # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" ) return front_layer, back_layer diff --git a/app/service/design_batch/utils/synthesis_item.py b/app/service/design_batch/utils/synthesis_item.py index d7711f3..272ab23 100644 --- a/app/service/design_batch/utils/synthesis_item.py +++ b/app/service/design_batch/utils/synthesis_item.py @@ -79,11 +79,9 @@ def synthesis(data, size, basic_info): _, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY) top_outer_mask = np.array(binary_body_mask) bottom_outer_mask = np.array(binary_body_mask) - accessories_outer_mask = np.array(binary_body_mask) top = True bottom = True - accessories = True i = len(data) while i: i -= 1 @@ -100,7 +98,7 @@ def synthesis(data, size, basic_info): background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] top_outer_mask = background + top_outer_mask elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]: - # bottom = False + bottom = False mask_shape = data[i]['mask'].shape y_offset, x_offset = data[i]['adaptive_position'] # 初始化叠加区域的起始和结束位置 @@ -111,23 +109,10 @@ def synthesis(data, size, basic_info): background = np.zeros_like(top_outer_mask) background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] bottom_outer_mask = background + bottom_outer_mask - elif accessories and data[i]['name'] in ['accessories_front']: - mask_shape = data[i]['mask'].shape - y_offset, x_offset = data[i]['adaptive_position'] - # 初始化叠加区域的起始和结束位置 - all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) - all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) - # 将叠加区域赋值为相应的像素值 - _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) - background = np.zeros_like(top_outer_mask) - background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] - accessories_outer_mask = background + accessories_outer_mask - pass elif bottom is False and top is False: break all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask) - all_mask = cv2.bitwise_or(all_mask, accessories_outer_mask) for layer in data: if layer['image'] is not None: @@ -200,14 +185,12 @@ def update_base_size_priority(layers, size): # 计算透明背景图片的宽度 min_x = min(info['position'][1] for info in layers) x_list = [] - new_height = 700 for info in layers: if info['image'] is not None: x_list.append(info['position'][1] + info['image'].width) - if info['name'] == 'mannequin': - new_height = info['image'].height max_x = max(x_list) new_width = max_x - min_x + new_height = 700 # 更新坐标 for info in layers: info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x) diff --git a/app/service/design_batch/utils/transparent.py b/app/service/design_batch/utils/transparent.py deleted file mode 100644 index 3f73807..0000000 --- a/app/service/design_batch/utils/transparent.py +++ /dev/null @@ -1,26 +0,0 @@ -from PIL import Image - - -def sketch_to_transparent(image, mask, transparency): - # 打开原始图片 - image = image.convert("RGBA") - # 打开mask图片,假设mask图片是灰度图,白色区域为要处理的区域,黑色区域为保留的区域 - mask = Image.fromarray(mask) - - # 根据透明度调整因子,将透明度转换为0-255之间的值 - alpha_value = int((1 - transparency) * 255.0) - - # 获取图片的像素数据 - image_pixels = image.load() - mask_pixels = mask.load() - - width, height = image.size - - for y in range(height): - for x in range(width): - # 如果mask区域对应的像素为白色(值大于128,这里假设白色为要处理的区域,可根据实际情况调整) - if mask_pixels[x, y] > 128: - r, g, b, a = image_pixels[x, y] - image_pixels[x, y] = (r, g, b, alpha_value) - - return image From be16c95faa3a74f2767ed9ccdbe2f407ced14412 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 11 Dec 2024 14:40:58 +0800 Subject: [PATCH 16/37] design design batch --- app/service/design_batch/design_batch_celery.py | 4 ++-- app/service/design_batch/service.py | 2 +- app/service/design_batch/test.py | 2 +- app/service/design_batch/utils/MQ.py | 4 +++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/app/service/design_batch/design_batch_celery.py b/app/service/design_batch/design_batch_celery.py index 3f12862..06ccc5e 100644 --- a/app/service/design_batch/design_batch_celery.py +++ b/app/service/design_batch/design_batch_celery.py @@ -12,7 +12,7 @@ from app.service.design_batch.utils.save_json import oss_upload_json from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single id_lock = threading.Lock() -celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://') +celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True) celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' celery_app.conf.worker_hijack_root_logger = False logging.getLogger('pika').setLevel(logging.WARNING) @@ -120,7 +120,7 @@ def batch_design(objects_data, tasks_id, json_name): for t in threads: t.join() - + logger.debug(object_response) oss_upload_json(minio_client, object_response, json_name) publish_status(tasks_id, "ok", json_name) return object_response diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index ca6908e..e9fb814 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id) + generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.tasks_id, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} diff --git a/app/service/design_batch/test.py b/app/service/design_batch/test.py index 6b94bc6..2e74cc9 100644 --- a/app/service/design_batch/test.py +++ b/app/service/design_batch/test.py @@ -157,6 +157,6 @@ if __name__ == '__main__': ], "process_id": "83" } - task_id = 1 + task_id = 10086 json_name = "test.json" batch_design.delay(data['objects'], task_id, json_name) diff --git a/app/service/design_batch/utils/MQ.py b/app/service/design_batch/utils/MQ.py index 50e98c2..1b64bf3 100644 --- a/app/service/design_batch/utils/MQ.py +++ b/app/service/design_batch/utils/MQ.py @@ -2,9 +2,11 @@ import json import pika +from app.core.config import RABBITMQ_PARAMS + def publish_status(task_id, progress, result): - connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213')) + connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) channel = connection.channel() channel.queue_declare(queue='DesignBatch', durable=True) message = {'task_id': task_id, 'progress': progress, "result": result} From 89bd88ffee347ec1d4b3435580c7d58aef44a04f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 20 Dec 2024 09:48:10 +0800 Subject: [PATCH 17/37] =?UTF-8?q?design=20brand=20dna=20=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_brand_dna.py | 34 ++++ app/schemas/brand_dna.py | 6 + app/service/brand_dna/service.py | 335 +++++++++++++++++++++++++++++++ 3 files changed, 375 insertions(+) create mode 100644 app/api/api_brand_dna.py create mode 100644 app/schemas/brand_dna.py create mode 100644 app/service/brand_dna/service.py diff --git a/app/api/api_brand_dna.py b/app/api/api_brand_dna.py new file mode 100644 index 0000000..6b19416 --- /dev/null +++ b/app/api/api_brand_dna.py @@ -0,0 +1,34 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.brand_dna import BrandDnaModel +from app.schemas.response_template import ResponseModel +from app.service.brand_dna.service import BrandDna + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/seg_product") +def image2sketch(request_item: BrandDnaModel): + """ + 创建一个具有以下参数的请求体: + - **image_url**: 提取图片url + - **is_brand_dna**: 是否提取属性 + + 示例参数: + { + "image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg", + "is_brand_dna": False + } + """ + try: + logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = BrandDna(request_item) + result_url = service.get_result() + except Exception as e: + logger.warning(f"brand dna Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=result_url) diff --git a/app/schemas/brand_dna.py b/app/schemas/brand_dna.py new file mode 100644 index 0000000..c5ae2ab --- /dev/null +++ b/app/schemas/brand_dna.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class BrandDnaModel(BaseModel): + image_url: str + is_brand_dna: bool diff --git a/app/service/brand_dna/service.py b/app/service/brand_dna/service.py new file mode 100644 index 0000000..012e682 --- /dev/null +++ b/app/service/brand_dna/service.py @@ -0,0 +1,335 @@ +import logging + +import cv2 +import mmcv +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient +from minio import Minio + +from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL +from app.schemas.brand_dna import BrandDnaModel +from app.service.attribute.config import local_debug_const +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.new_oss_client import oss_upload_image, oss_get_image + +logger = logging.getLogger() + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +class BrandDna: + def __init__(self, request_item): + self.sketch_bucket = "test" + self.image_url = request_item.image_url + self.is_brand_dna = request_item.is_brand_dna + # self.attr_type = pd.read_csv(CATEGORY_PATH) + self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv") + self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000') + # self.const = const + self.const = local_debug_const + + # 获取结果 + def get_result(self): + mask, image = self.get_seg_mask() + cv2.imshow("", image) + cv2.waitKey(0) + + height, width, channels = image.shape + result_dict = [] + white_img = np.ones((height, width, channels), dtype=image.dtype) * 255 + mask_image = np.zeros((height, width, 3)) + + for value in np.unique(mask): + if value == 1: + outwear_img = white_img.copy() + outwear_mask_img = mask_image.copy() + outwear_img[mask == value] = image[mask == value] + outwear_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", outwear_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(outwear_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + attribute = {} + mask_url = "" + img_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + img_url = self.put_image(outwear_img, f"img/{generate_uuid()}") + mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + if value == 2: + tops_img = white_img.copy() + tops_mask_img = mask_image.copy() + tops_img[mask == value] = image[mask == value] + tops_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", tops_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(tops_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + # 属性检测 + attribute = {} + img_url = "" + mask_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}") + img_url = self.put_image(tops_img, f"img/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + if value == 3: + bottoms_img = white_img.copy() + bottoms_mask_img = mask_image.copy() + bottoms_img[mask == value] = image[mask == value] + bottoms_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", bottoms_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(bottoms_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + attribute = {} + img_url = "" + mask_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}") + mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + return result_dict + + # 获取product mask + def get_seg_mask(self): + input_image = self.get_image() + input_img, ori_shape = self.seg_product_preprocess(input_image) + transformed_img = input_img.astype(np.float32) + + inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)] + results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs) + inference_output1 = results.as_numpy("seg_output__0") + mask = self.product_postprocess(inference_output1, ori_shape)[0] + return mask, input_image + + # 获取图片 + def get_image(self): + image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") + # 将其转换为彩色图像 + if len(image.shape) == 3 and image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + elif len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + return image + # return cv2.imread(self.image_url) + + # 图片上传 + def put_image(self, image, object_name): + try: + image_bytes = cv2.imencode('.jpg', image)[1].tobytes() + oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes) + return f"{self.sketch_bucket}/{object_name}.jpg" + except Exception as e: + logger.warning(e) + + # 服装分割预处理 + @staticmethod + def seg_product_preprocess(image): + img = mmcv.imread(image) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + # 如果图片size任意一边 大于 1024, 则会resize 成1024 + if ori_shape != (img_scale_w, img_scale_h): + # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 + img = cv2.resize(img, (img_scale_h, img_scale_w)) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + # 类别检测后处理 + @staticmethod + def product_postprocess(output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + # 类别检测模型预处理 + @staticmethod + def category_preprocess(img): + img = mmcv.imread(img) + # ori_shape = img.shape[:2] + img_scale = (224, 224) + img = cv2.resize(img, img_scale) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img + + # 类别检测 + def recognition_category(self, image): + inputs = [ + httpclient.InferInput("input__0", image.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(image, binary_data=True) + results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs) + inference_output = torch.from_numpy(results.as_numpy(f'output__0')) + + scores = inference_output.detach().numpy() + + colattr = list(self.attr_type['labelName']) + maxsc = np.max(scores[0][:5]) + indexs = np.argwhere(scores == maxsc)[:, 1] + + return colattr[indexs[0]] + + # 属性检测 + def recognition_attribute(self, model_name, description, image): + attr_type = pd.read_csv(description) + inputs = [ + httpclient.InferInput("input__0", image.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(image, binary_data=True) + results = self.att_client.infer(model_name=model_name, inputs=inputs) + inference_output = torch.from_numpy(results.as_numpy(f"output__0")) + scores = inference_output.detach().numpy() + colattr = list(attr_type['labelName']) + task = description.split('/')[-1][:-4] + maxsc = np.max(scores[0][:5]) + indexs = np.argwhere(scores == maxsc)[:, 1] + attr = { + task: [] + } + for i in range(len(indexs)): + atr = colattr[indexs[i]] + attr[task].append(atr) + return attr + + # 获取属性检测结果 + def get_recognition_attribute_result(self, category, input_img): + if category == "Blouse": + attr_dict = {} + for i in range(len(self.const.top_description_list)): + attr_description = self.const.top_description_list[i] + attr_model_path = self.const.top_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Trousers' or category == "Skirt": + attr_dict = {} + for i in range(len(self.const.bottom_description_list)): + attr_description = self.const.bottom_description_list[i] + attr_model_path = self.const.bottom_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Dress': + attr_dict = {} + for i in range(len(self.const.dress_description_list)): + attr_description = self.const.dress_description_list[i] + attr_model_path = self.const.dress_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Outwear': + attr_dict = {} + for i in range(len(self.const.outwear_description_list)): + attr_description = self.const.outwear_description_list[i] + attr_model_path = self.const.outwear_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + else: + attr_dict = {} + return attr_dict + + @staticmethod + def merge(dict1, dict2): + res = {**dict1, **dict2} + return res + + +if __name__ == '__main__': + # for path in os.listdir('./test_img'): + # img_path = os.path.join(r'./test_img', path) + # request_item = BrandDnaModel( + # image_url=img_path, + # is_brand_dna=True + # ) + # service = BrandDna(request_item) + # result_url = service.get_result() + # print(result_url) + request_item = BrandDnaModel( + image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png", + is_brand_dna=True + ) + service = BrandDna(request_item) + result_url = service.get_result() + print(result_url) From 4896975c5b715e328830d9c76c176b23549609fb Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 20 Dec 2024 17:50:53 +0800 Subject: [PATCH 18/37] =?UTF-8?q?design=20brand=20dna=20=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_route.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/app/api/api_route.py b/app/api/api_route.py index 0da3a66..973a940 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from app.api import api_attribute_retrieve, api_query_image +from app.api import api_brand_dna from app.api import api_brighten from app.api import api_chat_robot from app.api import api_design @@ -23,4 +24,5 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'], router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") -router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") \ No newline at end of file +router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") +router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") From cce606d2a83dd1e316a0d2b02127e04545d429db Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 23 Dec 2024 10:19:51 +0800 Subject: [PATCH 19/37] =?UTF-8?q?fix=20design=20stream=20=E4=B8=8E?= =?UTF-8?q?=E6=AD=A3=E5=B8=B8=E7=89=88=E6=9C=AC=E5=8C=BA=E5=88=86=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 4 ++-- app/schemas/design.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index f12e170..ee4a651 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -4,7 +4,7 @@ import os from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks -from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel +from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel from app.schemas.response_template import ResponseModel from app.service.design.model_process_service import model_transpose from app.service.design_batch.service import start_design_batch_generate @@ -197,7 +197,7 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks): @router.post("/design_v2") -async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks): +async def design_v2(request_data: DesignStreamModel, background_tasks: BackgroundTasks): """ 创建一个具有以下参数的请求体: 示例参数: diff --git a/app/schemas/design.py b/app/schemas/design.py index dab80d2..98d0a29 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -4,6 +4,11 @@ from pydantic import BaseModel class DesignModel(BaseModel): objects: list[dict] process_id: str + + +class DesignStreamModel(BaseModel): + objects: list[dict] + process_id: str requestId: str From 65f0678ce57fedbf1c6bfe434ee1f88cc7581540 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 23 Dec 2024 10:21:09 +0800 Subject: [PATCH 20/37] =?UTF-8?q?fix=20design=20stream=20=E4=B8=8E?= =?UTF-8?q?=E6=AD=A3=E5=B8=B8=E7=89=88=E6=9C=AC=E5=8C=BA=E5=88=86=20?= =?UTF-8?q?=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3f9e525..5a29a1f 100644 --- a/.gitignore +++ b/.gitignore @@ -136,4 +136,9 @@ app/logs/* /qodana.yaml .pth .pytorch -*.png \ No newline at end of file +*.png +*.pth +*.db +*.npy +*.pytorch +*.jpg \ No newline at end of file From 1ff5de111b2331d2237ff3322eea3fa8885aa4e7 Mon Sep 17 00:00:00 2001 From: xupei Date: Fri, 3 Jan 2025 14:02:38 +0800 Subject: [PATCH 21/37] =?UTF-8?q?BUGFIX:chat=20robot=E6=B7=BB=E5=8A=A0inte?= =?UTF-8?q?rnet=5Fsearch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_robot/script/service/CallQWen.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/app/service/chat_robot/script/service/CallQWen.py b/app/service/chat_robot/script/service/CallQWen.py index 33dcd04..781ba22 100644 --- a/app/service/chat_robot/script/service/CallQWen.py +++ b/app/service/chat_robot/script/service/CallQWen.py @@ -1,4 +1,5 @@ import json +import logging from dashscope import Generation from retry import retry @@ -159,10 +160,11 @@ def search_from_internet(message): model='qwen-turbo', api_key=QWEN_API_KEY, messages=message, - tools=tools, + prompt='The output must be in English.Keep the final result under 200 words.' + # tools=tools, # seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234 - result_format='message', # 将输出设置为message形式 - enable_search='True' + # result_format='message', # 将输出设置为message形式 + # enable_search='True' ) return response @@ -198,14 +200,9 @@ def get_response(messages): def call_with_messages(message, gender): + global tool_info user_input = message print('\n') - # messages = [ - # { - # "content": input('请输入:'), # 提问示例:"现在几点了?" "一个小时后几点" "北京天气如何?" - # "role": "user" - # } - # ] messages = [ { @@ -223,14 +220,10 @@ def call_with_messages(message, gender): } ] - # 模型的第一轮调用 - # first_response = get_response(messages) - # assistant_output = first_response.output.choices[0].message - # print(f"\n大模型第一轮输出信息:{first_response}\n") - # messages.append(assistant_output) flag = True count = 1 - result_content = "我是一个时尚AI助手,请问有什么可以帮您" + # result_content = "我是一个时尚AI助手,请问有什么可以帮您" + result_content = "I am a fashion AI assistant, how can I help you?" response_type = "chat" while flag and count <= 3: @@ -244,11 +237,15 @@ def call_with_messages(message, gender): print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容 result_content = assistant_output.content break - # 如果模型选择的工具是search_from_internet - # elif assistant_output.tool_calls[0]['function']['name'] == 'search_from_internet': - # tool_info = {"name": "search_from_internet", "role": "tool"} - # user_input = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['user_input'] - # tool_info['content'] = search_from_internet(user_input) + # 如果模型选择的工具是internet_search + elif assistant_output.tool_calls[0]['function']['name'] == 'internet_search': + tool_info = {"name": "search_from_internet", "role": "tool"} + message = [ + {'role': 'assistant', 'content': user_input} + ] + tool_info['content'] = search_from_internet(message) + flag = False + result_content = tool_info['content'].output.text # 如果模型选择的工具是get_database_table elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table': tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()} @@ -275,13 +272,16 @@ def call_with_messages(message, gender): flag = False result_content = tool_info['content'] response_type = "image" + else : + tool_info = {"name": assistant_output.tool_calls[0]['function']['name'], 'content': 'null'} + logging.info(assistant_output.tool_calls[0]['function']['name'] + "(unknown tools)") + flag = False print(f"工具输出信息:{tool_info['content']}\n") messages.append(tool_info) count += 1 - final_output = {"output": result_content} - final_output["response_type"] = response_type + final_output = {"output": result_content, "response_type": response_type} QWenCallbackHandler.on_chain_end(qwen, final_output) # 模型的第二轮调用,对工具的输出进行总结 From ff06db7ec110a5ee41b90677da020ee5b748d6a1 Mon Sep 17 00:00:00 2001 From: xupei Date: Fri, 3 Jan 2025 14:43:05 +0800 Subject: [PATCH 22/37] =?UTF-8?q?=E4=BC=98=E5=8C=96=EF=BC=9A=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E6=A8=A1=E5=9E=8B=E6=8F=90=E5=8F=96=E5=90=8E=E7=9A=84?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E8=BE=93=E5=85=A5=EF=BC=8C=E5=86=8D=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E5=88=B0=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_robot/script/service/CallQWen.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/app/service/chat_robot/script/service/CallQWen.py b/app/service/chat_robot/script/service/CallQWen.py index 781ba22..9f4abca 100644 --- a/app/service/chat_robot/script/service/CallQWen.py +++ b/app/service/chat_robot/script/service/CallQWen.py @@ -240,35 +240,37 @@ def call_with_messages(message, gender): # 如果模型选择的工具是internet_search elif assistant_output.tool_calls[0]['function']['name'] == 'internet_search': tool_info = {"name": "search_from_internet", "role": "tool"} + content = json.loads(assistant_output.tool_calls[0]['function']['arguments']) message = [ - {'role': 'assistant', 'content': user_input} + {'role': 'assistant', 'content': content['query']} ] tool_info['content'] = search_from_internet(message) flag = False result_content = tool_info['content'].output.text # 如果模型选择的工具是get_database_table - elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table': - tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()} - # 如果模型选择的工具是get_table_info - elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info': - tool_info = {"name": "get_table_info", "role": "tool"} - table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names'] - tool_info['content'] = get_table_info(table_names) - # 如果模型选择的工具是query_database - elif assistant_output.tool_calls[0]['function']['name'] == 'query_database': - tool_info = {"name": "query_database", "role": "tool"} - sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string'] - tool_info['content'] = query_database(sql_string) - flag = False - result_content = tool_info['content'] - response_type = "image" + # elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table': + # tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()} + # # 如果模型选择的工具是get_table_info + # elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info': + # tool_info = {"name": "get_table_info", "role": "tool"} + # table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names'] + # tool_info['content'] = get_table_info(table_names) + # # 如果模型选择的工具是query_database + # elif assistant_output.tool_calls[0]['function']['name'] == 'query_database': + # tool_info = {"name": "query_database", "role": "tool"} + # sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string'] + # tool_info['content'] = query_database(sql_string) + # flag = False + # result_content = tool_info['content'] + # response_type = "image" elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool': tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()} flag = False result_content = tool_info['content'] elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db': + content = json.loads(assistant_output.tool_calls[0]['function']['arguments']) tool_info = {"name": "get_image_from_vector_db", "role": "tool", - 'content': get_image_from_vector_db(gender, user_input)} + 'content': get_image_from_vector_db(gender, content['parameters']['content'])} flag = False result_content = tool_info['content'] response_type = "image" From 0e5f1ae1fadd3781945501788f260b199b88fa83 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 8 Jan 2025 15:44:59 +0800 Subject: [PATCH 23/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20sketch=20=E5=A4=9A=E8=A7=86=E8=A7=92=E5=9B=BE?= =?UTF-8?q?=E7=94=9F=E6=88=90=E5=8A=9F=E8=83=BD=E6=8E=A5=E5=8F=A3=20fix?= =?UTF-8?q?=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 41 +++++- app/core/config.py | 6 + app/schemas/generate_image.py | 5 + .../service_generate_multi_view.py | 126 ++++++++++++++++++ app/service/utils/new_oss_client.py | 2 +- 5 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 app/service/generate_image/service_generate_multi_view.py diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 53790a3..a37bec3 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,9 +3,10 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel +from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel @@ -61,6 +62,44 @@ def generate_image(tasks_id: str): return ResponseModel(data=data['data']) +'''multi view''' + + +@router.post("/generate_multi_view") +def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **image_url**: 前视角图的输入,minio或S3 url 地址 + + 示例参数: + { + "tasks_id": "123-89", + "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg" + } + """ + try: + logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = GenerateMultiView(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_multi_view_cancel/{tasks_id}") +def generate_multi_view(tasks_id: str): + try: + logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") + data = generate_multi_view_cancel(tasks_id) + logger.info(f"generate_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) + + '''single logo''' diff --git a/app/core/config.py b/app/core/config.py index 1d11a0b..7456912 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -109,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl' GI_MODEL_URL = '10.1.1.240:10061' GI_MODEL_NAME = 'flux' +GMV_MODEL_URL = '10.1.1.243:10081' +GMV_MODEL_NAME = 'multi_view' + +GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}") + + GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 11e295f..7181418 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -1,6 +1,11 @@ from pydantic import BaseModel +class GenerateMultiViewModel(BaseModel): + tasks_id: str + image_url: str + + class GenerateImageModel(BaseModel): tasks_id: str prompt: str diff --git a/app/service/generate_image/service_generate_multi_view.py b/app/service/generate_image/service_generate_multi_view.py new file mode 100644 index 0000000..c930ab2 --- /dev/null +++ b/app/service/generate_image/service_generate_multi_view.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time + +import numpy as np +import redis +import tritonclient.grpc as grpcclient + +from app.core.config import * +from app.schemas.generate_image import GenerateMultiViewModel +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() + + +class GenerateMultiView: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL) + + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.image = self.get_image(request_data.image_url) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + try: + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + return image + except Exception as e: + logger.error(e) + + def callback(self, result, error): + if error: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(error) + # self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: + # pil图像转成numpy数组 + images = result.as_numpy("generated_image") + # for id, img in enumerate(images): + # cv2.imwrite(f"{id}.png", img) + # image_url = "" + image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + images = [np.array(self.image).astype(np.uint8)] * 1 + + image_obj = np.array(images, dtype=np.uint8) + + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + + input_image.set_data_from_numpy(image_obj) + + inputs = [input_image] + ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback) + + time_out = 600 + generate_data = None + while time_out > 0: + generate_data, _ = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + return generate_data + except Exception as e: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data) + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateMultiViewModel( + tasks_id="123-89", + image_url="aida-sys-image/images/female/outwear/0628000123.jpg", + ) + server = GenerateMultiView(rd) + print(server.get_result()) diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 4b3cbb1..7939333 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/89/test/123-89.png" + url = "aida-users/89/123-89.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From 9e1a1996c7e747b468713eabbe65233bbc163c4b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:02:28 +0800 Subject: [PATCH 24/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20RunT?= =?UTF-8?q?ime=E5=88=A4=E5=AE=9A=E4=BF=AE=E6=94=B9=E4=B8=BA0.05=E7=A7=92?= =?UTF-8?q?=E5=86=85=E8=A7=A6=E5=8F=91=E6=89=93=E5=8D=B0=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/utils/decorator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/app/service/utils/decorator.py b/app/service/utils/decorator.py index 3e86182..bc7e9f1 100644 --- a/app/service/utils/decorator.py +++ b/app/service/utils/decorator.py @@ -7,9 +7,9 @@ def RunTime(func): t1 = time.time() res = func(*args, **kwargs) t2 = time.time() - # if t2 - t1 > 0.05: - # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") - logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + if t2 - t1 > 0.05: + logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") return res return wrapper From f6b2e283f93eedb903b63a039950cfda7e727298 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:04:13 +0800 Subject: [PATCH 25/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20RunT?= =?UTF-8?q?ime=E5=88=A4=E5=AE=9A=E4=BF=AE=E6=94=B9=E4=B8=BA0.05=E7=A7=92?= =?UTF-8?q?=E5=86=85=E8=A7=A6=E5=8F=91=E6=89=93=E5=8D=B0=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/segmentation.py | 2 +- app/service/design_batch/pipeline/segmentation.py | 2 +- app/service/design_fast/pipeline/segmentation.py | 2 +- app/service/design_pre_processing/service.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index 19eb1fd..abd30e6 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -64,7 +64,7 @@ class Segmentation(object): seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index cba3446..8ebcd3a 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -63,7 +63,7 @@ class Segmentation: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index ebf02b4..5e392ed 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -78,7 +78,7 @@ class Segmentation: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index 16ca870..e6dc951 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -266,7 +266,7 @@ class DesignPreprocessing: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logging.info("文件不存在") + # logging.info("文件不存在") return False, None except Exception as e: logging.warning(f"加载失败: {e}") From 48cfb4c1caa69ae2100d956fbd25e7fb56d323ef Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:07:06 +0800 Subject: [PATCH 26/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20?= =?UTF-8?q?=E5=85=B3=E9=97=AD=E8=B0=83=E8=AF=95=E6=97=A5=E5=BF=97=E6=89=93?= =?UTF-8?q?=E5=8D=B0=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_batch/pipeline/segmentation.py | 2 +- app/service/design_fast/design_generate.py | 2 +- app/service/design_fast/pipeline/segmentation.py | 2 +- app/service/utils/decorator.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index 8ebcd3a..ca7da1c 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -58,7 +58,7 @@ class Segmentation: @staticmethod def load_seg_result(image_id): file_path = f"seg_cache/{image_id}.npy" - logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) return True, seg_result diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index 4fc0726..2f7fa93 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -207,7 +207,7 @@ def design_generate_v2(request_data): 'Connection': "keep-alive", 'Content-Type': "application/json" } - logger.info(items_response) + # logger.info(items_response) response = post_request(url, json_data=items_response, headers=headers) if response: # 打印结果 diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 5e392ed..4828b33 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -73,7 +73,7 @@ class Segmentation: @staticmethod def load_seg_result(image_id): file_path = f"{SEG_CACHE_PATH}{image_id}.npy" - logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) return True, seg_result diff --git a/app/service/utils/decorator.py b/app/service/utils/decorator.py index bc7e9f1..c0164ab 100644 --- a/app/service/utils/decorator.py +++ b/app/service/utils/decorator.py @@ -22,7 +22,8 @@ def ClassCallRunTime(func): end_time = time.time() execution_time = end_time - start_time class_name = args[0].__class__.__name__ # 获取类名 - print(f"class name: {class_name} , run time is : {execution_time} s") + if execution_time > 0.05: + logging.info(f"class name: {class_name} , run time is : {execution_time} s") return result return wrapper From 72a8f3f920c8bda553cb191bb7e68d40bff3f581 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:08:18 +0800 Subject: [PATCH 27/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20?= =?UTF-8?q?=E5=85=B3=E9=97=AD=E8=B0=83=E8=AF=95=E6=97=A5=E5=BF=97=E6=89=93?= =?UTF-8?q?=E5=8D=B0=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 5c15efe..2a5ad4d 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): """ try: for item in request_item: - logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") + logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") if DEBUG: service = AttributeRecognition(const=local_debug_const, request_data=request_item) else: service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() - logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") + logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) From f36edbe24817e2e26c8cca7c249dfc77700fcb43 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:09:59 +0800 Subject: [PATCH 28/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20?= =?UTF-8?q?=E5=85=B3=E9=97=AD=E8=B0=83=E8=AF=95=E6=97=A5=E5=BF=97=E6=89=93?= =?UTF-8?q?=E5=8D=B0=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/segmentation.py | 2 +- app/service/design_batch/pipeline/segmentation.py | 2 +- app/service/design_fast/pipeline/segmentation.py | 2 +- app/service/design_pre_processing/service.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index abd30e6..7ed43e5 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -53,7 +53,7 @@ class Segmentation(object): file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index ca7da1c..aa05c0d 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -51,7 +51,7 @@ class Segmentation: file_path = f"seg_cache/{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 4828b33..1c3f5a0 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -66,7 +66,7 @@ class Segmentation: file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index e6dc951..636360c 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -277,7 +277,7 @@ class DesignPreprocessing: file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logging.info(f"保存成功,{os.path.abspath(file_path)}") + logging.debug(f"保存成功,{os.path.abspath(file_path)}") except Exception as e: logging.warning(f"保存失败: {e}") From 9c52811c050a1b6e373e08bf9fd3700e2bb098b1 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 13 Jan 2025 15:41:37 +0800 Subject: [PATCH 29/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20desi?= =?UTF-8?q?gn=20=E5=88=86=E5=89=B2=E9=A2=84=E5=A4=84=E7=90=86=E6=96=B0?= =?UTF-8?q?=E5=A2=9E25padding=EF=BC=8C=E5=90=8E=E5=A4=84=E7=90=86=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=8F=92=E5=80=BC=E5=A4=84=E7=90=86=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/segmentation.py | 10 +++++----- app/service/design_fast/utils/design_ensemble.py | 10 ++++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 1c3f5a0..0c9c51e 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -36,11 +36,11 @@ class Segmentation: # preview 过模型 不缓存 if "preview_submit" in result.keys() and result['preview_submit'] == "preview": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) # submit 过模型 缓存 elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) self.save_seg_result(seg_result, result['image_id']) # null 正常流程 加载本地缓存 无缓存则过模型 else: @@ -49,14 +49,14 @@ class Segmentation: # 判断缓存和实际图片size是否相同 if not _ or result["image"].shape[:2] != seg_result.shape: # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) self.save_seg_result(seg_result, result['image_id']) result['seg_result'] = seg_result # 处理前片后片 - temp_front = seg_result == 1.0 + temp_front = seg_result == 1 result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) - temp_back = seg_result == 2.0 + temp_back = seg_result == 2 result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) result['mask'] = result['front_mask'] + result['back_mask'] return result diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index 267ea00..9f30d0c 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -13,7 +13,6 @@ import cv2 import mmcv import numpy as np import torch -import torch.nn.functional as F import tritonclient.http as httpclient from app.core.config import * @@ -85,6 +84,9 @@ def seg_preprocess(img_path): if ori_shape != (img_scale_w, img_scale_h): # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) + + # 扩充25的白边 + img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape @@ -114,9 +116,9 @@ def get_seg_result(image_id, image): # no cache def seg_postprocess(image_id, output, ori_shape): - seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) - seg_pred = seg_logit.cpu().numpy() - return seg_pred[0] + seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50)) + seg_logit = seg_logit[25: - 25, 25: - 25] + return seg_logit def key_point_show(image_path, key_point_result=None): From 94000ceb21c70fd575a1bae97e1e857c9fd04874 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 13 Jan 2025 15:58:58 +0800 Subject: [PATCH 30/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20desi?= =?UTF-8?q?gn=20=E5=88=86=E5=89=B2=E9=A2=84=E5=A4=84=E7=90=86=E6=96=B0?= =?UTF-8?q?=E5=A2=9E25padding=EF=BC=8C=E5=90=8E=E5=A4=84=E7=90=86=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=8F=92=E5=80=BC=E5=A4=84=E7=90=86=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/utils/design_ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index 9f30d0c..bfc50c6 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -87,7 +87,7 @@ def seg_preprocess(img_path): # 扩充25的白边 img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) - # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape From edcd9b68cc8c09ca5b46d53f9c9c908d9b4ac427 Mon Sep 17 00:00:00 2001 From: xupei Date: Tue, 14 Jan 2025 09:40:15 +0800 Subject: [PATCH 31/37] =?UTF-8?q?=E7=BF=BB=E8=AF=91=E5=89=8D=E7=BD=AE?= =?UTF-8?q?=E8=AF=AD=E8=A8=80=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_prompt_generation.py | 2 +- app/service/chat_robot/script/prompt.py | 2 ++ .../chat_robot/script/service/CallQWen.py | 25 ++++++++++++++-- .../chatgpt_for_translation.py | 29 ++++++++++++++++++- 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index 11733e8..b731e33 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -26,7 +26,7 @@ def prompt_generation(request_data: PromptGenerationImageModel): """ try: logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") - data = get_translation_from_llama3("[" + request_data.text + "]") + data = get_translation_from_llama3(request_data.text) logger.info(f"prompt_generation response @@@@@@:{data}") except Exception as e: logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") diff --git a/app/service/chat_robot/script/prompt.py b/app/service/chat_robot/script/prompt.py index ad6ac9e..121e57d 100644 --- a/app/service/chat_robot/script/prompt.py +++ b/app/service/chat_robot/script/prompt.py @@ -69,3 +69,5 @@ TOOLS_FUNCTIONS_SUFFIX = ( ) TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now." + +GET_LANGUAGE_PREFIX = "Please identify the language. Only output the language name" diff --git a/app/service/chat_robot/script/service/CallQWen.py b/app/service/chat_robot/script/service/CallQWen.py index 9f4abca..cb7669b 100644 --- a/app/service/chat_robot/script/service/CallQWen.py +++ b/app/service/chat_robot/script/service/CallQWen.py @@ -8,7 +8,8 @@ from urllib3.exceptions import NewConnectionError from app.core.config import * from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.database import CustomDatabase -from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN +from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \ + GET_LANGUAGE_PREFIX from app.service.search_image_with_text.service import query get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." @@ -274,7 +275,7 @@ def call_with_messages(message, gender): flag = False result_content = tool_info['content'] response_type = "image" - else : + else: tool_info = {"name": assistant_output.tool_calls[0]['function']['name'], 'content': 'null'} logging.info(assistant_output.tool_calls[0]['function']['name'] + "(unknown tools)") flag = False @@ -300,5 +301,23 @@ def tutorial_tool(): return TUTORIAL_TOOL_RETURN +def get_language(message: str) -> str: + messages = [ + { + "content": message, # 用户message + "role": "user" + }, + { + "content": GET_LANGUAGE_PREFIX, # ai message + "role": "assistant" + } + ] + + first_response = get_response(messages) + assistant_output = first_response.output.choices[0].message.content + logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}") + return assistant_output + + if __name__ == '__main__': - call_with_messages() + get_language("") diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index e541781..e668500 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -8,6 +8,7 @@ from requests import RequestException from retry import retry from app.core.config import QWEN_API_KEY +from app.service.chat_robot.script.service.CallQWen import get_language logger = logging.getLogger(__name__) @@ -93,7 +94,13 @@ def get_translation_from_llama3(text): # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" - # 创建请求的负载 + # 先获取用户输入文本的语言 + language = get_language(text) + + if 'English' in language: + return text + + # 创建请求的负载 translator是自定义的翻译模型 payload = { "model": "translator", "prompt": f"[{text}]", @@ -117,6 +124,26 @@ def get_translation_from_llama3(text): print(response.text) +# 在llama3中创建一个翻译模型 +# def create_model_with_llama(text): +# url = "http://localhost:11434/api/create" +# # url = "http://10.1.1.240:1143/api/generate" +# +# # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" +# +# # 创建翻译器的配置文件 +# payload = { +# "model": "translator", +# "modelfile": "FROM llama3\nSYSTEM Translate everything within the brackets [] into English." +# "Never translate or modify any English input." +# "The input must be fully translated into coherent English sentences." +# } +# +# # 将负载转换为 JSON 格式 +# headers = {'Content-Type': 'application/json'} +# response = requests.post(url, data=json.dumps(payload), headers=headers) + + def main(): """Main function""" text = get_translation_from_llama3("[火焰]") From 001ae5b10b0123b90ac2ffef7bab0284685f10a7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 13 Jan 2025 15:41:37 +0800 Subject: [PATCH 32/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20design=20=E9=80=8F=E6=98=8E=E9=80=89=E5=8C=BA?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E6=8E=A5=E5=8F=A3=E8=AF=B4=E6=98=8E=20fix?= =?UTF-8?q?=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/app/api/api_design.py b/app/api/api_design.py index ee4a651..1c77ed8 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -18,6 +18,14 @@ logger = logging.getLogger() @router.post("/design") def design(request_data: DesignModel, background_tasks: BackgroundTasks): """ + objects.items.transparent: + "transparent":{ + "mask_url":"test/transparent_test/transparent_mask.png", + "scale":0.1 + }, + mask_url 为空"" -> 单件衣服透明 + mask_url 非空"mask_url" -> 区域透明 + 创建一个具有以下参数的请求体: 示例参数: { From f21b7203bb1a7e7c9a6ca2280413a75f3dd8b756 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 20 Jan 2025 11:26:42 +0800 Subject: [PATCH 33/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20?= =?UTF-8?q?=E5=85=B3=E9=94=AE=E7=82=B9=E6=A8=A1=E5=9E=8B=E9=A2=84=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=90=8E=E5=A4=84=E7=90=86=E5=A2=9E=E5=8A=A0=E7=99=BD?= =?UTF-8?q?=E8=BE=B9=E6=A1=86=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98?= =?UTF-8?q?=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84?= =?UTF-8?q?=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/utils/design_ensemble.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index bfc50c6..8eef4f2 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -25,6 +25,7 @@ from app.core.config import * def keypoint_preprocess(img_path): img = mmcv.imread(img_path) + img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) img_scale = (256, 256) h, w = img.shape[:2] img = cv2.resize(img, img_scale) @@ -62,7 +63,11 @@ def keypoint_postprocess(output, scale_factor): scale_matrix = np.diag(scale_factor) nan = np.isinf(scale_matrix) scale_matrix[nan] = 0 - return np.ceil(np.dot(segment_result, scale_matrix) * 4) + # 应用缩放因子 + scaled_result = np.ceil(np.dot(segment_result, scale_matrix) * 4) + # 补偿边框偏移 + compensated_result = scaled_result - 25 + return compensated_result """ From 07e09dc99d58ce9d8c816561da9d166f39e5237e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Feb 2025 14:33:06 +0800 Subject: [PATCH 34/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 246 +++++++++++++++--- 1 file changed, 215 insertions(+), 31 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 1c20a13..681e2b5 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -1,4 +1,196 @@ -#!/usr/bin/env python +# #!/usr/bin/env python +# # -*- coding: UTF-8 -*- +# """ +# @Project :trinity_client +# @File :service_att_recognition.py +# @Author :周成融 +# @Date :2023/7/26 12:01:05 +# @detail : +# """ +# import json +# import logging +# import time +# +# import cv2 +# import numpy as np +# import redis +# import tritonclient.grpc as grpcclient +# from PIL import Image +# from tritonclient.utils import np_to_triton_dtype +# +# from app.core.config import * +# from app.schemas.generate_image import GenerateProductImageModel +# from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +# from app.service.utils.oss_client import oss_get_image +# +# logger = logging.getLogger() +# +# +# class GenerateProductImage: +# def __init__(self, request_data): +# if DEBUG is False: +# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) +# self.channel = self.connection.channel() +# # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) +# # self.channel = self.connection.channel() +# # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) +# self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) +# self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) +# self.category = "product_image" +# self.image_strength = request_data.image_strength +# self.batch_size = 1 +# self.product_type = request_data.product_type +# self.prompt = request_data.prompt +# self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) +# self.tasks_id = request_data.tasks_id +# self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] +# self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# self.redis_client.expire(self.tasks_id, 600) +# +# def callback(self, result, error): +# if error: +# self.gen_product_data['status'] = "FAILURE" +# self.gen_product_data['message'] = str(error) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# else: +# # pil图像转成numpy数组 +# image = result.as_numpy("generated_inpaint_image") +# image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) +# cropped_image = post_processing_image(image_result, self.left, self.top) +# image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") +# self.gen_product_data['status'] = "SUCCESS" +# self.gen_product_data['message'] = "success" +# self.gen_product_data['image_url'] = str(image_url) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# +# def read_tasks_status(self): +# status_data = self.redis_client.get(self.tasks_id) +# return json.loads(status_data), status_data +# +# def get_result(self): +# try: +# prompts = [self.prompt] * self.batch_size +# self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) +# self.image = cv2.resize(self.image, (1024, 1024)) +# images = [self.image.astype(np.uint8)] * self.batch_size +# +# if self.product_type == "single": +# text_obj = np.array(prompts, dtype="object").reshape(-1, 1) +# image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) +# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) +# else: +# text_obj = np.array(prompts, dtype="object").reshape((1)) +# image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) +# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) +# +# # 假设 prompts、images 和 self.image_strength 已经定义 +# +# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) +# input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") +# input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) +# +# input_text.set_data_from_numpy(text_obj) +# input_image.set_data_from_numpy(image_obj) +# input_image_strength.set_data_from_numpy(image_strength_obj) +# +# inputs = [input_text, input_image, input_image_strength] +# +# if self.product_type == "single": +# ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) +# else: +# ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) +# +# time_out = 600 +# while time_out > 0: +# gen_product_data, _ = self.read_tasks_status() +# if gen_product_data['status'] in ["REVOKED", "FAILURE"]: +# ctx.cancel() +# break +# elif gen_product_data['status'] == "SUCCESS": +# break +# time_out -= 1 +# time.sleep(0.1) +# gen_product_data, _ = self.read_tasks_status() +# return gen_product_data +# except Exception as e: +# self.gen_product_data['status'] = "FAILURE" +# self.gen_product_data['message'] = str(e) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# raise Exception(str(e)) +# finally: +# dict_gen_product_data, str_gen_product_data = self.read_tasks_status() +# if DEBUG is False: +# self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) +# logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") +# +# +# def infer_cancel(tasks_id): +# redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) +# data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} +# gen_product_data = json.dumps(data) +# redis_client.set(tasks_id, gen_product_data) +# return data +# +# +# def pre_processing_image(image_url): +# image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") +# # resize 原图至1024*1024 +# image = image.resize((int(1024 / image.height * image.width), 1024)) +# +# # 原始图片的尺寸 +# width, height = image.size +# +# new_height, new_width = 1024, 1024 +# # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 +# pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) +# +# # 将原始图片粘贴到新的画布中心 +# left = (new_width - width) // 2 +# top = (new_height - height) // 2 +# pad_image.paste(image, (left, top)) +# +# # 将画布 resize 成宽度 1024,长度 1024 +# resized_image = pad_image.resize((1024, 1024)) +# image_size = (1024, 1024) +# +# if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): +# # 创建白色背景 +# background = Image.new("RGB", image_size, (255, 255, 255)) +# # 将图片粘贴到白色背景上 +# background.paste(resized_image, mask=resized_image.split()[3]) +# image = np.array(background) +# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +# return image, image_size, left, top +# +# +# def post_processing_image(image, left, top): +# resized_image = image.resize((int(image.width * (768 / image.height)), 768)) +# # 计算裁剪的坐标 +# left = (resized_image.width - 512) // 2 +# upper = 0 +# right = left + 512 +# lower = 768 +# +# # 进行裁剪 +# cropped_image = resized_image.crop((left, upper, right, lower)) +# return cropped_image +# +# +# if __name__ == '__main__': +# rd = GenerateProductImageModel( +# tasks_id="123-89", +# # prompt="", +# image_strength=0.7, +# prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", +# image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", +# product_type="overall" +# ) +# server = GenerateProductImage(rd) +# print(server.get_result()) + +# 旧版product +# !/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @@ -34,14 +226,14 @@ class GenerateProductImage: # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) + self.grpc_client = grpcclient.InferenceServerClient(url="10.1.1.243:18001") self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" self.image_strength = request_data.image_strength self.batch_size = 1 self.product_type = request_data.product_type self.prompt = request_data.prompt - self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) + self.image = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} @@ -56,9 +248,9 @@ class GenerateProductImage: else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") - image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) - cropped_image = post_processing_image(image_result, self.left, self.top) - image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + # cropped_image = post_processing_image(image_result, self.left, self.top) + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -72,7 +264,7 @@ class GenerateProductImage: try: prompts = [self.prompt] * self.batch_size self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) - self.image = cv2.resize(self.image, (1024, 1024)) + # self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size if self.product_type == "single": @@ -81,7 +273,7 @@ class GenerateProductImage: image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) else: text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) # 假设 prompts、images 和 self.image_strength 已经定义 @@ -99,7 +291,7 @@ class GenerateProductImage: if self.product_type == "single": ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) else: - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: @@ -135,33 +327,25 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") - # resize 原图至1024*1024 - image = image.resize((int(1024 / image.height * image.width), 1024)) - - # 原始图片的尺寸 + # 调整图片高度为768像素,保持宽高比 width, height = image.size + new_height = 768 + new_width = int(width * (new_height / height)) + resized_image = image.resize((new_width, new_height)) - new_height, new_width = 1024, 1024 - # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 - pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) + # 创建一个512x768的透明图片 + result_image = Image.new("RGBA", (512, 768), (0, 0, 0, 0)) - # 将原始图片粘贴到新的画布中心 - left = (new_width - width) // 2 - top = (new_height - height) // 2 - pad_image.paste(image, (left, top)) + # 计算需要粘贴的位置,使图片居中 + x_offset = (512 - new_width) // 2 + y_offset = 0 - # 将画布 resize 成宽度 1024,长度 1024 - resized_image = pad_image.resize((1024, 1024)) - image_size = (1024, 1024) + # 将调整大小后的图片粘贴到透明图片上 + result_image.paste(resized_image, (x_offset, y_offset)) - if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): - # 创建白色背景 - background = Image.new("RGB", image_size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(resized_image, mask=resized_image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image, image_size, left, top + image = np.array(result_image) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image def post_processing_image(image, left, top): From e1c00bbd669a00028003e0f0388786714399ae64 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Feb 2025 19:06:49 +0800 Subject: [PATCH 35/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_image/service_generate_product_image.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 681e2b5..0507953 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -334,17 +334,18 @@ def pre_processing_image(image_url): resized_image = image.resize((new_width, new_height)) # 创建一个512x768的透明图片 - result_image = Image.new("RGBA", (512, 768), (0, 0, 0, 0)) + result_image = Image.new("RGBA", (512, 768), (255, 255, 255, 255)) # 计算需要粘贴的位置,使图片居中 x_offset = (512 - new_width) // 2 y_offset = 0 # 将调整大小后的图片粘贴到透明图片上 - result_image.paste(resized_image, (x_offset, y_offset)) + result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3]) image = np.array(result_image) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) return image @@ -366,8 +367,8 @@ if __name__ == '__main__': tasks_id="123-89", # prompt="", image_strength=0.7, - prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", - image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", + prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR", + image_url="aida-users/11633/toProductImageElement/46166c36-c584-4e0f-b9fe-50615ec03ef3.png", product_type="overall" ) server = GenerateProductImage(rd) From 74008c4586d6d486584d4f409f86672470bbb970 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Feb 2025 19:13:48 +0800 Subject: [PATCH 36/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_product_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 0507953..22f7306 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -263,7 +263,8 @@ class GenerateProductImage: def get_result(self): try: prompts = [self.prompt] * self.batch_size - self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + + self.image = cv2.cvtColor(self.image, cv2.COLOR_RGBA2RGB) # self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size From 48ae1cfb75ec5329dc22c0aafc46c2fcf08f799c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Feb 2025 19:36:56 +0800 Subject: [PATCH 37/37] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 22f7306..287a983 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -247,7 +247,10 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") + if self.product_type == "single": + image = result.as_numpy("generated_cnet_image") + else: + image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) # cropped_image = post_processing_image(image_result, self.left, self.top) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -269,9 +272,9 @@ class GenerateProductImage: images = [self.image.astype(np.uint8)] * self.batch_size if self.product_type == "single": - text_obj = np.array(prompts, dtype="object").reshape(-1, 1) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1)) else: text_obj = np.array(prompts, dtype="object").reshape((1)) image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) @@ -290,7 +293,7 @@ class GenerateProductImage: inputs = [input_text, input_image, input_image_strength] if self.product_type == "single": - ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name="stable_diffusion_1_5_cnet", inputs=inputs, callback=self.callback) else: ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback) @@ -369,8 +372,8 @@ if __name__ == '__main__': # prompt="", image_strength=0.7, prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR", - image_url="aida-users/11633/toProductImageElement/46166c36-c584-4e0f-b9fe-50615ec03ef3.png", - product_type="overall" + image_url="aida-results/result_40c7924e-e220-11ef-8ea2-0242ac150003.png", + product_type="single" ) server = GenerateProductImage(rd) print(server.get_result())