Files
AiDA_Python/app/service/design_fast/design_generate.py

322 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import threading
import time
import requests
from minio import Minio
from app.core.config import settings
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem, TopMergeItem, BottomMergeItem, OthersMergeItem
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
from app.service.design_fast.utils.progress import final_progress, update_progress
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority, merge
from app.service.utils.decorator import RunTime
id_lock = threading.Lock()
logger = logging.getLogger()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def process_item(item, basic, design_type):
# 1. 定义映射配置
# key 为 item_type 的小写value 为对应的处理类
DESIGN_MAP = {
"body": BodyItem,
"blouse": TopItem,
"outwear": TopItem,
"dress": TopItem,
"tops": TopItem,
"skirt": BottomItem,
"trousers": BottomItem,
"bottoms": BottomItem,
"others": OthersItem,
}
MERGE_MAP = {
"body_merge": BodyItem,
"blouse_merge": TopMergeItem,
"outwear_merge": TopMergeItem,
"dress_merge": TopMergeItem,
"tops_merge": TopMergeItem,
"skirt_merge": BottomMergeItem,
"trousers_merge": BottomMergeItem,
"bottoms_merge": BottomMergeItem,
"others_merge": OthersMergeItem,
}
# 2. 根据 design_type 选择映射表
mapping = MERGE_MAP if design_type == "merge" else DESIGN_MAP
if design_type == "merge":
item_type_key = f"{item['type'].lower()}_merge"
elif design_type == "default":
item_type_key = item["type"].lower()
else:
item_type_key = item["type"].lower()
handler_class = mapping.get(item_type_key)
if not handler_class:
raise NotImplementedError(f"Item type {item['type']} not implemented for design_type={design_type}")
# 4. 统一实例化并执行
# 注意:这里假设所有 Item 类构造函数签名一致
server = handler_class(data=item, basic=basic, minio_client=minio_client)
item_data = server.process()
return item_data
def process_layer(item, layers):
# item处理结束后 对图层数据组装
if item["name"] == "mannequin":
body_layer = organize_body(item)
layers.append(body_layer)
return item["body_image"].size
elif item["name"] in ["others", "others_merge"]:
front_layer, back_layer = organize_others(item)
layers.append(front_layer)
layers.append(back_layer)
return None
else:
front_layer, back_layer = organize_clothing(item)
layers.append(front_layer)
layers.append(back_layer)
return None
@RunTime
def design_generate(request_data):
objects_data = request_data.dict()["objects"]
process_id = request_data.dict()["process_id"]
object_response = {}
threads = []
active_threads = 0
lock = threading.Lock()
total = len(objects_data)
def process_object(step, object):
nonlocal active_threads
basic = object["basic"]
items_response = {"layers": [], "objectSign": object["objectSign"] if "objectSign" in object.keys() else ""}
design_type = basic.get("design_type", "default")
if basic["single_overall"] == "overall":
item_results = []
for item in object["items"]:
item_results.append(process_item(item, basic, design_type))
layers = []
for item in item_results:
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float("inf")))
layers, new_size = update_base_size_priority(layers)
# pattern_overall_image_url 、 pattern_print_image_url
for lay in layers:
items_response["layers"].append(
{
"image_category": "body" if lay["name"] == "mannequin" else lay["name"],
"position": lay["position"],
"priority": lay.get("priority", None),
"resize_scale": lay["resize_scale"] if "resize_scale" in lay.keys() else None,
"image_size": lay["image"] if lay["image"] is None else lay["image"].size,
"gradient_string": lay["gradient_string"] if "gradient_string" in lay.keys() else "",
"mask_url": lay["mask_url"],
"image_url": lay["image_url"] if "image_url" in lay.keys() else None,
"pattern_overall_image_url": (
lay["pattern_overall_image_url"] if "pattern_overall_image_url" in lay.keys() else None
),
"pattern_print_image_url": lay["pattern_print_image_url"] if "pattern_print_image_url" in lay.keys() else None,
"transpose": lay.get("transpose", None),
"rotate": lay.get("rotate", None),
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
}
)
if basic.get("design_type") == "default":
items_response["synthesis_url"] = synthesis(layers, new_size, basic)
elif basic.get("design_type") == "merge":
items_response["synthesis_url"] = merge(layers, new_size, basic)
else:
items_response["synthesis_url"] = synthesis(layers, new_size, basic)
else:
item_result = process_item(object["items"][0], basic, design_type)
items_response["layers"].append(
{
"image_category": f"{item_result['name']}_front",
"image_size": item_result["back_image"].size if item_result["back_image"] else None,
"position": None,
"priority": 0,
"image_url": item_result["front_image_url"],
"mask_url": item_result["mask_url"],
"gradient_string": item_result["gradient_string"] if "gradient_string" in item_result.keys() else "",
"pattern_overall_image_url": (
item_result["pattern_overall_image_url"] if "pattern_overall_image_url" in item_result.keys() else None
),
"pattern_print_image_url": (
item_result["pattern_print_image_url"] if "pattern_print_image_url" in item_result.keys() else None
),
}
)
items_response["layers"].append(
{
"image_category": f"{item_result['name']}_back",
"image_size": item_result["front_image"].size if item_result["front_image"] else None,
"position": None,
"priority": 0,
"image_url": item_result["back_image_url"],
"mask_url": item_result["mask_url"],
"gradient_string": item_result["gradient_string"] if "gradient_string" in item_result.keys() else "",
"pattern_overall_image_url": (
item_result["pattern_overall_image_url"] if "pattern_overall_image_url" in item_result.keys() else None
),
"pattern_print_image_url": (
item_result["pattern_print_image_url"] if "pattern_print_image_url" in item_result.keys() else None
),
}
)
items_response["synthesis_url"] = synthesis_single(item_result["front_image"], item_result["back_image"])
update_progress(process_id, total)
with lock:
object_response[step] = items_response
active_threads -= 1
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object))
threads.append(t)
t.start()
with lock:
active_threads += 1
for t in threads:
t.join()
final_progress(process_id)
return object_response
@RunTime
def design_generate_v2(request_data):
objects_data = request_data.dict()["objects"]
callback_url = request_data.callback_url
request_id = request_data.requestId
threads = []
def process_object(object, callback_url):
basic = object["basic"]
design_type = basic.get("design_type", "default")
items_response = {
"layers": [],
"objectSign": object["objectSign"] if "objectSign" in object.keys() else "",
"requestId": request_id,
}
if basic["single_overall"] == "overall":
item_results = []
for item in object["items"]:
item_results.append(process_item(item, basic, design_type))
layers = []
for item in item_results:
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float("inf")))
layers, new_size = update_base_size_priority(layers)
for lay in layers:
items_response["layers"].append(
{
"image_category": "body" if lay["name"] == "mannequin" else lay["name"],
"position": lay["position"],
"priority": lay.get("priority", None),
"resize_scale": lay["resize_scale"] if "resize_scale" in lay.keys() else None,
"image_size": lay["image"] if lay["image"] is None else lay["image"].size,
"gradient_string": lay["gradient_string"] if "gradient_string" in lay.keys() else "",
"mask_url": lay["mask_url"],
"image_url": lay["image_url"] if "image_url" in lay.keys() else None,
"pattern_overall_image_url": (
lay["pattern_overall_image_url"] if "pattern_overall_image_url" in lay.keys() else None
),
"pattern_print_image_url": lay["pattern_print_image_url"] if "pattern_print_image_url" in lay.keys() else None,
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
}
)
items_response["synthesis_url"] = synthesis(layers, new_size, basic)
else:
item_result = process_item(object["items"][0], basic, design_type)
items_response["layers"].append(
{
"image_category": f"{item_result['name']}_front",
"image_size": item_result["back_image"].size if item_result["back_image"] else None,
"position": None,
"priority": 0,
"image_url": item_result["front_image_url"],
"mask_url": item_result["mask_url"],
"gradient_string": item_result["gradient_string"] if "gradient_string" in item_result.keys() else "",
"pattern_overall_image_url": (
item_result["pattern_overall_image_url"] if "pattern_overall_image_url" in item_result.keys() else None
),
"pattern_print_image_url": (
item_result["pattern_print_image_url"] if "pattern_print_image_url" in item_result.keys() else None
),
}
)
items_response["layers"].append(
{
"image_category": f"{item_result['name']}_back",
"image_size": item_result["front_image"].size if item_result["front_image"] else None,
"position": None,
"priority": 0,
"image_url": item_result["back_image_url"],
"mask_url": item_result["mask_url"],
"gradient_string": item_result["gradient_string"] if "gradient_string" in item_result.keys() else "",
"pattern_overall_image_url": (
item_result["pattern_overall_image_url"] if "pattern_overall_image_url" in item_result.keys() else None
),
"pattern_print_image_url": (
item_result["pattern_print_image_url"] if "pattern_print_image_url" in item_result.keys() else None
),
}
)
items_response["synthesis_url"] = synthesis_single(item_result["front_image"], item_result["back_image"])
# 发送结果给java端
url = callback_url
logger.info(f"java 回调 -> {url}")
headers = {
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
"User-Agent": "PostmanRuntime-ApipostRuntime/1.1.0",
"Connection": "keep-alive",
"Content-Type": "application/json",
}
# logger.info(items_response)
response = post_request(url, json_data=items_response, headers=headers)
if response:
# 打印结果
logger.info(response.text)
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(object, callback_url))
threads.append(t)
t.start()
def post_request(url, data=None, json_data=None, headers=None, auth=None, timeout=5):
"""
发送POST请求的封装函数
:param url: 接口的URL地址
:param data: 要发送的数据(字典形式,用于表单数据等,会自动编码)
:param json_data: 要发送的JSON数据字典形式会自动转换为JSON字符串
:param headers: 请求头字典
:param auth: 认证信息(如 ('username', 'password') 形式用于基本认证)
:param timeout: 超时时间,单位为秒
:return: 返回接口的响应对象
"""
try:
response = requests.post(url, data=data, json=json_data, headers=headers, auth=auth, timeout=timeout)
response.raise_for_status() # 如果请求失败,抛出异常
return response
except requests.RequestException as e:
print(f"POST请求出错: {e}")
return None