diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 99d1836..a989f2e 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -1,3 +1,5 @@ +from typing import List + from pydantic import BaseModel @@ -43,13 +45,18 @@ class GenerateRelightImageModel(BaseModel): """ -class BatchGenerateProductImageModel(BaseModel): +class ProductItemModel(BaseModel): tasks_id: str + image_strength: float prompt: str image_url: str - image_strength: float product_type: str - batch_size: int + + +class BatchGenerateProductImageModel(BaseModel): + batch_tasks_id: str + user_id: str + batch_data_list: List[ProductItemModel] class BatchGenerateRelightImageModel(BaseModel): diff --git a/app/service/generate_batch_image/service_batch_generate_product_image.py b/app/service/generate_batch_image/service_batch_generate_product_image.py index f09fbd5..46a5695 100644 --- a/app/service/generate_batch_image/service_batch_generate_product_image.py +++ b/app/service/generate_batch_image/service_batch_generate_product_image.py @@ -19,7 +19,7 @@ from celery import Celery from tritonclient.utils import np_to_triton_dtype from app.core.config import * -from app.schemas.generate_image import BatchGenerateProductImageModel +from app.schemas.generate_image import BatchGenerateProductImageModel, ProductItemModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image from app.service.utils.oss_client import oss_get_image @@ -35,38 +35,38 @@ category = "product_image" @celery_app.task def batch_generate_product(batch_request_data): + batch_size = len(batch_request_data['batch_data_list']) logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}") - tasks_id = batch_request_data['tasks_id'] - user_id = tasks_id.rsplit('-', 1)[1] - batch_size = batch_request_data['batch_size'] - image = pre_processing_image(batch_request_data['image_url']) - image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) - images = [image.astype(np.uint8)] * 1 + batch_tasks_id = batch_request_data['batch_tasks_id'] + user_id = batch_request_data['user_id'] + result_data_list = [] - prompts = [batch_request_data['prompt']] * 1 + for i, data in enumerate(batch_request_data['batch_data_list']): + tasks_id = data['tasks_id'] + image = pre_processing_image(data['image_url']) + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + images = [image.astype(np.uint8)] * 1 + prompts = [data['prompt']] * 1 + if data['product_type'] == "single": + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1)) + else: + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((1)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) - if batch_request_data['product_type'] == "single": - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) - image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1)) - else: - text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) - image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((1)) - input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") - input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_image_strength.set_data_from_numpy(image_strength_obj) - input_text.set_data_from_numpy(text_obj) - input_image.set_data_from_numpy(image_obj) - input_image_strength.set_data_from_numpy(image_strength_obj) + inputs = [input_text, input_image, input_image_strength] - inputs = [input_text, input_image, input_image_strength] - - image_url_list = [] - for i in range(batch_size): try: - if batch_request_data['product_type'] == "single": + if data['product_type'] == "single": result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100) image = result.as_numpy("generated_cnet_image") else: @@ -77,7 +77,7 @@ def batch_generate_product(batch_request_data): if 'mask_list' in str(e): e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) - e_image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1)) + e_image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1)) e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype)) e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8") @@ -96,18 +96,29 @@ def batch_generate_product(batch_request_data): if isinstance(image_result, Image.Image): image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png") - image_url_list.append(image_url) + data['product_img'] = image_url + result_data_list.append(data) else: image_url = image_result - if DEBUG is False: - if i + 1 < batch_size: - publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) - logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") - # print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") - else: - publish_status(tasks_id, f"OK", image_url_list) - logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}") - # print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}") + data['product_img'] = image_url + result_data_list.append(data) + + # 发送每条结果 + if DEBUG: + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") + print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}") + else: + publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url) + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}") + + # 任务完成,发送所有数据结果 + if DEBUG: + print(result_data_list) + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") + print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") + else: + publish_status(batch_tasks_id, f"OK", result_data_list) + logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}") def pre_processing_image(image_url): @@ -180,12 +191,52 @@ def publish_status(task_id, progress, result): if __name__ == '__main__': + # rd = BatchGenerateProductImageModel( + # tasks_id="123-15-51-89", + # image_strength=0.7, + # prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", + # image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + # product_type="overall", + # batch_size=20 + # ) + # batch_generate_product(rd.dict()) + # rd = { + # "user_id": "89", + # "batch_data_list": [ + # { + # "tasks_id": "A-123-15-51-89", + # "image_strength": 0.7, + # "prompt": " The best quality, ma123sterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", + # "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + # "product_type": "overall", + # }, + # { + # "tasks_id": "B-123-15-51-89", + # "image_strength": 0.7, + # "prompt": " The best quality, masterpiece, real image.Outwear123,high quality clothing details,8K realistic,HDR", + # "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + # "product_type": "overall", + # } + # ] + # } rd = BatchGenerateProductImageModel( - tasks_id="123-15-51-89", - image_strength=0.7, - prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", - image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", - product_type="overall", - batch_size=20 + batch_tasks_id="abcd", + user_id="89", + batch_data_list=[ + ProductItemModel( + tasks_id="123-5464", + image_strength=0.7, + product_type="overall", + image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + prompt="123" + ), + ProductItemModel( + tasks_id="123-5464123", + image_strength=0.7, + product_type="overall", + image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", + prompt="123" + ) + ] ) batch_generate_product(rd.dict())