Files
AiDA_Python/app/service/design/service.py
2024-07-19 15:10:28 +08:00

181 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import concurrent.futures
import io
import cv2
from PIL import Image
from app.core.config import PRIORITY_DICT
from app.service.design.core.layer import Layer
from app.service.design.items import build_item
from app.service.design.utils.redis_utils import Redis
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
from app.service.utils.decorator import RunTime
from app.service.utils.oss_client import oss_upload_image
def process_item(item, layers):
# logging.info("process running.........")
item.process()
item.organize(layers)
if item.result['name'] == "mannequin":
return item.result['body_image'].size
def update_progress(process_id, total):
r = Redis()
progress = r.read(key=process_id)
if progress and total != 1:
if int(progress) <= 100:
r.write(key=process_id, value=int(progress) + int(100 / total))
else:
r.write(key=process_id, value=100)
return progress
elif total == 1:
r.write(key=process_id, value=100)
return progress
else:
r.write(key=process_id, value=int(100 / total))
return progress
def final_progress(process_id):
r = Redis()
progress = r.read(key=process_id)
r.write(key=process_id, value=100)
return progress
@RunTime
def generate(request_data):
return_response = {}
return_png_mask = []
request_data = request_data.dict()
assert "process_id" in request_data.keys(), "Need process_id parameters"
objects = request_data['objects']
# insert_keypoint_cache(objects)
process_id = request_data['process_id']
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交每个对象的处理任务
futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)}
# 获取处理结果
for future in concurrent.futures.as_completed(futures):
obj = futures[future]
return_response[obj] = future.result()[0]
return_png_mask.extend(future.result()[1])
final_progress(process_id)
upload_results = process_images(return_png_mask)
return return_response
def process_object(cfg, process_id, total):
uploaded_images = []
basic_info = cfg.get('basic')
items_response = {
'layers': []
}
if cfg.get('basic')['single_overall'] == 'overall':
basic_info['debug'] = False
items = [build_item(x, default_args=basic_info) for x in cfg.get('items')]
layers = Layer()
body_size = None
futures = []
for item in items:
futures = [process_item(item, layers)]
for future in futures:
if future is not None:
body_size = future
# 是否自定义排序
if basic_info.get('layer_order', False):
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
else:
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
# 上传所有图片
for layer in layers:
if 'image' in layer.keys() and layer['image'] is not None:
uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url']})
if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None:
uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url']})
if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None:
uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url']})
# 合成
items_response['synthesis_url'] = synthesis(layers, body_size)
for lay in layers:
items_response['layers'].append({
'image_category': lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
# 'image': lay['image'],
# 'mask_image': lay['mask_image'],
})
elif cfg.get('basic')['single_overall'] == 'single':
assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters "
basic_info['debug'] = False
for item in cfg.get('items'):
if item['type'] == cfg.get('basic')['switch_category']:
item = build_item(item, default_args=cfg.get('basic'))
item.process()
items_response['layers'].append({
'image_category': f"{item.result['name']}_front",
'image_size': item.result['back_image'].size if item.result['back_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['front_image_url'],
'mask_url': item.result['front_mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else ""
})
items_response['layers'].append({
'image_category': f"{item.result['name']}_back",
'image_size': item.result['front_image'].size if item.result['front_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['back_image_url'],
'mask_url': item.result['back_mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else ""
})
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
break
update_progress(process_id, total)
return items_response, uploaded_images
@RunTime
def process_images(images):
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(upload_images, images))
# results = []
# for image in images:
# results.append(upload_images(image))
return results
@RunTime
def upload_images(image_obj):
bucket_name = image_obj['image_url'].split("/", 1)[0]
object_name = image_obj['image_url'].split("/", 1)[1]
if isinstance(image_obj['image_obj'], Image.Image):
image_data = io.BytesIO()
image_obj['image_obj'].save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return image_obj['image_url']
else:
mask_inverted = cv2.bitwise_not(image_obj['image_obj'])
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1])
return image_obj['image_url']