Files
AiDA_Python/app/service/generate_image/service_pose_transform.py

186 lines
7.5 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_pose_transform.py
@Author 周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
from io import BytesIO
import imageio
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from PIL import Image
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.pose_transform import PoseTransformModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class PoseTransformService:
def __init__(self, request_data):
self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "pose_transform"
self.image_url = request_data.image_url
self.pose_num = request_data.pose_id
self.image = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '',
'video_url': '', 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error):
if error:
self.pose_transform_data['status'] = "FAILURE"
self.pose_transform_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
else:
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
# 第一帧图像
first_image = Image.fromarray(result_data[0])
first_image_url = upload_first_image(first_image, user_id=self.user_id,
category=f"{self.category}_first_img",
file_name=f"{self.tasks_id}.png")
# 上传GIF
gif_buffer = BytesIO()
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
gif_buffer.seek(0)
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif",
file_name=f"{self.tasks_id}.gif")
# 上传video
video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video",
file_name=f"{self.tasks_id}.mp4")
self.pose_transform_data['status'] = "SUCCESS"
self.pose_transform_data['message'] = "success"
self.pose_transform_data['gif_url'] = str(gif_url)
self.pose_transform_data['video_url'] = str(video_url)
self.pose_transform_data['image_url'] = str(first_image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
pose_num = [self.pose_num] * 1
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape,
np_to_triton_dtype(pose_num_obj.dtype))
input_pose_num.set_data_from_numpy(pose_num_obj)
image_files = [self.image.astype(np.uint8)] * 1
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files],
callback=self.callback, client_timeout=60000)
time_out = 60000
while time_out > 0:
pose_transform_data, _ = self.read_tasks_status()
if pose_transform_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif pose_transform_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
pose_transform_data, _ = self.read_tasks_status()
return pose_transform_data
except Exception as e:
self.pose_transform_data['status'] = "FAILURE"
self.pose_transform_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
raise Exception(str(e))
finally:
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
if not DEBUG:
publish_status(json.dumps(str_pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
pose_transform_data = json.dumps(data)
redis_client.set(tasks_id, pose_transform_data)
return data
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:],
data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
result_image = result_image.convert("RGB")
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
if __name__ == '__main__':
rd = PoseTransformModel(
tasks_id="123-89",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
pose_id="1"
)
server = PoseTransformService(rd)
print(server.get_result())