feat(新功能): pose transform 逻辑修改

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zhouchengrong
2025-03-20 10:25:33 +08:00
parent 6f48626005
commit dab155d200
4 changed files with 18 additions and 5 deletions

View File

@@ -24,7 +24,8 @@ def pose_transform(request_item: PoseTransformModel, background_tasks: Backgroun
{
"tasks_id": "123-89",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"pose_id": "1"
"pose_id": "1",
"result_type" : "gif"
}
"""
try:

View File

@@ -146,6 +146,11 @@ GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}")
# SEG service config
SEGMENTATION = {
"new_model_name": "seg_knet",

View File

@@ -5,3 +5,4 @@ class PoseTransformModel(BaseModel):
image_url: str
tasks_id: str
pose_id: str
result_type: str

View File

@@ -38,7 +38,12 @@ class PoseTransformService:
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'image_url': ''}
self.result_type = request_data.result_type
if self.result_type == "gif":
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': 'test/mannequin_name.png', 'video_url': '', 'type': self.result_type}
else:
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': '', 'video_url': 'test/mannequin_name.png', 'type': self.result_type}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
@@ -95,8 +100,8 @@ class PoseTransformService:
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GRI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
@@ -111,7 +116,8 @@ if __name__ == '__main__':
rd = PoseTransformModel(
tasks_id="123-89",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
pose_id="1"
pose_id="1",
result_type="gif",
)
server = PoseTransformService(rd)
print(server.get_result())