feat(新功能): pose transform 部署

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-04-07 16:47:27 +08:00
parent 3ad724fe9f
commit 24eb43e2f0
2 changed files with 6 additions and 8 deletions

View File

@@ -29,9 +29,6 @@ logger = logging.getLogger()
class PoseTransformService:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "pose_transform"
@@ -72,7 +69,6 @@ class PoseTransformService:
self.pose_transform_data['video_url'] = str(video_url)
self.pose_transform_data['image_url'] = str(first_image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
def read_tasks_status(self):
@@ -91,8 +87,8 @@ class PoseTransformService:
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback)
time_out = 6000
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback, client_timeout=60000)
time_out = 60000
while time_out > 0:
pose_transform_data, _ = self.read_tasks_status()
if pose_transform_data['status'] in ["REVOKED", "FAILURE"]:
@@ -112,7 +108,9 @@ class PoseTransformService:
finally:
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data)
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data)
logger.info(f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")

View File

@@ -47,7 +47,7 @@ def upload_video(frames, user_id, category, file_name):
# 生成内存中的视频字节流
video_buffer = io.BytesIO()
with imageio.get_writer(video_buffer, format="mp4", fps=24) as writer:
for img in images:
for img in frames:
writer.append_data(img)
writer.close()
video_bytes = video_buffer.getvalue()