From 24eb43e2f024cb9d5cdc4d0c6d2160eeec882e11 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Mon, 7 Apr 2025 16:47:27 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20pose=20transform=20=E9=83=A8=E7=BD=B2=20fix?= =?UTF-8?q?=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_pose_transform.py | 12 +++++------- .../generate_image/utils/pose_transform_upload.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/app/service/generate_image/service_pose_transform.py b/app/service/generate_image/service_pose_transform.py index 8a5e4c8..3fc65c6 100644 --- a/app/service/generate_image/service_pose_transform.py +++ b/app/service/generate_image/service_pose_transform.py @@ -29,9 +29,6 @@ logger = logging.getLogger() class PoseTransformService: def __init__(self, request_data): - if DEBUG is False: - self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - self.channel = self.connection.channel() self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "pose_transform" @@ -72,7 +69,6 @@ class PoseTransformService: self.pose_transform_data['video_url'] = str(video_url) self.pose_transform_data['image_url'] = str(first_image_url) - self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) def read_tasks_status(self): @@ -91,8 +87,8 @@ class PoseTransformService: input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8") input_image_files.set_data_from_numpy(image_files_obj) - ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback) - time_out = 6000 + ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback, client_timeout=60000) + time_out = 60000 while time_out > 0: pose_transform_data, _ = self.read_tasks_status() if pose_transform_data['status'] in ["REVOKED", "FAILURE"]: @@ -112,7 +108,9 @@ class PoseTransformService: finally: dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data) + connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + channel = connection.channel() + channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data) logger.info(f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}") diff --git a/app/service/generate_image/utils/pose_transform_upload.py b/app/service/generate_image/utils/pose_transform_upload.py index 7e97a26..69708f6 100644 --- a/app/service/generate_image/utils/pose_transform_upload.py +++ b/app/service/generate_image/utils/pose_transform_upload.py @@ -47,7 +47,7 @@ def upload_video(frames, user_id, category, file_name): # 生成内存中的视频字节流 video_buffer = io.BytesIO() with imageio.get_writer(video_buffer, format="mp4", fps=24) as writer: - for img in images: + for img in frames: writer.append_data(img) writer.close() video_bytes = video_buffer.getvalue()