feat(新功能): batch generate product 入参回参修改

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-06-04 15:55:55 +08:00
parent f234ae29ff
commit 90f9879edb
2 changed files with 105 additions and 47 deletions

View File

@@ -1,3 +1,5 @@
from typing import List
from pydantic import BaseModel from pydantic import BaseModel
@@ -43,13 +45,18 @@ class GenerateRelightImageModel(BaseModel):
""" """
class BatchGenerateProductImageModel(BaseModel): class ProductItemModel(BaseModel):
tasks_id: str tasks_id: str
image_strength: float
prompt: str prompt: str
image_url: str image_url: str
image_strength: float
product_type: str product_type: str
batch_size: int
class BatchGenerateProductImageModel(BaseModel):
batch_tasks_id: str
user_id: str
batch_data_list: List[ProductItemModel]
class BatchGenerateRelightImageModel(BaseModel): class BatchGenerateRelightImageModel(BaseModel):

View File

@@ -19,7 +19,7 @@ from celery import Celery
from tritonclient.utils import np_to_triton_dtype from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import BatchGenerateProductImageModel from app.schemas.generate_image import BatchGenerateProductImageModel, ProductItemModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image from app.service.utils.oss_client import oss_get_image
@@ -35,38 +35,38 @@ category = "product_image"
@celery_app.task @celery_app.task
def batch_generate_product(batch_request_data): def batch_generate_product(batch_request_data):
batch_size = len(batch_request_data['batch_data_list'])
logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}") logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}")
tasks_id = batch_request_data['tasks_id'] batch_tasks_id = batch_request_data['batch_tasks_id']
user_id = tasks_id.rsplit('-', 1)[1] user_id = batch_request_data['user_id']
batch_size = batch_request_data['batch_size'] result_data_list = []
image = pre_processing_image(batch_request_data['image_url'])
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
images = [image.astype(np.uint8)] * 1
prompts = [batch_request_data['prompt']] * 1 for i, data in enumerate(batch_request_data['batch_data_list']):
tasks_id = data['tasks_id']
image = pre_processing_image(data['image_url'])
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
images = [image.astype(np.uint8)] * 1
prompts = [data['prompt']] * 1
if data['product_type'] == "single":
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
if batch_request_data['product_type'] == "single": input_text.set_data_from_numpy(text_obj)
text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) input_image.set_data_from_numpy(image_obj)
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) input_image_strength.set_data_from_numpy(image_strength_obj)
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
input_text.set_data_from_numpy(text_obj) inputs = [input_text, input_image, input_image_strength]
input_image.set_data_from_numpy(image_obj)
input_image_strength.set_data_from_numpy(image_strength_obj)
inputs = [input_text, input_image, input_image_strength]
image_url_list = []
for i in range(batch_size):
try: try:
if batch_request_data['product_type'] == "single": if data['product_type'] == "single":
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100) result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
image = result.as_numpy("generated_cnet_image") image = result.as_numpy("generated_cnet_image")
else: else:
@@ -77,7 +77,7 @@ def batch_generate_product(batch_request_data):
if 'mask_list' in str(e): if 'mask_list' in str(e):
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
e_image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1)) e_image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype)) e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8") e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
@@ -96,18 +96,29 @@ def batch_generate_product(batch_request_data):
if isinstance(image_result, Image.Image): if isinstance(image_result, Image.Image):
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png") image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
image_url_list.append(image_url) data['product_img'] = image_url
result_data_list.append(data)
else: else:
image_url = image_result image_url = image_result
if DEBUG is False: data['product_img'] = image_url
if i + 1 < batch_size: result_data_list.append(data)
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}") # 发送每条结果
# print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}") if DEBUG:
else: logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
publish_status(tasks_id, f"OK", image_url_list) print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progressOK | image_url{image_url}") else:
# print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progressOK | image_url{image_url}") publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}")
# 任务完成,发送所有数据结果
if DEBUG:
print(result_data_list)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
else:
publish_status(batch_tasks_id, f"OK", result_data_list)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
def pre_processing_image(image_url): def pre_processing_image(image_url):
@@ -180,12 +191,52 @@ def publish_status(task_id, progress, result):
if __name__ == '__main__': if __name__ == '__main__':
# rd = BatchGenerateProductImageModel(
# tasks_id="123-15-51-89",
# image_strength=0.7,
# prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
# image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# product_type="overall",
# batch_size=20
# )
# batch_generate_product(rd.dict())
# rd = {
# "user_id": "89",
# "batch_data_list": [
# {
# "tasks_id": "A-123-15-51-89",
# "image_strength": 0.7,
# "prompt": " The best quality, ma123sterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# "product_type": "overall",
# },
# {
# "tasks_id": "B-123-15-51-89",
# "image_strength": 0.7,
# "prompt": " The best quality, masterpiece, real image.Outwear123,high quality clothing details,8K realistic,HDR",
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# "product_type": "overall",
# }
# ]
# }
rd = BatchGenerateProductImageModel( rd = BatchGenerateProductImageModel(
tasks_id="123-15-51-89", batch_tasks_id="abcd",
image_strength=0.7, user_id="89",
prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", batch_data_list=[
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", ProductItemModel(
product_type="overall", tasks_id="123-5464",
batch_size=20 image_strength=0.7,
product_type="overall",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
prompt="123"
),
ProductItemModel(
tasks_id="123-5464123",
image_strength=0.7,
product_type="overall",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
prompt="123"
)
]
) )
batch_generate_product(rd.dict()) batch_generate_product(rd.dict())