feat(新功能): batch generate relight 入参回参修改

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-06-05 15:14:36 +08:00
parent 90f9879edb
commit 12bb128351
2 changed files with 96 additions and 61 deletions

View File

@@ -45,6 +45,7 @@ class GenerateRelightImageModel(BaseModel):
""" """
# product任务子项
class ProductItemModel(BaseModel): class ProductItemModel(BaseModel):
tasks_id: str tasks_id: str
image_strength: float image_strength: float
@@ -53,16 +54,24 @@ class ProductItemModel(BaseModel):
product_type: str product_type: str
# product批处理 集合
class BatchGenerateProductImageModel(BaseModel): class BatchGenerateProductImageModel(BaseModel):
batch_tasks_id: str batch_tasks_id: str
user_id: str user_id: str
batch_data_list: List[ProductItemModel] batch_data_list: List[ProductItemModel]
class BatchGenerateRelightImageModel(BaseModel): # relight任务子项
class RelightItemModel(BaseModel):
tasks_id: str tasks_id: str
prompt: str prompt: str
image_url: str image_url: str
direction: str direction: str
product_type: str product_type: str
batch_size: int
# relight批处理集合
class BatchGenerateRelightImageModel(BaseModel):
batch_tasks_id: str
user_id: str
batch_data_list: List[RelightItemModel]

View File

@@ -18,7 +18,7 @@ from celery import Celery
from tritonclient.utils import np_to_triton_dtype from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import BatchGenerateRelightImageModel from app.schemas.generate_image import BatchGenerateRelightImageModel, RelightItemModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image from app.service.utils.oss_client import oss_get_image
@@ -34,55 +34,58 @@ category = "relight_image"
@celery_app.task @celery_app.task
def batch_generate_relight(batch_request_data): def batch_generate_relight(batch_request_data):
batch_size = len(batch_request_data['batch_data_list'])
logger.info(f"batch_generate_relight batch_request_data: {json.dumps(batch_request_data, indent=4)}") logger.info(f"batch_generate_relight batch_request_data: {json.dumps(batch_request_data, indent=4)}")
batch_tasks_id = batch_request_data['batch_tasks_id']
user_id = batch_request_data['user_id']
result_data_list = []
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
direction = batch_request_data['direction']
seed = "1" seed = "1"
prompt = batch_request_data['prompt']
product_type = batch_request_data['product_type']
image_url = batch_request_data['image_url']
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2")
tasks_id = batch_request_data['tasks_id']
user_id = tasks_id.rsplit('-', 1)[1]
batch_size = batch_request_data['batch_size']
prompts = [prompt] * 1 for i, data in enumerate(batch_request_data['batch_data_list']):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) direction = data['direction']
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * 1
seeds = [seed] * 1
nagetive_prompts = [negative_prompt] * 1
directions = [direction] * 1
if product_type == 'single': prompt = data['prompt']
text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) product_type = data['product_type']
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) image_url = data['image_url']
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1)) image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2")
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1)) tasks_id = data['tasks_id']
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj) prompts = [prompt] * 1
input_image.set_data_from_numpy(image_obj) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_natext.set_data_from_numpy(na_text_obj) image = cv2.resize(image, (512, 768))
input_seed.set_data_from_numpy(seed_obj) images = [image.astype(np.uint8)] * 1
input_direction.set_data_from_numpy(direction_obj) seeds = [seed] * 1
nagetive_prompts = [negative_prompt] * 1
directions = [direction] * 1
inputs = [input_text, input_natext, input_image, input_seed, input_direction] if product_type == 'single':
image_url_list = [] text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
for i in range(batch_size): image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
try: try:
if batch_request_data['product_type'] == "single": if data['product_type'] == "single":
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100) result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
image = result.as_numpy("generated_relight_image") image = result.as_numpy("generated_relight_image")
else: else:
@@ -121,18 +124,29 @@ def batch_generate_relight(batch_request_data):
logger.error(e) logger.error(e)
if isinstance(image_result, Image.Image): if isinstance(image_result, Image.Image):
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png") image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
image_url_list.append(image_url) data['relight_img'] = image_url
result_data_list.append(data)
else: else:
image_url = image_result image_url = image_result
if DEBUG is False: data['relight_img'] = image_url
if i + 1 < batch_size: result_data_list.append(data)
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}") # 发送每条结果
# print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}") if DEBUG:
else: logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
publish_status(tasks_id, f"OK", image_url_list) print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progressOK | image_url{image_url}") else:
# print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progressOK | image_url{image_url}") publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}")
# 任务完成,发送所有数据结果
if DEBUG:
print(result_data_list)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
else:
publish_status(batch_tasks_id, f"OK", result_data_list)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
def publish_status(task_id, progress, result): def publish_status(task_id, progress, result):
@@ -151,12 +165,24 @@ def publish_status(task_id, progress, result):
if __name__ == '__main__': if __name__ == '__main__':
rd = BatchGenerateRelightImageModel( rd = BatchGenerateRelightImageModel(
tasks_id="123-89", batch_tasks_id="abcd",
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", user_id="89",
prompt="Colorful black", batch_data_list=[
image_url='aida-users/89/clothing_seg/283c5c82-1a92-11f0-b72a-0242ac150002.png', RelightItemModel(
direction="Right Light", tasks_id="123-5464",
product_type="overall", product_type="overall",
batch_size=10 image_url="aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
prompt="Colorful black",
direction="Right Light",
),
RelightItemModel(
tasks_id="123-5464123",
product_type="overall",
image_url="aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
direction="Right Light",
prompt="Colorful black",
)
]
) )
batch_generate_relight(rd.dict()) batch_generate_relight(rd.dict())