From d04c3857fcaa8e672c7b290f288081fe9d895e64 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 19 Jun 2024 16:44:04 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=20=E4=BA=A7=E5=93=81=E5=9B=BE=E6=89=93?= =?UTF-8?q?=E5=85=89=E6=A8=A1=E5=9E=8B=E9=83=A8=E7=BD=B2=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- .../service_generate_product_image.py | 9 +- .../service_generate_relight_image.py | 111 ++++++------------ 3 files changed, 41 insertions(+), 81 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 0af065b..3932bf5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -124,7 +124,7 @@ GPI_MODEL_URL = '10.1.1.240:10061' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") -GRI_MODEL_NAME = 'stable_diffusion_1_5' +GRI_MODEL_NAME = 'diffusion_relight_ensemble' GRI_MODEL_URL = '10.1.1.150:8001' # SEG service config diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index ce449ea..2416d2c 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -20,7 +20,7 @@ from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * -from app.schemas.generate_image import GenerateImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image logger = logging.getLogger() @@ -166,10 +166,11 @@ def infer_cancel(tasks_id): if __name__ == '__main__': - rd = GenerateImageModel( + rd = GenerateProductImageModel( tasks_id="123-89", - prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", - image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png", + prompt="", + # prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", ) server = GenerateProductImage(rd) print(server.get_result()) diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 0eacec9..7c7f4b1 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -38,9 +38,10 @@ class GenerateRelightImage: self.batch_size = 1 self.prompt = request_data.prompt self.seed = "12345" + self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' + self.direction = "Right Light" # TODO aida design 结果图背景改为白色 - # self.image, self.image_size = self.get_image(request_data.image_url) - self.image = request_data.image_url + self.image = self.get_image(request_data.image_url) # TODO image 填充并resize成512*768 self.tasks_id = request_data.tasks_id @@ -51,37 +52,8 @@ class GenerateRelightImage: def get_image(self, image_url): response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) - image_bytes = io.BytesIO(response.read()) - - # 转换为PIL图像对象 - image = Image.open(image_bytes) - target_height = 768 - target_width = 512 - - aspect_ratio = image.width / image.height - new_width = int(target_height * aspect_ratio) - - resized_image = image.resize((new_width, target_height)) - left = (target_width - resized_image.width) // 2 - top = (target_height - resized_image.height) // 2 - right = target_width - resized_image.width - left - bottom = target_height - resized_image.height - top - image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") - image_size = image.size - if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): - # 创建白色背景 - background = Image.new("RGB", image.size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(image, mask=image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - # image_file = BytesIO(response.data) - # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) - # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) - # image = cv2.resize(image_rbg, (1024, 1024)) - return image, image_size + image = cv2.imdecode(np.frombuffer(response.data, np.uint8), 1) + return image def callback(self, result, error): if error: @@ -92,7 +64,7 @@ class GenerateRelightImage: else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") - image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") @@ -114,47 +86,33 @@ class GenerateRelightImage: def get_result(self): try: - direction = "Right Light" - negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere' prompts = [self.prompt] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - input_text = grpcclient.InferInput( - "prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype) - ) + image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * self.batch_size + seeds = [self.seed] * self.batch_size + nagetive_prompts = [self.negative_prompt] * self.batch_size + directions = [self.direction] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) + input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) + input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_natext.set_data_from_numpy(na_text_obj) + input_seed.set_data_from_numpy(seed_obj) + input_direction.set_data_from_numpy(direction_obj) - negative_prompts = [negative_prompt] * self.batch_size - text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1)) - input_text_neg = grpcclient.InferInput( - "negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype) - ) - input_text_neg.set_data_from_numpy(text_obj_neg) - - seed = np.array(self.seed, dtype="object").reshape((-1, 1)) - input_seed = grpcclient.InferInput( - "seed", seed.shape, np_to_triton_dtype(seed.dtype) - ) - input_seed.set_data_from_numpy(seed) - - input_images = [self.image] * self.batch_size - text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1)) - input_input_images = grpcclient.InferInput( - "input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype) - ) - input_input_images.set_data_from_numpy(text_obj_images) - - directions = [direction] * self.batch_size - text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1)) - input_directions = grpcclient.InferInput( - "direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype) - ) - input_directions.set_data_from_numpy(text_obj_directions) - - output_img = grpcclient.InferRequestedOutput("generated_image") - request_start = time.time() - - inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions] + inputs = [input_text, input_natext, input_image, input_seed, input_direction] ctx = self.infer(inputs) time_out = 600 @@ -179,9 +137,9 @@ class GenerateRelightImage: finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) + self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) - logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): @@ -195,8 +153,9 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateRelightImageModel( tasks_id="123-89", - prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - image_url="/workspace/i3.png", + # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + prompt="", + image_url='aida-users/89/product_image/123-89.png' ) server = GenerateRelightImage(rd) print(server.get_result())