From 0e5f1ae1fadd3781945501788f260b199b88fa83 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 8 Jan 2025 15:44:59 +0800 Subject: [PATCH 1/8] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20sketch=20=E5=A4=9A=E8=A7=86=E8=A7=92=E5=9B=BE?= =?UTF-8?q?=E7=94=9F=E6=88=90=E5=8A=9F=E8=83=BD=E6=8E=A5=E5=8F=A3=20fix?= =?UTF-8?q?=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 41 +++++- app/core/config.py | 6 + app/schemas/generate_image.py | 5 + .../service_generate_multi_view.py | 126 ++++++++++++++++++ app/service/utils/new_oss_client.py | 2 +- 5 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 app/service/generate_image/service_generate_multi_view.py diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 53790a3..a37bec3 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,9 +3,10 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel +from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel @@ -61,6 +62,44 @@ def generate_image(tasks_id: str): return ResponseModel(data=data['data']) +'''multi view''' + + +@router.post("/generate_multi_view") +def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **image_url**: 前视角图的输入,minio或S3 url 地址 + + 示例参数: + { + "tasks_id": "123-89", + "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg" + } + """ + try: + logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = GenerateMultiView(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_multi_view_cancel/{tasks_id}") +def generate_multi_view(tasks_id: str): + try: + logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") + data = generate_multi_view_cancel(tasks_id) + logger.info(f"generate_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) + + '''single logo''' diff --git a/app/core/config.py b/app/core/config.py index 1d11a0b..7456912 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -109,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl' GI_MODEL_URL = '10.1.1.240:10061' GI_MODEL_NAME = 'flux' +GMV_MODEL_URL = '10.1.1.243:10081' +GMV_MODEL_NAME = 'multi_view' + +GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}") + + GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 11e295f..7181418 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -1,6 +1,11 @@ from pydantic import BaseModel +class GenerateMultiViewModel(BaseModel): + tasks_id: str + image_url: str + + class GenerateImageModel(BaseModel): tasks_id: str prompt: str diff --git a/app/service/generate_image/service_generate_multi_view.py b/app/service/generate_image/service_generate_multi_view.py new file mode 100644 index 0000000..c930ab2 --- /dev/null +++ b/app/service/generate_image/service_generate_multi_view.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time + +import numpy as np +import redis +import tritonclient.grpc as grpcclient + +from app.core.config import * +from app.schemas.generate_image import GenerateMultiViewModel +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() + + +class GenerateMultiView: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL) + + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.image = self.get_image(request_data.image_url) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + try: + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + return image + except Exception as e: + logger.error(e) + + def callback(self, result, error): + if error: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(error) + # self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: + # pil图像转成numpy数组 + images = result.as_numpy("generated_image") + # for id, img in enumerate(images): + # cv2.imwrite(f"{id}.png", img) + # image_url = "" + image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + images = [np.array(self.image).astype(np.uint8)] * 1 + + image_obj = np.array(images, dtype=np.uint8) + + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + + input_image.set_data_from_numpy(image_obj) + + inputs = [input_image] + ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback) + + time_out = 600 + generate_data = None + while time_out > 0: + generate_data, _ = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + return generate_data + except Exception as e: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data) + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateMultiViewModel( + tasks_id="123-89", + image_url="aida-sys-image/images/female/outwear/0628000123.jpg", + ) + server = GenerateMultiView(rd) + print(server.get_result()) diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 4b3cbb1..7939333 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/89/test/123-89.png" + url = "aida-users/89/123-89.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From 9e1a1996c7e747b468713eabbe65233bbc163c4b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:02:28 +0800 Subject: [PATCH 2/8] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20RunT?= =?UTF-8?q?ime=E5=88=A4=E5=AE=9A=E4=BF=AE=E6=94=B9=E4=B8=BA0.05=E7=A7=92?= =?UTF-8?q?=E5=86=85=E8=A7=A6=E5=8F=91=E6=89=93=E5=8D=B0=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/utils/decorator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/app/service/utils/decorator.py b/app/service/utils/decorator.py index 3e86182..bc7e9f1 100644 --- a/app/service/utils/decorator.py +++ b/app/service/utils/decorator.py @@ -7,9 +7,9 @@ def RunTime(func): t1 = time.time() res = func(*args, **kwargs) t2 = time.time() - # if t2 - t1 > 0.05: - # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") - logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + if t2 - t1 > 0.05: + logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") return res return wrapper From f6b2e283f93eedb903b63a039950cfda7e727298 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:04:13 +0800 Subject: [PATCH 3/8] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20RunT?= =?UTF-8?q?ime=E5=88=A4=E5=AE=9A=E4=BF=AE=E6=94=B9=E4=B8=BA0.05=E7=A7=92?= =?UTF-8?q?=E5=86=85=E8=A7=A6=E5=8F=91=E6=89=93=E5=8D=B0=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/segmentation.py | 2 +- app/service/design_batch/pipeline/segmentation.py | 2 +- app/service/design_fast/pipeline/segmentation.py | 2 +- app/service/design_pre_processing/service.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index 19eb1fd..abd30e6 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -64,7 +64,7 @@ class Segmentation(object): seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index cba3446..8ebcd3a 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -63,7 +63,7 @@ class Segmentation: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index ebf02b4..5e392ed 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -78,7 +78,7 @@ class Segmentation: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index 16ca870..e6dc951 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -266,7 +266,7 @@ class DesignPreprocessing: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logging.info("文件不存在") + # logging.info("文件不存在") return False, None except Exception as e: logging.warning(f"加载失败: {e}") From 48cfb4c1caa69ae2100d956fbd25e7fb56d323ef Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:07:06 +0800 Subject: [PATCH 4/8] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20?= =?UTF-8?q?=E5=85=B3=E9=97=AD=E8=B0=83=E8=AF=95=E6=97=A5=E5=BF=97=E6=89=93?= =?UTF-8?q?=E5=8D=B0=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_batch/pipeline/segmentation.py | 2 +- app/service/design_fast/design_generate.py | 2 +- app/service/design_fast/pipeline/segmentation.py | 2 +- app/service/utils/decorator.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index 8ebcd3a..ca7da1c 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -58,7 +58,7 @@ class Segmentation: @staticmethod def load_seg_result(image_id): file_path = f"seg_cache/{image_id}.npy" - logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) return True, seg_result diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index 4fc0726..2f7fa93 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -207,7 +207,7 @@ def design_generate_v2(request_data): 'Connection': "keep-alive", 'Content-Type': "application/json" } - logger.info(items_response) + # logger.info(items_response) response = post_request(url, json_data=items_response, headers=headers) if response: # 打印结果 diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 5e392ed..4828b33 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -73,7 +73,7 @@ class Segmentation: @staticmethod def load_seg_result(image_id): file_path = f"{SEG_CACHE_PATH}{image_id}.npy" - logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) return True, seg_result diff --git a/app/service/utils/decorator.py b/app/service/utils/decorator.py index bc7e9f1..c0164ab 100644 --- a/app/service/utils/decorator.py +++ b/app/service/utils/decorator.py @@ -22,7 +22,8 @@ def ClassCallRunTime(func): end_time = time.time() execution_time = end_time - start_time class_name = args[0].__class__.__name__ # 获取类名 - print(f"class name: {class_name} , run time is : {execution_time} s") + if execution_time > 0.05: + logging.info(f"class name: {class_name} , run time is : {execution_time} s") return result return wrapper From 72a8f3f920c8bda553cb191bb7e68d40bff3f581 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:08:18 +0800 Subject: [PATCH 5/8] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20?= =?UTF-8?q?=E5=85=B3=E9=97=AD=E8=B0=83=E8=AF=95=E6=97=A5=E5=BF=97=E6=89=93?= =?UTF-8?q?=E5=8D=B0=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 5c15efe..2a5ad4d 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): """ try: for item in request_item: - logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") + logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") if DEBUG: service = AttributeRecognition(const=local_debug_const, request_data=request_item) else: service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() - logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") + logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) From f36edbe24817e2e26c8cca7c249dfc77700fcb43 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 10 Jan 2025 15:09:59 +0800 Subject: [PATCH 6/8] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20?= =?UTF-8?q?=E5=85=B3=E9=97=AD=E8=B0=83=E8=AF=95=E6=97=A5=E5=BF=97=E6=89=93?= =?UTF-8?q?=E5=8D=B0=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/segmentation.py | 2 +- app/service/design_batch/pipeline/segmentation.py | 2 +- app/service/design_fast/pipeline/segmentation.py | 2 +- app/service/design_pre_processing/service.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index abd30e6..7ed43e5 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -53,7 +53,7 @@ class Segmentation(object): file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index ca7da1c..aa05c0d 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -51,7 +51,7 @@ class Segmentation: file_path = f"seg_cache/{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 4828b33..1c3f5a0 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -66,7 +66,7 @@ class Segmentation: file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index e6dc951..636360c 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -277,7 +277,7 @@ class DesignPreprocessing: file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logging.info(f"保存成功,{os.path.abspath(file_path)}") + logging.debug(f"保存成功,{os.path.abspath(file_path)}") except Exception as e: logging.warning(f"保存失败: {e}") From 9c52811c050a1b6e373e08bf9fd3700e2bb098b1 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 13 Jan 2025 15:41:37 +0800 Subject: [PATCH 7/8] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20desi?= =?UTF-8?q?gn=20=E5=88=86=E5=89=B2=E9=A2=84=E5=A4=84=E7=90=86=E6=96=B0?= =?UTF-8?q?=E5=A2=9E25padding=EF=BC=8C=E5=90=8E=E5=A4=84=E7=90=86=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=8F=92=E5=80=BC=E5=A4=84=E7=90=86=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/segmentation.py | 10 +++++----- app/service/design_fast/utils/design_ensemble.py | 10 ++++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 1c3f5a0..0c9c51e 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -36,11 +36,11 @@ class Segmentation: # preview 过模型 不缓存 if "preview_submit" in result.keys() and result['preview_submit'] == "preview": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) # submit 过模型 缓存 elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) self.save_seg_result(seg_result, result['image_id']) # null 正常流程 加载本地缓存 无缓存则过模型 else: @@ -49,14 +49,14 @@ class Segmentation: # 判断缓存和实际图片size是否相同 if not _ or result["image"].shape[:2] != seg_result.shape: # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) self.save_seg_result(seg_result, result['image_id']) result['seg_result'] = seg_result # 处理前片后片 - temp_front = seg_result == 1.0 + temp_front = seg_result == 1 result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) - temp_back = seg_result == 2.0 + temp_back = seg_result == 2 result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) result['mask'] = result['front_mask'] + result['back_mask'] return result diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index 267ea00..9f30d0c 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -13,7 +13,6 @@ import cv2 import mmcv import numpy as np import torch -import torch.nn.functional as F import tritonclient.http as httpclient from app.core.config import * @@ -85,6 +84,9 @@ def seg_preprocess(img_path): if ori_shape != (img_scale_w, img_scale_h): # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) + + # 扩充25的白边 + img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape @@ -114,9 +116,9 @@ def get_seg_result(image_id, image): # no cache def seg_postprocess(image_id, output, ori_shape): - seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) - seg_pred = seg_logit.cpu().numpy() - return seg_pred[0] + seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50)) + seg_logit = seg_logit[25: - 25, 25: - 25] + return seg_logit def key_point_show(image_path, key_point_result=None): From 94000ceb21c70fd575a1bae97e1e857c9fd04874 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 13 Jan 2025 15:58:58 +0800 Subject: [PATCH 8/8] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20desi?= =?UTF-8?q?gn=20=E5=88=86=E5=89=B2=E9=A2=84=E5=A4=84=E7=90=86=E6=96=B0?= =?UTF-8?q?=E5=A2=9E25padding=EF=BC=8C=E5=90=8E=E5=A4=84=E7=90=86=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=8F=92=E5=80=BC=E5=A4=84=E7=90=86=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/utils/design_ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index 9f30d0c..bfc50c6 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -87,7 +87,7 @@ def seg_preprocess(img_path): # 扩充25的白边 img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) - # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape