feat(新功能): pose transform 部署

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-04-07 13:35:01 +08:00
parent ae38a3a357
commit ddadf3e287
3 changed files with 177 additions and 49 deletions

View File

@@ -148,6 +148,7 @@ GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}")
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config
SEGMENTATION = {

View File

@@ -9,16 +9,19 @@
"""
import json
import logging
import time
from io import BytesIO
import cv2
import imageio
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from PIL import Image
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.pose_transform import PoseTransformModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
@@ -29,33 +32,48 @@ class PoseTransformService:
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "pose_transform"
self.batch_size = 1
self.seed = "1"
self.image_url = request_data.image_url
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
self.pose_num = request_data.pose_id
self.image = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': 'test/mannequin_name.png', 'video_url': 'test/mannequin_name.png', 'image_url': 'test/mannequin_name.png'}
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.pose_transform_data['status'] = "FAILURE"
self.pose_transform_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
# 第一帧图像
first_image = Image.fromarray(result_data[0])
first_image_url = upload_first_image(first_image, user_id=self.user_id, category=f"{self.category}_first_img", file_name=f"{self.tasks_id}.png")
# 上传GIF
gif_buffer = BytesIO()
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
gif_buffer.seek(0)
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif", file_name=f"{self.tasks_id}.gif")
# 上传video
video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video", file_name=f"{self.tasks_id}.mp4")
self.pose_transform_data['status'] = "SUCCESS"
self.pose_transform_data['message'] = "success"
self.pose_transform_data['gif_url'] = str(gif_url)
self.pose_transform_data['video_url'] = str(video_url)
self.pose_transform_data['image_url'] = str(first_image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
@@ -63,51 +81,92 @@ class PoseTransformService:
def get_result(self):
try:
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * self.batch_size
pose_num = [self.pose_num] * 1
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype))
input_pose_num.set_data_from_numpy(pose_num_obj)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_files = [self.image.astype(np.uint8)] * 1
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image.set_data_from_numpy(image_obj)
inputs = [input_image]
# ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
# time_out = 600
# while time_out > 0:
# gen_product_data, _ = self.read_tasks_status()
# if gen_product_data['status'] in ["REVOKED", "FAILURE", "NO_FACE"]:
# ctx.cancel()
# break
# elif gen_product_data['status'] == "SUCCESS":
# break
# time_out -= 1
# time.sleep(0.1)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback)
time_out = 6000
while time_out > 0:
pose_transform_data, _ = self.read_tasks_status()
if pose_transform_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif pose_transform_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
pose_transform_data, _ = self.read_tasks_status()
return pose_transform_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.pose_transform_data['status'] = "FAILURE"
self.pose_transform_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data)
logger.info(f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
pose_transform_data = json.dumps(data)
redis_client.set(tasks_id, pose_transform_data)
return data
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
result_image = result_image.convert("RGB")
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
if __name__ == '__main__':
rd = PoseTransformModel(
tasks_id="123-89",

View File

@@ -0,0 +1,68 @@
import io
import logging
import imageio
import numpy as np
# import boto3
from minio import Minio
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def upload_first_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_gif(gif_buffer, user_id, category, file_name):
try:
object_name = f'{user_id}/{category}/{file_name}'
req = minio_client.put_object(
"aida-users",
object_name,
gif_buffer,
length=gif_buffer.getbuffer().nbytes,
content_type="image/gif"
)
return f"aida-users/{object_name}"
except Exception as e:
logging.warning(f"upload_gif runtime exception : {e}")
def upload_video(frames, user_id, category, file_name):
try:
video_buffer = io.BytesIO()
with imageio.get_writer(video_buffer, format='mp4', fps=24) as writer:
for frame in frames:
writer.append_data(frame)
video_buffer.seek(0)
object_name = f'{user_id}/{category}/{file_name}'
# 上传视频流到MinIO
minio_client.put_object(
bucket_name="aida-users",
object_name=object_name,
data=video_buffer,
length=video_buffer.getbuffer().nbytes,
content_type='video/mp4'
)
return f"aida-users/{object_name}"
except Exception as e:
logging.warning(f"upload_video runtime exception : {e}")
if __name__ == '__main__':
images = np.random.randint(0, 256, size=(4, 512, 512, 3), dtype=np.uint8)
print(upload_video(images, user_id=89, category='test', file_name="1.mp4"))