From ed908d04722cd363ac1fdf0916058b576e91527c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 19 Jul 2024 15:10:28 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=20=E6=B5=8B=E8=AF=95=E6=89=B9=E9=87=8F?= =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E5=9B=BE=E7=89=87=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + app/api/api_design.py | 17 +++--- app/core/config.py | 4 +- app/service/design/items/clothing.py | 5 +- .../design/items/pipelines/keypoints.py | 3 +- app/service/design/items/pipelines/loading.py | 2 + .../design/items/pipelines/painting.py | 6 +- app/service/design/items/pipelines/scale.py | 3 +- .../design/items/pipelines/segmentation.py | 34 ++++++++++- app/service/design/items/pipelines/split.py | 4 +- app/service/design/service.py | 54 +++++++++++++++-- app/service/design/utils/upload_image.py | 60 +++++++++++-------- app/service/utils/decorator.py | 15 ++++- 13 files changed, 160 insertions(+), 48 deletions(-) diff --git a/.gitignore b/.gitignore index 1bf82fb..87a4934 100644 --- a/.gitignore +++ b/.gitignore @@ -120,6 +120,7 @@ dmypy.json #runtime produce test +seg_cache logs seg_result/ seg_result diff --git a/app/api/api_design.py b/app/api/api_design.py index 5ce6096..d4537c1 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -136,13 +136,16 @@ def design(request_data: DesignModel): "process_id": "89" } """ - try: - logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") - data = generate(request_data=request_data) - logger.info(f"design response @@@@@@:{json.dumps(data)}") - except Exception as e: - logger.warning(f"design Run Exception @@@@@@:{e}") - raise HTTPException(status_code=404, detail=str(e)) + logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") + data = generate(request_data=request_data) + logger.info(f"design response @@@@@@:{json.dumps(data)}") + # try: + # logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") + # data = generate(request_data=request_data) + # logger.info(f"design response @@@@@@:{json.dumps(data)}") + # except Exception as e: + # logger.warning(f"design Run Exception @@@@@@:{e}") + # raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) diff --git a/app/core/config.py b/app/core/config.py index 8b7e7e8..0e32724 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -24,11 +24,11 @@ DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" - # FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" + SEG_CACHE_PATH = "../seg_cache/" else: LOGS_PATH = "app/logs/" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" - # FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' + SEG_CACHE_PATH = "/seg_cache/" # RABBITMQ_ENV = "" # 生产环境 RABBITMQ_ENV = "-dev" # 开发环境 diff --git a/app/service/design/items/clothing.py b/app/service/design/items/clothing.py index f9f9561..7dd845b 100644 --- a/app/service/design/items/clothing.py +++ b/app/service/design/items/clothing.py @@ -37,7 +37,8 @@ class Clothing(object): resize_scale=self.result["resize_scale"], mask=cv2.resize(self.result['mask'], self.result["front_image"].size), gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", - pattern_image_url=self.result['pattern_image_url'] + pattern_image_url=self.result['pattern_image_url'], + pattern_image=self.result['pattern_image'] ) layer.insert(front_layer) @@ -54,7 +55,7 @@ class Clothing(object): resize_scale=self.result["resize_scale"], mask=cv2.resize(self.result['mask'], self.result["front_image"].size), gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", - pattern_image_url=self.result['pattern_image_url'] + pattern_image_url=self.result['pattern_image_url'], ) layer.insert(back_layer) diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 1f53ced..1a264d6 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -5,6 +5,7 @@ import numpy as np from pymilvus import MilvusClient from app.core.config import * +from app.service.utils.decorator import RunTime, ClassCallRunTime from ..builder import PIPELINES from ...utils.design_ensemble import get_keypoint_result @@ -27,7 +28,7 @@ class KeypointDetection(object): # self.client.close() # print(f"client close time : {time.time() - start_time}") - # @ RunTime + @ ClassCallRunTime def __call__(self, result): # logging.info("KeypointDetection run ") if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 diff --git a/app/service/design/items/pipelines/loading.py b/app/service/design/items/pipelines/loading.py index d792646..ad5aec2 100644 --- a/app/service/design/items/pipelines/loading.py +++ b/app/service/design/items/pipelines/loading.py @@ -1,5 +1,6 @@ import cv2 +from app.service.utils.decorator import RunTime, ClassCallRunTime from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES @@ -12,6 +13,7 @@ class LoadImageFromFile(object): self.print_dict = print_dict # self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + @ClassCallRunTime def __call__(self, result): result['image'], result['pre_mask'] = self.read_image(self.path) result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 0fd2897..8e4e524 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -4,6 +4,7 @@ import cv2 import numpy as np from PIL import Image +from app.service.utils.decorator import RunTime, ClassCallRunTime from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES @@ -13,7 +14,7 @@ class Painting(object): def __init__(self, painting_flag=True): self.painting_flag = painting_flag - # @ RunTime + @ClassCallRunTime def __call__(self, result): if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none': dim_image_h, dim_image_w = result['image'].shape[0:2] @@ -86,7 +87,7 @@ class PrintPainting(object): def __init__(self, print_flag=True): self.print_flag = print_flag - # @ RunTime + @ClassCallRunTime def __call__(self, result): single_print = result['print']['single'] overall_print = result['print']['overall'] @@ -236,7 +237,6 @@ class PrintPainting(object): print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) - print(1) else: mask = self.get_mask_inv(image) mask = np.expand_dims(mask, axis=2) diff --git a/app/service/design/items/pipelines/scale.py b/app/service/design/items/pipelines/scale.py index d101530..43604cb 100644 --- a/app/service/design/items/pipelines/scale.py +++ b/app/service/design/items/pipelines/scale.py @@ -2,6 +2,7 @@ import math import cv2 +from app.service.utils.decorator import ClassCallRunTime from ..builder import PIPELINES @@ -10,7 +11,7 @@ class Scaling(object): def __init__(self): pass - # @ RunTime + @ClassCallRunTime def __call__(self, result): if result['keypoint'] in ['waistband', 'shoulder', 'head_point']: # milvus_db_keypoint_cache diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index d9f8ac0..4e6e0d0 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -1,3 +1,9 @@ +import os + +import numpy as np + +from app.core.config import SEG_CACHE_PATH +from app.service.utils.decorator import ClassCallRunTime from ..builder import PIPELINES from ...utils.design_ensemble import get_seg_result @@ -9,6 +15,32 @@ class Segmentation(object): self.device = device self.debug = debug + @ClassCallRunTime def __call__(self, result): - result['seg_result'] = get_seg_result(result["image_id"], result['image']) + _, seg_result = self.load_seg_result(result["image_id"]) + if not _: + result['seg_result'] = get_seg_result(result["image_id"], result['image']) + self.save_seg_result(result['seg_result'][0], result['image_id']) return result + + @staticmethod + def save_seg_result(seg_result, image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" + try: + np.save(file_path, seg_result) + print("保存成功", os.path.abspath(file_path)) + except Exception as e: + print(f"保存失败: {e}") + + @staticmethod + def load_seg_result(image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" + try: + seg_result = np.load(file_path) + return True, seg_result + except FileNotFoundError: + print("文件不存在") + return False, None + except Exception as e: + print(f"加载失败: {e}") + return False, None diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index efa20e4..dd9becb 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -5,6 +5,7 @@ import numpy as np from PIL import Image from cv2 import cvtColor, COLOR_BGR2RGBA +from app.service.utils.decorator import ClassCallRunTime from app.service.utils.generate_uuid import generate_uuid from ..builder import PIPELINES from ...utils.conversion_image import rgb_to_rgba @@ -17,6 +18,7 @@ class Split(object): Split image into front and back layer according to the segmentation result """ + @ClassCallRunTime # KNet def __call__(self, result): try: @@ -66,7 +68,7 @@ class Split(object): # 创建中间图层 result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask']) result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA)) - _, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}') + result['pattern'], result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}') return result except Exception as e: logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design/service.py b/app/service/design/service.py index 54cb45b..3d71326 100644 --- a/app/service/design/service.py +++ b/app/service/design/service.py @@ -1,4 +1,8 @@ import concurrent.futures +import io + +import cv2 +from PIL import Image from app.core.config import PRIORITY_DICT from app.service.design.core.layer import Layer @@ -6,6 +10,7 @@ from app.service.design.items import build_item from app.service.design.utils.redis_utils import Redis from app.service.design.utils.synthesis_item import synthesis, synthesis_single from app.service.utils.decorator import RunTime +from app.service.utils.oss_client import oss_upload_image def process_item(item, layers): @@ -43,6 +48,7 @@ def final_progress(process_id): @RunTime def generate(request_data): return_response = {} + return_png_mask = [] request_data = request_data.dict() assert "process_id" in request_data.keys(), "Need process_id parameters" @@ -55,14 +61,15 @@ def generate(request_data): # 获取处理结果 for future in concurrent.futures.as_completed(futures): obj = futures[future] - - result = future.result() - return_response[obj] = result + return_response[obj] = future.result()[0] + return_png_mask.extend(future.result()[1]) final_progress(process_id) + upload_results = process_images(return_png_mask) return return_response def process_object(cfg, process_id, total): + uploaded_images = [] basic_info = cfg.get('basic') items_response = { 'layers': [] @@ -83,6 +90,15 @@ def process_object(cfg, process_id, total): layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf'))) else: layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf'))) + # 上传所有图片 + for layer in layers: + if 'image' in layer.keys() and layer['image'] is not None: + uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url']}) + if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None: + uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url']}) + if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None: + uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url']}) + # 合成 items_response['synthesis_url'] = synthesis(layers, body_size) @@ -131,4 +147,34 @@ def process_object(cfg, process_id, total): items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image']) break update_progress(process_id, total) - return items_response + return items_response, uploaded_images + + +@RunTime +def process_images(images): + with concurrent.futures.ThreadPoolExecutor() as executor: + results = list(executor.map(upload_images, images)) + # results = [] + # for image in images: + # results.append(upload_images(image)) + return results + + +@RunTime +def upload_images(image_obj): + bucket_name = image_obj['image_url'].split("/", 1)[0] + object_name = image_obj['image_url'].split("/", 1)[1] + if isinstance(image_obj['image_obj'], Image.Image): + image_data = io.BytesIO() + image_obj['image_obj'].save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return image_obj['image_url'] + else: + mask_inverted = cv2.bitwise_not(image_obj['image_obj']) + # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1]) + return image_obj['image_url'] diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py index 3571816..610c188 100644 --- a/app/service/design/utils/upload_image.py +++ b/app/service/design/utils/upload_image.py @@ -13,33 +13,43 @@ import logging import cv2 from app.core.config import * +from app.service.utils.decorator import RunTime from app.service.utils.oss_client import oss_upload_image # @RunTime -def upload_png_mask(front_image, object_name, mask=None): - try: - mask_url = None - if mask is not None: - mask_inverted = cv2.bitwise_not(mask) - # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 - rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) - rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - # image_bytes = io.BytesIO() - # image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - # image_bytes.seek(0) - # mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" - # oss upload #################### - req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) - mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" +# def upload_png_mask(front_image, object_name, mask=None): +# try: +# mask_url = None +# if mask is not None: +# mask_inverted = cv2.bitwise_not(mask) +# # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 +# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) +# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] +# # image_bytes = io.BytesIO() +# # image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) +# # image_bytes.seek(0) +# # mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" +# # oss upload #################### +# req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) +# mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" +# +# image_data = io.BytesIO() +# front_image.save(image_data, format='PNG') +# image_data.seek(0) +# image_bytes = image_data.read() +# # image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" +# req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) +# image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" +# return front_image, image_url, mask_url +# except Exception as e: +# logging.warning(f"upload_png_mask runtime exception : {e}") - image_data = io.BytesIO() - front_image.save(image_data, format='PNG') - image_data.seek(0) - image_bytes = image_data.read() - # image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) - image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" - return front_image, image_url, mask_url - except Exception as e: - logging.warning(f"upload_png_mask runtime exception : {e}") + +@RunTime +def upload_png_mask(front_image, object_name, mask=None): + mask_url = None + if mask is not None: + mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" + image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" + return front_image, image_url, mask_url diff --git a/app/service/utils/decorator.py b/app/service/utils/decorator.py index 294b54b..fcf8666 100644 --- a/app/service/utils/decorator.py +++ b/app/service/utils/decorator.py @@ -1,5 +1,5 @@ -import time import logging +import time def RunTime(func): @@ -12,3 +12,16 @@ def RunTime(func): return res return wrapper + + +def ClassCallRunTime(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + class_name = args[0].__class__.__name__ # 获取类名 + print(f"class name: {class_name} , run time is : {execution_time} s") + return result + + return wrapper