From dd0781b9aee1b25c92028e6fbe085545cf0ffd7a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 16:42:33 +0800 Subject: [PATCH 01/21] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 8 ++-- .../design/items/pipelines/keypoints.py | 39 ++++++++----------- 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 651dd8b..b293cef 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -23,11 +23,11 @@ DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" - FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" + # FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" else: LOGS_PATH = "app/logs/" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" - FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' + # FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' # RABBITMQ_ENV = "" # 生产环境 # RABBITMQ_ENV = "-dev" # 开发环境 @@ -60,9 +60,9 @@ RABBITMQ_PARAMS = { } # milvus 配置 -MILVUS_DB_HOST = "10.1.1.240" +MILVUS_URL = "http://10.1.1.240:19530http://127.0.0.1:8000/docs#/design/design_api_design_post" +MILVUS_TOKEN = "root:Milvus" MILVUS_ALIAS = "default" -MILVUS_PORT = "19530" MILVUS_TABLE_KEYPOINT = "keypoint_cache" MILVUS_TABLE_SEG = "seg_cache" diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 4d0a081..4a9e4d1 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -14,17 +14,17 @@ class KeypointDetection(object): path here: abstract path """ - def __init__(self): - self.client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) + # def __init__(self): + # self.client = MilvusClient( + # uri="http://10.1.1.240:19530", + # token="root:Milvus", + # db_name=MILVUS_ALIAS + # ) - def __del__(self): - # start_time = time.time() - self.client.close() - # print(f"client close time : {time.time() - start_time}") + # def __del__(self): + # start_time = time.time() + # self.client.close() + # print(f"client close time : {time.time() - start_time}") # @ RunTime def __call__(self, result): @@ -69,24 +69,19 @@ class KeypointDetection(object): "keypoint_vector": result.tolist() } ] - client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) start_time = time.time() res = client.upsert( collection_name=MILVUS_TABLE_KEYPOINT, data=data, ) # logging.info(f"save keypoint time : {time.time() - start_time}") + client.close() return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - finally: - client.close() @staticmethod def update_keypoint_cache(keypoint_id, infer_result, search_result, site): @@ -102,12 +97,9 @@ class KeypointDetection(object): "keypoint_vector": result.tolist() } ] - client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) start_time = time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. @@ -125,8 +117,9 @@ class KeypointDetection(object): # @ RunTime def keypoint_cache(self, result, site): try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) keypoint_id = result['image_id'] - res = self.client.query( + res = client.query( collection_name=MILVUS_TABLE_KEYPOINT, # ids=[keypoint_id], filter=f"keypoint_id == {keypoint_id}", From e29bed20f7983e804998e9da784d15eff088e175 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 16:59:08 +0800 Subject: [PATCH 02/21] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 267b796..89a5e3f 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,5 +1,7 @@ import json import logging +import os + from fastapi import APIRouter, HTTPException from app.schemas.attribute_retrieve import * @@ -17,6 +19,8 @@ logger = logging.getLogger() def attribute_recognition(request_item: list[AttributeRecognitionModel]): try: logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") + logger.info(const.top_description_list) + logger.info(os.getcwd()) service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") From 557e3cd1007ab9e23513066a43ca1fc571c55a72 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:02:26 +0800 Subject: [PATCH 03/21] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/attribute/config/const.py | 64 +++++++++++++-------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/app/service/attribute/config/const.py b/app/service/attribute/config/const.py index 24d9412..738e486 100644 --- a/app/service/attribute/config/const.py +++ b/app/service/attribute/config/const.py @@ -1,13 +1,13 @@ -top_description_list = ['service/attribute/config/descriptor/top/length.csv', - 'service/attribute/config/descriptor/top/type.csv', - 'service/attribute/config/descriptor/top/sleeve_length.csv', - 'service/attribute/config/descriptor/top/sleeve_shape.csv', - 'service/attribute/config/descriptor/top/sleeve_shoulder.csv', - 'service/attribute/config/descriptor/top/neckline.csv', - 'service/attribute/config/descriptor/top/design.csv', - 'service/attribute/config/descriptor/top/opening_type.csv', - 'service/attribute/config/descriptor/top/silhouette.csv', - 'service/attribute/config/descriptor/top/collar.csv'] +top_description_list = ['app/service/attribute/config/descriptor/top/length.csv', + 'app/service/attribute/config/descriptor/top/type.csv', + 'app/service/attribute/config/descriptor/top/sleeve_length.csv', + 'app/service/attribute/config/descriptor/top/sleeve_shape.csv', + 'app/service/attribute/config/descriptor/top/sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/top/neckline.csv', + 'app/service/attribute/config/descriptor/top/design.csv', + 'app/service/attribute/config/descriptor/top/opening_type.csv', + 'app/service/attribute/config/descriptor/top/silhouette.csv', + 'app/service/attribute/config/descriptor/top/collar.csv'] top_model_list = ['attr_retrieve_T_length', 'attr_retrieve_T_type', @@ -22,11 +22,11 @@ top_model_list = ['attr_retrieve_T_length', ] bottom_description_list = [ - 'service/attribute/config/descriptor/bottom/subtype.csv', - 'service/attribute/config/descriptor/bottom/length.csv', - 'service/attribute/config/descriptor/bottom/silhouette.csv', - 'service/attribute/config/descriptor/bottom/opening_type.csv', - 'service/attribute/config/descriptor/bottom/design.csv'] + 'app/service/attribute/config/descriptor/bottom/subtype.csv', + 'app/service/attribute/config/descriptor/bottom/length.csv', + 'app/service/attribute/config/descriptor/bottom/silhouette.csv', + 'app/service/attribute/config/descriptor/bottom/opening_type.csv', + 'app/service/attribute/config/descriptor/bottom/design.csv'] bottom_model_list = [ 'attr_retrieve_B_subtype', @@ -35,14 +35,14 @@ bottom_model_list = [ 'attr_recong_B_optype', 'attr_retrieve_B_design'] -outwear_description_list = ['service/attribute/config/descriptor/outwear/length.csv', - 'service/attribute/config/descriptor/outwear/sleeve_length.csv', - 'service/attribute/config/descriptor/outwear/sleeve_shape.csv', - 'service/attribute/config/descriptor/outwear/sleeve_shoulder.csv', - 'service/attribute/config/descriptor/outwear/collar.csv', - 'service/attribute/config/descriptor/outwear/design.csv', - 'service/attribute/config/descriptor/outwear/opening_type.csv', - 'service/attribute/config/descriptor/outwear/silhouette.csv', ] +outwear_description_list = ['app/service/attribute/config/descriptor/outwear/length.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_length.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_shape.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/outwear/collar.csv', + 'app/service/attribute/config/descriptor/outwear/design.csv', + 'app/service/attribute/config/descriptor/outwear/opening_type.csv', + 'app/service/attribute/config/descriptor/outwear/silhouette.csv', ] outwear_model_list = ['attr_recong_O_length', 'attr_retrieve_O_sleeve_length', @@ -53,15 +53,15 @@ outwear_model_list = ['attr_recong_O_length', 'attr_recong_O_optype', 'attr_retrieve_O_silhouette'] -dress_description_list = [ # 'service/attribute/config/descriptor/dress/D_length.csv', - 'service/attribute/config/descriptor/dress/sleeve_length.csv', - 'service/attribute/config/descriptor/dress/sleeve_shape.csv', - # 'service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv', - 'service/attribute/config/descriptor/dress/neckline.csv', - 'service/attribute/config/descriptor/dress/collar.csv', - 'service/attribute/config/descriptor/dress/design.csv', - 'service/attribute/config/descriptor/dress/silhouette.csv', - 'service/attribute/config/descriptor/dress/type.csv'] +dress_description_list = [ # 'app/service/attribute/config/descriptor/dress/D_length.csv', + 'app/service/attribute/config/descriptor/dress/sleeve_length.csv', + 'app/service/attribute/config/descriptor/dress/sleeve_shape.csv', + # 'app/service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/dress/neckline.csv', + 'app/service/attribute/config/descriptor/dress/collar.csv', + 'app/service/attribute/config/descriptor/dress/design.csv', + 'app/service/attribute/config/descriptor/dress/silhouette.csv', + 'app/service/attribute/config/descriptor/dress/type.csv'] dress_model_list = [ # 'attr_recong_D_length', 'attr_retrieve_D_sleeve_length', From e0d9512b26987d3917ea979c5b0a48277346ea2c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:04:20 +0800 Subject: [PATCH 04/21] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 89a5e3f..ef3955f 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,9 +1,6 @@ import json import logging -import os - from fastapi import APIRouter, HTTPException - from app.schemas.attribute_retrieve import * from app.schemas.response_template import ResponseModel from app.service.attribute.config import const @@ -19,8 +16,6 @@ logger = logging.getLogger() def attribute_recognition(request_item: list[AttributeRecognitionModel]): try: logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") - logger.info(const.top_description_list) - logger.info(os.getcwd()) service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") From 88014bdb4c55851ffcaf76f21376073cc01de7b9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:14:10 +0800 Subject: [PATCH 05/21] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_attribute_retrieve.py | 11 ++++++++--- app/core/config.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index ef3955f..7a14e9d 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -1,9 +1,11 @@ import json import logging from fastapi import APIRouter, HTTPException + +from app.core.config import DEBUG from app.schemas.attribute_retrieve import * from app.schemas.response_template import ResponseModel -from app.service.attribute.config import const +from app.service.attribute.config import const, local_debug_const from app.service.attribute.service_att_recognition import AttributeRecognition from app.service.attribute.service_category_recognition import CategoryRecognition @@ -16,13 +18,16 @@ logger = logging.getLogger() def attribute_recognition(request_item: list[AttributeRecognitionModel]): try: logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") - service = AttributeRecognition(const=const, request_data=request_item) + if DEBUG: + service = AttributeRecognition(const=local_debug_const, request_data=request_item) + else: + service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) - return ResponseModel(data=data) + return ResponseModel(data={"list": data}) # 类别识别 diff --git a/app/core/config.py b/app/core/config.py index b293cef..08c0998 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From 6861e89f8d2faf12bd150ceb07dd14611080107c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:14:49 +0800 Subject: [PATCH 06/21] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 08c0998..b293cef 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = True +DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" From a09476354e76b87c30eab0cb92010a04f72b51af Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:17:36 +0800 Subject: [PATCH 07/21] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index b293cef..0af065b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -60,7 +60,7 @@ RABBITMQ_PARAMS = { } # milvus 配置 -MILVUS_URL = "http://10.1.1.240:19530http://127.0.0.1:8000/docs#/design/design_api_design_post" +MILVUS_URL = "http://10.1.1.240:19530" MILVUS_TOKEN = "root:Milvus" MILVUS_ALIAS = "default" MILVUS_TABLE_KEYPOINT = "keypoint_cache" From a0993d7e3a4dca9541cad1d7c206e8395e13818c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 17 Jun 2024 17:34:57 +0800 Subject: [PATCH 08/21] =?UTF-8?q?feat=20=20=E6=9B=B4=E6=96=B0=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=A8=A1=E6=9D=BF=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/keypoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 4a9e4d1..6cf1141 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -55,7 +55,7 @@ class KeypointDetection(object): @staticmethod # @ RunTime - def save_keypoint_cache(keypoint_id, cache, site, KEYPOINT_RESULT_TABLE_FIELD_SET=None): + def save_keypoint_cache(keypoint_id, cache, site): if site == "down": zeros = np.zeros(20, dtype=int) result = np.concatenate([zeros, cache.flatten()]) From 63db4b891798e2f4d9b1c3d6e585607885b3f2ff Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 18 Jun 2024 10:50:15 +0800 Subject: [PATCH 09/21] =?UTF-8?q?feat=20fix=20=20design=20=E8=BF=9B?= =?UTF-8?q?=E5=BA=A6=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 17 ++++++++++++++++- app/schemas/design.py | 4 ++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index c77d4c2..ecac4f5 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -4,9 +4,10 @@ import time from fastapi import APIRouter, HTTPException -from app.schemas.design import DesignModel +from app.schemas.design import DesignModel, DesignProgressModel from app.schemas.response_template import ResponseModel from app.service.design.service import generate +from app.service.design.utils.redis_utils import Redis router = APIRouter() logger = logging.getLogger() @@ -22,3 +23,17 @@ def design(request_data: DesignModel): logger.warning(f"design Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) + + +@router.post('/get_progress') +def get_progress(request_data: DesignProgressModel): + try: + logger.info(f"get_progress request item is : @@@@@@:{request_data}") + process_id = request_data.process_id + r = Redis() + data = r.read(key=process_id) + logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}") + except Exception as e: + logger.warning(f"design Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/schemas/design.py b/app/schemas/design.py index b203970..994deb4 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -48,3 +48,7 @@ from pydantic import BaseModel class DesignModel(BaseModel): objects: list[dict] process_id: str + + +class DesignProgressModel(BaseModel): + process_id: str From 61ae688dd60eef4d0dca52d9d34d0ec5f7566625 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 18 Jun 2024 14:31:11 +0800 Subject: [PATCH 10/21] =?UTF-8?q?feat=20fix=20=20design=20keypoint=20?= =?UTF-8?q?=E5=8F=96=E6=B6=88=E8=AE=B0=E5=BD=95keypoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design/items/pipelines/keypoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 6cf1141..956e052 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -34,9 +34,9 @@ class KeypointDetection(object): site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) - keypoint_cache = self.keypoint_cache(result, site) + # keypoint_cache = self.keypoint_cache(result, site) # 取消向量查询 直接过模型推理 - # keypoint_cache = False + keypoint_cache = False if keypoint_cache is False: keypoint_infer_result, site = self.infer_keypoint_result(result) From 8476bb3727e3ed974dad3f09d2ca3050e8e3da7e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 10:53:11 +0800 Subject: [PATCH 11/21] feat fix --- app/api/api_design.py | 4 +++- app/service/design/items/pipelines/keypoints.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/app/api/api_design.py b/app/api/api_design.py index ecac4f5..cdbd1f5 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -32,8 +32,10 @@ def get_progress(request_data: DesignProgressModel): process_id = request_data.process_id r = Redis() data = r.read(key=process_id) + if data is None: + raise ValueError("The progress must be numbers ") logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}") except Exception as e: - logger.warning(f"design Run Exception @@@@@@:{e}") + logger.warning(f"get_progress Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 956e052..6cf1141 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -34,9 +34,9 @@ class KeypointDetection(object): site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # keypoint_cache = search_keypoint_cache(result["image_id"], site) - # keypoint_cache = self.keypoint_cache(result, site) + keypoint_cache = self.keypoint_cache(result, site) # 取消向量查询 直接过模型推理 - keypoint_cache = False + # keypoint_cache = False if keypoint_cache is False: keypoint_infer_result, site = self.infer_keypoint_result(result) From d04c3857fcaa8e672c7b290f288081fe9d895e64 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 16:44:04 +0800 Subject: [PATCH 12/21] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- .../service_generate_product_image.py | 9 +- .../service_generate_relight_image.py | 111 ++++++------------ 3 files changed, 41 insertions(+), 81 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 0af065b..3932bf5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -124,7 +124,7 @@ GPI_MODEL_URL = '10.1.1.240:10061' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") -GRI_MODEL_NAME = 'stable_diffusion_1_5' +GRI_MODEL_NAME = 'diffusion_relight_ensemble' GRI_MODEL_URL = '10.1.1.150:8001' # SEG service config diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index ce449ea..2416d2c 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -20,7 +20,7 @@ from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * -from app.schemas.generate_image import GenerateImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image logger = logging.getLogger() @@ -166,10 +166,11 @@ def infer_cancel(tasks_id): if __name__ == '__main__': - rd = GenerateImageModel( + rd = GenerateProductImageModel( tasks_id="123-89", - prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", - image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png", + prompt="", + # prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", ) server = GenerateProductImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 0eacec9..7c7f4b1 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -38,9 +38,10 @@ class GenerateRelightImage: self.batch_size = 1 self.prompt = request_data.prompt self.seed = "12345" + self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' + self.direction = "Right Light" # TODO aida design 结果图背景改为白色 - # self.image, self.image_size = self.get_image(request_data.image_url) - self.image = request_data.image_url + self.image = self.get_image(request_data.image_url) # TODO image 填充并resize成512*768 self.tasks_id = request_data.tasks_id @@ -51,37 +52,8 @@ class GenerateRelightImage: def get_image(self, image_url): response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_bytes = io.BytesIO(response.read()) - - # 转换为PIL图像对象 - image = Image.open(image_bytes) - target_height = 768 - target_width = 512 - - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): - # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - # image_file = BytesIO(response.data) - # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - # image = cv2.resize(image_rbg, (1024, 1024)) - return image, image_size + image = cv2.imdecode(np.frombuffer(response.data, np.uint8), 1) + return image def callback(self, result, error): if error: @@ -92,7 +64,7 @@ class GenerateRelightImage: else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") - image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") @@ -114,47 +86,33 @@ class GenerateRelightImage: def get_result(self): try: - direction = "Right Light" - negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere' prompts = [self.prompt] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - input_text = grpcclient.InferInput( - "prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype) - ) + image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * self.batch_size + seeds = [self.seed] * self.batch_size + nagetive_prompts = [self.negative_prompt] * self.batch_size + directions = [self.direction] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) + input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) + input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_natext.set_data_from_numpy(na_text_obj) + input_seed.set_data_from_numpy(seed_obj) + input_direction.set_data_from_numpy(direction_obj) - negative_prompts = [negative_prompt] * self.batch_size - text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1)) - input_text_neg = grpcclient.InferInput( - "negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype) - ) - input_text_neg.set_data_from_numpy(text_obj_neg) - - seed = np.array(self.seed, dtype="object").reshape((-1, 1)) - input_seed = grpcclient.InferInput( - "seed", seed.shape, np_to_triton_dtype(seed.dtype) - ) - input_seed.set_data_from_numpy(seed) - - input_images = [self.image] * self.batch_size - text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1)) - input_input_images = grpcclient.InferInput( - "input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype) - ) - input_input_images.set_data_from_numpy(text_obj_images) - - directions = [direction] * self.batch_size - text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1)) - input_directions = grpcclient.InferInput( - "direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype) - ) - input_directions.set_data_from_numpy(text_obj_directions) - - output_img = grpcclient.InferRequestedOutput("generated_image") - request_start = time.time() - - inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions] + inputs = [input_text, input_natext, input_image, input_seed, input_direction] ctx = self.infer(inputs) time_out = 600 @@ -179,9 +137,9 @@ class GenerateRelightImage: finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) + self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) - logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): @@ -195,8 +153,9 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateRelightImageModel( tasks_id="123-89", - prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - image_url="/workspace/i3.png", + # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + prompt="", + image_url='aida-users/89/product_image/123-89.png' ) server = GenerateRelightImage(rd) print(server.get_result()) From 64e85a9c72c9208a841a7926164f65f4d00272d7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 16:58:21 +0800 Subject: [PATCH 13/21] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/api/api_test.py b/app/api/api_test.py index 0504349..86ed25c 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,6 +1,6 @@ import logging from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES from fastapi import FastAPI, HTTPException from app.schemas.response_template import ResponseModel @@ -15,6 +15,7 @@ def test(id: int): "SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, + "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, } logger.info(data) if id == 1: From 20b0f81fce2e13b77e95a53b0dd000983d75044d Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 16:59:46 +0800 Subject: [PATCH 14/21] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 3932bf5..b574845 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -123,7 +123,7 @@ GPI_MODEL_NAME = 'diffusion_ensemble_all' GPI_MODEL_URL = '10.1.1.240:10061' # Generate Single Logo service config -GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") +GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_MODEL_NAME = 'diffusion_relight_ensemble' GRI_MODEL_URL = '10.1.1.150:8001' From d0597f4b4c6a27bbcf6d907152e869ac7256c27a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 20 Jun 2024 16:23:02 +0800 Subject: [PATCH 15/21] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE?= =?UTF-8?q?=E6=89=93=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 6 +- .../generate_image/service_generate_image.py | 17 ++-- .../service_generate_product_image.py | 89 +++++++------------ .../service_generate_relight_image.py | 34 ++----- .../service_generate_single_logo.py | 22 +---- .../generate_image/utils/upload_sd_image.py | 37 ++++---- app/service/utils/oss_client.py | 70 +++++++++++++++ 7 files changed, 146 insertions(+), 129 deletions(-) create mode 100644 app/service/utils/oss_client.py diff --git a/app/core/config.py b/app/core/config.py index b574845..4e74711 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,6 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') +OSS = "minio" DEBUG = False if DEBUG: LOGS_PATH = "logs/" @@ -47,7 +48,7 @@ S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8" S3_REGION_NAME = "ap-east-1" # redis 配置 -REDIS_HOST = "10.1.1.240" +REDIS_HOST = "10.1.1.150" REDIS_PORT = "6379" REDIS_DB = "2" @@ -120,7 +121,8 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Single Logo service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME = 'diffusion_ensemble_all' -GPI_MODEL_URL = '10.1.1.240:10061' +# GPI_MODEL_URL = '10.1.1.240:10061' +GPI_MODEL_URL = '10.1.1.150:8001' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 6f8d092..889aed7 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -25,6 +25,7 @@ from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.adjust_contrast import adjust_contrast from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd +from app.service.utils.oss_client import get_image logger = logging.getLogger() @@ -36,7 +37,7 @@ class GenerateImage: self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": @@ -63,10 +64,13 @@ class GenerateImage: # Read data from response. # read image use cv2 try: - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_file = BytesIO(response.data) - image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + # image_file = BytesIO(response.data) + # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + + image_cv2 = get_image(object_name=image_url, data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: @@ -189,7 +193,8 @@ if __name__ == '__main__': prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', image_url="", mode='txt2img', - category="test" + category="test", + gender="male" ) server = GenerateImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 2416d2c..dcdf09f 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -18,10 +18,10 @@ import numpy as np from PIL import Image, ImageOps from minio import Minio from tritonclient.utils import np_to_triton_dtype - from app.core.config import * -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel +from app.schemas.generate_image import GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -33,69 +33,29 @@ class GenerateProductImage: self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" self.batch_size = 1 self.prompt = request_data.prompt - # TODO aida design 结果图背景改为白色 - self.image, self.image_size = self.get_image(request_data.image_url) - # TODO image 填充并resize成512*768 - + self.image, self.image_size = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) - def get_image(self, image_url): - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_bytes = io.BytesIO(response.read()) - - # 转换为PIL图像对象 - image = Image.open(image_bytes) - target_height = 768 - target_width = 512 - - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): - # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - # image_file = BytesIO(response.data) - # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - # image = cv2.resize(image_rbg, (1024, 1024)) - return image, image_size - def callback(self, result, error): if error: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(error) - # self.gen_product_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) - - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -105,13 +65,6 @@ class GenerateProductImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GPI_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def get_result(self): try: prompts = [self.prompt] * self.batch_size @@ -129,11 +82,10 @@ class GenerateProductImage: input_image.set_data_from_numpy(image_obj) inputs = [input_text, input_image] - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() - # logger.info(gen_product_data) if gen_product_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -141,7 +93,6 @@ class GenerateProductImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, gen_product_data) gen_product_data, _ = self.read_tasks_status() return gen_product_data except Exception as e: @@ -153,7 +104,6 @@ class GenerateProductImage: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) - # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") @@ -165,11 +115,36 @@ def infer_cancel(tasks_id): return data +def pre_processing_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + + # resize 图片内尺寸 并贴到768-512的纯白图像上 + target_height = 768 + target_width = 512 + aspect_ratio = image.width / image.height + new_width = int(target_height * aspect_ratio) + resized_image = image.resize((new_width, target_height)) + left = (target_width - resized_image.width) // 2 + top = (target_height - resized_image.height) // 2 + right = target_width - resized_image.width - left + bottom = target_height - resized_image.height - top + image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") + image_size = image.size + if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 创建白色背景 + background = Image.new("RGB", image.size, (255, 255, 255)) + # 将图片粘贴到白色背景上 + background.paste(image, mask=image.split()[3]) + image = np.array(background) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image, image_size + + if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", prompt="", - # prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", ) server = GenerateProductImage(rd) diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 7c7f4b1..8793c42 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -22,6 +22,7 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateRelightImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -31,43 +32,34 @@ class GenerateRelightImage: if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "relight_image" self.batch_size = 1 self.prompt = request_data.prompt - self.seed = "12345" + self.seed = "1" self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' self.direction = "Right Light" - # TODO aida design 结果图背景改为白色 - self.image = self.get_image(request_data.image_url) - # TODO image 填充并resize成512*768 - + self.image_url = request_data.image_url + self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) - def get_image(self, image_url): - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image = cv2.imdecode(np.frombuffer(response.data, np.uint8), 1) - return image - def callback(self, result, error): if error: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(error) - # self.gen_product_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -77,13 +69,6 @@ class GenerateRelightImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GRI_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def get_result(self): try: prompts = [self.prompt] * self.batch_size @@ -114,11 +99,10 @@ class GenerateRelightImage: inputs = [input_text, input_natext, input_image, input_seed, input_direction] - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() - # logger.info(gen_product_data) if gen_product_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -126,7 +110,6 @@ class GenerateRelightImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, gen_product_data) gen_product_data, _ = self.read_tasks_status() return gen_product_data except Exception as e: @@ -138,7 +121,6 @@ class GenerateRelightImage: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) - # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") @@ -154,7 +136,7 @@ if __name__ == '__main__': rd = GenerateRelightImageModel( tasks_id="123-89", # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - prompt="", + prompt="Colorful black", image_url='aida-users/89/product_image/123-89.png' ) server = GenerateRelightImage(rd) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index f3d1719..e3def3e 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -31,8 +31,6 @@ class GenerateSingleLogoImage: if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) @@ -51,23 +49,15 @@ class GenerateSingleLogoImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GSL_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def callback(self, result, error): if error: self.gen_single_logo_data['status'] = "FAILURE" self.gen_single_logo_data['message'] = str(error) - # self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) else: image = result.as_numpy("generated_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_single_logo_data['status'] = "SUCCESS" self.gen_single_logo_data['message'] = "success" self.gen_single_logo_data['image_url'] = str(image_url) @@ -81,25 +71,19 @@ class GenerateSingleLogoImage: input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text.set_data_from_numpy(text_obj) - # negative_prompts text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1)) - # print('text obj neg: ', text_obj_neg) input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)) input_text_neg.set_data_from_numpy(text_obj_neg) - # seed seed = np.array(self.seed, dtype="object").reshape((-1, 1)) input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype)) input_seed.set_data_from_numpy(seed) - inputs = [input_text, input_text_neg, input_seed] - - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 generate_data = None while time_out > 0: generate_data, _ = self.read_tasks_status() - # logger.info(generate_data) if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -107,7 +91,6 @@ class GenerateSingleLogoImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, generate_data) return generate_data except Exception as e: raise Exception(str(e)) @@ -115,7 +98,6 @@ class GenerateSingleLogoImage: dict_generate_data, str_generate_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) - # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index ec476f9..a63488c 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -16,8 +16,11 @@ from PIL import Image from minio import Minio from app.core.config import * +from app.service.utils.oss_client import oss_upload_image minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + # s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @@ -34,36 +37,34 @@ minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET # except Exception as e: # print(f'上传到 S3 失败: {e}') -def upload_SDXL_image(image, user_id, category, object_name): +def upload_SDXL_image(image, user_id, category, file_name): try: image_data = io.BytesIO() image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - minio_req = minio_client.put_object( - GI_MINIO_BUCKET, - f'{user_id}/{category}/{object_name}', - io.BytesIO(image_bytes), - len(image_bytes), - content_type='image/jpeg' - ) - image_url = f"aida-users/{minio_req.object_name}" + + # minio_req = minio_client.put_object( + # GI_MINIO_BUCKET, + # f'{user_id}/{category}/{file_name}', + # io.BytesIO(image_bytes), + # len(image_bytes), + # content_type='image/jpeg' + # ) + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"aida-users/{object_name}" return image_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") -def upload_png_sd(image, user_id, category, object_name): +def upload_png_sd(image, user_id, category, file_name): try: _, img_byte_array = cv2.imencode('.jpg', image) - minio_req = minio_client.put_object( - GI_MINIO_BUCKET, - f'{user_id}/{category}/{object_name}', - io.BytesIO(img_byte_array), - len(img_byte_array), - content_type='image/jpeg' - ) - image_url = f"aida-users/{minio_req.object_name}" + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=img_byte_array) + image_url = f"aida-users/{object_name}" return image_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py new file mode 100644 index 0000000..b2d3b7d --- /dev/null +++ b/app/service/utils/oss_client.py @@ -0,0 +1,70 @@ +import io +import logging +from io import BytesIO + +import boto3 +import cv2 +import numpy as np +from PIL import Image +from minio import Minio + +from app.core.config import * + +logger = logging.getLogger() + + +# 获取图片 +def oss_get_image(bucket, object_name, data_type): + image_object = None + + try: + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body'] + + if data_type == "cv2": + image_bytes = image_data.read() + image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 + image_object = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + else: + data_bytes = BytesIO(image_data.read()) + image_object = Image.open(data_bytes) + except Exception as e: + logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}") + return image_object + + +def oss_upload_image(bucket, object_name, image_bytes): + req = None + try: + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + req = oss_client.put_object(Bucket=AIDA_CLOTHING, Key=object_name, Body=image_bytes, ContentType='image/png') + except Exception as e: + logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") + return req + + +if __name__ == '__main__': + # url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png" + # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" + # url = "aida-sys-image/images/female/outwear/0628000054.jpg" + # url = "aida-users/89/product_image/string-89.png" + # url = "aida-users/89/single_logo/123-89.png" + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + url = 'aida-users/89/relight_image/123-89.png' + read_type = "PIL" + if read_type == "cv2": + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + cv2.imshow("", img) + cv2.waitKey(0) + else: + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + img.show() From 2df1518a9957cda75b3df1bad3c30362dc50a9b7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 21 Jun 2024 17:13:39 +0800 Subject: [PATCH 16/21] feat fix minio and s3 --- app/api/api_test.py | 3 +- .../attribute/service_att_recognition.py | 11 +- .../attribute/service_category_recognition.py | 10 +- .../design/items/pipelines/keypoints.py | 8 +- app/service/design/items/pipelines/loading.py | 60 +++++---- .../design/items/pipelines/painting.py | 77 +++++++----- app/service/design/items/pipelines/split.py | 2 +- app/service/design_pre_processing/service.py | 117 ++++++++++++------ .../generate_image/service_generate_image.py | 13 +- app/service/super_resolution/service.py | 33 +++-- app/service/utils/oss_client.py | 12 +- 11 files changed, 200 insertions(+), 146 deletions(-) diff --git a/app/api/api_test.py b/app/api/api_test.py index 86ed25c..0ff521a 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,6 +1,6 @@ import logging from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS from fastapi import FastAPI, HTTPException from app.schemas.response_template import ResponseModel @@ -16,6 +16,7 @@ def test(id: int): "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, + "local_oss_server": OSS } logger.info(data) if id == 1: diff --git a/app/service/attribute/service_att_recognition.py b/app/service/attribute/service_att_recognition.py index da71c16..ddcfd1c 100644 --- a/app/service/attribute/service_att_recognition.py +++ b/app/service/attribute/service_att_recognition.py @@ -11,12 +11,12 @@ from minio import Minio import tritonclient.http as httpclient from app.core.config import * from app.schemas.attribute_retrieve import AttributeRecognitionModel +from app.service.utils.oss_client import oss_get_image class AttributeRecognition: def __init__(self, const, request_data): - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - logging.info("实例化完成") + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] for i, sketch in enumerate(request_data): self.request_data.append( @@ -97,9 +97,10 @@ class AttributeRecognition: return res def get_image(self, url): - response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) - img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) + # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # + img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/app/service/attribute/service_category_recognition.py b/app/service/attribute/service_category_recognition.py index 18ee043..fb997e9 100644 --- a/app/service/attribute/service_category_recognition.py +++ b/app/service/attribute/service_category_recognition.py @@ -18,12 +18,13 @@ import torch from app.core.config import * from app.schemas.attribute_retrieve import CategoryRecognitionModel +from app.service.utils.oss_client import oss_get_image class CategoryRecognition: def __init__(self, request_data): self.attr_type = pd.read_csv(CATEGORY_PATH) - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) for sketch in request_data: @@ -51,9 +52,10 @@ class CategoryRecognition: def get_image(self, url): # Get data of an object. # Read data from response. - response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) - img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) + # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 6cf1141..1f53ced 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -1,5 +1,6 @@ import logging import time + import numpy as np from pymilvus import MilvusClient @@ -71,11 +72,8 @@ class KeypointDetection(object): ] try: client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) - start_time = time.time() - res = client.upsert( - collection_name=MILVUS_TABLE_KEYPOINT, - data=data, - ) + # start_time = time.time() + res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) # logging.info(f"save keypoint time : {time.time() - start_time}") client.close() return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) diff --git a/app/service/design/items/pipelines/loading.py b/app/service/design/items/pipelines/loading.py index 2697006..a1a49a5 100644 --- a/app/service/design/items/pipelines/loading.py +++ b/app/service/design/items/pipelines/loading.py @@ -1,6 +1,5 @@ import io import logging -import time import cv2 import numpy as np @@ -8,6 +7,7 @@ from PIL import Image from minio import Minio from app.core.config import * +from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES @@ -70,11 +70,7 @@ class LoadImageFromFile(object): class LoadBodyImageFromFile(object): def __init__(self, body_path): self.body_path = body_path - self.minioClient = Minio( - f"{MINIO_URL}", - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + # self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png") @@ -82,33 +78,33 @@ class LoadBodyImageFromFile(object): def __call__(self, result): result["image_url"] = result['body_path'] = self.body_path result["name"] = "mannequin" - if not result['image_url'].lower().endswith(".png"): - logging.info(1) - bucket = self.body_path.split("/", 1)[0] - object_name = self.body_path.split("/", 1)[1] - new_object_name = f'{object_name[:object_name.rfind(".")]}.png' - image = self.minioClient.get_object(bucket, object_name) - image = Image.open(io.BytesIO(image.data)) - image = image.convert("RGBA") - data = image.getdata() - # - new_data = [] - for item in data: - if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: - new_data.append((255, 255, 255, 0)) - else: - new_data.append(item) - image.putdata(new_data) - image_data = io.BytesIO() - image.save(image_data, format='PNG') - image_data.seek(0) - image_bytes = image_data.read() - image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - self.body_path = image_path - result["image_url"] = result['body_path'] = self.body_path - response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1]) + # if not result['image_url'].lower().endswith(".png"): + # bucket = self.body_path.split("/", 1)[0] + # object_name = self.body_path.split("/", 1)[1] + # new_object_name = f'{object_name[:object_name.rfind(".")]}.png' + # image = self.minioClient.get_object(bucket, object_name) + # image = Image.open(io.BytesIO(image.data)) + # image = image.convert("RGBA") + # data = image.getdata() + # # + # new_data = [] + # for item in data: + # if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: + # new_data.append((255, 255, 255, 0)) + # else: + # new_data.append(item) + # image.putdata(new_data) + # image_data = io.BytesIO() + # image.save(image_data, format='PNG') + # image_data.seek(0) + # image_bytes = image_data.read() + # image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # self.body_path = image_path + # result["image_url"] = result['body_path'] = self.body_path + # response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1]) # put_image_time = time.time() - result['body_image'] = Image.open(io.BytesIO(response.read())) + # result['body_image'] = Image.open(io.BytesIO(response.read())) + result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL") # logging.info(f"Image.open time is : {time.time() - put_image_time}") return result diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 6d88411..3c9c233 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -1,19 +1,16 @@ import random -from io import BytesIO + # import boto3 import cv2 import numpy as np from PIL import Image -from minio import Minio -from app.core.config import * +from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES -minio_client = Minio( - MINIO_URL, - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + +# minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @@ -56,17 +53,18 @@ class Painting(object): @staticmethod def get_gradient(bucket_name, object_name): - image_data = minio_client.get_object(bucket_name, object_name) + # image_data = minio_client.get_object(bucket_name, object_name) # image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body'] # 从数据流中读取图像 - image_bytes = image_data.read() + # image_bytes = image_data.read() # 将图像数据转换为numpy数组 - image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8) + # image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8) # 使用OpenCV解码图像数组 - image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") return image @staticmethod @@ -494,16 +492,20 @@ class PrintPainting(object): if not 'IfSingle' in print_dict.keys(): print_dict['IfSingle'] = False - data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) - # data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body'] + # data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) + # data_bytes = BytesIO(data.read()) + # image = Image.open(data_bytes) + # image_mode = image.mode - data_bytes = BytesIO(data.read()) - image = Image.open(data_bytes) - image_mode = image.mode + bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] + object_name = print_dict['print_path_list'][0].split("/", 1)[1] + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 - if image_mode == "RGBA": - new_background = Image.new('RGB', image.size, (255, 255, 255)) - new_background.paste(image, mask=image.split()[3]) + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image_pil = Image.fromarray(image_rgb) + new_background = Image.new('RGB', image_pil.size, (255, 255, 255)) + new_background.paste(image_pil, mask=image.split()[3]) image = new_background print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) @@ -577,21 +579,30 @@ class PrintPainting(object): @staticmethod def read_image(image_url): - data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) - # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] - - data_bytes = BytesIO(data.read()) - image = Image.open(data_bytes) - image_mode = image.mode - # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 - if image_mode == "RGBA": - # new_background = Image.new('RGB', image.size, (255, 255, 255)) - # new_background.paste(image, mask=image.split()[3]) - # image = new_background - return image, image_mode - image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2") + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image = Image.fromarray(image_rgb) + image_mode = "RGBA" + else: + image_mode = "RGB" return image, image_mode + # data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) + # # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] + # + # data_bytes = BytesIO(data.read()) + # image = Image.open(data_bytes) + # image_mode = image.mode + # # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 + # if image_mode == "RGBA": + # # new_background = Image.new('RGB', image.size, (255, 255, 255)) + # # new_background.paste(image, mask=image.split()[3]) + # # image = new_background + # return image, image_mode + # image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + # return image, "RGB" + # @staticmethod # def read_image(image_url): # response = requests.get(image_url) diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index e46a3e1..0183352 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -41,7 +41,7 @@ class Split(object): else: back_mask = result['back_mask'] - rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask']) + rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), re4sult['final_image'], result['mask']) result_front_image = np.zeros_like(rgba_image) result_front_image[front_mask != 0] = rgba_image[front_mask != 0] diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index e655087..88ed739 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -13,15 +13,12 @@ import io from app.core.config import * from app.service.design.utils.design_ensemble import get_keypoint_result +from app.service.utils.oss_client import oss_get_image, oss_upload_image class DesignPreprocessing: - def __init__(self): - self.minio_client = Minio( - MINIO_URL, - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + # def __init__(self): + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # @ RunTime def pipeline(self, image_list): @@ -51,8 +48,9 @@ class DesignPreprocessing: def read_image(self, image_list): for obj in image_list: - file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data - image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + # file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data + # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2") if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # 如果是四通道 mask @@ -125,7 +123,10 @@ class DesignPreprocessing: try: # 覆盖到minio image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() - self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + bucket_name = item['image_url'].split("/", 1)[0] + object_name = item['image_url'].split("/", 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: print(f"Error: {err}") @@ -165,36 +166,76 @@ class DesignPreprocessing: # @ RunTime def composing_image(self, image_list): for image in image_list: - if image['site'] == 'down': - image_width = image['obj'].shape[1] - waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] - scale = 0.4 - if waist_width / scale >= image['obj'].shape[1]: - add_width = int((waist_width / scale - image_width) / 2) - ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) - if IF_DEBUG_SHOW: - cv2.imshow("composing_image", ret) - cv2.waitKey(0) - image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" - else: - image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + ''' 比例相同 整合上下装代码''' + image_width = image['obj'].shape[1] + waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + scale = 0.4 + if waist_width / scale >= image_width: + add_width = int((waist_width / scale - image_width) / 2) + ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + if IF_DEBUG_SHOW: + cv2.imshow("composing_image", ret) + cv2.waitKey(0) + image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + bucket_name = image['image_url'].split('/', 1)[0] + object_name = image['image_url'].split('/', 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + image['show_image_url'] = f"{bucket_name}/{object_name}" else: - scale = 0.4 - image_width = image['obj'].shape[1] - waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1] - if waist_width / scale >= image_width: - add_width = int((waist_width / scale - image_width) / 2) - ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) - if IF_DEBUG_SHOW: - cv2.imshow("composing_image", ret) - cv2.waitKey(0) - image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" - else: - image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + bucket_name = image['image_url'].split('/', 1)[0] + object_name = image['image_url'].split('/', 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + image['show_image_url'] = f"{bucket_name}/{object_name}" + + # if image['site'] == 'down': + # image_width = image['obj'].shape[1] + # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + # scale = 0.4 + # if waist_width / scale >= image_width: + # add_width = int((waist_width / scale - image_width) / 2) + # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + # if IF_DEBUG_SHOW: + # cv2.imshow("composing_image", ret) + # cv2.waitKey(0) + # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_width = image['obj'].shape[1] + # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + # scale = 0.4 + # if waist_width / scale >= image_width: + # add_width = int((waist_width / scale - image_width) / 2) + # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + # if IF_DEBUG_SHOW: + # cv2.imshow("composing_image", ret) + # cv2.waitKey(0) + # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" return image_list @staticmethod diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 889aed7..d193de7 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -10,22 +10,17 @@ import json import logging import time -from io import BytesIO - import cv2 import minio import redis import tritonclient.grpc as grpcclient import numpy as np -from minio import Minio from tritonclient.utils import np_to_triton_dtype - from app.core.config import * from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.utils.adjust_contrast import adjust_contrast from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd -from app.service.utils.oss_client import get_image +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -70,7 +65,7 @@ class GenerateImage: # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - image_cv2 = get_image(object_name=image_url, data_type="cv2") + image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url, data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: @@ -197,4 +192,4 @@ if __name__ == '__main__': gender="male" ) server = GenerateImage(rd) - print(server.get_result()) + print(server.get_result()) \ No newline at end of file diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index e87f1a7..f864d01 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -1,17 +1,15 @@ -import io +import json import logging import time -import minio.error -import redis -import json import cv2 +import minio.error import numpy as np +import redis import torch import tritonclient.grpc as grpcclient -from minio import Minio from app.core.config import * from app.schemas.super_resolution import SuperResolutionModel -from app.service.utils.decorator import RunTime +from app.service.utils.oss_client import oss_get_image, oss_upload_image logger = logging.getLogger() @@ -24,7 +22,7 @@ class SuperResolution: self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.expire(self.tasks_id, 600) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) @@ -33,16 +31,25 @@ class SuperResolution: # @RunTime def read_image(self): try: - image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) + img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2") except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") - img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 return img + # try: + # image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) + # except minio.error.S3Error as e: + # sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) + # self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) + # logger.info(f" [x] Sent {sr_data}") + # raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") + # img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 + # return img + def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") @@ -101,8 +108,10 @@ class SuperResolution: def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() - res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') - image_url = f"aida-users/{res.object_name}" + # res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') + object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg' + oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"{SR_MINIO_BUCKET}/{object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index b2d3b7d..e293117 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -15,8 +15,8 @@ logger = logging.getLogger() # 获取图片 def oss_get_image(bucket, object_name, data_type): + # cv2 默认全通道读取 image_object = None - try: if OSS == "minio": oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) @@ -24,11 +24,10 @@ def oss_get_image(bucket, object_name, data_type): else: oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body'] - if data_type == "cv2": image_bytes = image_data.read() image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 - image_object = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED) else: data_bytes = BytesIO(image_data.read()) image_object = Image.open(data_bytes) @@ -56,11 +55,12 @@ if __name__ == '__main__': # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" # url = "aida-sys-image/images/female/outwear/0628000054.jpg" # url = "aida-users/89/product_image/string-89.png" - # url = "aida-users/89/single_logo/123-89.png" + url = "aida-users/89/single_logo/123-89.png" # url = 'aida-users/89/relight_image/123-89.png' # url = 'aida-users/89/relight_image/123-89.png' - url = 'aida-users/89/relight_image/123-89.png' - read_type = "PIL" + # url = 'aida-users/89/relight_image/123-89.png' + # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" + read_type = "cv2" if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) cv2.imshow("", img) From 484659122aeb71b0041dbb05ff139b049e7e5e3c Mon Sep 17 00:00:00 2001 From: zchen Date: Sat, 22 Jun 2024 17:16:52 +0800 Subject: [PATCH 17/21] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 4e74711..96dbaad 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -113,7 +113,7 @@ GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") # Generate Single Logo service config -GSL_MODEL_URL = '10.1.1.240:10051' +GSL_MODEL_URL = '10.1.1.240:10041' GSL_MINIO_BUCKET = "aida-users" GSL_MODEL_NAME = 'stable_diffusion_xl' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") @@ -121,13 +121,12 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Single Logo service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME = 'diffusion_ensemble_all' -# GPI_MODEL_URL = '10.1.1.240:10061' -GPI_MODEL_URL = '10.1.1.150:8001' +GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_MODEL_NAME = 'diffusion_relight_ensemble' -GRI_MODEL_URL = '10.1.1.150:8001' +GRI_MODEL_URL = '10.1.1.240:10041' # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' From dcfe0f71abc5a97551be2a5ddd1e6db6dec24895 Mon Sep 17 00:00:00 2001 From: zchen Date: Sat, 22 Jun 2024 17:27:01 +0800 Subject: [PATCH 18/21] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1232 -> 1246 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7b3fa73dfc24137723e7ae7bbbff5c0ccbedbba0..68c778ce116e570ae67b442b8196e96fbd5a3079 100644 GIT binary patch delta 21 bcmcb>d5?305DRA#Lq0 Date: Sun, 23 Jun 2024 15:38:33 +0800 Subject: [PATCH 19/21] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 96dbaad..69fb6c2 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -115,7 +115,7 @@ SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_E # Generate Single Logo service config GSL_MODEL_URL = '10.1.1.240:10041' GSL_MINIO_BUCKET = "aida-users" -GSL_MODEL_NAME = 'stable_diffusion_xl' +GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") # Generate Single Logo service config From 7266de9a484ae83f79a0190f80b835235c9b1672 Mon Sep 17 00:00:00 2001 From: zchen Date: Sun, 23 Jun 2024 16:06:09 +0800 Subject: [PATCH 20/21] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 69fb6c2..cfc04f5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -126,7 +126,7 @@ GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_MODEL_NAME = 'diffusion_relight_ensemble' -GRI_MODEL_URL = '10.1.1.240:10041' +GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' From 6c7c6b47af87f0d5aedbce28a3b1361fb6865ca1 Mon Sep 17 00:00:00 2001 From: zchen Date: Sun, 23 Jun 2024 16:07:11 +0800 Subject: [PATCH 21/21] =?UTF-8?q?generate=20=E6=A8=A1=E5=9E=8B=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1246 -> 1232 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 68c778ce116e570ae67b442b8196e96fbd5a3079..7b3fa73dfc24137723e7ae7bbbff5c0ccbedbba0 100644 GIT binary patch delta 11 Scmcb|d4Y3-5X)o*mVE#kNdx=< delta 21 bcmcb>d5?305DRA#Lq0