From 423ff8dd26ee8705f924761334ed91d5728f96ff Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 12 Sep 2024 10:06:25 +0800 Subject: [PATCH] feat design batch fix --- app/api/api_design.py | 41 ++++++- app/schemas/design.py | 51 ++------ .../design/service_design_batch_generate.py | 114 ++++++++++++++++++ 3 files changed, 160 insertions(+), 46 deletions(-) create mode 100644 app/service/design/service_design_batch_generate.py diff --git a/app/api/api_design.py b/app/api/api_design.py index bc3d1b9..5210477 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -1,12 +1,14 @@ import json import logging +import os -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, UploadFile, File, Form -from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel +from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel from app.schemas.response_template import ResponseModel from app.service.design.model_process_service import model_transpose from app.service.design.service import generate +from app.service.design.service_design_batch_generate import start_design_batch_generate from app.service.design.utils.redis_utils import Redis router = APIRouter() @@ -238,3 +240,38 @@ def model_process(request_data: ModelProgressModel): logger.warning(f"model_process Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) + + +# ############################################################## + + +@router.post("/design_batch_generate") +async def design(file: UploadFile = File(...), + tasks_id: str = Form(...), + user_id: str = Form(...), + priority: int = Form(...), + total: int = Form(...) + ): + # file_content = await file.read() + dbg_config = DBGConfigModel( + tasks_id=tasks_id, + user_id=user_id, + priority=priority, + total=total + ) + contents = await file.read() + file_name = file.filename + await save_request_file(contents, file_name) + + return await start_design_batch_generate(dbg_config, contents) + + +async def save_request_file(contents, file_name): + # 创建保存文件的目录(如果不存在) + save_dir = os.path.join(os.getcwd(), "design_batch", "request_data") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # 处理文件 + file_path = os.path.join(save_dir, file_name) + with open(file_path, "wb") as f: + f.write(contents) diff --git a/app/schemas/design.py b/app/schemas/design.py index edcc392..763e0a0 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -1,50 +1,6 @@ from pydantic import BaseModel -# class BodyPointModel(BaseModel): -# waistband_right: list[int] -# hand_point_right: list[int] -# waistband_left: list[int] -# hand_point_left: list[int] -# shoulder_left: list[int] -# shoulder_right: list[int] -# -# -# class BasicModel(BaseModel): -# body_point: BodyPointModel -# layer_order: bool -# scale_bag: float -# scale_earrings: float -# self_template: bool -# single_overall: str -# switch_category: str -# body_path: str -# -# -# class PrintModel(BaseModel): -# if_single: bool -# print_path_list: list[str] -# -# -# class ItemModel(BaseModel): -# color: str -# image_id: str -# offset: list[int] -# path: str -# print: PrintModel -# resize_scale: float -# type: str -# -# -# class CollocationModel(BaseModel): -# basic: BasicModel -# item: list[ItemModel] -# -# -# class DesignModel(BaseModel): -# object: list[CollocationModel] -# process_id: str - class DesignModel(BaseModel): objects: list[dict] process_id: str @@ -56,3 +12,10 @@ class DesignProgressModel(BaseModel): class ModelProgressModel(BaseModel): model_path: str + + +class DBGConfigModel(BaseModel): + tasks_id: str + user_id: str + priority: int + total: int diff --git a/app/service/design/service_design_batch_generate.py b/app/service/design/service_design_batch_generate.py new file mode 100644 index 0000000..0696176 --- /dev/null +++ b/app/service/design/service_design_batch_generate.py @@ -0,0 +1,114 @@ +import json + +import pika +from celery import Celery + +from app.service.design.design_batch.items.item import process_layer, process_item, update_base_size_priority +from app.service.design.utils.synthesis_item import synthesis_single, synthesis + +celery_app = Celery('clothes_generation', broker='amqp://guest:guest@localhost:5672//') + + + +@celery_app.task +def design_batch_generate(design_objects, total_steps, task_id): + objects_response = [] + for step, object in enumerate(design_objects): + basic = object['basic'] + items_response = {'layers': []} + if basic['single_overall'] == "overall": + item_results = [process_item(item, basic) for item in object['items']] + layers = [] + futures = [] + body_size = None + for item in item_results: + futures = [process_layer(item, layers)] + for future in futures: + if future is not None: + body_size = future + layers = sorted(layers, key=lambda s: s.get("priority", float('inf'))) + + layers, new_size = update_base_size_priority(layers, body_size) + + for lay in layers: + items_response['layers'].append({ + 'image_category': lay['name'], + 'position': lay['position'], + 'priority': lay.get("priority", None), + 'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None, + 'image_size': lay['image'] if lay['image'] is None else lay['image'].size, + 'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "", + 'mask_url': lay['mask_url'], + 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, + 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, + + # 'image': lay['image'], + # 'mask_image': lay['mask_image'], + }) + items_response['synthesis_url'] = synthesis(layers, new_size, basic) + else: + item_results = process_item(object['items'][0], basic) + items_response['layers'].append({ + 'image_category': f"{item_results['name']}_front", + 'image_size': item_results['back_image'].size if item_results['back_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item_results['front_image_url'], + 'mask_url': item_results['mask_url'], + "gradient_string": item_results['gradient_string'] if 'gradient_string' in item_results.keys() else "", + 'pattern_image_url': item_results['pattern_image_url'] if 'pattern_image_url' in item_results.keys() else None, + + }) + items_response['layers'].append({ + 'image_category': f"{item_results['name']}_back", + 'image_size': item_results['front_image'].size if item_results['front_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item_results['back_image_url'], + 'mask_url': item_results['mask_url'], + "gradient_string": item_results['gradient_string'] if 'gradient_string' in item_results.keys() else "", + 'pattern_image_url': item_results['pattern_image_url'] if 'pattern_image_url' in item_results.keys() else None, + + }) + items_response['synthesis_url'] = synthesis_single(item_results['front_image'], item_results['back_image']) + objects_response.append(items_response) + publish_status(task_id, f"{step + 1}/{total_steps}", objects_response) + print(objects_response) + return objects_response + + +def publish_status(task_id, progress, result): + connection = pika.BlockingConnection(pika.ConnectionParameters('localhost')) + channel = connection.channel() + channel.queue_declare(queue='DesignBatch', durable=True) + message = {'task_id': task_id, 'progress': progress, "result": result} + print(message) + channel.basic_publish(exchange='', + routing_key='DesignBatch', + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode=2, + )) + connection.close() + + +async def start_design_batch_generate(data, file): + generate_clothes_task = design_batch_generate.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id) + print(generate_clothes_task) + publish_status(data.tasks_id, "0/100", "") + return {"task_id": data.tasks_id} +# +# +# if __name__ == '__main__': +# data = {"objects": [{"basic": {"body_point_test": {"waistband_right": [200, 241], "hand_point_right": [223, 297], "waistband_left": [112, 241], "hand_point_left": [92, 305], "shoulder_left": [99, 116], "shoulder_right": [215, 116]}, "layer_order": True, "scale_bag": 0.7, "scale_earrings": 0.16, "self_template": True, "single_overall": "overall", "switch_category": ""}, "items": [ +# {"businessId": 270372, "color": "30 28 28", "image_id": 69780, "offset": [0, 0], "path": "aida-sys-image/images/female/trousers/0825000630.jpg", "seg_mask_url": "test/result.png", +# "print": {"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []}, "overall": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []}, "single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []}}, "priority": 10, "resize_scale": [1.0, 1.0], "type": "Trousers"}, +# {"businessId": 270373, "color": "30 28 28", "image_id": 98243, "offset": [0, 0], "path": "aida-sys-image/images/female/blouse/0902003811.jpg", "seg_mask_url": "test/result.png", +# "print": {"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []}, "overall": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []}, "single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []}}, "priority": 11, "resize_scale": [1.0, 1.0], "type": "Blouse"}, +# {"businessId": 270374, "color": "172 68 68", "image_id": 98244, "offset": [0, 0], "path": "aida-sys-image/images/female/outwear/0825000410.jpg", "seg_mask_url": "test/result.png", +# "print": {"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []}, "overall": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []}, "single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []}}, "priority": 12, "resize_scale": [1.0, 1.0], "type": "Outwear"}, +# {"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png", "image_id": 96090, "type": "Body"}]}], "process_id": "83"} +# total_steps = 1 +# task_id = 1 +# design_batch_generate(data['objects'], total_steps, task_id) +# # publish_status(task_id="0/100", progress=100)