diff --git a/app/api/api_pose_transform.py b/app/api/api_pose_transform.py index fe5fc5a..4b66467 100644 --- a/app/api/api_pose_transform.py +++ b/app/api/api_pose_transform.py @@ -24,7 +24,8 @@ def pose_transform(request_item: PoseTransformModel, background_tasks: Backgroun { "tasks_id": "123-89", "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", - "pose_id": "1" + "pose_id": "1", + "result_type" : "gif" } """ try: diff --git a/app/core/config.py b/app/core/config.py index 5a1e2a3..662d7e2 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -146,6 +146,11 @@ GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble' GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_URL = '10.1.1.240:10051' + +# Pose Transform service config + +PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}") + # SEG service config SEGMENTATION = { "new_model_name": "seg_knet", diff --git a/app/schemas/pose_transform.py b/app/schemas/pose_transform.py index 045d8b9..05db63f 100644 --- a/app/schemas/pose_transform.py +++ b/app/schemas/pose_transform.py @@ -5,3 +5,4 @@ class PoseTransformModel(BaseModel): image_url: str tasks_id: str pose_id: str + result_type: str diff --git a/app/service/generate_image/service_pose_transform.py b/app/service/generate_image/service_pose_transform.py index f2948b3..8de243e 100644 --- a/app/service/generate_image/service_pose_transform.py +++ b/app/service/generate_image/service_pose_transform.py @@ -38,7 +38,12 @@ class PoseTransformService: self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] - self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'image_url': ''} + self.result_type = request_data.result_type + if self.result_type == "gif": + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': 'test/mannequin_name.png', 'video_url': '', 'type': self.result_type} + else: + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': '', 'video_url': 'test/mannequin_name.png', 'type': self.result_type} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) @@ -95,8 +100,8 @@ class PoseTransformService: finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: - self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) - logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): @@ -111,7 +116,8 @@ if __name__ == '__main__': rd = PoseTransformModel( tasks_id="123-89", image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png', - pose_id="1" + pose_id="1", + result_type="gif", ) server = PoseTransformService(rd) print(server.get_result())