From 12bb12835126e89e5238e7b6530e5be7e3504a23 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Thu, 5 Jun 2025 15:14:36 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20batch=20generate=20relight=20=E5=85=A5=E5=8F=82?= =?UTF-8?q?=E5=9B=9E=E5=8F=82=E4=BF=AE=E6=94=B9=20fix=EF=BC=88=E4=BF=AE?= =?UTF-8?q?=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D?= =?UTF-8?q?=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/schemas/generate_image.py | 13 +- .../service_batch_generate_relight_image.py | 144 +++++++++++------- 2 files changed, 96 insertions(+), 61 deletions(-) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index a989f2e..7d1d864 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -45,6 +45,7 @@ class GenerateRelightImageModel(BaseModel): """ +# product任务子项 class ProductItemModel(BaseModel): tasks_id: str image_strength: float @@ -53,16 +54,24 @@ class ProductItemModel(BaseModel): product_type: str +# product批处理 集合 class BatchGenerateProductImageModel(BaseModel): batch_tasks_id: str user_id: str batch_data_list: List[ProductItemModel] -class BatchGenerateRelightImageModel(BaseModel): +# relight任务子项 +class RelightItemModel(BaseModel): tasks_id: str prompt: str image_url: str direction: str product_type: str - batch_size: int + + +# relight批处理集合 +class BatchGenerateRelightImageModel(BaseModel): + batch_tasks_id: str + user_id: str + batch_data_list: List[RelightItemModel] diff --git a/app/service/generate_batch_image/service_batch_generate_relight_image.py b/app/service/generate_batch_image/service_batch_generate_relight_image.py index 83a5701..d75b0a7 100644 --- a/app/service/generate_batch_image/service_batch_generate_relight_image.py +++ b/app/service/generate_batch_image/service_batch_generate_relight_image.py @@ -18,7 +18,7 @@ from celery import Celery from tritonclient.utils import np_to_triton_dtype from app.core.config import * -from app.schemas.generate_image import BatchGenerateRelightImageModel +from app.schemas.generate_image import BatchGenerateRelightImageModel, RelightItemModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image from app.service.utils.oss_client import oss_get_image @@ -34,55 +34,58 @@ category = "relight_image" @celery_app.task def batch_generate_relight(batch_request_data): + batch_size = len(batch_request_data['batch_data_list']) logger.info(f"batch_generate_relight batch_request_data: {json.dumps(batch_request_data, indent=4)}") + batch_tasks_id = batch_request_data['batch_tasks_id'] + user_id = batch_request_data['user_id'] + result_data_list = [] negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - direction = batch_request_data['direction'] seed = "1" - prompt = batch_request_data['prompt'] - product_type = batch_request_data['product_type'] - image_url = batch_request_data['image_url'] - image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2") - tasks_id = batch_request_data['tasks_id'] - user_id = tasks_id.rsplit('-', 1)[1] - batch_size = batch_request_data['batch_size'] - prompts = [prompt] * 1 - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - image = cv2.resize(image, (512, 768)) - images = [image.astype(np.uint8)] * 1 - seeds = [seed] * 1 - nagetive_prompts = [negative_prompt] * 1 - directions = [direction] * 1 + for i, data in enumerate(batch_request_data['batch_data_list']): + direction = data['direction'] - if product_type == 'single': - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) - na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) - seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) - direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) - else: - text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) - seed_obj = np.array(seeds, dtype="object").reshape((1)) - direction_obj = np.array(directions, dtype="object").reshape((1)) - input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") - input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) - input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) - input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + prompt = data['prompt'] + product_type = data['product_type'] + image_url = data['image_url'] + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2") + tasks_id = data['tasks_id'] - input_text.set_data_from_numpy(text_obj) - input_image.set_data_from_numpy(image_obj) - input_natext.set_data_from_numpy(na_text_obj) - input_seed.set_data_from_numpy(seed_obj) - input_direction.set_data_from_numpy(direction_obj) + prompts = [prompt] * 1 + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * 1 + seeds = [seed] * 1 + nagetive_prompts = [negative_prompt] * 1 + directions = [direction] * 1 - inputs = [input_text, input_natext, input_image, input_seed, input_direction] - image_url_list = [] - for i in range(batch_size): + if product_type == 'single': + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) + seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) + direction_obj = np.array(directions, dtype="object").reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) + seed_obj = np.array(seeds, dtype="object").reshape((1)) + direction_obj = np.array(directions, dtype="object").reshape((1)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) + input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) + input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_natext.set_data_from_numpy(na_text_obj) + input_seed.set_data_from_numpy(seed_obj) + input_direction.set_data_from_numpy(direction_obj) + + inputs = [input_text, input_natext, input_image, input_seed, input_direction] try: - if batch_request_data['product_type'] == "single": + if data['product_type'] == "single": result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100) image = result.as_numpy("generated_relight_image") else: @@ -121,18 +124,29 @@ def batch_generate_relight(batch_request_data): logger.error(e) if isinstance(image_result, Image.Image): image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png") - image_url_list.append(image_url) + data['relight_img'] = image_url + + result_data_list.append(data) else: image_url = image_result - if DEBUG is False: - if i + 1 < batch_size: - publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) - logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") - # print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") - else: - publish_status(tasks_id, f"OK", image_url_list) - logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}") - # print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}") + data['relight_img'] = image_url + result_data_list.append(data) + + # 发送每条结果 + if DEBUG: + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") + print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") + else: + publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") + # 任务完成,发送所有数据结果 + if DEBUG: + print(result_data_list) + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") + print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") + else: + publish_status(batch_tasks_id, f"OK", result_data_list) + logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") def publish_status(task_id, progress, result): @@ -151,12 +165,24 @@ def publish_status(task_id, progress, result): if __name__ == '__main__': rd = BatchGenerateRelightImageModel( - tasks_id="123-89", - # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - prompt="Colorful black", - image_url='aida-users/89/clothing_seg/283c5c82-1a92-11f0-b72a-0242ac150002.png', - direction="Right Light", - product_type="overall", - batch_size=10 + batch_tasks_id="abcd", + user_id="89", + batch_data_list=[ + RelightItemModel( + tasks_id="123-5464", + product_type="overall", + image_url="aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png", + prompt="Colorful black", + direction="Right Light", + ), + RelightItemModel( + tasks_id="123-5464123", + product_type="overall", + image_url="aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png", + direction="Right Light", + prompt="Colorful black", + ) + ] ) + batch_generate_relight(rd.dict())