diff --git a/app/core/config.py b/app/core/config.py index 2da37a3..62b9cdf 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" @@ -119,6 +119,9 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Single Logo service config GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"GenProductImage{RABBITMQ_ENV}") +GPI_MODEL_NAME = 'diffusion_ensemble_all' +GPI_MODEL_URL = '10.1.1.240:10061' + # SEG service config SEG_MODEL_URL = '10.1.1.240:10000' diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 49cf9ce..fee4a92 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -5,9 +5,6 @@ class GenerateImageModel(BaseModel): tasks_id: str prompt: str image_url: str - mode: str - category: str - gender: str class GenerateSingleLogoImageModel(BaseModel): diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index ea875bd..84f7940 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -7,38 +7,42 @@ @Date :2023/7/26 12:01:05 @detail : """ +import io import json import logging import time -from io import BytesIO - import cv2 -import minio import redis import tritonclient.grpc as grpcclient import numpy as np +from PIL import Image, ImageOps from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.utils.adjust_contrast import adjust_contrast -from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image logger = logging.getLogger() class GenerateProductImage: def __init__(self, request_data): - # if DEBUG is False: - # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - # self.channel = self.connection.channel() - self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - self.channel = self.connection.channel() + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.category = "product_image" + self.batch_size = 1 + self.prompt = request_data.prompt + # TODO aida design 结果图背景改为白色 + self.image, self.image_size = self.get_image(request_data.image_url) + # TODO image 填充并resize成512*768 + self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} @@ -46,63 +50,56 @@ class GenerateProductImage: self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): - # Get data of an object. - # Read data from response. - # read image use cv2 - try: - response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_file = BytesIO(response.data) - image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - image = cv2.resize(image_rbg, (1024, 1024)) - except minio.error.S3Error: - image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - return image + response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + image_bytes = io.BytesIO(response.read()) + + # 转换为PIL图像对象 + image = Image.open(image_bytes) + target_height = 768 + target_width = 512 + + aspect_ratio = image.width / image.height + new_width = int(target_height * aspect_ratio) + + resized_image = image.resize((new_width, target_height)) + left = (target_width - resized_image.width) // 2 + top = (target_height - resized_image.height) // 2 + right = target_width - resized_image.width - left + bottom = target_height - resized_image.height - top + image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") + image_size = image.size + if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 创建白色背景 + background = Image.new("RGB", image.size, (255, 255, 255)) + # 将图片粘贴到白色背景上 + background.paste(image, mask=image.split()[3]) + image = np.array(background) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # image_file = BytesIO(response.data) + # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + # image = cv2.resize(image_rbg, (1024, 1024)) + return image, image_size def callback(self, result, error): if error: - self.generate_data['status'] = "FAILURE" - self.generate_data['message'] = str(error) - # self.generate_data['data'] = str(error) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(error) + # self.gen_product_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_image") - image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) - is_smudge = True - if self.category == "sketch": - # 色阶调整 - cutoff = 1 - levels_img = autoLevels(image_result, cutoff) - # 亮度调整 - luminance = luminance_adjust(0.3, levels_img) - # 去背景 - remove_bg_image = remove_background(luminance) - # 人脸检测 - if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: - is_smudge = False - else: - # 污点/ - is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) - # 类型识别 - category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) - self.generate_data['category'] = str(category) - image_result = not_smudge_image - if is_smudge: # 无污点 - # image_result = adjust_contrast(image_result) - image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - # logger.info(f"upload image SUCCESS : {image_url}") - self.generate_data['status'] = "SUCCESS" - self.generate_data['message'] = "success" - self.generate_data['image_url'] = str(image_url) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) - else: # 有污点 保存图片到本地 测试用 - self.generate_data['status'] = "SUCCESS" - self.generate_data['message'] = "success" - self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) - # logger.info(f"stain_detection result : {self.generate_data}") + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.gen_product_data['status'] = "SUCCESS" + self.gen_product_data['message'] = "success" + self.gen_product_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) @@ -110,46 +107,43 @@ class GenerateProductImage: def infer(self, inputs): return self.grpc_client.async_infer( - model_name=GI_MODEL_NAME, + model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback ) def get_result(self): try: - # prompts = [self.prompt] * self.batch_size - # modes = [self.mode] * self.batch_size - # images = [self.image.astype(np.float16)] * self.batch_size - # - # text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - # mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) - # image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) - # - # input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - # input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") - # input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) - # - # input_text.set_data_from_numpy(text_obj) - # input_image.set_data_from_numpy(image_obj) - # input_mode.set_data_from_numpy(mode_obj) - # - # inputs = [input_text, input_image, input_mode] - # ctx = self.infer(inputs) - # time_out = 600 - # generate_data = None - # while time_out > 0: - # generate_data, _ = self.read_tasks_status() - # # logger.info(generate_data) - # if generate_data['status'] in ["REVOKED", "FAILURE"]: - # ctx.cancel() - # break - # elif generate_data['status'] == "SUCCESS": - # break - # time_out -= 1 - # time.sleep(0.1) - # # logger.info(time_out, generate_data) - generate_data, _ = self.read_tasks_status() - return generate_data + prompts = [self.prompt] * self.batch_size + self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + self.image = cv2.resize(self.image, (512, 768)) + images = [self.image.astype(np.uint8)] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape(1) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + inputs = [input_text, input_image] + + ctx = self.infer(inputs) + time_out = 600 + while time_out > 0: + gen_product_data, _ = self.read_tasks_status() + # logger.info(gen_product_data) + if gen_product_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif gen_product_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + # logger.info(time_out, gen_product_data) + gen_product_data, _ = self.read_tasks_status() + return gen_product_data except Exception as e: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(e) @@ -157,25 +151,25 @@ class GenerateProductImage: raise Exception(str(e)) finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() - # if DEBUG is False: - # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) - self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_gen_product_data) + # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} - generate_data = json.dumps(data) - redis_client.set(tasks_id, generate_data) + gen_product_data = json.dumps(data) + redis_client.set(tasks_id, gen_product_data) return data if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", - prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', - image_url="", + prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png", ) - server = GenerateImage(rd) + server = GenerateProductImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py index ea28099..f3d1719 100644 --- a/app/service/generate_image/service_generate_single_logo.py +++ b/app/service/generate_image/service_generate_single_logo.py @@ -21,7 +21,7 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * import tritonclient.grpc as grpcclient from app.schemas.generate_image import GenerateSingleLogoImageModel -from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_single_logo +from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image logger = logging.getLogger() @@ -67,7 +67,7 @@ class GenerateSingleLogoImage: else: image = result.as_numpy("generated_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_single_logo(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") self.gen_single_logo_data['status'] = "SUCCESS" self.gen_single_logo_data['message'] = "success" self.gen_single_logo_data['image_url'] = str(image_url) @@ -131,7 +131,7 @@ if __name__ == '__main__': rd = GenerateSingleLogoImageModel( tasks_id="123-89", prompt='an apple', - seed="1", + seed="2", ) server = GenerateSingleLogoImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/utils/upload_sd_image.py b/app/service/generate_image/utils/upload_sd_image.py index 7cb7f3e..2773aa2 100644 --- a/app/service/generate_image/utils/upload_sd_image.py +++ b/app/service/generate_image/utils/upload_sd_image.py @@ -34,7 +34,7 @@ s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S # except Exception as e: # print(f'上传到 S3 失败: {e}') -def upload_single_logo(image, user_id, category, object_name): +def upload_SDXL_image(image, user_id, category, object_name): try: image_data = io.BytesIO() image.save(image_data, format='PNG')