diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 267b796..7a14e9d 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -2,9 +2,10 @@ import json import logging from fastapi import APIRouter, HTTPException +from app.core.config import DEBUG from app.schemas.attribute_retrieve import * from app.schemas.response_template import ResponseModel -from app.service.attribute.config import const +from app.service.attribute.config import const, local_debug_const from app.service.attribute.service_att_recognition import AttributeRecognition from app.service.attribute.service_category_recognition import CategoryRecognition @@ -17,13 +18,16 @@ logger = logging.getLogger() def attribute_recognition(request_item: list[AttributeRecognitionModel]): try: logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}") - service = AttributeRecognition(const=const, request_data=request_item) + if DEBUG: + service = AttributeRecognition(const=local_debug_const, request_data=request_item) + else: + service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) - return ResponseModel(data=data) + return ResponseModel(data={"list": data}) # 类别识别 diff --git a/app/api/api_design.py b/app/api/api_design.py index c77d4c2..cdbd1f5 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -4,9 +4,10 @@ import time from fastapi import APIRouter, HTTPException -from app.schemas.design import DesignModel +from app.schemas.design import DesignModel, DesignProgressModel from app.schemas.response_template import ResponseModel from app.service.design.service import generate +from app.service.design.utils.redis_utils import Redis router = APIRouter() logger = logging.getLogger() @@ -22,3 +23,19 @@ def design(request_data: DesignModel): logger.warning(f"design Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) + + +@router.post('/get_progress') +def get_progress(request_data: DesignProgressModel): + try: + logger.info(f"get_progress request item is : @@@@@@:{request_data}") + process_id = request_data.process_id + r = Redis() + data = r.read(key=process_id) + if data is None: + raise ValueError("The progress must be numbers ") + logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}") + except Exception as e: + logger.warning(f"get_progress Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_test.py b/app/api/api_test.py index 0504349..0ff521a 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -1,6 +1,6 @@ import logging from fastapi import APIRouter -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS from fastapi import FastAPI, HTTPException from app.schemas.response_template import ResponseModel @@ -15,6 +15,8 @@ def test(id: int): "SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, + "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, + "local_oss_server": OSS } logger.info(data) if id == 1: diff --git a/app/core/config.py b/app/core/config.py index f35eb9c..cfc04f5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,15 +19,16 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') +OSS = "minio" DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" - FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" + # FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" else: LOGS_PATH = "app/logs/" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" - FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' + # FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' # RABBITMQ_ENV = "" # 生产环境 # RABBITMQ_ENV = "-dev" # 开发环境 @@ -47,7 +48,7 @@ S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8" S3_REGION_NAME = "ap-east-1" # redis 配置 -REDIS_HOST = "10.1.1.240" +REDIS_HOST = "10.1.1.150" REDIS_PORT = "6379" REDIS_DB = "2" @@ -60,9 +61,9 @@ RABBITMQ_PARAMS = { } # milvus 配置 -MILVUS_DB_HOST = "10.1.1.240" +MILVUS_URL = "http://10.1.1.240:19530" +MILVUS_TOKEN = "root:Milvus" MILVUS_ALIAS = "default" -MILVUS_PORT = "19530" MILVUS_TABLE_KEYPOINT = "keypoint_cache" MILVUS_TABLE_SEG = "seg_cache" @@ -123,8 +124,8 @@ GPI_MODEL_NAME = 'diffusion_ensemble_all' GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config -GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") -GRI_MODEL_NAME = 'stable_diffusion_1_5' +GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") +GRI_MODEL_NAME = 'diffusion_relight_ensemble' GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config diff --git a/app/schemas/design.py b/app/schemas/design.py index b203970..994deb4 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -48,3 +48,7 @@ from pydantic import BaseModel class DesignModel(BaseModel): objects: list[dict] process_id: str + + +class DesignProgressModel(BaseModel): + process_id: str diff --git a/app/service/attribute/config/const.py b/app/service/attribute/config/const.py index 24d9412..738e486 100644 --- a/app/service/attribute/config/const.py +++ b/app/service/attribute/config/const.py @@ -1,13 +1,13 @@ -top_description_list = ['service/attribute/config/descriptor/top/length.csv', - 'service/attribute/config/descriptor/top/type.csv', - 'service/attribute/config/descriptor/top/sleeve_length.csv', - 'service/attribute/config/descriptor/top/sleeve_shape.csv', - 'service/attribute/config/descriptor/top/sleeve_shoulder.csv', - 'service/attribute/config/descriptor/top/neckline.csv', - 'service/attribute/config/descriptor/top/design.csv', - 'service/attribute/config/descriptor/top/opening_type.csv', - 'service/attribute/config/descriptor/top/silhouette.csv', - 'service/attribute/config/descriptor/top/collar.csv'] +top_description_list = ['app/service/attribute/config/descriptor/top/length.csv', + 'app/service/attribute/config/descriptor/top/type.csv', + 'app/service/attribute/config/descriptor/top/sleeve_length.csv', + 'app/service/attribute/config/descriptor/top/sleeve_shape.csv', + 'app/service/attribute/config/descriptor/top/sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/top/neckline.csv', + 'app/service/attribute/config/descriptor/top/design.csv', + 'app/service/attribute/config/descriptor/top/opening_type.csv', + 'app/service/attribute/config/descriptor/top/silhouette.csv', + 'app/service/attribute/config/descriptor/top/collar.csv'] top_model_list = ['attr_retrieve_T_length', 'attr_retrieve_T_type', @@ -22,11 +22,11 @@ top_model_list = ['attr_retrieve_T_length', ] bottom_description_list = [ - 'service/attribute/config/descriptor/bottom/subtype.csv', - 'service/attribute/config/descriptor/bottom/length.csv', - 'service/attribute/config/descriptor/bottom/silhouette.csv', - 'service/attribute/config/descriptor/bottom/opening_type.csv', - 'service/attribute/config/descriptor/bottom/design.csv'] + 'app/service/attribute/config/descriptor/bottom/subtype.csv', + 'app/service/attribute/config/descriptor/bottom/length.csv', + 'app/service/attribute/config/descriptor/bottom/silhouette.csv', + 'app/service/attribute/config/descriptor/bottom/opening_type.csv', + 'app/service/attribute/config/descriptor/bottom/design.csv'] bottom_model_list = [ 'attr_retrieve_B_subtype', @@ -35,14 +35,14 @@ bottom_model_list = [ 'attr_recong_B_optype', 'attr_retrieve_B_design'] -outwear_description_list = ['service/attribute/config/descriptor/outwear/length.csv', - 'service/attribute/config/descriptor/outwear/sleeve_length.csv', - 'service/attribute/config/descriptor/outwear/sleeve_shape.csv', - 'service/attribute/config/descriptor/outwear/sleeve_shoulder.csv', - 'service/attribute/config/descriptor/outwear/collar.csv', - 'service/attribute/config/descriptor/outwear/design.csv', - 'service/attribute/config/descriptor/outwear/opening_type.csv', - 'service/attribute/config/descriptor/outwear/silhouette.csv', ] +outwear_description_list = ['app/service/attribute/config/descriptor/outwear/length.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_length.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_shape.csv', + 'app/service/attribute/config/descriptor/outwear/sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/outwear/collar.csv', + 'app/service/attribute/config/descriptor/outwear/design.csv', + 'app/service/attribute/config/descriptor/outwear/opening_type.csv', + 'app/service/attribute/config/descriptor/outwear/silhouette.csv', ] outwear_model_list = ['attr_recong_O_length', 'attr_retrieve_O_sleeve_length', @@ -53,15 +53,15 @@ outwear_model_list = ['attr_recong_O_length', 'attr_recong_O_optype', 'attr_retrieve_O_silhouette'] -dress_description_list = [ # 'service/attribute/config/descriptor/dress/D_length.csv', - 'service/attribute/config/descriptor/dress/sleeve_length.csv', - 'service/attribute/config/descriptor/dress/sleeve_shape.csv', - # 'service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv', - 'service/attribute/config/descriptor/dress/neckline.csv', - 'service/attribute/config/descriptor/dress/collar.csv', - 'service/attribute/config/descriptor/dress/design.csv', - 'service/attribute/config/descriptor/dress/silhouette.csv', - 'service/attribute/config/descriptor/dress/type.csv'] +dress_description_list = [ # 'app/service/attribute/config/descriptor/dress/D_length.csv', + 'app/service/attribute/config/descriptor/dress/sleeve_length.csv', + 'app/service/attribute/config/descriptor/dress/sleeve_shape.csv', + # 'app/service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv', + 'app/service/attribute/config/descriptor/dress/neckline.csv', + 'app/service/attribute/config/descriptor/dress/collar.csv', + 'app/service/attribute/config/descriptor/dress/design.csv', + 'app/service/attribute/config/descriptor/dress/silhouette.csv', + 'app/service/attribute/config/descriptor/dress/type.csv'] dress_model_list = [ # 'attr_recong_D_length', 'attr_retrieve_D_sleeve_length', diff --git a/app/service/attribute/service_att_recognition.py b/app/service/attribute/service_att_recognition.py index da71c16..ddcfd1c 100644 --- a/app/service/attribute/service_att_recognition.py +++ b/app/service/attribute/service_att_recognition.py @@ -11,12 +11,12 @@ from minio import Minio import tritonclient.http as httpclient from app.core.config import * from app.schemas.attribute_retrieve import AttributeRecognitionModel +from app.service.utils.oss_client import oss_get_image class AttributeRecognition: def __init__(self, const, request_data): - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - logging.info("实例化完成") + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] for i, sketch in enumerate(request_data): self.request_data.append( @@ -97,9 +97,10 @@ class AttributeRecognition: return res def get_image(self, url): - response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) - img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) + # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # + img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/app/service/attribute/service_category_recognition.py b/app/service/attribute/service_category_recognition.py index 18ee043..fb997e9 100644 --- a/app/service/attribute/service_category_recognition.py +++ b/app/service/attribute/service_category_recognition.py @@ -18,12 +18,13 @@ import torch from app.core.config import * from app.schemas.attribute_retrieve import CategoryRecognitionModel +from app.service.utils.oss_client import oss_get_image class CategoryRecognition: def __init__(self, request_data): self.attr_type = pd.read_csv(CATEGORY_PATH) - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) for sketch in request_data: @@ -51,9 +52,10 @@ class CategoryRecognition: def get_image(self, url): # Get data of an object. # Read data from response. - response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) - img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + # response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1]) + # img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 + img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 4d0a081..1f53ced 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -1,5 +1,6 @@ import logging import time + import numpy as np from pymilvus import MilvusClient @@ -14,17 +15,17 @@ class KeypointDetection(object): path here: abstract path """ - def __init__(self): - self.client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) + # def __init__(self): + # self.client = MilvusClient( + # uri="http://10.1.1.240:19530", + # token="root:Milvus", + # db_name=MILVUS_ALIAS + # ) - def __del__(self): - # start_time = time.time() - self.client.close() - # print(f"client close time : {time.time() - start_time}") + # def __del__(self): + # start_time = time.time() + # self.client.close() + # print(f"client close time : {time.time() - start_time}") # @ RunTime def __call__(self, result): @@ -55,7 +56,7 @@ class KeypointDetection(object): @staticmethod # @ RunTime - def save_keypoint_cache(keypoint_id, cache, site, KEYPOINT_RESULT_TABLE_FIELD_SET=None): + def save_keypoint_cache(keypoint_id, cache, site): if site == "down": zeros = np.zeros(20, dtype=int) result = np.concatenate([zeros, cache.flatten()]) @@ -69,24 +70,16 @@ class KeypointDetection(object): "keypoint_vector": result.tolist() } ] - client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) try: - start_time = time.time() - res = client.upsert( - collection_name=MILVUS_TABLE_KEYPOINT, - data=data, - ) + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + # start_time = time.time() + res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) # logging.info(f"save keypoint time : {time.time() - start_time}") + client.close() return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) except Exception as e: logging.info(f"save keypoint cache milvus error : {e}") return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) - finally: - client.close() @staticmethod def update_keypoint_cache(keypoint_id, infer_result, search_result, site): @@ -102,12 +95,9 @@ class KeypointDetection(object): "keypoint_vector": result.tolist() } ] - client = MilvusClient( - uri="http://10.1.1.240:19530", - token="root:Milvus", - db_name=MILVUS_ALIAS - ) + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) start_time = time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. @@ -125,8 +115,9 @@ class KeypointDetection(object): # @ RunTime def keypoint_cache(self, result, site): try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) keypoint_id = result['image_id'] - res = self.client.query( + res = client.query( collection_name=MILVUS_TABLE_KEYPOINT, # ids=[keypoint_id], filter=f"keypoint_id == {keypoint_id}", diff --git a/app/service/design/items/pipelines/loading.py b/app/service/design/items/pipelines/loading.py index 2697006..a1a49a5 100644 --- a/app/service/design/items/pipelines/loading.py +++ b/app/service/design/items/pipelines/loading.py @@ -1,6 +1,5 @@ import io import logging -import time import cv2 import numpy as np @@ -8,6 +7,7 @@ from PIL import Image from minio import Minio from app.core.config import * +from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES @@ -70,11 +70,7 @@ class LoadImageFromFile(object): class LoadBodyImageFromFile(object): def __init__(self, body_path): self.body_path = body_path - self.minioClient = Minio( - f"{MINIO_URL}", - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + # self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png") @@ -82,33 +78,33 @@ class LoadBodyImageFromFile(object): def __call__(self, result): result["image_url"] = result['body_path'] = self.body_path result["name"] = "mannequin" - if not result['image_url'].lower().endswith(".png"): - logging.info(1) - bucket = self.body_path.split("/", 1)[0] - object_name = self.body_path.split("/", 1)[1] - new_object_name = f'{object_name[:object_name.rfind(".")]}.png' - image = self.minioClient.get_object(bucket, object_name) - image = Image.open(io.BytesIO(image.data)) - image = image.convert("RGBA") - data = image.getdata() - # - new_data = [] - for item in data: - if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: - new_data.append((255, 255, 255, 0)) - else: - new_data.append(item) - image.putdata(new_data) - image_data = io.BytesIO() - image.save(image_data, format='PNG') - image_data.seek(0) - image_bytes = image_data.read() - image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - self.body_path = image_path - result["image_url"] = result['body_path'] = self.body_path - response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1]) + # if not result['image_url'].lower().endswith(".png"): + # bucket = self.body_path.split("/", 1)[0] + # object_name = self.body_path.split("/", 1)[1] + # new_object_name = f'{object_name[:object_name.rfind(".")]}.png' + # image = self.minioClient.get_object(bucket, object_name) + # image = Image.open(io.BytesIO(image.data)) + # image = image.convert("RGBA") + # data = image.getdata() + # # + # new_data = [] + # for item in data: + # if item[0] >= 230 and item[1] >= 230 and item[2] >= 230: + # new_data.append((255, 255, 255, 0)) + # else: + # new_data.append(item) + # image.putdata(new_data) + # image_data = io.BytesIO() + # image.save(image_data, format='PNG') + # image_data.seek(0) + # image_bytes = image_data.read() + # image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # self.body_path = image_path + # result["image_url"] = result['body_path'] = self.body_path + # response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1]) # put_image_time = time.time() - result['body_image'] = Image.open(io.BytesIO(response.read())) + # result['body_image'] = Image.open(io.BytesIO(response.read())) + result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL") # logging.info(f"Image.open time is : {time.time() - put_image_time}") return result diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 6d88411..3c9c233 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -1,19 +1,16 @@ import random -from io import BytesIO + # import boto3 import cv2 import numpy as np from PIL import Image -from minio import Minio -from app.core.config import * +from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES -minio_client = Minio( - MINIO_URL, - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + +# minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @@ -56,17 +53,18 @@ class Painting(object): @staticmethod def get_gradient(bucket_name, object_name): - image_data = minio_client.get_object(bucket_name, object_name) + # image_data = minio_client.get_object(bucket_name, object_name) # image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body'] # 从数据流中读取图像 - image_bytes = image_data.read() + # image_bytes = image_data.read() # 将图像数据转换为numpy数组 - image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8) + # image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8) # 使用OpenCV解码图像数组 - image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") return image @staticmethod @@ -494,16 +492,20 @@ class PrintPainting(object): if not 'IfSingle' in print_dict.keys(): print_dict['IfSingle'] = False - data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) - # data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body'] + # data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]) + # data_bytes = BytesIO(data.read()) + # image = Image.open(data_bytes) + # image_mode = image.mode - data_bytes = BytesIO(data.read()) - image = Image.open(data_bytes) - image_mode = image.mode + bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] + object_name = print_dict['print_path_list'][0].split("/", 1)[1] + image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2") # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 - if image_mode == "RGBA": - new_background = Image.new('RGB', image.size, (255, 255, 255)) - new_background.paste(image, mask=image.split()[3]) + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image_pil = Image.fromarray(image_rgb) + new_background = Image.new('RGB', image_pil.size, (255, 255, 255)) + new_background.paste(image_pil, mask=image.split()[3]) image = new_background print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) @@ -577,21 +579,30 @@ class PrintPainting(object): @staticmethod def read_image(image_url): - data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) - # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] - - data_bytes = BytesIO(data.read()) - image = Image.open(data_bytes) - image_mode = image.mode - # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 - if image_mode == "RGBA": - # new_background = Image.new('RGB', image.size, (255, 255, 255)) - # new_background.paste(image, mask=image.split()[3]) - # image = new_background - return image, image_mode - image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2") + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image = Image.fromarray(image_rgb) + image_mode = "RGBA" + else: + image_mode = "RGB" return image, image_mode + # data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) + # # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body'] + # + # data_bytes = BytesIO(data.read()) + # image = Image.open(data_bytes) + # image_mode = image.mode + # # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 + # if image_mode == "RGBA": + # # new_background = Image.new('RGB', image.size, (255, 255, 255)) + # # new_background.paste(image, mask=image.split()[3]) + # # image = new_background + # return image, image_mode + # image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + # return image, "RGB" + # @staticmethod # def read_image(image_url): # response = requests.get(image_url) diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index e46a3e1..0183352 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -41,7 +41,7 @@ class Split(object): else: back_mask = result['back_mask'] - rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask']) + rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), re4sult['final_image'], result['mask']) result_front_image = np.zeros_like(rgba_image) result_front_image[front_mask != 0] = rgba_image[front_mask != 0] diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index e655087..88ed739 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -13,15 +13,12 @@ import io from app.core.config import * from app.service.design.utils.design_ensemble import get_keypoint_result +from app.service.utils.oss_client import oss_get_image, oss_upload_image class DesignPreprocessing: - def __init__(self): - self.minio_client = Minio( - MINIO_URL, - access_key=MINIO_ACCESS, - secret_key=MINIO_SECRET, - secure=MINIO_SECURE) + # def __init__(self): + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # @ RunTime def pipeline(self, image_list): @@ -51,8 +48,9 @@ class DesignPreprocessing: def read_image(self, image_list): for obj in image_list: - file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data - image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + # file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data + # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) + image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2") if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # 如果是四通道 mask @@ -125,7 +123,10 @@ class DesignPreprocessing: try: # 覆盖到minio image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes() - self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) + bucket_name = item['image_url'].split("/", 1)[0] + object_name = item['image_url'].split("/", 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: print(f"Error: {err}") @@ -165,36 +166,76 @@ class DesignPreprocessing: # @ RunTime def composing_image(self, image_list): for image in image_list: - if image['site'] == 'down': - image_width = image['obj'].shape[1] - waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] - scale = 0.4 - if waist_width / scale >= image['obj'].shape[1]: - add_width = int((waist_width / scale - image_width) / 2) - ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) - if IF_DEBUG_SHOW: - cv2.imshow("composing_image", ret) - cv2.waitKey(0) - image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" - else: - image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + ''' 比例相同 整合上下装代码''' + image_width = image['obj'].shape[1] + waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + scale = 0.4 + if waist_width / scale >= image_width: + add_width = int((waist_width / scale - image_width) / 2) + ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + if IF_DEBUG_SHOW: + cv2.imshow("composing_image", ret) + cv2.waitKey(0) + image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + bucket_name = image['image_url'].split('/', 1)[0] + object_name = image['image_url'].split('/', 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + image['show_image_url'] = f"{bucket_name}/{object_name}" else: - scale = 0.4 - image_width = image['obj'].shape[1] - waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1] - if waist_width / scale >= image_width: - add_width = int((waist_width / scale - image_width) / 2) - ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) - if IF_DEBUG_SHOW: - cv2.imshow("composing_image", ret) - cv2.waitKey(0) - image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" - else: - image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() - image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + bucket_name = image['image_url'].split('/', 1)[0] + object_name = image['image_url'].split('/', 1)[1] + oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + image['show_image_url'] = f"{bucket_name}/{object_name}" + + # if image['site'] == 'down': + # image_width = image['obj'].shape[1] + # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + # scale = 0.4 + # if waist_width / scale >= image_width: + # add_width = int((waist_width / scale - image_width) / 2) + # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + # if IF_DEBUG_SHOW: + # cv2.imshow("composing_image", ret) + # cv2.waitKey(0) + # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_width = image['obj'].shape[1] + # waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1] + # scale = 0.4 + # if waist_width / scale >= image_width: + # add_width = int((waist_width / scale - image_width) / 2) + # ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256)) + # if IF_DEBUG_SHOW: + # cv2.imshow("composing_image", ret) + # cv2.waitKey(0) + # image_bytes = cv2.imencode(".jpg", ret)[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" + # else: + # image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes() + # # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}" + # bucket_name = image['image_url'].split('/', 1)[0] + # object_name = image['image_url'].split('/', 1)[1] + # oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + # image['show_image_url'] = f"{bucket_name}/{object_name}" return image_list @staticmethod diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 6f8d092..d193de7 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -10,21 +10,17 @@ import json import logging import time -from io import BytesIO - import cv2 import minio import redis import tritonclient.grpc as grpcclient import numpy as np -from minio import Minio from tritonclient.utils import np_to_triton_dtype - from app.core.config import * from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.utils.adjust_contrast import adjust_contrast from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -36,7 +32,7 @@ class GenerateImage: self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": @@ -63,10 +59,13 @@ class GenerateImage: # Read data from response. # read image use cv2 try: - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_file = BytesIO(response.data) - image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + # image_file = BytesIO(response.data) + # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + + image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url, data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: @@ -189,7 +188,8 @@ if __name__ == '__main__': prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', image_url="", mode='txt2img', - category="test" + category="test", + gender="male" ) server = GenerateImage(rd) - print(server.get_result()) + print(server.get_result()) \ No newline at end of file diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index ce449ea..dcdf09f 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -18,10 +18,10 @@ import numpy as np from PIL import Image, ImageOps from minio import Minio from tritonclient.utils import np_to_triton_dtype - from app.core.config import * -from app.schemas.generate_image import GenerateImageModel +from app.schemas.generate_image import GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -33,69 +33,29 @@ class GenerateProductImage: self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" self.batch_size = 1 self.prompt = request_data.prompt - # TODO aida design 结果图背景改为白色 - self.image, self.image_size = self.get_image(request_data.image_url) - # TODO image 填充并resize成512*768 - + self.image, self.image_size = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) - def get_image(self, image_url): - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_bytes = io.BytesIO(response.read()) - - # 转换为PIL图像对象 - image = Image.open(image_bytes) - target_height = 768 - target_width = 512 - - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): - # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - # image_file = BytesIO(response.data) - # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - # image = cv2.resize(image_rbg, (1024, 1024)) - return image, image_size - def callback(self, result, error): if error: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(error) - # self.gen_product_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) - - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -105,13 +65,6 @@ class GenerateProductImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GPI_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def get_result(self): try: prompts = [self.prompt] * self.batch_size @@ -129,11 +82,10 @@ class GenerateProductImage: input_image.set_data_from_numpy(image_obj) inputs = [input_text, input_image] - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() - # logger.info(gen_product_data) if gen_product_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -141,7 +93,6 @@ class GenerateProductImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, gen_product_data) gen_product_data, _ = self.read_tasks_status() return gen_product_data except Exception as e: @@ -153,7 +104,6 @@ class GenerateProductImage: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) - # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") @@ -165,11 +115,37 @@ def infer_cancel(tasks_id): return data +def pre_processing_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + + # resize 图片内尺寸 并贴到768-512的纯白图像上 + target_height = 768 + target_width = 512 + aspect_ratio = image.width / image.height + new_width = int(target_height * aspect_ratio) + resized_image = image.resize((new_width, target_height)) + left = (target_width - resized_image.width) // 2 + top = (target_height - resized_image.height) // 2 + right = target_width - resized_image.width - left + bottom = target_height - resized_image.height - top + image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") + image_size = image.size + if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 创建白色背景 + background = Image.new("RGB", image.size, (255, 255, 255)) + # 将图片粘贴到白色背景上 + background.paste(image, mask=image.split()[3]) + image = np.array(background) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image, image_size + + if __name__ == '__main__': - rd = GenerateImageModel( + rd = GenerateProductImageModel( tasks_id="123-89", - prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", - image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png", + prompt="", + # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", ) server = GenerateProductImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 0eacec9..8793c42 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -22,6 +22,7 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateRelightImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -31,71 +32,34 @@ class GenerateRelightImage: if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "relight_image" self.batch_size = 1 self.prompt = request_data.prompt - self.seed = "12345" - # TODO aida design 结果图背景改为白色 - # self.image, self.image_size = self.get_image(request_data.image_url) - self.image = request_data.image_url - # TODO image 填充并resize成512*768 - + self.seed = "1" + self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' + self.direction = "Right Light" + self.image_url = request_data.image_url + self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) - def get_image(self, image_url): - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_bytes = io.BytesIO(response.read()) - - # 转换为PIL图像对象 - image = Image.open(image_bytes) - target_height = 768 - target_width = 512 - - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): - # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - # image_file = BytesIO(response.data) - # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - # image = cv2.resize(image_rbg, (1024, 1024)) - return image, image_size - def callback(self, result, error): if error: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(error) - # self.gen_product_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") - image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -105,62 +69,40 @@ class GenerateRelightImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GRI_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def get_result(self): try: - direction = "Right Light" - negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere' prompts = [self.prompt] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - input_text = grpcclient.InferInput( - "prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype) - ) + image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * self.batch_size + seeds = [self.seed] * self.batch_size + nagetive_prompts = [self.negative_prompt] * self.batch_size + directions = [self.direction] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) + input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) + input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_natext.set_data_from_numpy(na_text_obj) + input_seed.set_data_from_numpy(seed_obj) + input_direction.set_data_from_numpy(direction_obj) - negative_prompts = [negative_prompt] * self.batch_size - text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1)) - input_text_neg = grpcclient.InferInput( - "negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype) - ) - input_text_neg.set_data_from_numpy(text_obj_neg) + inputs = [input_text, input_natext, input_image, input_seed, input_direction] - seed = np.array(self.seed, dtype="object").reshape((-1, 1)) - input_seed = grpcclient.InferInput( - "seed", seed.shape, np_to_triton_dtype(seed.dtype) - ) - input_seed.set_data_from_numpy(seed) - - input_images = [self.image] * self.batch_size - text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1)) - input_input_images = grpcclient.InferInput( - "input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype) - ) - input_input_images.set_data_from_numpy(text_obj_images) - - directions = [direction] * self.batch_size - text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1)) - input_directions = grpcclient.InferInput( - "direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype) - ) - input_directions.set_data_from_numpy(text_obj_directions) - - output_img = grpcclient.InferRequestedOutput("generated_image") - request_start = time.time() - - inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions] - - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() - # logger.info(gen_product_data) if gen_product_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -168,7 +110,6 @@ class GenerateRelightImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, gen_product_data) gen_product_data, _ = self.read_tasks_status() return gen_product_data except Exception as e: @@ -179,9 +120,8 @@ class GenerateRelightImage: finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) - # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) - logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): @@ -195,8 +135,9 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateRelightImageModel( tasks_id="123-89", - prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - image_url="/workspace/i3.png", + # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + prompt="Colorful black", + image_url='aida-users/89/product_image/123-89.png' ) server = GenerateRelightImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index f3d1719..e3def3e 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -31,8 +31,6 @@ class GenerateSingleLogoImage: if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) @@ -51,23 +49,15 @@ class GenerateSingleLogoImage: status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data - def infer(self, inputs): - return self.grpc_client.async_infer( - model_name=GSL_MODEL_NAME, - inputs=inputs, - callback=self.callback - ) - def callback(self, result, error): if error: self.gen_single_logo_data['status'] = "FAILURE" self.gen_single_logo_data['message'] = str(error) - # self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) else: image = result.as_numpy("generated_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_single_logo_data['status'] = "SUCCESS" self.gen_single_logo_data['message'] = "success" self.gen_single_logo_data['image_url'] = str(image_url) @@ -81,25 +71,19 @@ class GenerateSingleLogoImage: input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text.set_data_from_numpy(text_obj) - # negative_prompts text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1)) - # print('text obj neg: ', text_obj_neg) input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)) input_text_neg.set_data_from_numpy(text_obj_neg) - # seed seed = np.array(self.seed, dtype="object").reshape((-1, 1)) input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype)) input_seed.set_data_from_numpy(seed) - inputs = [input_text, input_text_neg, input_seed] - - ctx = self.infer(inputs) + ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 generate_data = None while time_out > 0: generate_data, _ = self.read_tasks_status() - # logger.info(generate_data) if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -107,7 +91,6 @@ class GenerateSingleLogoImage: break time_out -= 1 time.sleep(0.1) - # logger.info(time_out, generate_data) return generate_data except Exception as e: raise Exception(str(e)) @@ -115,7 +98,6 @@ class GenerateSingleLogoImage: dict_generate_data, str_generate_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) - # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index ec476f9..a63488c 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -16,8 +16,11 @@ from PIL import Image from minio import Minio from app.core.config import * +from app.service.utils.oss_client import oss_upload_image minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + # s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @@ -34,36 +37,34 @@ minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET # except Exception as e: # print(f'上传到 S3 失败: {e}') -def upload_SDXL_image(image, user_id, category, object_name): +def upload_SDXL_image(image, user_id, category, file_name): try: image_data = io.BytesIO() image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - minio_req = minio_client.put_object( - GI_MINIO_BUCKET, - f'{user_id}/{category}/{object_name}', - io.BytesIO(image_bytes), - len(image_bytes), - content_type='image/jpeg' - ) - image_url = f"aida-users/{minio_req.object_name}" + + # minio_req = minio_client.put_object( + # GI_MINIO_BUCKET, + # f'{user_id}/{category}/{file_name}', + # io.BytesIO(image_bytes), + # len(image_bytes), + # content_type='image/jpeg' + # ) + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"aida-users/{object_name}" return image_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") -def upload_png_sd(image, user_id, category, object_name): +def upload_png_sd(image, user_id, category, file_name): try: _, img_byte_array = cv2.imencode('.jpg', image) - minio_req = minio_client.put_object( - GI_MINIO_BUCKET, - f'{user_id}/{category}/{object_name}', - io.BytesIO(img_byte_array), - len(img_byte_array), - content_type='image/jpeg' - ) - image_url = f"aida-users/{minio_req.object_name}" + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=img_byte_array) + image_url = f"aida-users/{object_name}" return image_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/super_resolution/service.py b/app/service/super_resolution/service.py index e87f1a7..f864d01 100644 --- a/app/service/super_resolution/service.py +++ b/app/service/super_resolution/service.py @@ -1,17 +1,15 @@ -import io +import json import logging import time -import minio.error -import redis -import json import cv2 +import minio.error import numpy as np +import redis import torch import tritonclient.grpc as grpcclient -from minio import Minio from app.core.config import * from app.schemas.super_resolution import SuperResolutionModel -from app.service.utils.decorator import RunTime +from app.service.utils.oss_client import oss_get_image, oss_upload_image logger = logging.getLogger() @@ -24,7 +22,7 @@ class SuperResolution: self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn - self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.expire(self.tasks_id, 600) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) @@ -33,16 +31,25 @@ class SuperResolution: # @RunTime def read_image(self): try: - image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) + img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2") except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") - img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 - img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 return img + # try: + # image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) + # except minio.error.S3Error as e: + # sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) + # self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) + # logger.info(f" [x] Sent {sr_data}") + # raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") + # img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 + # img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 + # return img + def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") @@ -101,8 +108,10 @@ class SuperResolution: def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() - res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') - image_url = f"aida-users/{res.object_name}" + # res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') + object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg' + oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"{SR_MINIO_BUCKET}/{object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py new file mode 100644 index 0000000..e293117 --- /dev/null +++ b/app/service/utils/oss_client.py @@ -0,0 +1,70 @@ +import io +import logging +from io import BytesIO + +import boto3 +import cv2 +import numpy as np +from PIL import Image +from minio import Minio + +from app.core.config import * + +logger = logging.getLogger() + + +# 获取图片 +def oss_get_image(bucket, object_name, data_type): + # cv2 默认全通道读取 + image_object = None + try: + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body'] + if data_type == "cv2": + image_bytes = image_data.read() + image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 + image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED) + else: + data_bytes = BytesIO(image_data.read()) + image_object = Image.open(data_bytes) + except Exception as e: + logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}") + return image_object + + +def oss_upload_image(bucket, object_name, image_bytes): + req = None + try: + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + req = oss_client.put_object(Bucket=AIDA_CLOTHING, Key=object_name, Body=image_bytes, ContentType='image/png') + except Exception as e: + logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") + return req + + +if __name__ == '__main__': + # url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png" + # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" + # url = "aida-sys-image/images/female/outwear/0628000054.jpg" + # url = "aida-users/89/product_image/string-89.png" + url = "aida-users/89/single_logo/123-89.png" + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" + read_type = "cv2" + if read_type == "cv2": + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + cv2.imshow("", img) + cv2.waitKey(0) + else: + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + img.show()