#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_pose_transform.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import json import logging import time from io import BytesIO import imageio import numpy as np import redis import tritonclient.grpc as grpcclient from PIL import Image from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import settings, PS_RABBITMQ_QUEUES, PT_MODEL_URL from app.schemas.pose_transform import PoseTransformModel from app.service.generate_image.utils.mq import publish_status from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image from app.service.utils.new_oss_client import oss_get_image logger = logging.getLogger() minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) class PoseTransformService: def __init__(self, request_data): self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL) self.redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True) self.category = "pose_transform" self.image_url = request_data.image_url self.pose_num = request_data.pose_id self.image = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) self.redis_client.expire(self.tasks_id, 600) def callback(self, result, error): if error: self.pose_transform_data['status'] = "FAILURE" self.pose_transform_data['message'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) else: result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1] # 第一帧图像 first_image = Image.fromarray(result_data[0]) first_image_url = upload_first_image(first_image, user_id=self.user_id, category=f"{self.category}_first_img", file_name=f"{self.tasks_id}.png") # 上传GIF gif_buffer = BytesIO() imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5) gif_buffer.seek(0) gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif", file_name=f"{self.tasks_id}.gif") # 上传video video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video", file_name=f"{self.tasks_id}.mp4") self.pose_transform_data['status'] = "SUCCESS" self.pose_transform_data['message'] = "success" self.pose_transform_data['gif_url'] = str(gif_url) self.pose_transform_data['video_url'] = str(video_url) self.pose_transform_data['image_url'] = str(first_image_url) self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data def get_result(self): try: pose_num = [self.pose_num] * 1 pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1)) input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype)) input_pose_num.set_data_from_numpy(pose_num_obj) image_files = [self.image.astype(np.uint8)] * 1 image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3)) input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8") input_image_files.set_data_from_numpy(image_files_obj) ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback, client_timeout=60000) time_out = 60000 while time_out > 0: pose_transform_data, _ = self.read_tasks_status() if pose_transform_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break elif pose_transform_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(1) pose_transform_data, _ = self.read_tasks_status() return pose_transform_data except Exception as e: self.pose_transform_data['status'] = "FAILURE" self.pose_transform_data['message'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) raise Exception(str(e)) finally: dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status() if not settings.DEBUG: publish_status(json.dumps(str_pose_transform_data), PS_RABBITMQ_QUEUES) logger.info( f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} pose_transform_data = json.dumps(data) redis_client.set(tasks_id, pose_transform_data) return data def pre_processing_image(image_url): image = oss_get_image(oss_client=minio_client, bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") # 目标图片的尺寸 target_width = 512 target_height = 768 # 原始图片的尺寸 original_width, original_height = image.size # 计算宽度和高度的缩放比例 width_ratio = target_width / original_width height_ratio = target_height / original_height # 选择较小的缩放比例,确保图片能完整放入目标图片中 scale_ratio = min(width_ratio, height_ratio) # 计算调整后的尺寸 new_width = int(original_width * scale_ratio) new_height = int(original_height * scale_ratio) # 调整图片大小 resized_image = image.resize((new_width, new_height)) # 创建一个 512x768 的透明图片 result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0)) # 计算需要粘贴的位置,使图片居中 x_offset = (target_width - new_width) // 2 y_offset = (target_height - new_height) // 2 # 将调整大小后的图片粘贴到透明图片上 if resized_image.mode == "RGBA": result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3]) else: result_image.paste(resized_image, (x_offset, y_offset)) result_image = result_image.convert("RGB") image = np.array(result_image) # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) return image if __name__ == '__main__': rd = PoseTransformModel( tasks_id="123-89", image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png', pose_id="1" ) server = PoseTransformService(rd) print(server.get_result())