diff --git a/app/core/config.py b/app/core/config.py index 6ac56e3..ac9181f 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -148,6 +148,7 @@ GRI_MODEL_URL = '10.1.1.240:10051' # Pose Transform service config PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}") +PT_MODEL_URL = '10.1.1.243:10061' # SEG service config SEGMENTATION = { diff --git a/app/service/generate_image/service_pose_transform.py b/app/service/generate_image/service_pose_transform.py index 6c1c1c9..8a5e4c8 100644 --- a/app/service/generate_image/service_pose_transform.py +++ b/app/service/generate_image/service_pose_transform.py @@ -9,16 +9,19 @@ """ import json import logging +import time +from io import BytesIO -import cv2 +import imageio import numpy as np import redis import tritonclient.grpc as grpcclient from PIL import Image +from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.pose_transform import PoseTransformModel -from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() @@ -29,33 +32,48 @@ class PoseTransformService: if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() - self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) + self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "pose_transform" - self.batch_size = 1 - self.seed = "1" self.image_url = request_data.image_url - self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") + self.pose_num = request_data.pose_id + self.image = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] - self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': 'test/mannequin_name.png', 'video_url': 'test/mannequin_name.png', 'image_url': 'test/mannequin_name.png'} + self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''} - self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) self.redis_client.expire(self.tasks_id, 600) def callback(self, result, error): if error: - self.gen_product_data['status'] = "FAILURE" - self.gen_product_data['message'] = str(error) - self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.pose_transform_data['status'] = "FAILURE" + self.pose_transform_data['message'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) else: - image = result.as_numpy("generated_inpaint_image") - image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") - self.gen_product_data['status'] = "SUCCESS" - self.gen_product_data['message'] = "success" - self.gen_product_data['image_url'] = str(image_url) - self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1] + + # 第一帧图像 + first_image = Image.fromarray(result_data[0]) + first_image_url = upload_first_image(first_image, user_id=self.user_id, category=f"{self.category}_first_img", file_name=f"{self.tasks_id}.png") + + # 上传GIF + gif_buffer = BytesIO() + imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5) + gif_buffer.seek(0) + gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif", file_name=f"{self.tasks_id}.gif") + + # 上传video + video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video", file_name=f"{self.tasks_id}.mp4") + + self.pose_transform_data['status'] = "SUCCESS" + self.pose_transform_data['message'] = "success" + self.pose_transform_data['gif_url'] = str(gif_url) + self.pose_transform_data['video_url'] = str(video_url) + self.pose_transform_data['image_url'] = str(first_image_url) + + + self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) @@ -63,51 +81,92 @@ class PoseTransformService: def get_result(self): try: - image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) - image = cv2.resize(image, (512, 768)) - images = [image.astype(np.uint8)] * self.batch_size + pose_num = [self.pose_num] * 1 + pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1)) + input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype)) + input_pose_num.set_data_from_numpy(pose_num_obj) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_files = [self.image.astype(np.uint8)] * 1 + image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3)) + input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8") + input_image_files.set_data_from_numpy(image_files_obj) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") - - input_image.set_data_from_numpy(image_obj) - - inputs = [input_image] - # ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) - - # time_out = 600 - # while time_out > 0: - # gen_product_data, _ = self.read_tasks_status() - # if gen_product_data['status'] in ["REVOKED", "FAILURE", "NO_FACE"]: - # ctx.cancel() - # break - # elif gen_product_data['status'] == "SUCCESS": - # break - # time_out -= 1 - # time.sleep(0.1) - gen_product_data, _ = self.read_tasks_status() - return gen_product_data + ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback) + time_out = 6000 + while time_out > 0: + pose_transform_data, _ = self.read_tasks_status() + if pose_transform_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif pose_transform_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + pose_transform_data, _ = self.read_tasks_status() + return pose_transform_data except Exception as e: - self.gen_product_data['status'] = "FAILURE" - self.gen_product_data['message'] = str(e) - self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.pose_transform_data['status'] = "FAILURE" + self.pose_transform_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data)) raise Exception(str(e)) finally: - dict_gen_product_data, str_gen_product_data = self.read_tasks_status() + dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_gen_product_data) - logger.info(f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data) + logger.info(f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} - gen_product_data = json.dumps(data) - redis_client.set(tasks_id, gen_product_data) + pose_transform_data = json.dumps(data) + redis_client.set(tasks_id, pose_transform_data) return data +def pre_processing_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + # 目标图片的尺寸 + target_width = 512 + target_height = 768 + + # 原始图片的尺寸 + original_width, original_height = image.size + + # 计算宽度和高度的缩放比例 + width_ratio = target_width / original_width + height_ratio = target_height / original_height + + # 选择较小的缩放比例,确保图片能完整放入目标图片中 + scale_ratio = min(width_ratio, height_ratio) + + # 计算调整后的尺寸 + new_width = int(original_width * scale_ratio) + new_height = int(original_height * scale_ratio) + + # 调整图片大小 + resized_image = image.resize((new_width, new_height)) + + # 创建一个 512x768 的透明图片 + result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0)) + + # 计算需要粘贴的位置,使图片居中 + x_offset = (target_width - new_width) // 2 + y_offset = (target_height - new_height) // 2 + + # 将调整大小后的图片粘贴到透明图片上 + if resized_image.mode == "RGBA": + result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3]) + else: + result_image.paste(resized_image, (x_offset, y_offset)) + result_image = result_image.convert("RGB") + image = np.array(result_image) + + # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + + return image + + if __name__ == '__main__': rd = PoseTransformModel( tasks_id="123-89", diff --git a/app/service/generate_image/utils/pose_transform_upload.py b/app/service/generate_image/utils/pose_transform_upload.py new file mode 100644 index 0000000..86c3e6e --- /dev/null +++ b/app/service/generate_image/utils/pose_transform_upload.py @@ -0,0 +1,68 @@ +import io +import logging + +import imageio +import numpy as np +# import boto3 +from minio import Minio + +from app.core.config import * +from app.service.utils.new_oss_client import oss_upload_image + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +def upload_first_image(image, user_id, category, file_name): + try: + image_data = io.BytesIO() + image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + object_name = f'{user_id}/{category}/{file_name}' + req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) + image_url = f"aida-users/{object_name}" + return image_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") + + +def upload_gif(gif_buffer, user_id, category, file_name): + try: + object_name = f'{user_id}/{category}/{file_name}' + req = minio_client.put_object( + "aida-users", + object_name, + gif_buffer, + length=gif_buffer.getbuffer().nbytes, + content_type="image/gif" + ) + return f"aida-users/{object_name}" + except Exception as e: + logging.warning(f"upload_gif runtime exception : {e}") + + +def upload_video(frames, user_id, category, file_name): + try: + video_buffer = io.BytesIO() + with imageio.get_writer(video_buffer, format='mp4', fps=24) as writer: + for frame in frames: + writer.append_data(frame) + video_buffer.seek(0) + + object_name = f'{user_id}/{category}/{file_name}' + # 上传视频流到MinIO + minio_client.put_object( + bucket_name="aida-users", + object_name=object_name, + data=video_buffer, + length=video_buffer.getbuffer().nbytes, + content_type='video/mp4' + ) + return f"aida-users/{object_name}" + except Exception as e: + logging.warning(f"upload_video runtime exception : {e}") + + +if __name__ == '__main__': + images = np.random.randint(0, 256, size=(4, 512, 512, 3), dtype=np.uint8) + print(upload_video(images, user_id=89, category='test', file_name="1.mp4"))