From 9d0689d98ea204c6a83bbd7f8f51d94aaa0bf598 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 5 Jul 2024 15:45:48 +0800 Subject: [PATCH] =?UTF-8?q?feat=20fix=20=20relight=20=E6=96=B0=E5=A2=9Esin?= =?UTF-8?q?gle=20item=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 5 ++- app/core/config.py | 3 +- app/schemas/generate_image.py | 1 + .../service_generate_relight_image.py | 33 ++++++++++++++----- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 95d8c50..3dee667 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -155,13 +155,16 @@ def generate_relight_image(request_item: GenerateRelightImageModel, background_t - **prompt**: 想要生成图片的描述词 - **image_url**: 被生成图片的S3或minio url地址 - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light + - **product_type**: 输入single item 还是 overall item + 示例参数: { "tasks_id": "123-89", "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", - "direction": "Right Light" + "direction": "Right Light", + "product_type": "overall" } """ try: diff --git a/app/core/config.py b/app/core/config.py index 4caaf13..a01a2c0 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -127,7 +127,8 @@ GPI_MODEL_URL = '10.1.1.240:10041' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") -GRI_MODEL_NAME = 'diffusion_relight_ensemble' +GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble' +GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 7e7beb5..3dd7cf8 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -29,3 +29,4 @@ class GenerateRelightImageModel(BaseModel): prompt: str image_url: str direction: str + product_type: str diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index 6f51435..e0729ba 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -38,6 +38,7 @@ class GenerateRelightImage: self.batch_size = 1 self.prompt = request_data.prompt self.seed = "1" + self.product_type = request_data.product_type self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' self.direction = request_data.direction self.image_url = request_data.image_url @@ -55,7 +56,11 @@ class GenerateRelightImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") + if self.product_type == 'single': + image = result.as_numpy("generated_relight_image") + else: + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -78,11 +83,18 @@ class GenerateRelightImage: nagetive_prompts = [self.negative_prompt] * self.batch_size directions = [self.direction] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) - seed_obj = np.array(seeds, dtype="object").reshape((1)) - direction_obj = np.array(directions, dtype="object").reshape((1)) + if self.product_type == 'single': + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) + seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) + direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") @@ -97,8 +109,11 @@ class GenerateRelightImage: input_direction.set_data_from_numpy(direction_obj) inputs = [input_text, input_natext, input_image, input_seed, input_direction] + if self.product_type == 'single': + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) - ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() @@ -136,7 +151,9 @@ if __name__ == '__main__': tasks_id="123-89", # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", prompt="Colorful black", - image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png' + image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png', + direction="Right Light", + product_type="single" ) server = GenerateRelightImage(rd) print(server.get_result())