diff --git a/app/api/api_design.py b/app/api/api_design.py index 1c77ed8..03e0b25 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -433,12 +433,12 @@ def model_process(request_data: ModelProgressModel): @router.post("/design_batch_generate") -async def design(file: UploadFile = File(...), - tasks_id: str = Form(...), - user_id: str = Form(...), - file_name: str = Form(...), - total: int = Form(...) - ): +async def design_batch(file: UploadFile = File(...), + tasks_id: str = Form(...), + user_id: str = Form(...), + file_name: str = Form(...), + total: int = Form(...) + ): dbg_config = DBGConfigModel( tasks_id=tasks_id, user_id=user_id, diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index f151b91..2706abd 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -236,15 +236,71 @@ def generate_relight_image(tasks_id: str): @router.post("/batch_generate_product_image") -async def design(request_batch_item: BatchGenerateProductImageModel): +async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **image_url**: 被生成图片的S3或minio url地址 + - **image_strength**: 生成强度,越低越接近原图 + - **product_type**: 输入single item 还是 overall item + - **batch_size**: 批生成数量 + + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", + "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", + "image_strength": 0.8, + "product_type": "overall", + "batch_size": 1 + } + """ return await start_product_batch_generate(request_batch_item) @router.post("/batch_generate_relight_image") -async def design(request_batch_item: BatchGenerateRelightImageModel): +async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于获取生成结果 + - **prompt**: 想要生成图片的描述词 + - **image_url**: 被生成图片的S3或minio url地址 + - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light + - **product_type**: 输入single item 还是 overall item + - **batch_size**: 批生成数量 + + + 示例参数: + { + "tasks_id": "123-89", + "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", + "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", + "direction": "Right Light", + "product_type": "overall", + "batch_size": 1 + } + """ return await start_relight_batch_generate(request_batch_item) @router.post("/batch_generate_pose_transform_image") -async def design(request_batch_item: BatchPoseTransformModel): +async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **image_url**: 被生成图片的S3或minio url地址 + - **pose_id**: 1 + - **batch_size**: 批生成数量 + + + 示例参数: + { + "tasks_id": "123-89", + "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", + "pose_id": "1", + "batch_size": 1 + } + """ return await start_pose_transform_batch_generate(request_batch_item) diff --git a/app/core/config.py b/app/core/config.py index aaf32d7..9650023 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -162,6 +162,11 @@ SEGMENTATION = { } # ollama config OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings" + + +# design batch +BATCH_DESIGN_RABBITMQ_QUEUES = os.getenv("BATCH_DESIGN_RABBITMQ_QUEUES", f"Design{RABBITMQ_ENV}") + # DESIGN config DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing" diff --git a/app/service/design_batch/utils/MQ.py b/app/service/design_batch/utils/MQ.py index 1b64bf3..4fc839b 100644 --- a/app/service/design_batch/utils/MQ.py +++ b/app/service/design_batch/utils/MQ.py @@ -2,16 +2,16 @@ import json import pika -from app.core.config import RABBITMQ_PARAMS +from app.core.config import RABBITMQ_PARAMS, BATCH_DESIGN_RABBITMQ_QUEUES def publish_status(task_id, progress, result): connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) channel = connection.channel() - channel.queue_declare(queue='DesignBatch', durable=True) + channel.queue_declare(queue=BATCH_DESIGN_RABBITMQ_QUEUES, durable=True) message = {'task_id': task_id, 'progress': progress, "result": result} channel.basic_publish(exchange='', - routing_key='DesignBatch', + routing_key=BATCH_DESIGN_RABBITMQ_QUEUES, body=json.dumps(message), properties=pika.BasicProperties( delivery_mode=2,