feat(新功能): batch generate product 入参回参修改
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -43,13 +45,18 @@ class GenerateRelightImageModel(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class BatchGenerateProductImageModel(BaseModel):
|
class ProductItemModel(BaseModel):
|
||||||
tasks_id: str
|
tasks_id: str
|
||||||
|
image_strength: float
|
||||||
prompt: str
|
prompt: str
|
||||||
image_url: str
|
image_url: str
|
||||||
image_strength: float
|
|
||||||
product_type: str
|
product_type: str
|
||||||
batch_size: int
|
|
||||||
|
|
||||||
|
class BatchGenerateProductImageModel(BaseModel):
|
||||||
|
batch_tasks_id: str
|
||||||
|
user_id: str
|
||||||
|
batch_data_list: List[ProductItemModel]
|
||||||
|
|
||||||
|
|
||||||
class BatchGenerateRelightImageModel(BaseModel):
|
class BatchGenerateRelightImageModel(BaseModel):
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from celery import Celery
|
|||||||
from tritonclient.utils import np_to_triton_dtype
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import BatchGenerateProductImageModel
|
from app.schemas.generate_image import BatchGenerateProductImageModel, ProductItemModel
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||||
from app.service.utils.oss_client import oss_get_image
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
@@ -35,24 +35,26 @@ category = "product_image"
|
|||||||
|
|
||||||
@celery_app.task
|
@celery_app.task
|
||||||
def batch_generate_product(batch_request_data):
|
def batch_generate_product(batch_request_data):
|
||||||
|
batch_size = len(batch_request_data['batch_data_list'])
|
||||||
logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}")
|
logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}")
|
||||||
tasks_id = batch_request_data['tasks_id']
|
batch_tasks_id = batch_request_data['batch_tasks_id']
|
||||||
user_id = tasks_id.rsplit('-', 1)[1]
|
user_id = batch_request_data['user_id']
|
||||||
batch_size = batch_request_data['batch_size']
|
result_data_list = []
|
||||||
image = pre_processing_image(batch_request_data['image_url'])
|
|
||||||
|
for i, data in enumerate(batch_request_data['batch_data_list']):
|
||||||
|
tasks_id = data['tasks_id']
|
||||||
|
image = pre_processing_image(data['image_url'])
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||||
images = [image.astype(np.uint8)] * 1
|
images = [image.astype(np.uint8)] * 1
|
||||||
|
prompts = [data['prompt']] * 1
|
||||||
prompts = [batch_request_data['prompt']] * 1
|
if data['product_type'] == "single":
|
||||||
|
|
||||||
if batch_request_data['product_type'] == "single":
|
|
||||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||||
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||||
else:
|
else:
|
||||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||||
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((1))
|
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((1))
|
||||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||||
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||||
@@ -63,10 +65,8 @@ def batch_generate_product(batch_request_data):
|
|||||||
|
|
||||||
inputs = [input_text, input_image, input_image_strength]
|
inputs = [input_text, input_image, input_image_strength]
|
||||||
|
|
||||||
image_url_list = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
try:
|
try:
|
||||||
if batch_request_data['product_type'] == "single":
|
if data['product_type'] == "single":
|
||||||
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
|
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
|
||||||
image = result.as_numpy("generated_cnet_image")
|
image = result.as_numpy("generated_cnet_image")
|
||||||
else:
|
else:
|
||||||
@@ -77,7 +77,7 @@ def batch_generate_product(batch_request_data):
|
|||||||
if 'mask_list' in str(e):
|
if 'mask_list' in str(e):
|
||||||
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||||
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||||
e_image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
e_image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||||
|
|
||||||
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
|
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
|
||||||
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
|
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
|
||||||
@@ -96,18 +96,29 @@ def batch_generate_product(batch_request_data):
|
|||||||
|
|
||||||
if isinstance(image_result, Image.Image):
|
if isinstance(image_result, Image.Image):
|
||||||
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
|
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
|
||||||
image_url_list.append(image_url)
|
data['product_img'] = image_url
|
||||||
|
result_data_list.append(data)
|
||||||
else:
|
else:
|
||||||
image_url = image_result
|
image_url = image_result
|
||||||
if DEBUG is False:
|
data['product_img'] = image_url
|
||||||
if i + 1 < batch_size:
|
result_data_list.append(data)
|
||||||
|
|
||||||
|
# 发送每条结果
|
||||||
|
if DEBUG:
|
||||||
|
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||||
|
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||||
|
else:
|
||||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
|
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
|
||||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
||||||
# print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
|
||||||
|
# 任务完成,发送所有数据结果
|
||||||
|
if DEBUG:
|
||||||
|
print(result_data_list)
|
||||||
|
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||||
|
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||||
else:
|
else:
|
||||||
publish_status(tasks_id, f"OK", image_url_list)
|
publish_status(batch_tasks_id, f"OK", result_data_list)
|
||||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}")
|
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||||
# print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}")
|
|
||||||
|
|
||||||
|
|
||||||
def pre_processing_image(image_url):
|
def pre_processing_image(image_url):
|
||||||
@@ -180,12 +191,52 @@ def publish_status(task_id, progress, result):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# rd = BatchGenerateProductImageModel(
|
||||||
|
# tasks_id="123-15-51-89",
|
||||||
|
# image_strength=0.7,
|
||||||
|
# prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||||
|
# image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||||
|
# product_type="overall",
|
||||||
|
# batch_size=20
|
||||||
|
# )
|
||||||
|
# batch_generate_product(rd.dict())
|
||||||
|
# rd = {
|
||||||
|
# "user_id": "89",
|
||||||
|
# "batch_data_list": [
|
||||||
|
# {
|
||||||
|
# "tasks_id": "A-123-15-51-89",
|
||||||
|
# "image_strength": 0.7,
|
||||||
|
# "prompt": " The best quality, ma123sterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||||
|
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||||
|
# "product_type": "overall",
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tasks_id": "B-123-15-51-89",
|
||||||
|
# "image_strength": 0.7,
|
||||||
|
# "prompt": " The best quality, masterpiece, real image.Outwear123,high quality clothing details,8K realistic,HDR",
|
||||||
|
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||||
|
# "product_type": "overall",
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
rd = BatchGenerateProductImageModel(
|
rd = BatchGenerateProductImageModel(
|
||||||
tasks_id="123-15-51-89",
|
batch_tasks_id="abcd",
|
||||||
|
user_id="89",
|
||||||
|
batch_data_list=[
|
||||||
|
ProductItemModel(
|
||||||
|
tasks_id="123-5464",
|
||||||
image_strength=0.7,
|
image_strength=0.7,
|
||||||
prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
|
||||||
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
|
||||||
product_type="overall",
|
product_type="overall",
|
||||||
batch_size=20
|
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||||
|
prompt="123"
|
||||||
|
),
|
||||||
|
ProductItemModel(
|
||||||
|
tasks_id="123-5464123",
|
||||||
|
image_strength=0.7,
|
||||||
|
product_type="overall",
|
||||||
|
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||||
|
prompt="123"
|
||||||
|
)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
batch_generate_product(rd.dict())
|
batch_generate_product(rd.dict())
|
||||||
|
|||||||
Reference in New Issue
Block a user