feat(新功能): pose transform 接口

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zhouchengrong
2025-03-17 11:14:54 +08:00
parent 00b8e9fb02
commit b4671a3793
4 changed files with 176 additions and 3 deletions

View File

@@ -0,0 +1,49 @@
import json
import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.pose_transform import PoseTransformModel
from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
router = APIRouter()
logger = logging.getLogger()
@router.post("/pose_transform")
def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 被生成图片的S3或minio url地址
- **pose_id**: 1
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"pose_id": "1"
}
"""
try:
logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = PoseTransformService(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/pose_transform_cancel/{tasks_id}")
def pose_transform_cancel(tasks_id: str):
try:
logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
data = pose_transform_infer_cancel(tasks_id)
logger.info(f"pose_transform_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter
from app.api import api_agent_generate_image
from app.api import api_attribute_retrieve, api_query_image
from app.api import api_brand_dna
from app.api import api_brighten
@@ -9,11 +10,9 @@ from app.api import api_design_pre_processing
from app.api import api_generate_image
from app.api import api_image2sketch
from app.api import api_mannequins_edit
from app.api import api_pose_transform
from app.api import api_prompt_generation
from app.api import api_recommendation
from app.api import api_super_resolution
from app.api import api_agent_generate_image
from app.api import api_test
router = APIRouter()
@@ -33,3 +32,4 @@ router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api
# router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")