feat(新功能): generate product relight pose_transform 开发,设置batch generate 的优先级为100 ,single generate 的优先级为1
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
24
app/service/generate_batch_image/service.py
Normal file
24
app/service/generate_batch_image/service.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product, publish_status as product_publish_status
|
||||
from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight, publish_status as relight_publish_status
|
||||
from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform, publish_status as pose_transform_publish_status
|
||||
|
||||
|
||||
async def start_product_batch_generate(data):
|
||||
generate_clothes_task = batch_generate_product.delay(data.dict())
|
||||
print(generate_clothes_task)
|
||||
product_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
|
||||
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}
|
||||
|
||||
|
||||
async def start_relight_batch_generate(data):
|
||||
generate_clothes_task = batch_generate_relight.delay(data.dict())
|
||||
print(generate_clothes_task)
|
||||
relight_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
|
||||
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}
|
||||
|
||||
|
||||
async def start_pose_transform_batch_generate(data):
|
||||
generate_clothes_task = batch_generate_pose_transform.delay(data.dict())
|
||||
print(generate_clothes_task)
|
||||
pose_transform_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
|
||||
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}
|
||||
@@ -0,0 +1,191 @@
|
||||
# 旧版product
|
||||
# !/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from celery import Celery
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import BatchGenerateProductImageModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
celery_app = Celery('product_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app.conf.task_default_queue = 'queue_product'
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logger = logging.getLogger()
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||
category = "product_image"
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def batch_generate_product(batch_request_data):
|
||||
logger.info(f"batch_generate_product batch_request_data:{batch_request_data}")
|
||||
tasks_id = batch_request_data['tasks_id']
|
||||
user_id = tasks_id.rsplit('-', 1)[1]
|
||||
batch_size = batch_request_data['batch_size']
|
||||
image = pre_processing_image(batch_request_data['image_url'])
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
images = [image.astype(np.uint8)] * 1
|
||||
|
||||
prompts = [batch_request_data['prompt']] * 1
|
||||
|
||||
if batch_request_data['product_type'] == "single":
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_image_strength.set_data_from_numpy(image_strength_obj)
|
||||
|
||||
inputs = [input_text, input_image, input_image_strength]
|
||||
|
||||
image_url_list = []
|
||||
for i in range(batch_size):
|
||||
try:
|
||||
if batch_request_data['product_type'] == "single":
|
||||
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_cnet_image")
|
||||
else:
|
||||
result = grpc_client.infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_inpaint_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
except Exception as e:
|
||||
if 'mask_list' in str(e):
|
||||
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
e_image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||
|
||||
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
|
||||
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
|
||||
e_input_image_strength = grpcclient.InferInput("image_strength", e_image_strength_obj.shape, np_to_triton_dtype(e_image_strength_obj.dtype))
|
||||
|
||||
e_input_text.set_data_from_numpy(e_text_obj)
|
||||
e_input_image.set_data_from_numpy(e_image_obj)
|
||||
e_input_image_strength.set_data_from_numpy(e_image_strength_obj)
|
||||
|
||||
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=[e_input_text, e_input_image, e_input_image_strength], priority=100)
|
||||
image = result.as_numpy("generated_cnet_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
else:
|
||||
image_result = str(e)
|
||||
logger.error(image_result)
|
||||
|
||||
if isinstance(image_result, Image.Image):
|
||||
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
|
||||
image_url_list.append(image_url)
|
||||
else:
|
||||
image_url = image_result
|
||||
if DEBUG is False:
|
||||
if i + 1 < batch_size:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
|
||||
logger.info(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{image_url}")
|
||||
print(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{image_url}")
|
||||
else:
|
||||
publish_status(tasks_id, f"OK", image_url_list)
|
||||
logger.info(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{image_url_list}")
|
||||
print(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{image_url_list}")
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# 目标图片的尺寸
|
||||
target_width = 512
|
||||
target_height = 768
|
||||
|
||||
# 原始图片的尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 计算宽度和高度的缩放比例
|
||||
width_ratio = target_width / original_width
|
||||
height_ratio = target_height / original_height
|
||||
|
||||
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||
scale_ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# 计算调整后的尺寸
|
||||
new_width = int(original_width * scale_ratio)
|
||||
new_height = int(original_height * scale_ratio)
|
||||
|
||||
# 调整图片大小
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
# 创建一个 512x768 的透明图片
|
||||
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
|
||||
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (target_width - new_width) // 2
|
||||
y_offset = (target_height - new_height) // 2
|
||||
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
if resized_image.mode == "RGBA":
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
else:
|
||||
result_image.paste(resized_image, (x_offset, y_offset))
|
||||
|
||||
image = np.array(result_image)
|
||||
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
return image
|
||||
|
||||
|
||||
def post_processing_image(image, left, top):
|
||||
resized_image = image.resize((int(image.width * (768 / image.height)), 768))
|
||||
# 计算裁剪的坐标
|
||||
left = (resized_image.width - 512) // 2
|
||||
upper = 0
|
||||
right = left + 512
|
||||
lower = 768
|
||||
|
||||
# 进行裁剪
|
||||
cropped_image = resized_image.crop((left, upper, right, lower))
|
||||
return cropped_image
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=BATCH_GPI_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=BATCH_GPI_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = BatchGenerateProductImageModel(
|
||||
tasks_id="123-15-51-89",
|
||||
image_strength=0.7,
|
||||
prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
product_type="overall",
|
||||
batch_size=20
|
||||
)
|
||||
batch_generate_product(rd.dict())
|
||||
@@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from celery import Celery
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import BatchGenerateRelightImageModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
celery_app = Celery('relight_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app.conf.task_default_queue = 'queue_relight'
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
||||
category = "relight_image"
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def batch_generate_relight(batch_request_data):
|
||||
logger.info(f"batch_generate_relight batch_request_data: {batch_request_data}")
|
||||
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
||||
direction = batch_request_data['direction']
|
||||
seed = "1"
|
||||
prompt = batch_request_data['prompt']
|
||||
product_type = batch_request_data['product_type']
|
||||
image_url = batch_request_data['image_url']
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2")
|
||||
tasks_id = batch_request_data['tasks_id']
|
||||
user_id = tasks_id.rsplit('-', 1)[1]
|
||||
batch_size = batch_request_data['batch_size']
|
||||
|
||||
prompts = [prompt] * 1
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (512, 768))
|
||||
images = [image.astype(np.uint8)] * 1
|
||||
seeds = [seed] * 1
|
||||
nagetive_prompts = [negative_prompt] * 1
|
||||
directions = [direction] * 1
|
||||
|
||||
if product_type == 'single':
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
|
||||
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
|
||||
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_natext.set_data_from_numpy(na_text_obj)
|
||||
input_seed.set_data_from_numpy(seed_obj)
|
||||
input_direction.set_data_from_numpy(direction_obj)
|
||||
|
||||
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
|
||||
image_url_list = []
|
||||
for i in range(batch_size):
|
||||
try:
|
||||
if batch_request_data['product_type'] == "single":
|
||||
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_relight_image")
|
||||
else:
|
||||
result = grpc_client.infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_inpaint_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if 'mask_list' in str(e):
|
||||
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
e_na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
|
||||
e_seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
|
||||
e_direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
|
||||
|
||||
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
|
||||
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
|
||||
e_input_natext = grpcclient.InferInput("negative_prompt", e_na_text_obj.shape, np_to_triton_dtype(e_na_text_obj.dtype))
|
||||
e_input_seed = grpcclient.InferInput("seed", e_seed_obj.shape, np_to_triton_dtype(e_seed_obj.dtype))
|
||||
e_input_direction = grpcclient.InferInput("direction", e_direction_obj.shape, np_to_triton_dtype(e_direction_obj.dtype))
|
||||
|
||||
e_input_text.set_data_from_numpy(e_text_obj)
|
||||
e_input_image.set_data_from_numpy(e_image_obj)
|
||||
e_input_natext.set_data_from_numpy(e_na_text_obj)
|
||||
e_input_seed.set_data_from_numpy(e_seed_obj)
|
||||
e_input_direction.set_data_from_numpy(e_direction_obj)
|
||||
|
||||
e_inputs = [e_input_text, e_input_natext, e_input_image, e_input_seed, e_input_direction]
|
||||
|
||||
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=e_inputs, priority=100)
|
||||
image = result.as_numpy("generated_relight_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
else:
|
||||
image_result = str(e)
|
||||
logger.error(e)
|
||||
if isinstance(image_result, Image.Image):
|
||||
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
|
||||
image_url_list.append(image_url)
|
||||
else:
|
||||
image_url = image_result
|
||||
if DEBUG is False:
|
||||
if i + 1 < batch_size:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
|
||||
logger.info(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{image_url}")
|
||||
print(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{image_url}")
|
||||
else:
|
||||
publish_status(tasks_id, f"OK", image_url_list)
|
||||
logger.info(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{image_url_list}")
|
||||
print(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{image_url_list}")
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=BATCH_GRI_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = BatchGenerateRelightImageModel(
|
||||
tasks_id="123-89",
|
||||
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
prompt="Colorful black",
|
||||
image_url='aida-users/89/clothing_seg/283c5c82-1a92-11f0-b72a-0242ac150002.png',
|
||||
direction="Right Light",
|
||||
product_type="overall",
|
||||
batch_size=10
|
||||
)
|
||||
batch_generate_relight(rd.dict())
|
||||
176
app/service/generate_batch_image/service_batch_pose_transform.py
Normal file
176
app/service/generate_batch_image/service_batch_pose_transform.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from celery import Celery
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.pose_transform import BatchPoseTransformModel
|
||||
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
logger = logging.getLogger()
|
||||
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app.conf.task_default_queue = 'queue_post_transform'
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
|
||||
category = "pose_transform"
|
||||
|
||||
|
||||
def upload_first_image(image, user_id, category, file_name):
|
||||
try:
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
object_name = f'{user_id}/{category}/{file_name}'
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# 目标图片的尺寸
|
||||
target_width = 512
|
||||
target_height = 768
|
||||
|
||||
# 原始图片的尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 计算宽度和高度的缩放比例
|
||||
width_ratio = target_width / original_width
|
||||
height_ratio = target_height / original_height
|
||||
|
||||
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||
scale_ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# 计算调整后的尺寸
|
||||
new_width = int(original_width * scale_ratio)
|
||||
new_height = int(original_height * scale_ratio)
|
||||
|
||||
# 调整图片大小
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
# 创建一个 512x768 的透明图片
|
||||
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
|
||||
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (target_width - new_width) // 2
|
||||
y_offset = (target_height - new_height) // 2
|
||||
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
if resized_image.mode == "RGBA":
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
else:
|
||||
result_image.paste(resized_image, (x_offset, y_offset))
|
||||
result_image = result_image.convert("RGB")
|
||||
image = np.array(result_image)
|
||||
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def batch_generate_pose_transform(batch_request_data):
|
||||
logger.info(f"batch_generate_pose_transform batch_request_data: {batch_request_data}")
|
||||
batch_size = batch_request_data['batch_size']
|
||||
image_url = batch_request_data['image_url']
|
||||
image = pre_processing_image(image_url)
|
||||
pose_num = batch_request_data['pose_id']
|
||||
tasks_id = batch_request_data['tasks_id']
|
||||
user_id = tasks_id.rsplit('-', 1)[1]
|
||||
|
||||
pose_num = [pose_num] * 1
|
||||
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
|
||||
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype))
|
||||
input_pose_num.set_data_from_numpy(pose_num_obj)
|
||||
|
||||
image_files = [image.astype(np.uint8)] * 1
|
||||
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
|
||||
input_image_files.set_data_from_numpy(image_files_obj)
|
||||
|
||||
result_url_list = []
|
||||
for i in range(batch_size):
|
||||
try:
|
||||
result = grpc_client.infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], client_timeout=60000, priority=100)
|
||||
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
|
||||
# 第一帧图像
|
||||
first_image = Image.fromarray(result_data[0])
|
||||
first_image_url = upload_first_image(first_image, user_id=user_id, category=f"{category}_first_img", file_name=f"{tasks_id}_batch_{i}.png")
|
||||
|
||||
# 上传GIF
|
||||
gif_buffer = BytesIO()
|
||||
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
|
||||
gif_buffer.seek(0)
|
||||
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=user_id, category=f"{category}_gif", file_name=f"{tasks_id}_batch_{i}.gif")
|
||||
|
||||
# 上传video
|
||||
video_url = upload_video(frames=result_data, user_id=user_id, category=f"{category}_video", file_name=f"{tasks_id}_batch_{i}.mp4")
|
||||
data = {
|
||||
"gif_url": gif_url,
|
||||
"video_url": video_url,
|
||||
"first_image_url": first_image_url,
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
data = {}
|
||||
result_url_list.append(data)
|
||||
if DEBUG is False:
|
||||
if i + 1 < batch_size:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
|
||||
logger.info(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{data}")
|
||||
print(f" [x] {tasks_id}:tasks_id *** progress:{i + 1}/{batch_size} *** image_url:{data}")
|
||||
else:
|
||||
publish_status(tasks_id, f"OK", result_url_list)
|
||||
logger.info(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{result_url_list}")
|
||||
print(f" [x] {tasks_id}:tasks_id *** progress:OK *** image_url:{result_url_list}")
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=BATCH_GRI_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = BatchPoseTransformModel(
|
||||
tasks_id="123-89",
|
||||
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
||||
pose_id="1",
|
||||
batch_size=10
|
||||
)
|
||||
batch_generate_pose_transform(rd.dict())
|
||||
Reference in New Issue
Block a user