feat(新功能): generate product relight pose_transform 开发,设置batch generate 的优先级为100 ,single generate 的优先级为1

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-04-21 10:04:40 +08:00
parent 88c9d6ef93
commit 4e55275e6e
12 changed files with 712 additions and 7 deletions

View File

@@ -3,8 +3,10 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel
from app.schemas.pose_transform import BatchPoseTransformModel
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
@@ -228,3 +230,21 @@ def generate_relight_image(tasks_id: str):
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}") logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data']) return ResponseModel(data=data['data'])
"""batch generate img"""
@router.post("/batch_generate_product_image")
async def design(request_batch_item: BatchGenerateProductImageModel):
return await start_product_batch_generate(request_batch_item)
@router.post("/batch_generate_relight_image")
async def design(request_batch_item: BatchGenerateRelightImageModel):
return await start_relight_batch_generate(request_batch_item)
@router.post("/batch_generate_pose_transform_image")
async def design(request_batch_item: BatchPoseTransformModel):
return await start_pose_transform_batch_generate(request_batch_item)

View File

@@ -135,12 +135,14 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f
# Generate Product service config 旧版product img 模型 # Generate Product service config 旧版product img 模型
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
BATCH_GPI_RABBITMQ_QUEUES = os.getenv("BATCH_GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"BatchToProductImage{RABBITMQ_ENV}")
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.243:10051' GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Single Logo service config # Generate Single Logo service config
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
BATCH_GRI_RABBITMQ_QUEUES = os.getenv("BATCH_GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"BatchRelight{RABBITMQ_ENV}")
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble' GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051' GRI_MODEL_URL = '10.1.1.240:10051'
@@ -148,6 +150,7 @@ GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config # Pose Transform service config
PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}") PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}")
BATCH_PS_RABBITMQ_QUEUES = os.getenv("BATCH_PS_RABBITMQ_QUEUES", f"BatchPoseTransform{RABBITMQ_ENV}")
PT_MODEL_URL = '10.1.1.243:10061' PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config # SEG service config

View File

@@ -36,3 +36,26 @@ class GenerateRelightImageModel(BaseModel):
image_url: str image_url: str
direction: str direction: str
product_type: str product_type: str
"""
batch generate image
"""
class BatchGenerateProductImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
image_strength: float
product_type: str
batch_size: int
class BatchGenerateRelightImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
direction: str
product_type: str
batch_size: int

View File

@@ -5,3 +5,10 @@ class PoseTransformModel(BaseModel):
image_url: str image_url: str
tasks_id: str tasks_id: str
pose_id: str pose_id: str
class BatchPoseTransformModel(BaseModel):
image_url: str
tasks_id: str
pose_id: str
batch_size: int

View File

@@ -0,0 +1,24 @@
from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product, publish_status as product_publish_status
from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight, publish_status as relight_publish_status
from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform, publish_status as pose_transform_publish_status
async def start_product_batch_generate(data):
generate_clothes_task = batch_generate_product.delay(data.dict())
print(generate_clothes_task)
product_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}
async def start_relight_batch_generate(data):
generate_clothes_task = batch_generate_relight.delay(data.dict())
print(generate_clothes_task)
relight_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}
async def start_pose_transform_batch_generate(data):
generate_clothes_task = batch_generate_pose_transform.delay(data.dict())
print(generate_clothes_task)
pose_transform_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}

View File

@@ -0,0 +1,191 @@
# 旧版product
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import BatchGenerateProductImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
celery_app = Celery('product_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_product'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logger = logging.getLogger()
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
category = "product_image"
@celery_app.task
def batch_generate_product(batch_request_data):
logger.info(f"batch_generate_product batch_request_data:{batch_request_data}")
tasks_id = batch_request_data['tasks_id']
user_id = tasks_id.rsplit('-', 1)[1]
batch_size = batch_request_data['batch_size']
image = pre_processing_image(batch_request_data['image_url'])
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
images = [image.astype(np.uint8)] * 1
prompts = [batch_request_data['prompt']] * 1
if batch_request_data['product_type'] == "single":
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_image_strength.set_data_from_numpy(image_strength_obj)
inputs = [input_text, input_image, input_image_strength]
image_url_list = []
for i in range(batch_size):
try:
if batch_request_data['product_type'] == "single":
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
image = result.as_numpy("generated_cnet_image")
else:
result = grpc_client.infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
except Exception as e:
if 'mask_list' in str(e):
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
e_image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
e_input_image_strength = grpcclient.InferInput("image_strength", e_image_strength_obj.shape, np_to_triton_dtype(e_image_strength_obj.dtype))
e_input_text.set_data_from_numpy(e_text_obj)
e_input_image.set_data_from_numpy(e_image_obj)
e_input_image_strength.set_data_from_numpy(e_image_strength_obj)
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=[e_input_text, e_input_image, e_input_image_strength], priority=100)
image = result.as_numpy("generated_cnet_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
else:
image_result = str(e)
logger.error(image_result)
if isinstance(image_result, Image.Image):
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
image_url_list.append(image_url)
else:
image_url = image_result
if DEBUG is False:
if i + 1 < batch_size:
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
logger.info(f" [x] {tasks_id}tasks_id *** progress{i + 1}/{batch_size} *** image_url{image_url}")
print(f" [x] {tasks_id}tasks_id *** progress{i + 1}/{batch_size} *** image_url{image_url}")
else:
publish_status(tasks_id, f"OK", image_url_list)
logger.info(f" [x] {tasks_id}tasks_id *** progressOK *** image_url{image_url_list}")
print(f" [x] {tasks_id}tasks_id *** progressOK *** image_url{image_url_list}")
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def post_processing_image(image, left, top):
resized_image = image.resize((int(image.width * (768 / image.height)), 768))
# 计算裁剪的坐标
left = (resized_image.width - 512) // 2
upper = 0
right = left + 512
lower = 768
# 进行裁剪
cropped_image = resized_image.crop((left, upper, right, lower))
return cropped_image
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_GPI_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_GPI_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
rd = BatchGenerateProductImageModel(
tasks_id="123-15-51-89",
image_strength=0.7,
prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
product_type="overall",
batch_size=20
)
batch_generate_product(rd.dict())

View File

@@ -0,0 +1,162 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import BatchGenerateRelightImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
celery_app = Celery('relight_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_relight'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
category = "relight_image"
@celery_app.task
def batch_generate_relight(batch_request_data):
logger.info(f"batch_generate_relight batch_request_data: {batch_request_data}")
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
direction = batch_request_data['direction']
seed = "1"
prompt = batch_request_data['prompt']
product_type = batch_request_data['product_type']
image_url = batch_request_data['image_url']
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2")
tasks_id = batch_request_data['tasks_id']
user_id = tasks_id.rsplit('-', 1)[1]
batch_size = batch_request_data['batch_size']
prompts = [prompt] * 1
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * 1
seeds = [seed] * 1
nagetive_prompts = [negative_prompt] * 1
directions = [direction] * 1
if product_type == 'single':
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
image_url_list = []
for i in range(batch_size):
try:
if batch_request_data['product_type'] == "single":
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
image = result.as_numpy("generated_relight_image")
else:
result = grpc_client.infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
except Exception as e:
print(e)
if 'mask_list' in str(e):
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
e_na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
e_seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
e_direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
e_input_natext = grpcclient.InferInput("negative_prompt", e_na_text_obj.shape, np_to_triton_dtype(e_na_text_obj.dtype))
e_input_seed = grpcclient.InferInput("seed", e_seed_obj.shape, np_to_triton_dtype(e_seed_obj.dtype))
e_input_direction = grpcclient.InferInput("direction", e_direction_obj.shape, np_to_triton_dtype(e_direction_obj.dtype))
e_input_text.set_data_from_numpy(e_text_obj)
e_input_image.set_data_from_numpy(e_image_obj)
e_input_natext.set_data_from_numpy(e_na_text_obj)
e_input_seed.set_data_from_numpy(e_seed_obj)
e_input_direction.set_data_from_numpy(e_direction_obj)
e_inputs = [e_input_text, e_input_natext, e_input_image, e_input_seed, e_input_direction]
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=e_inputs, priority=100)
image = result.as_numpy("generated_relight_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
else:
image_result = str(e)
logger.error(e)
if isinstance(image_result, Image.Image):
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
image_url_list.append(image_url)
else:
image_url = image_result
if DEBUG is False:
if i + 1 < batch_size:
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
logger.info(f" [x] {tasks_id}tasks_id *** progress{i + 1}/{batch_size} *** image_url{image_url}")
print(f" [x] {tasks_id}tasks_id *** progress{i + 1}/{batch_size} *** image_url{image_url}")
else:
publish_status(tasks_id, f"OK", image_url_list)
logger.info(f" [x] {tasks_id}tasks_id *** progressOK *** image_url{image_url_list}")
print(f" [x] {tasks_id}tasks_id *** progressOK *** image_url{image_url_list}")
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_GRI_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
rd = BatchGenerateRelightImageModel(
tasks_id="123-89",
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
prompt="Colorful black",
image_url='aida-users/89/clothing_seg/283c5c82-1a92-11f0-b72a-0242ac150002.png',
direction="Right Light",
product_type="overall",
batch_size=10
)
batch_generate_relight(rd.dict())

View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
from io import BytesIO
import imageio
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.pose_transform import BatchPoseTransformModel
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video
from app.service.utils.new_oss_client import oss_upload_image
from app.service.utils.oss_client import oss_get_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
logger = logging.getLogger()
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_post_transform'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
category = "pose_transform"
def upload_first_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
result_image = result_image.convert("RGB")
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
@celery_app.task
def batch_generate_pose_transform(batch_request_data):
logger.info(f"batch_generate_pose_transform batch_request_data: {batch_request_data}")
batch_size = batch_request_data['batch_size']
image_url = batch_request_data['image_url']
image = pre_processing_image(image_url)
pose_num = batch_request_data['pose_id']
tasks_id = batch_request_data['tasks_id']
user_id = tasks_id.rsplit('-', 1)[1]
pose_num = [pose_num] * 1
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype))
input_pose_num.set_data_from_numpy(pose_num_obj)
image_files = [image.astype(np.uint8)] * 1
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
result_url_list = []
for i in range(batch_size):
try:
result = grpc_client.infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], client_timeout=60000, priority=100)
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
# 第一帧图像
first_image = Image.fromarray(result_data[0])
first_image_url = upload_first_image(first_image, user_id=user_id, category=f"{category}_first_img", file_name=f"{tasks_id}_batch_{i}.png")
# 上传GIF
gif_buffer = BytesIO()
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
gif_buffer.seek(0)
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=user_id, category=f"{category}_gif", file_name=f"{tasks_id}_batch_{i}.gif")
# 上传video
video_url = upload_video(frames=result_data, user_id=user_id, category=f"{category}_video", file_name=f"{tasks_id}_batch_{i}.mp4")
data = {
"gif_url": gif_url,
"video_url": video_url,
"first_image_url": first_image_url,
}
except Exception as e:
print(e)
data = {}
result_url_list.append(data)
if DEBUG is False:
if i + 1 < batch_size:
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
logger.info(f" [x] {tasks_id}tasks_id *** progress{i + 1}/{batch_size} *** image_url{data}")
print(f" [x] {tasks_id}tasks_id *** progress{i + 1}/{batch_size} *** image_url{data}")
else:
publish_status(tasks_id, f"OK", result_url_list)
logger.info(f" [x] {tasks_id}tasks_id *** progressOK *** image_url{result_url_list}")
print(f" [x] {tasks_id}tasks_id *** progressOK *** image_url{result_url_list}")
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_GRI_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
rd = BatchPoseTransformModel(
tasks_id="123-89",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
pose_id="1",
batch_size=10
)
batch_generate_pose_transform(rd.dict())

View File

@@ -153,9 +153,9 @@ class GenerateImage:
inputs = [input_text, input_image, input_mode] inputs = [input_text, input_image, input_mode]
if self.version == "fast": if self.version == "fast":
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
else: else:
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
time_out = 600 time_out = 600
generate_data = None generate_data = None

View File

@@ -295,9 +295,9 @@ class GenerateProductImage:
inputs = [input_text, input_image, input_image_strength] inputs = [input_text, input_image, input_image_strength]
if self.product_type == "single": if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
else: else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
time_out = 600 time_out = 600
while time_out > 0: while time_out > 0:

View File

@@ -114,9 +114,9 @@ class GenerateRelightImage:
inputs = [input_text, input_natext, input_image, input_seed, input_direction] inputs = [input_text, input_natext, input_image, input_seed, input_direction]
if self.product_type == 'single': if self.product_type == 'single':
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
else: else:
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
time_out = 600 time_out = 600
while time_out > 0: while time_out > 0:

View File

@@ -0,0 +1,99 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))