feat(新功能): pose transform 部署

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-04-07 13:35:01 +08:00
parent ae38a3a357
commit ddadf3e287
3 changed files with 177 additions and 49 deletions

View File

@@ -148,6 +148,7 @@ GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config # Pose Transform service config
PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}") PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}")
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config # SEG service config
SEGMENTATION = { SEGMENTATION = {

View File

@@ -9,16 +9,19 @@
""" """
import json import json
import logging import logging
import time
from io import BytesIO
import cv2 import imageio
import numpy as np import numpy as np
import redis import redis
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
from PIL import Image from PIL import Image
from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.pose_transform import PoseTransformModel from app.schemas.pose_transform import PoseTransformModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
from app.service.utils.oss_client import oss_get_image from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger() logger = logging.getLogger()
@@ -29,33 +32,48 @@ class PoseTransformService:
if DEBUG is False: if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel() self.channel = self.connection.channel()
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "pose_transform" self.category = "pose_transform"
self.batch_size = 1
self.seed = "1"
self.image_url = request_data.image_url self.image_url = request_data.image_url
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.pose_num = request_data.pose_id
self.image = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': 'test/mannequin_name.png', 'video_url': 'test/mannequin_name.png', 'image_url': 'test/mannequin_name.png'} self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
self.redis_client.expire(self.tasks_id, 600) self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error): def callback(self, result, error):
if error: if error:
self.gen_product_data['status'] = "FAILURE" self.pose_transform_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error) self.pose_transform_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
else: else:
image = result.as_numpy("generated_inpaint_image") result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") # 第一帧图像
self.gen_product_data['status'] = "SUCCESS" first_image = Image.fromarray(result_data[0])
self.gen_product_data['message'] = "success" first_image_url = upload_first_image(first_image, user_id=self.user_id, category=f"{self.category}_first_img", file_name=f"{self.tasks_id}.png")
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) # 上传GIF
gif_buffer = BytesIO()
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
gif_buffer.seek(0)
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif", file_name=f"{self.tasks_id}.gif")
# 上传video
video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video", file_name=f"{self.tasks_id}.mp4")
self.pose_transform_data['status'] = "SUCCESS"
self.pose_transform_data['message'] = "success"
self.pose_transform_data['gif_url'] = str(gif_url)
self.pose_transform_data['video_url'] = str(video_url)
self.pose_transform_data['image_url'] = str(first_image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
def read_tasks_status(self): def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id) status_data = self.redis_client.get(self.tasks_id)
@@ -63,51 +81,92 @@ class PoseTransformService:
def get_result(self): def get_result(self):
try: try:
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) pose_num = [self.pose_num] * 1
image = cv2.resize(image, (512, 768)) pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
images = [image.astype(np.uint8)] * self.batch_size input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype))
input_pose_num.set_data_from_numpy(pose_num_obj)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) image_files = [self.image.astype(np.uint8)] * 1
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback)
time_out = 6000
input_image.set_data_from_numpy(image_obj) while time_out > 0:
pose_transform_data, _ = self.read_tasks_status()
inputs = [input_image] if pose_transform_data['status'] in ["REVOKED", "FAILURE"]:
# ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) ctx.cancel()
break
# time_out = 600 elif pose_transform_data['status'] == "SUCCESS":
# while time_out > 0: break
# gen_product_data, _ = self.read_tasks_status() time_out -= 1
# if gen_product_data['status'] in ["REVOKED", "FAILURE", "NO_FACE"]: time.sleep(0.1)
# ctx.cancel() pose_transform_data, _ = self.read_tasks_status()
# break return pose_transform_data
# elif gen_product_data['status'] == "SUCCESS":
# break
# time_out -= 1
# time.sleep(0.1)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e: except Exception as e:
self.gen_product_data['status'] = "FAILURE" self.pose_transform_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e) self.pose_transform_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
raise Exception(str(e)) raise Exception(str(e))
finally: finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status() dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
if DEBUG is False: if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_gen_product_data) self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data)
logger.info(f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}") logger.info(f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
def infer_cancel(tasks_id): def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data) pose_transform_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data) redis_client.set(tasks_id, pose_transform_data)
return data return data
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
result_image = result_image.convert("RGB")
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
if __name__ == '__main__': if __name__ == '__main__':
rd = PoseTransformModel( rd = PoseTransformModel(
tasks_id="123-89", tasks_id="123-89",

View File

@@ -0,0 +1,68 @@
import io
import logging
import imageio
import numpy as np
# import boto3
from minio import Minio
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def upload_first_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_gif(gif_buffer, user_id, category, file_name):
try:
object_name = f'{user_id}/{category}/{file_name}'
req = minio_client.put_object(
"aida-users",
object_name,
gif_buffer,
length=gif_buffer.getbuffer().nbytes,
content_type="image/gif"
)
return f"aida-users/{object_name}"
except Exception as e:
logging.warning(f"upload_gif runtime exception : {e}")
def upload_video(frames, user_id, category, file_name):
try:
video_buffer = io.BytesIO()
with imageio.get_writer(video_buffer, format='mp4', fps=24) as writer:
for frame in frames:
writer.append_data(frame)
video_buffer.seek(0)
object_name = f'{user_id}/{category}/{file_name}'
# 上传视频流到MinIO
minio_client.put_object(
bucket_name="aida-users",
object_name=object_name,
data=video_buffer,
length=video_buffer.getbuffer().nbytes,
content_type='video/mp4'
)
return f"aida-users/{object_name}"
except Exception as e:
logging.warning(f"upload_video runtime exception : {e}")
if __name__ == '__main__':
images = np.random.randint(0, 256, size=(4, 512, 512, 3), dtype=np.uint8)
print(upload_video(images, user_id=89, category='test', file_name="1.mp4"))