diff --git a/.gitignore b/.gitignore index 1bf82fb..3f9e525 100644 --- a/.gitignore +++ b/.gitignore @@ -120,10 +120,10 @@ dmypy.json #runtime produce test +seg_cache logs seg_result/ seg_result -*.png uwsgi *.yaml *.yml @@ -133,5 +133,7 @@ Dockerfile app/logs app/logs/* *.log -*.jpg /qodana.yaml +.pth +.pytorch +*.png \ No newline at end of file diff --git a/app/api/api_brighten.py b/app/api/api_brighten.py new file mode 100644 index 0000000..cc5a03f --- /dev/null +++ b/app/api/api_brighten.py @@ -0,0 +1,59 @@ +import io +import json +import logging +import time + +from PIL import ImageEnhance +from fastapi import APIRouter, HTTPException + +from app.schemas.brighten import BrightenModel +from app.schemas.response_template import ResponseModel +from app.service.utils.oss_client import oss_get_image, oss_upload_image + +router = APIRouter() +logger = logging.getLogger() + + +def increase_brightness(img, factor): + enhancer = ImageEnhance.Brightness(img) + bright_img = enhancer.enhance(factor) + return bright_img + + +@router.post("/brighten") +async def brighten(request_item: BrightenModel): + """ + 创建一个具有以下参数的请求体: + - **image_url**: 提亮图片url + - **brighten_value**: 提高亮度的比重 亮度因子 1.0 表示原始亮度,1.5 表示增加 50% 的亮度 + + 示例参数: + { + "image_url": "aida-users/89/relight_image/3850e17b-3efd-4597-90ef-2a7bcd1a1a0b-0-89.png", + "brighten_value": 1.5 + } + """ + try: + start_time = time.time() + logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict())}") + image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL") + new_image = increase_brightness(image, request_item.brighten_value) + image_data = io.BytesIO() + new_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes) + brighten_url = f"{req.bucket_name}/{req.object_name}" + logger.info(f"run time is : {time.time() - start_time}") + except Exception as e: + logger.warning(f"brighten Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=brighten_url) + + +if __name__ == '__main__': + request_item = BrightenModel(image_url="aida-users/89/relight_image/3850e17b-3efd-4597-90ef-2a7bcd1a1a0b-0-89.png", + brighten_value=1.5) + image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL") + new_image = increase_brightness(image, request_item.brighten_value) + new_image.show() diff --git a/app/api/api_design.py b/app/api/api_design.py index 5ce6096..aa9fe43 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -1,13 +1,15 @@ import json import logging +import os -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, UploadFile, File, Form -from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel +from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel from app.schemas.response_template import ResponseModel from app.service.design.model_process_service import model_transpose -from app.service.design.service import generate -from app.service.design.utils.redis_utils import Redis +from app.service.design_batch.service import start_design_batch_generate +from app.service.design_fast.design_generate import design_generate +from app.service.design_fast.utils.redis_utils import Redis router = APIRouter() logger = logging.getLogger() @@ -24,28 +26,28 @@ def design(request_data: DesignModel): "basic": { "body_point_test": { "waistband_right": [ - 203, - 249 + 200, + 241 ], "hand_point_right": [ - 229, - 343 + 223, + 297 ], "waistband_left": [ - 119, - 248 + 112, + 241 ], "hand_point_left": [ - 97, - 343 + 92, + 305 ], "shoulder_left": [ - 108, - 107 + 99, + 116 ], "shoulder_right": [ - 212, - 107 + 215, + 116 ] }, "layer_order": true, @@ -57,65 +59,33 @@ def design(request_data: DesignModel): }, "items": [ { - "businessId": 255303, - "color": "139 148 156", - "image_id": 95159, + "businessId": 270372, + "color": "30 28 28", + "image_id": 69780, "offset": [ 0, 0 ], - "path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png", + "path": "aida-sys-image/images/female/trousers/0825000630.jpg", + "seg_mask_url": "test/result.png", "print": { - "single": { - "location": [ - [ - 200.0, - 200.0 - ] - ], - "print_angle_list": [ - 0.0 - ], - "print_path_list": [ - "aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png" - ], - "print_scale_list": [ - 1.0 - ] + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] }, "overall": { - "location": [ - [ - 512.0, - 512.0 - ] - ], - "print_angle_list": [ - 0.0 - ], - "print_path_list": [ - "aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png" - ], - "print_scale_list": [ - 1.0 - ] + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] }, - "element": { - "element_angle_list": [ - 0.0 - ], - "element_path_list": [ - "aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png" - ], - "element_scale_list": [ - 0.2731036750637755 - ], - "location": [ - [ - 228.63694825464364, - 406.4843844199667 - ] - ] + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] } }, "priority": 10, @@ -123,22 +93,101 @@ def design(request_data: DesignModel): 1.0, 1.0 ], - "type": "Dress" + "type": "Trousers" }, { - "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", - "image_id": 67277, + "businessId": 270373, + "color": "30 28 28", + "image_id": 98243, + "offset": [ + 0, + 0 + ], + "path": "aida-sys-image/images/female/blouse/0902003811.jpg", + "seg_mask_url": "test/result.png", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 11, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "businessId": 270374, + "color": "172 68 68", + "image_id": 98244, + "offset": [ + 0, + 0 + ], + "path": "aida-sys-image/images/female/outwear/0825000410.jpg", + "seg_mask_url": "test/result.png", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 12, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png", + "image_id": 96090, "type": "Body" } ] } ], - "process_id": "89" + "process_id": "83" } """ + # logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") + # data = generate(request_data=request_data) + # logger.info(f"design response @@@@@@:{json.dumps(data)}") + # + try: logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") - data = generate(request_data=request_data) + data = design_generate(request_data=request_data) logger.info(f"design response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"design Run Exception @@@@@@:{e}") @@ -193,3 +242,36 @@ def model_process(request_data: ModelProgressModel): logger.warning(f"model_process Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=data) + + +# ############################################################## + + +@router.post("/design_batch_generate") +async def design(file: UploadFile = File(...), + tasks_id: str = Form(...), + user_id: str = Form(...), + file_name: str = Form(...), + total: int = Form(...) + ): + dbg_config = DBGConfigModel( + tasks_id=tasks_id, + user_id=user_id, + file_name=file_name, + total=total + ) + contents = await file.read() + file_name = file.filename + await save_request_file(contents, file_name) + return await start_design_batch_generate(dbg_config, contents) + + +async def save_request_file(contents, file_name): + # 创建保存文件的目录(如果不存在) + save_dir = os.path.join(os.getcwd(), "design_batch", "request_data") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # 处理文件 + file_path = os.path.join(save_dir, file_name) + with open(file_path, "wb") as f: + f.write(contents) diff --git a/app/api/api_image2sketch.py b/app/api/api_image2sketch.py new file mode 100644 index 0000000..cac7652 --- /dev/null +++ b/app/api/api_image2sketch.py @@ -0,0 +1,38 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.image2sketch import Image2SketchModel +from app.schemas.response_template import ResponseModel +from app.service.lineart.service import LineArtService + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/image2sketch") +def image2sketch(request_item: Image2SketchModel): + """ + 创建一个具有以下参数的请求体: + - **image_url**: 提取图片url + - **default_style**: 原始、 1、2、3、4、5 + - **sketch_bucket**: sketch保存的bucket + - **sketch_name**: sketch保存的object name + + 示例参数: + { + "image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg", + "default_style": 0, + "sketch_bucket": "test", + "sketch_name": "image2sketch/area_fill_img.png" + } + """ + try: + logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = LineArtService(request_item) + result_url = service.get_result() + except Exception as e: + logger.warning(f"image2sketch Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=result_url) diff --git a/app/api/api_route.py b/app/api/api_route.py index c2bd2d2..7ee774d 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,14 +1,15 @@ from fastapi import APIRouter -from app.api import api_test -from app.api import api_super_resolution -from app.api import api_generate_image from app.api import api_attribute_retrieve -from app.api import api_design +from app.api import api_brighten from app.api import api_chat_robot -from app.api import api_prompt_generation +from app.api import api_design from app.api import api_design_pre_processing - +from app.api import api_generate_image +from app.api import api_image2sketch +from app.api import api_prompt_generation +from app.api import api_super_resolution +from app.api import api_test router = APIRouter() @@ -20,3 +21,5 @@ router.include_router(api_design.router, tags=['design'], prefix="/api") router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") +router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") +router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") diff --git a/app/core/config.py b/app/core/config.py index 3dc2132..35c12b7 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -24,11 +24,11 @@ DEBUG = False if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" - # FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml" + SEG_CACHE_PATH = "../seg_cache/" else: LOGS_PATH = "app/logs/" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" - # FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml' + SEG_CACHE_PATH = "/seg_cache/" # RABBITMQ_ENV = "" # 生产环境 RABBITMQ_ENV = "-dev" # 开发环境 @@ -64,7 +64,7 @@ RABBITMQ_PARAMS = { MILVUS_URL = "http://10.1.1.240:19530" MILVUS_TOKEN = "root:Milvus" MILVUS_ALIAS = "default" -MILVUS_TABLE_KEYPOINT = "keypoint_cache" +MILVUS_TABLE_KEYPOINT = "keypoint_cache_2" MILVUS_TABLE_SEG = "seg_cache" # Mysql 配置 diff --git a/app/design_batch/request_data/requests_data.json b/app/design_batch/request_data/requests_data.json new file mode 100644 index 0000000..1dba8d1 --- /dev/null +++ b/app/design_batch/request_data/requests_data.json @@ -0,0 +1,90 @@ +{ + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 201, + 242 + ], + "hand_point_right": [ + 222, + 312 + ], + "waistband_left": [ + 114, + 243 + ], + "hand_point_left": [ + 94, + 310 + ], + "shoulder_left": [ + 102, + 116 + ], + "shoulder_right": [ + 211, + 115 + ] + }, + "layer_order": true, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": true, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 264931, + "color": "145 220 232", + "image_id": 96844, + "offset": [ + 0, + 0 + ], + "path": "aida-users/87/sketch/2aa7aad5-74bb-41fa-9cdf-f06611b3e89a-2-87.png", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 10, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Dress" + }, + { + "body_path": "aida-sys-image/models/female/79805ec3-3f01-466d-91e0-36028d079699.png", + "image_id": 95444, + "type": "Body" + } + ] + } + + ], + "process_id": "87", + "tasks_id": , +} + + +//用 openai jsonl +// \ No newline at end of file diff --git a/app/schemas/brighten.py b/app/schemas/brighten.py new file mode 100644 index 0000000..e407905 --- /dev/null +++ b/app/schemas/brighten.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class BrightenModel(BaseModel): + image_url: str + brighten_value: float diff --git a/app/schemas/design.py b/app/schemas/design.py index edcc392..7ebd8e6 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -1,50 +1,6 @@ from pydantic import BaseModel -# class BodyPointModel(BaseModel): -# waistband_right: list[int] -# hand_point_right: list[int] -# waistband_left: list[int] -# hand_point_left: list[int] -# shoulder_left: list[int] -# shoulder_right: list[int] -# -# -# class BasicModel(BaseModel): -# body_point: BodyPointModel -# layer_order: bool -# scale_bag: float -# scale_earrings: float -# self_template: bool -# single_overall: str -# switch_category: str -# body_path: str -# -# -# class PrintModel(BaseModel): -# if_single: bool -# print_path_list: list[str] -# -# -# class ItemModel(BaseModel): -# color: str -# image_id: str -# offset: list[int] -# path: str -# print: PrintModel -# resize_scale: float -# type: str -# -# -# class CollocationModel(BaseModel): -# basic: BasicModel -# item: list[ItemModel] -# -# -# class DesignModel(BaseModel): -# object: list[CollocationModel] -# process_id: str - class DesignModel(BaseModel): objects: list[dict] process_id: str @@ -56,3 +12,10 @@ class DesignProgressModel(BaseModel): class ModelProgressModel(BaseModel): model_path: str + + +class DBGConfigModel(BaseModel): + tasks_id: str + user_id: str + file_name: str + total: int diff --git a/app/schemas/image2sketch.py b/app/schemas/image2sketch.py new file mode 100644 index 0000000..dbbbbb5 --- /dev/null +++ b/app/schemas/image2sketch.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class Image2SketchModel(BaseModel): + image_url: str + default_style: str + sketch_bucket: str + sketch_name: str diff --git a/app/service/design/fastapi_request.json b/app/service/design/fastapi_request.json deleted file mode 100644 index 8c27a56..0000000 --- a/app/service/design/fastapi_request.json +++ /dev/null @@ -1,771 +0,0 @@ -{ - "objects": [ - { - "basic": { - "body_point_test": { - "waistband_right": [ - 336, - 264 - ], - "hand_point_right": [ - 350, - 303 - ], - "waistband_left": [ - 245, - 274 - ], - "hand_point_left": [ - 219, - 315 - ], - "shoulder_left": [ - 227, - 155 - ], - "shoulder_right": [ - 338, - 149 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "businessId": 493827, - "color": "127 61 21", - "elementId": 493827, - "icon": "none", - "image_id": 110201, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketch/62302527-2910-4740-808d-2cb8221daa34-3-31.png", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Dress" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Body" - } - ] - }, - { - "basic": { - "body_point_test": { - "waistband_right": [ - 336, - 264 - ], - "hand_point_right": [ - 350, - 303 - ], - "waistband_left": [ - 245, - 274 - ], - "hand_point_left": [ - 219, - 315 - ], - "shoulder_left": [ - 227, - 155 - ], - "shoulder_right": [ - 338, - 149 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "27 25 23", - "icon": "none", - "image_id": 110202, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/skirt/0916000602.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Skirt" - }, - { - "businessId": 493825, - "color": "229 214 200", - "elementId": 493825, - "icon": "none", - "image_id": 107101, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Blouse" - }, - { - "businessId": 493824, - "color": "76 124 124", - "elementId": 493824, - "icon": "none", - "image_id": 104522, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Outwear" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Body" - } - ] - }, - { - "basic": { - "body_point_test": { - "waistband_right": [ - 336, - 264 - ], - "hand_point_right": [ - 350, - 303 - ], - "waistband_left": [ - 245, - 274 - ], - "hand_point_left": [ - 219, - 315 - ], - "shoulder_left": [ - 227, - 155 - ], - "shoulder_right": [ - 338, - 149 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "229 214 200", - "icon": "none", - "image_id": 110203, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0825001576.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Blouse" - }, - { - "color": "76 124 124", - "icon": "none", - "image_id": 96071, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/skirt/903000097.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Skirt" - }, - { - "color": "209 125 29", - "icon": "none", - "image_id": 93798, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/outwear/outwear_p4_561.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Outwear" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Body" - } - ] - }, - { - "basic": { - "body_point_test": { - "waistband_right": [ - 336, - 264 - ], - "hand_point_right": [ - 350, - 303 - ], - "waistband_left": [ - 245, - 274 - ], - "hand_point_left": [ - 219, - 315 - ], - "shoulder_left": [ - 227, - 155 - ], - "shoulder_right": [ - 338, - 149 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "businessId": 493824, - "color": "209 125 29", - "elementId": 493824, - "icon": "none", - "image_id": 104522, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Outwear" - }, - { - "color": "118 123 115", - "icon": "none", - "image_id": 110204, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/0902000457.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Blouse" - }, - { - "color": "118 123 115", - "icon": "none", - "image_id": 79259, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/trousers/826000094.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Trousers" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Body" - } - ] - }, - { - "basic": { - "body_point_test": { - "waistband_right": [ - 336, - 264 - ], - "hand_point_right": [ - 350, - 303 - ], - "waistband_left": [ - 245, - 274 - ], - "hand_point_left": [ - 219, - 315 - ], - "shoulder_left": [ - 227, - 155 - ], - "shoulder_right": [ - 338, - 149 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "127 61 21", - "icon": "none", - "image_id": 96038, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/dress/0902003549.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Dress" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Body" - } - ] - }, - { - "basic": { - "body_point_test": { - "waistband_right": [ - 336, - 264 - ], - "hand_point_right": [ - 350, - 303 - ], - "waistband_left": [ - 245, - 274 - ], - "hand_point_left": [ - 219, - 315 - ], - "shoulder_left": [ - 227, - 155 - ], - "shoulder_right": [ - 338, - 149 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "businessId": 493822, - "color": "127 61 21", - "elementId": 493822, - "icon": "none", - "image_id": 62309, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketchboard/female/trousers/c37c2ea6-8955-4b40-8339-c737e672ca3d.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Trousers" - }, - { - "businessId": 493825, - "color": "118 123 115", - "elementId": 493825, - "icon": "none", - "image_id": 107101, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Blouse" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Body" - } - ] - }, - { - "basic": { - "body_point_test": { - "waistband_right": [ - 336, - 264 - ], - "hand_point_right": [ - 350, - 303 - ], - "waistband_left": [ - 245, - 274 - ], - "hand_point_left": [ - 219, - 315 - ], - "shoulder_left": [ - 227, - 155 - ], - "shoulder_right": [ - 338, - 149 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "businessId": 493826, - "color": "127 61 21", - "elementId": 493826, - "icon": "none", - "image_id": 107105, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketchboard/female/Skirt/58710352-6301-450d-b69a-fb2922b5429a.png", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Skirt" - }, - { - "color": "118 123 115", - "icon": "none", - "image_id": 79114, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/blouse/903000169.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Blouse" - }, - { - "color": "229 214 200", - "icon": "none", - "image_id": 90573, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/outwear/0628000541.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Outwear" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Body" - } - ] - }, - { - "basic": { - "body_point_test": { - "waistband_right": [ - 336, - 264 - ], - "hand_point_right": [ - 350, - 303 - ], - "waistband_left": [ - 245, - 274 - ], - "hand_point_left": [ - 219, - 315 - ], - "shoulder_left": [ - 227, - 155 - ], - "shoulder_right": [ - 338, - 149 - ] - }, - "layer_order": false, - "scale_bag": 0.7, - "scale_earrings": 0.16, - "self_template": true, - "single_overall": "overall", - "switch_category": "" - }, - "items": [ - { - "color": "229 214 200", - "icon": "none", - "image_id": 110205, - "offset": [ - 1, - 1 - ], - "path": "aida-sys-image/images/female/trousers/0916000217.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Trousers" - }, - { - "businessId": 493825, - "color": "209 125 29", - "elementId": 493825, - "icon": "none", - "image_id": 107101, - "offset": [ - 1, - 1 - ], - "path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg", - "print": { - "IfSingle": false, - "print_path_list": [] - }, - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Blouse" - }, - { - "body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png", - "image_id": 82966, - "offset": [ - 1, - 1 - ], - "resize_scale": [ - 1.0, - 1.0 - ], - "type": "Body" - } - ] - } - ], - "process_id": "6878547032381675" -} \ No newline at end of file diff --git a/app/service/design/items/bottom.py b/app/service/design/items/bottom.py index eb575fb..e01ec02 100644 --- a/app/service/design/items/bottom.py +++ b/app/service/design/items/bottom.py @@ -10,6 +10,7 @@ class Bottom(Clothing): dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']), dict(type='KeypointDetection'), dict(type='ContourDetection'), + # dict(type='Segmentation'), dict(type='Painting', painting_flag=True), dict(type='PrintPainting', print_flag=True), dict(type='Scaling'), diff --git a/app/service/design/items/clothing.py b/app/service/design/items/clothing.py index f9f9561..953cecf 100644 --- a/app/service/design/items/clothing.py +++ b/app/service/design/items/clothing.py @@ -30,14 +30,15 @@ class Clothing(object): image=self.result["front_image"], # mask_image=self.result['front_mask_image'], image_url=self.result['front_image_url'], - mask_url=self.result['front_mask_url'], + mask_url=self.result['mask_url'], sacle=self.result['scale'], clothes_keypoint=self.result['clothes_keypoint'], position=start_point, resize_scale=self.result["resize_scale"], mask=cv2.resize(self.result['mask'], self.result["front_image"].size), gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", - pattern_image_url=self.result['pattern_image_url'] + pattern_image_url=self.result['pattern_image_url'], + pattern_image=self.result['pattern_image'] ) layer.insert(front_layer) @@ -47,14 +48,14 @@ class Clothing(object): image=self.result["back_image"], # mask_image=self.result['back_mask_image'], image_url=self.result['back_image_url'], - mask_url=self.result['back_mask_url'], + mask_url=self.result['mask_url'], sacle=self.result['scale'], clothes_keypoint=self.result['clothes_keypoint'], position=start_point, resize_scale=self.result["resize_scale"], mask=cv2.resize(self.result['mask'], self.result["front_image"].size), gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "", - pattern_image_url=self.result['pattern_image_url'] + pattern_image_url=self.result['pattern_image_url'], ) layer.insert(back_layer) diff --git a/app/service/design/items/pipelines/contour_detection.py b/app/service/design/items/pipelines/contour_detection.py index 018dbca..487d2d6 100644 --- a/app/service/design/items/pipelines/contour_detection.py +++ b/app/service/design/items/pipelines/contour_detection.py @@ -43,7 +43,8 @@ class ContourDetection(object): result['mask'] = Mask else: result['mask'] = cv2.bitwise_and(Mask, result['pre_mask']) - + result['front_mask'] = result['mask'] + result['back_mask'] = result['mask'] return result @staticmethod diff --git a/app/service/design/items/pipelines/keypoints.py b/app/service/design/items/pipelines/keypoints.py index 1f53ced..fded7de 100644 --- a/app/service/design/items/pipelines/keypoints.py +++ b/app/service/design/items/pipelines/keypoints.py @@ -5,6 +5,7 @@ import numpy as np from pymilvus import MilvusClient from app.core.config import * +from app.service.utils.decorator import RunTime, ClassCallRunTime from ..builder import PIPELINES from ...utils.design_ensemble import get_keypoint_result @@ -27,7 +28,7 @@ class KeypointDetection(object): # self.client.close() # print(f"client close time : {time.time() - start_time}") - # @ RunTime + # @ClassCallRunTime def __call__(self, result): # logging.info("KeypointDetection run ") if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 diff --git a/app/service/design/items/pipelines/loading.py b/app/service/design/items/pipelines/loading.py index d792646..04dc4d8 100644 --- a/app/service/design/items/pipelines/loading.py +++ b/app/service/design/items/pipelines/loading.py @@ -12,6 +12,7 @@ class LoadImageFromFile(object): self.print_dict = print_dict # self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # @ClassCallRunTime def __call__(self, result): result['image'], result['pre_mask'] = self.read_image(self.path) result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) @@ -45,15 +46,18 @@ class LoadImageFromFile(object): @staticmethod def read_image(image_path): image_mask = None - # file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data - # image = cv2.imdecode(np.frombuffer(file, np.uint8), 1) - image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2") if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) if image.shape[2] == 4: # 如果是四通道 mask image_mask = image[:, :, 3] image = image[:, :, :3] + + if image.shape[:2] <= (50, 50): + # 计算新尺寸 + new_size = (image.shape[1] * 2, image.shape[0] * 2) + # 调整大小 + image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR) return image, image_mask diff --git a/app/service/design/items/pipelines/painting.py b/app/service/design/items/pipelines/painting.py index 0fd2897..993697c 100644 --- a/app/service/design/items/pipelines/painting.py +++ b/app/service/design/items/pipelines/painting.py @@ -1,3 +1,4 @@ +import logging import random import cv2 @@ -7,13 +8,15 @@ from PIL import Image from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES +logger = logging.getLogger() + @PIPELINES.register_module() class Painting(object): def __init__(self, painting_flag=True): self.painting_flag = painting_flag - # @ RunTime + # @ClassCallRunTime def __call__(self, result): if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none': dim_image_h, dim_image_w = result['image'].shape[0:2] @@ -86,7 +89,7 @@ class PrintPainting(object): def __init__(self, print_flag=True): self.print_flag = print_flag - # @ RunTime + # @ClassCallRunTime def __call__(self, result): single_print = result['print']['single'] overall_print = result['print']['overall'] @@ -236,7 +239,6 @@ class PrintPainting(object): print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) - print(1) else: mask = self.get_mask_inv(image) mask = np.expand_dims(mask, axis=2) diff --git a/app/service/design/items/pipelines/scale.py b/app/service/design/items/pipelines/scale.py index d101530..edd98c9 100644 --- a/app/service/design/items/pipelines/scale.py +++ b/app/service/design/items/pipelines/scale.py @@ -2,6 +2,7 @@ import math import cv2 +from app.service.utils.decorator import ClassCallRunTime from ..builder import PIPELINES @@ -10,7 +11,7 @@ class Scaling(object): def __init__(self): pass - # @ RunTime + # @ClassCallRunTime def __call__(self, result): if result['keypoint'] in ['waistband', 'shoulder', 'head_point']: # milvus_db_keypoint_cache diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index d9f8ac0..19eb1fd 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -1,14 +1,71 @@ +import logging +import os + +import cv2 +import numpy as np + +from app.core.config import SEG_CACHE_PATH +from app.service.utils.decorator import ClassCallRunTime +from app.service.utils.oss_client import oss_get_image from ..builder import PIPELINES from ...utils.design_ensemble import get_seg_result +logger = logging.getLogger() + @PIPELINES.register_module() class Segmentation(object): - def __init__(self, device='cpu', show=False, debug=None): - self.show = show - self.device = device - self.debug = debug + @ClassCallRunTime def __call__(self, result): - result['seg_result'] = get_seg_result(result["image_id"], result['image']) + if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "": + seg_mask = oss_get_image(bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2") + seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST) + # 转换颜色空间为 RGB(OpenCV 默认是 BGR) + image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB) + + r, g, b = cv2.split(image_rgb) + red_mask = r > g + green_mask = g > r + + # 创建红色和绿色掩码 + result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255 + result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255 + result['mask'] = result['front_mask'] + result['back_mask'] + else: + # 本地查询seg 缓存是否存在 + _, seg_result = self.load_seg_result(result["image_id"]) + result['seg_result'] = seg_result + if not _: + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + self.save_seg_result(seg_result, result['image_id']) + # 处理前片后片 + temp_front = seg_result == 1.0 + result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) + temp_back = seg_result == 2.0 + result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) + result['mask'] = result['front_mask'] + result['back_mask'] return result + + @staticmethod + def save_seg_result(seg_result, image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" + try: + np.save(file_path, seg_result) + logger.info(f"保存成功 :{os.path.abspath(file_path)}") + except Exception as e: + logger.error(f"保存失败: {e}") + + @staticmethod + def load_seg_result(image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" + try: + seg_result = np.load(file_path) + return True, seg_result + except FileNotFoundError: + logger.warning("文件不存在") + return False, None + except Exception as e: + logger.error(f"加载失败: {e}") + return False, None diff --git a/app/service/design/items/pipelines/split.py b/app/service/design/items/pipelines/split.py index efa20e4..3485453 100644 --- a/app/service/design/items/pipelines/split.py +++ b/app/service/design/items/pipelines/split.py @@ -1,3 +1,4 @@ +import io import logging import cv2 @@ -5,7 +6,9 @@ import numpy as np from PIL import Image from cv2 import cvtColor, COLOR_BGR2RGBA +from app.core.config import AIDA_CLOTHING from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.oss_client import oss_upload_image from ..builder import PIPELINES from ...utils.conversion_image import rgb_to_rgba from ...utils.upload_image import upload_png_mask @@ -17,32 +20,14 @@ class Split(object): Split image into front and back layer according to the segmentation result """ + # @ClassCallRunTime # KNet def __call__(self, result): try: - if 'mask' not in result.keys(): - raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.') - if 'seg_result' not in result.keys(): # 没过seg模型 - result['front_mask'] = result['mask'].copy() - result['back_mask'] = np.zeros_like(result['mask']) - else: - temp_front = result['seg_result'] == 1 - result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8)) - temp_back = result['seg_result'] == 2 - result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8)) if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): - if len(result['front_mask'].shape) > 2: - front_mask = result['front_mask'][0] - else: - front_mask = result['front_mask'] - - if len(result['back_mask'].shape) > 2: - back_mask = result['back_mask'][0] - else: - back_mask = result['back_mask'] - - # rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], front_mask + back_mask) + front_mask = result['front_mask'] + back_mask = result['back_mask'] rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1])) rgba_image = cv2.resize(rgba_image, new_size) @@ -50,23 +35,45 @@ class Split(object): front_mask = cv2.resize(front_mask, new_size) result_front_image[front_mask != 0] = rgba_image[front_mask != 0] result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) - result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask) + result['front_image'], result["front_image_url"], _ = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=None) + + height, width = front_mask.shape + mask_image = np.zeros((height, width, 3)) + mask_image[front_mask != 0] = [0, 0, 255] + if result["name"] in ('blouse', 'dress', 'outwear', 'tops'): result_back_image = np.zeros_like(rgba_image) back_mask = cv2.resize(back_mask, new_size) result_back_image[back_mask != 0] = rgba_image[back_mask != 0] result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) - result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask) + result['back_image'], result["back_image_url"], _ = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=None) + mask_image[back_mask != 0] = [0, 255, 0] + + rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask) + mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA)) + image_data = io.BytesIO() + mask_pil.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) + result['mask_url'] = req.bucket_name + "/" + req.object_name else: + rbga_mask = rgb_to_rgba(mask_image, front_mask) + mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA)) + image_data = io.BytesIO() + mask_pil.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) + result['mask_url'] = req.bucket_name + "/" + req.object_name result['back_image'] = None result["back_image_url"] = None - result["back_mask_url"] = None - result['back_mask_image'] = None - - # 创建中间图层 + # result["back_mask_url"] = None + # result['back_mask_image'] = None + # 创建中间图层 result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask']) result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA)) - _, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}') + result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}') return result except Exception as e: logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design/items/top.py b/app/service/design/items/top.py index 135328f..fc0d2a5 100644 --- a/app/service/design/items/top.py +++ b/app/service/design/items/top.py @@ -9,8 +9,8 @@ class Top(Clothing): pipeline = [ dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']), dict(type='KeypointDetection'), - dict(type='ContourDetection'), - dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']), + # dict(type='ContourDetection'), + dict(type='Segmentation'), dict(type='Painting', painting_flag=True), dict(type='PrintPainting', print_flag=True), # dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']), diff --git a/app/service/design/service.py b/app/service/design/service.py index ac17351..ba7e987 100644 --- a/app/service/design/service.py +++ b/app/service/design/service.py @@ -1,4 +1,7 @@ import concurrent.futures +import io + +import cv2 from app.core.config import PRIORITY_DICT from app.service.design.core.layer import Layer @@ -6,6 +9,7 @@ from app.service.design.items import build_item from app.service.design.utils.redis_utils import Redis from app.service.design.utils.synthesis_item import synthesis, synthesis_single from app.service.utils.decorator import RunTime +from app.service.utils.oss_client import oss_upload_image def process_item(item, layers): @@ -23,7 +27,7 @@ def update_progress(process_id, total): if int(progress) <= 100: r.write(key=process_id, value=int(progress) + int(100 / total)) else: - r.write(key=process_id, value=100) + r.write(key=process_id, value=99) return progress elif total == 1: r.write(key=process_id, value=100) @@ -43,6 +47,7 @@ def final_progress(process_id): @RunTime def generate(request_data): return_response = {} + return_png_mask = [] request_data = request_data.dict() assert "process_id" in request_data.keys(), "Need process_id parameters" @@ -55,14 +60,15 @@ def generate(request_data): # 获取处理结果 for future in concurrent.futures.as_completed(futures): obj = futures[future] - - result = future.result() - return_response[obj] = result + return_response[obj] = future.result()[0] + return_png_mask.extend(future.result()[1]) + # upload_results = process_images(return_png_mask) final_progress(process_id) return return_response def process_object(cfg, process_id, total): + uploaded_images = [] basic_info = cfg.get('basic') items_response = { 'layers': [] @@ -83,8 +89,17 @@ def process_object(cfg, process_id, total): layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf'))) else: layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf'))) + # 上传所有图片 + # for layer in layers: + # if 'image' in layer.keys() and layer['image'] is not None: + # uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url'], 'image_type': 'image'}) + # if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None: + # uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url'], 'image_type': 'pattern_image'}) + # if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None: + # uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url'], 'image_type': 'mask'}) + layers, new_size = update_base_size_priority(layers, body_size) # 合成 - items_response['synthesis_url'] = synthesis(layers, body_size) + items_response['synthesis_url'] = synthesis(layers, new_size, basic_info) for lay in layers: items_response['layers'].append({ @@ -114,9 +129,10 @@ def process_object(cfg, process_id, total): 'position': None, 'priority': 0, 'image_url': item.result['front_image_url'], - 'mask_url': item.result['front_mask_url'], + 'mask_url': item.result['mask_url'], "gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "", 'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None, + }) items_response['layers'].append({ 'image_category': f"{item.result['name']}_back", @@ -124,11 +140,58 @@ def process_object(cfg, process_id, total): 'position': None, 'priority': 0, 'image_url': item.result['back_image_url'], - 'mask_url': item.result['back_mask_url'], + 'mask_url': item.result['mask_url'], "gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "", 'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None, + }) items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image']) break update_progress(process_id, total) - return items_response + return items_response, uploaded_images + + +@RunTime +def process_images(images): + with concurrent.futures.ThreadPoolExecutor() as executor: + results = list(executor.map(upload_images, images)) + # results = [] + # for image in images: + # results.append(upload_images(image)) + return results + + +# @RunTime +def upload_images(image_obj): + bucket_name = image_obj['image_url'].split("/", 1)[0] + object_name = image_obj['image_url'].split("/", 1)[1] + if image_obj['image_type'] == 'image' or image_obj['image_type'] == 'pattern_image': + image_data = io.BytesIO() + image_obj['image_obj'].save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return image_obj['image_url'] + else: + mask_inverted = cv2.bitwise_not(image_obj['image_obj']) + # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1]) + return image_obj['image_url'] + + +def update_base_size_priority(layers, size): + # 计算透明背景图片的宽度 + min_x = min(info['position'][1] for info in layers) + x_list = [] + for info in layers: + if info['image'] is not None: + x_list.append(info['position'][1] + info['image'].width) + max_x = max(x_list) + new_width = max_x - min_x + new_height = 700 + # 更新坐标 + for info in layers: + info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x) + return layers, (new_width, new_height) diff --git a/app/service/design/utils/synthesis_item.py b/app/service/design/utils/synthesis_item.py index 7bedbe6..03df2d9 100644 --- a/app/service/design/utils/synthesis_item.py +++ b/app/service/design/utils/synthesis_item.py @@ -59,14 +59,26 @@ def positioning(all_mask_shape, mask_shape, offset): # @RunTime -def synthesis(data, size): +def synthesis(data, size, basic_info): # 创建底图 base_image = Image.new('RGBA', size, (0, 0, 0, 0)) try: - all_mask_shape = (size[1], size[0]) - top_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8) - bottom_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8) + body_mask = None + for d in data: + if d['name'] == 'body': + # 创建一个新的宽高透明图像, 把模特贴上去获取mask + transparent_image = Image.new("RGBA", size, (0, 0, 0, 0)) + transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position + body_mask = np.array(transparent_image.split()[3]) + + # 根据新的坐标获取新的肩点 + left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])] + right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])] + body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255 + _, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY) + top_outer_mask = np.array(binary_body_mask) + bottom_outer_mask = np.array(binary_body_mask) top = True bottom = True @@ -76,21 +88,27 @@ def synthesis(data, size): if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]: top = False mask_shape = data[i]['mask'].shape - y_offset, x_offset = data[i]['position'] + y_offset, x_offset = data[i]['adaptive_position'] # 初始化叠加区域的起始和结束位置 all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) # 将叠加区域赋值为相应的像素值 - top_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end] - elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front"]: + _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) + background = np.zeros_like(top_outer_mask) + background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] + top_outer_mask = background + top_outer_mask + elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]: bottom = False mask_shape = data[i]['mask'].shape - y_offset, x_offset = data[i]['position'] + y_offset, x_offset = data[i]['adaptive_position'] # 初始化叠加区域的起始和结束位置 all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) # 将叠加区域赋值为相应的像素值 - bottom_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end] + _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) + background = np.zeros_like(top_outer_mask) + background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] + bottom_outer_mask = background + bottom_outer_mask elif bottom is False and top is False: break @@ -100,13 +118,13 @@ def synthesis(data, size): if layer['image'] is not None: if layer['name'] != "body": test_image = Image.new('RGBA', size, (0, 0, 0, 0)) - test_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image']) - # mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8) - # mask_alpha = Image.fromarray(mask_data) - # cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha) - base_image.paste(test_image, (0, 0), test_image) + test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image']) + mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8) + mask_alpha = Image.fromarray(mask_data) + cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha) + base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00 else: - base_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image']) + base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image']) result_image = base_image diff --git a/app/service/design_batch/design_batch_celery.py b/app/service/design_batch/design_batch_celery.py new file mode 100644 index 0000000..3f12862 --- /dev/null +++ b/app/service/design_batch/design_batch_celery.py @@ -0,0 +1,126 @@ +import logging +import threading + +from celery import Celery +from minio import Minio + +from app.core.config import * +from app.service.design_batch.item import BodyItem, TopItem, BottomItem +from app.service.design_batch.utils.MQ import publish_status +from app.service.design_batch.utils.organize import organize_body, organize_clothing +from app.service.design_batch.utils.save_json import oss_upload_json +from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single + +id_lock = threading.Lock() +celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://') +celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' +celery_app.conf.worker_hijack_root_logger = False +logging.getLogger('pika').setLevel(logging.WARNING) +logger = logging.getLogger() +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +def process_item(item, basic): + # 处理project中单个item + if item['type'] == "Body": + body_server = BodyItem(data=item, basic=basic, minio_client=minio_client) + item_data = body_server.process() + elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']: + top_server = TopItem(data=item, basic=basic, minio_client=minio_client) + item_data = top_server.process() + else: + bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client) + item_data = bottom_server.process() + return item_data + + +def process_layer(item, layers): + # item处理结束后 对图层数据组装 + if item['name'] == "mannequin": + body_layer = organize_body(item) + layers.append(body_layer) + return item['body_image'].size + else: + front_layer, back_layer = organize_clothing(item) + layers.append(front_layer) + layers.append(back_layer) + + +@celery_app.task +def batch_design(objects_data, tasks_id, json_name): + object_response = [] + threads = [] + active_threads = 0 + lock = threading.Lock() + + def process_object(step, object): + nonlocal active_threads + basic = object['basic'] + items_response = {'layers': []} + if basic['single_overall'] == "overall": + item_results = [] + for item in object['items']: + item_results.append(process_item(item, basic)) + layers = [] + body_size = None + for item in item_results: + body_size = process_layer(item, layers) + layers = sorted(layers, key=lambda s: s.get("priority", float('inf'))) + + layers, new_size = update_base_size_priority(layers, body_size) + + for lay in layers: + items_response['layers'].append({ + 'image_category': lay['name'], + 'position': lay['position'], + 'priority': lay.get("priority", None), + 'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None, + 'image_size': lay['image'] if lay['image'] is None else lay['image'].size, + 'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "", + 'mask_url': lay['mask_url'], + 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, + 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, + }) + items_response['synthesis_url'] = synthesis(layers, new_size, basic) + else: + item_result = process_item(object['items'][0], basic) + items_response['layers'].append({ + 'image_category': f"{item_result['name']}_front", + 'image_size': item_result['back_image'].size if item_result['back_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item_result['front_image_url'], + 'mask_url': item_result['mask_url'], + "gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "", + 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, + }) + items_response['layers'].append({ + 'image_category': f"{item_result['name']}_back", + 'image_size': item_result['front_image'].size if item_result['front_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item_result['back_image_url'], + 'mask_url': item_result['mask_url'], + "gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "", + 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, + }) + items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image']) + + with lock: + object_response.append(items_response) + publish_status(tasks_id, step + 1, items_response) + active_threads -= 1 + + for step, object in enumerate(objects_data): + t = threading.Thread(target=process_object, args=(step, object)) + threads.append(t) + t.start() + with lock: + active_threads += 1 + + for t in threads: + t.join() + + oss_upload_json(minio_client, object_response, json_name) + publish_status(tasks_id, "ok", json_name) + return object_response diff --git a/app/service/design_batch/item.py b/app/service/design_batch/item.py new file mode 100644 index 0000000..cad1488 --- /dev/null +++ b/app/service/design_batch/item.py @@ -0,0 +1,61 @@ +from app.service.design_batch.pipeline import * + + +class BaseItem: + def __init__(self, data, basic): + self.result = data.copy() + self.result['name'] = data['type'].lower() + self.result.pop("type") + self.result.update(basic) + + +class TopItem(BaseItem): + def __init__(self, data, basic, minio_client): + super().__init__(data, basic) + self.top_pipeline = [ + LoadImage(minio_client), + KeyPoint(), + Segmentation(minio_client), + Color(minio_client), + PrintPainting(minio_client), + Scaling(), + Split(minio_client) + ] + + def process(self): + for item in self.top_pipeline: + self.result = item(self.result) + return self.result + + +class BottomItem(BaseItem): + def __init__(self, data, basic, minio_client): + super().__init__(data, basic) + self.bottom_pipeline = [ + LoadImage(minio_client), + KeyPoint(), + ContourDetection(), + # Segmentation(), + Color(minio_client), + PrintPainting(minio_client), + Scaling(), + Split(minio_client) + ] + + def process(self): + for item in self.bottom_pipeline: + self.result = item(self.result) + return self.result + + +class BodyItem(BaseItem): + def __init__(self, data, basic, minio_client): + super().__init__(data, basic) + self.top_pipeline = [ + LoadBodyImage(minio_client), + ] + + def process(self): + for item in self.top_pipeline: + self.result = item(self.result) + return self.result diff --git a/app/service/design_batch/pipeline/__init__.py b/app/service/design_batch/pipeline/__init__.py new file mode 100644 index 0000000..ec55933 --- /dev/null +++ b/app/service/design_batch/pipeline/__init__.py @@ -0,0 +1,20 @@ +from .color import Color +from .contour_detection import ContourDetection +from .keypoint import KeyPoint +from .keypoint import KeyPoint +from .loading import LoadImage, LoadBodyImage +from .print_painting import PrintPainting +from .scale import Scaling +from .segmentation import Segmentation +from .split import Split + +__all__ = [ + 'LoadBodyImage', 'LoadImage', + 'KeyPoint', + 'ContourDetection', + 'Segmentation', + 'Color', + 'PrintPainting', + 'Scaling', + 'Split' +] diff --git a/app/service/design_batch/pipeline/color.py b/app/service/design_batch/pipeline/color.py new file mode 100644 index 0000000..546c671 --- /dev/null +++ b/app/service/design_batch/pipeline/color.py @@ -0,0 +1,62 @@ +import logging + +import cv2 +import numpy as np + +from app.service.utils.new_oss_client import oss_get_image + +logger = logging.getLogger() + + +class Color: + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + dim_image_h, dim_image_w = result['image'].shape[0:2] + if "gradient" in result.keys() and result['gradient'] != "": + bucket_name = result['gradient'].split('/')[0] + object_name = result['gradient'][result['gradient'].find('/') + 1:] + pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name) + resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + else: + pattern = self.get_pattern(result['color']) + resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255) + result['pattern_image'] = get_image_fir.astype(np.uint8) + result['final_image'] = result['pattern_image'] + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + result['alpha'] = 100 / 255.0 + return result + + def get_gradient(self, bucket_name, object_name): + # 获取渐变色图案 + image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2") + if image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + return image + + @staticmethod + def crop_image(image, image_size_h, image_size_w): + x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6) + y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6) + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :] + return image + + @staticmethod + def get_pattern(single_color): + if single_color is None: + raise False + R, G, B = single_color.split(' ') + pattern = np.zeros([1, 1, 3], np.uint8) + pattern[0, 0, 0] = int(B) + pattern[0, 0, 1] = int(G) + pattern[0, 0, 2] = int(R) + return pattern diff --git a/app/service/design_batch/pipeline/contour_detection.py b/app/service/design_batch/pipeline/contour_detection.py new file mode 100644 index 0000000..2b76c0b --- /dev/null +++ b/app/service/design_batch/pipeline/contour_detection.py @@ -0,0 +1,37 @@ +import cv2 +import numpy as np + + +class ContourDetection: + def __call__(self, result): + Contour = self.get_contours(result['image']) + Mask = np.zeros(result['image'].shape[:2], np.uint8) + if len(Contour): + Max_contour = Contour[0] + Epsilon = 0.001 * cv2.arcLength(Max_contour, True) + Approx = cv2.approxPolyDP(Max_contour, Epsilon, True) + cv2.drawContours(Mask, [Approx], -1, 255, -1) + else: + Mask = np.ones(result['image'].shape[:2], np.uint8) * 255 + # TODO 修复部分图片出现透明的情况 下版本上线 + # img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) + # ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY) + # Mask = cv2.bitwise_not(Mask) + if result['pre_mask'] is None: + result['mask'] = Mask + else: + result['mask'] = cv2.bitwise_and(Mask, result['pre_mask']) + result['front_mask'] = result['mask'] + result['back_mask'] = result['mask'] + return result + + @staticmethod + def get_contours(image): + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + Edge = cv2.Canny(gray, 10, 150) + kernel = np.ones((5, 5), np.uint8) + Edge = cv2.dilate(Edge, kernel=kernel, iterations=1) + Edge = cv2.erode(Edge, kernel=kernel, iterations=1) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contour = sorted(Contour, key=cv2.contourArea, reverse=True) + return Contour diff --git a/app/service/design_batch/pipeline/keypoint.py b/app/service/design_batch/pipeline/keypoint.py new file mode 100644 index 0000000..313a613 --- /dev/null +++ b/app/service/design_batch/pipeline/keypoint.py @@ -0,0 +1,114 @@ +import logging + +import numpy as np +from pymilvus import MilvusClient + +from app.core.config import * +from app.service.design_batch.utils.design_ensemble import get_keypoint_result + +logger = logging.getLogger(__name__) + + +class KeyPoint: + name = "KeyPoint" + + @classmethod + def get_name(cls): + return cls.name + + def __call__(self, result): + if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 + # result['clothes_keypoint'] = self.infer_keypoint_result(result) + site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + # keypoint_cache = search_keypoint_cache(result["image_id"], site) + keypoint_cache = self.keypoint_cache(result, site) + # 取消向量查询 直接过模型推理 + # keypoint_cache = False + if keypoint_cache is False: + keypoint_infer_result, site = self.infer_keypoint_result(result) + result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site) + else: + result['clothes_keypoint'] = keypoint_cache + return result + + @staticmethod + def infer_keypoint_result(result): + site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果 + return keypoint_infer_result, site + + @staticmethod + def save_keypoint_cache(keypoint_id, cache, site): + if site == "down": + zeros = np.zeros(20, dtype=int) + result = np.concatenate([zeros, cache.flatten()]) + else: + zeros = np.zeros(4, dtype=int) + result = np.concatenate([cache.flatten(), zeros]) + # 取消向量保存 直接拿结果 + data = [ + {"keypoint_id": keypoint_id, + "keypoint_site": site, + "keypoint_vector": result.tolist() + } + ] + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) + client.close() + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logger.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + @staticmethod + def update_keypoint_cache(keypoint_id, infer_result, search_result, site): + if site == "up": + # 需要的是up 即推理出来的是up 那么查询的就是down + result = np.concatenate([infer_result.flatten(), search_result[-4:]]) + else: + # 需要的是down 即推理出来的是down 那么查询的就是up + result = np.concatenate([search_result[:20], infer_result.flatten()]) + data = [ + {"keypoint_id": keypoint_id, + "keypoint_site": "all", + "keypoint_vector": result.tolist() + } + ] + + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + client.upsert( + collection_name=MILVUS_TABLE_KEYPOINT, + data=data + ) + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logger.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + # @ RunTime + def keypoint_cache(self, result, site): + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + keypoint_id = result['image_id'] + res = client.query( + collection_name=MILVUS_TABLE_KEYPOINT, + # ids=[keypoint_id], + filter=f"keypoint_id == {keypoint_id}", + output_fields=['keypoint_vector', 'keypoint_site'] + ) + if len(res) == 0: + # 没有结果 直接推理拿结果 并保存 + keypoint_infer_result, site = self.infer_keypoint_result(result) + return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) + elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: + # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) + elif res[0]["keypoint_site"] != site: + # 需要的类型和查询到的不一致,则更新类型为all + keypoint_infer_result, site = self.infer_keypoint_result(result) + return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) + except Exception as e: + logger.info(f"search keypoint cache milvus error {e}") + return False diff --git a/app/service/design_batch/pipeline/loading.py b/app/service/design_batch/pipeline/loading.py new file mode 100644 index 0000000..8f02378 --- /dev/null +++ b/app/service/design_batch/pipeline/loading.py @@ -0,0 +1,77 @@ +import logging + +import cv2 + +from app.service.utils.new_oss_client import oss_get_image + +logger = logging.getLogger() + + +class LoadBodyImage: + name = "LoadBodyImage" + + def __init__(self, minio_client): + self.minio_client = minio_client + + @classmethod + def get_name(cls): + return cls.name + + def __call__(self, result): + result["name"] = "mannequin" + result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL") + return result + + +class LoadImage: + name = "LoadImage" + + def __init__(self, minio_client): + self.minio_client = minio_client + + @classmethod + def get_name(cls): + return cls.name + + def __call__(self, result): + result['image'], result['pre_mask'] = self.read_image(result['path']) + result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) + result['keypoint'] = self.get_keypoint(result['name']) + result['img_shape'] = result['image'].shape + result['ori_shape'] = result['image'].shape + return result + + def read_image(self, image_path): + image_mask = None + image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2") + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + if image.shape[2] == 4: # 如果是四通道 mask + image_mask = image[:, :, 3] + image = image[:, :, :3] + + if image.shape[:2] <= (50, 50): + # 计算新尺寸 + new_size = (image.shape[1] * 2, image.shape[0] * 2) + # 调整大小 + image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR) + return image, image_mask + + @staticmethod + def get_keypoint(name): + if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops': + keypoint = 'shoulder' + elif name == 'trousers' or name == 'skirt' or name == 'bottoms': + keypoint = 'waistband' + elif name == 'bag': + keypoint = 'hand_point' + elif name == 'shoes': + keypoint = 'toe' + elif name == 'hairstyle': + keypoint = 'head_point' + elif name == 'earring': + keypoint = 'ear_point' + else: + raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " + f"bag, shoes, hairstyle, earring.") + return keypoint diff --git a/app/service/design_batch/pipeline/print_painting.py b/app/service/design_batch/pipeline/print_painting.py new file mode 100644 index 0000000..6fe40d8 --- /dev/null +++ b/app/service/design_batch/pipeline/print_painting.py @@ -0,0 +1,524 @@ +import random + +import cv2 +import numpy as np +from PIL import Image + +from app.service.utils.new_oss_client import oss_get_image + + +class PrintPainting: + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + single_print = result['print']['single'] + overall_print = result['print']['overall'] + element_print = result['print']['element'] + result['single_image'] = None + result['print_image'] = None + if overall_print['print_path_list']: + painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} + result['print_image'] = result['pattern_image'] + if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True) + painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True) + painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True) + + # resize 到sketch大小 + painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + else: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False) + result['print_image'] = self.printpaint(result, painting_dict, print_=True) + result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image'] + + if single_print['print_path_list']: + print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(single_print['print_path_list'])): + image, image_mode = self.read_image(single_print['print_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i]) + + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask)) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + + if element_print['element_path_list']: + print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(element_print['element_path_list'])): + image, image_mode = self.read_image(element_print['element_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i]) + + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + # TODO element 丢失信息 + three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)]) + img_bg = cv2.bitwise_and(result['final_image'], three_channel_image) + # mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + # gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + # img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + return result + + @staticmethod + def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x): + temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8) + temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY) + ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY) + mask_inv = cv2.bitwise_not(mask_) + img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv) + img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_) + print_background = img1_bg + img2_fg + return print_background + + def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False): + if print_trigger: + print_ = self.get_print(print_dict) + painting_dict['Trigger'] = not is_single + painting_dict['location'] = print_['location'] + single_mask_inv_print = self.get_mask_inv(print_['image']) + dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w']) + dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) + if not is_single: + self.random_seed = random.randint(0, 1000) + # 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪 + if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) + else: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + else: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) + painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern + return painting_dict + + def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False): + tile = None + if not trigger: + tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) + else: + resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) + if len(pattern.shape) == 2: + tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4)) + if len(pattern.shape) == 3: + tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1)) + tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape) + return tile + + def get_mask_inv(self, print_): + if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255: + bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0] + print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2] + bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True) + bg_a_high, bg_a_low = self.get_low_high_lab(bg_a) + bg_b_high, bg_b_low = self.get_low_high_lab(bg_b) + lower = np.array([bg_L_low, bg_a_low, bg_b_low]) + upper = np.array([bg_L_high, bg_a_high, bg_b_high]) + mask_inv = cv2.inRange(print_tile, lower, upper) + return mask_inv + else: + # bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0] + # print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + # bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2] + # bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True) + # bg_a_high, bg_a_low = self.get_low_high_lab(bg_a) + # bg_b_high, bg_b_low = self.get_low_high_lab(bg_b) + # lower = np.array([bg_L_low, bg_a_low, bg_b_low]) + # upper = np.array([bg_L_high, bg_a_high, bg_b_high]) + + # print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + # mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY) + + # mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY) + mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8) + return mask_inv + + @staticmethod + def printpaint(result, painting_dict, print_=False): + + if print_ and painting_dict['Trigger']: + print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print'])) + img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask) + else: + print_mask = result['mask'] + img_fg = result['final_image'] + if print_ and not painting_dict['Trigger']: + index_ = None + try: + index_ = len(painting_dict['location']) + except: + assert f'there must be parameter of location if choose IfSingle' + + for i in range(index_): + start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0]) + + length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0]) + length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1]) + + change_region = img_fg[start_h: length_h, start_w: length_w, :] + # problem in change_mask + change_mask = print_mask[start_h: length_h, start_w: length_w] + # get real part into change mask + _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY) + mask = cv2.bitwise_not(painting_dict['mask_inv_print']) + img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region + + clothes_mask_print = cv2.bitwise_not(print_mask) + + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + print_image = cv2.add(img_bg, img_fg) + return print_image + + def get_print(self, print_dict): + if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3: + print_dict['scale'] = 0.3 + else: + print_dict['scale'] = print_dict['print_scale_list'][0] + + bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] + object_name = print_dict['print_path_list'][0].split("/", 1)[1] + image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL") + # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 + if image.mode == "RGBA": + new_background = Image.new('RGB', image.size, (255, 255, 255)) + new_background.paste(image, mask=image.split()[3]) + image = new_background + print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + return print_dict + + def crop_image(self, image, image_size_h, image_size_w, location, print_shape): + print_w = print_shape[1] + print_h = print_shape[0] + + random.seed(self.random_seed) + # logging.info(f'overall print location : {location}') + # x_offset = random.randint(0, image.shape[0] - image_size_h) + # y_offset = random.randint(0, image.shape[1] - image_size_w) + + # 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量 + x_offset = print_w - int(location[0][1] % print_w) + y_offset = print_w - int(location[0][0] % print_h) + + # y_offset = int(location[0][0]) + # x_offset = int(location[0][1]) + + if len(image.shape) == 2: + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w] + elif len(image.shape) == 3: + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :] + return image + + @staticmethod + def get_low_high_lab(Lab_value, L=False): + if L: + high = Lab_value + 30 if Lab_value + 30 < 255 else 255 + low = Lab_value - 30 if Lab_value - 30 > 0 else 0 + else: + high = Lab_value + 30 if Lab_value + 30 < 255 else 255 + low = Lab_value - 30 if Lab_value - 30 > 0 else 0 + return high, low + + @staticmethod + def img_rotate(image, angel, scale): + """顺时针旋转图像任意角度 + + Args: + image (np.array): [原始图像] + angel (float): [逆时针旋转的角度] + + Returns: + [array]: [旋转后的图像] + """ + + h, w = image.shape[:2] + center = (w // 2, h // 2) + # if type(angel) is not int: + # angel = 0 + M = cv2.getRotationMatrix2D(center, -angel, scale) + # 调整旋转后的图像长宽 + rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0])))) + rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0])))) + M[0, 2] += (rotated_w - w) // 2 + M[1, 2] += (rotated_h - h) // 2 + # 旋转图像 + rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h)) + + return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2) + # return rotated_img, (0, 0) + + @staticmethod + def rotate_crop_image(img, angle, crop): + """ + angle: 旋转的角度 + crop: 是否需要进行裁剪,布尔向量 + """ + crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w] + w, h = img.shape[:2] + # 旋转角度的周期是360° + angle %= 360 + # 计算仿射变换矩阵 + M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + # 得到旋转后的图像 + img_rotated = cv2.warpAffine(img, M_rotation, (w, h)) + + # 如果需要去除黑边 + if crop: + # 裁剪角度的等效周期是180° + angle_crop = angle % 180 + if angle > 90: + angle_crop = 180 - angle_crop + # 转化角度为弧度 + theta = angle_crop * np.pi / 180 + # 计算高宽比 + hw_ratio = float(h) / float(w) + # 计算裁剪边长系数的分子项 + tan_theta = np.tan(theta) + numerator = np.cos(theta) + np.sin(theta) * np.tan(theta) + + # 计算分母中和高宽比相关的项 + r = hw_ratio if h > w else 1 / hw_ratio + # 计算分母项 + denominator = r * tan_theta + 1 + # 最终的边长系数 + crop_mult = numerator / denominator + + # 得到裁剪区域 + w_crop = int(crop_mult * w) + h_crop = int(crop_mult * h) + x0 = int((w - w_crop) / 2) + y0 = int((h - h_crop) / 2) + + img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop) + + return img_rotated + + def read_image(self, image_url): + image = oss_get_image(oss_client=self.minio_client, bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2") + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image = Image.fromarray(image_rgb) + image_mode = "RGBA" + else: + image_mode = "RGB" + return image, image_mode + + @staticmethod + def resize_and_crop(img, target_width, target_height): + # 获取原始图像的尺寸 + original_height, original_width = img.shape[:2] + + # 计算目标尺寸的宽高比 + target_ratio = target_width / target_height + + # 计算原始图像的宽高比 + original_ratio = original_width / original_height + + # 调整尺寸 + if original_ratio > target_ratio: + # 原始图像更宽,按高度resize,然后裁剪宽度 + new_height = target_height + new_width = int(original_width * (target_height / original_height)) + resized_img = cv2.resize(img, (new_width, new_height)) + # 裁剪宽度 + start_x = (new_width - target_width) // 2 + cropped_img = resized_img[:, start_x:start_x + target_width] + else: + # 原始图像更高,按宽度resize,然后裁剪高度 + new_width = target_width + new_height = int(original_height * (target_width / original_width)) + resized_img = cv2.resize(img, (new_width, new_height)) + # 裁剪高度 + start_y = (new_height - target_height) // 2 + cropped_img = resized_img[start_y:start_y + target_height, :] + + return cropped_img diff --git a/app/service/design_batch/pipeline/scale.py b/app/service/design_batch/pipeline/scale.py new file mode 100644 index 0000000..1908a9c --- /dev/null +++ b/app/service/design_batch/pipeline/scale.py @@ -0,0 +1,49 @@ +import math + +import cv2 + + +class Scaling: + def __call__(self, result): + if result['keypoint'] in ['waistband', 'shoulder', 'head_point']: + # milvus_db_keypoint_cache + distance_clo = math.sqrt( + (int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2 + + + (int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2 + ) + + distance_bdy = math.sqrt( + (int(result['body_point_test'][result['keypoint'] + '_left'][0]) + - + int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1 + ) + + if distance_clo == 0: + result['scale'] = 1 + else: + result['scale'] = distance_bdy / distance_clo + elif result['keypoint'] == 'toe': + distance_bdy = math.sqrt( + (int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2 + + + (int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2 + ) + + Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0) + Edge = cv2.Canny(Blur, 10, 200) + Edge = cv2.dilate(Edge, None) + Edge = cv2.erode(Edge, None) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contours = sorted(Contour, key=cv2.contourArea, reverse=True) + + Max_contour = Contours[0] + x, y, w, h = cv2.boundingRect(Max_contour) + width = w + distance_clo = width + result['scale'] = distance_bdy / distance_clo + elif result['keypoint'] == 'hand_point': + result['scale'] = result['scale_bag'] + elif result['keypoint'] == 'ear_point': + result['scale'] = result['scale_earrings'] + return result diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py new file mode 100644 index 0000000..cba3446 --- /dev/null +++ b/app/service/design_batch/pipeline/segmentation.py @@ -0,0 +1,70 @@ +import logging +import os + +import cv2 +import numpy as np + +from app.core.config import SEG_CACHE_PATH +from app.service.design_batch.utils.design_ensemble import get_seg_result +from app.service.utils.new_oss_client import oss_get_image + +logger = logging.getLogger() + + +class Segmentation: + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "": + seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2") + seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST) + # 转换颜色空间为 RGB(OpenCV 默认是 BGR) + image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB) + + r, g, b = cv2.split(image_rgb) + red_mask = r > g + green_mask = g > r + + # 创建红色和绿色掩码 + result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255 + result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255 + result['mask'] = result['front_mask'] + result['back_mask'] + else: + # 本地查询seg 缓存是否存在 + _, seg_result = self.load_seg_result(result["image_id"]) + result['seg_result'] = seg_result + if not _: + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + self.save_seg_result(seg_result, result['image_id']) + # 处理前片后片 + temp_front = seg_result == 1.0 + result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) + temp_back = seg_result == 2.0 + result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) + result['mask'] = result['front_mask'] + result['back_mask'] + return result + + @staticmethod + def save_seg_result(seg_result, image_id): + file_path = f"seg_cache/{image_id}.npy" + try: + np.save(file_path, seg_result) + logger.info(f"保存成功 :{os.path.abspath(file_path)}") + except Exception as e: + logger.error(f"保存失败: {e}") + + @staticmethod + def load_seg_result(image_id): + file_path = f"seg_cache/{image_id}.npy" + logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + try: + seg_result = np.load(file_path) + return True, seg_result + except FileNotFoundError: + logger.warning("文件不存在") + return False, None + except Exception as e: + logger.error(f"加载失败: {e}") + return False, None diff --git a/app/service/design_batch/pipeline/split.py b/app/service/design_batch/pipeline/split.py new file mode 100644 index 0000000..5dbcef5 --- /dev/null +++ b/app/service/design_batch/pipeline/split.py @@ -0,0 +1,74 @@ +import io +import logging + +import cv2 +import numpy as np +from PIL import Image +from cv2 import cvtColor, COLOR_BGR2RGBA + +from app.core.config import AIDA_CLOTHING +from app.service.design_batch.utils.conversion_image import rgb_to_rgba +from app.service.design_batch.utils.upload_image import upload_png_mask +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.new_oss_client import oss_upload_image + + +class Split(object): + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + try: + + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + front_mask = result['front_mask'] + back_mask = result['back_mask'] + rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) + new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1])) + rgba_image = cv2.resize(rgba_image, new_size) + result_front_image = np.zeros_like(rgba_image) + front_mask = cv2.resize(front_mask, new_size) + result_front_image[front_mask != 0] = rgba_image[front_mask != 0] + result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) + result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None) + + height, width = front_mask.shape + mask_image = np.zeros((height, width, 3)) + mask_image[front_mask != 0] = [0, 0, 255] + + if result["name"] in ('blouse', 'dress', 'outwear', 'tops'): + result_back_image = np.zeros_like(rgba_image) + back_mask = cv2.resize(back_mask, new_size) + result_back_image[back_mask != 0] = rgba_image[back_mask != 0] + result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) + result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None) + mask_image[back_mask != 0] = [0, 255, 0] + + rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask) + mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA)) + image_data = io.BytesIO() + mask_pil.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) + result['mask_url'] = req.bucket_name + "/" + req.object_name + else: + rbga_mask = rgb_to_rgba(mask_image, front_mask) + mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA)) + image_data = io.BytesIO() + mask_pil.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) + result['mask_url'] = req.bucket_name + "/" + req.object_name + result['back_image'] = None + result["back_image_url"] = None + # result["back_mask_url"] = None + # result['back_mask_image'] = None + # 创建中间图层 + result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask']) + result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA)) + result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}') + return result + except Exception as e: + logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py new file mode 100644 index 0000000..ca6908e --- /dev/null +++ b/app/service/design_batch/service.py @@ -0,0 +1,11 @@ +import json + +from app.service.design_batch.design_batch_celery import batch_design +from app.service.design_batch.utils.MQ import publish_status + + +async def start_design_batch_generate(data, file): + generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id) + print(generate_clothes_task) + publish_status(data.tasks_id, "0/100", "") + return {"task_id": data.tasks_id} diff --git a/app/service/design_batch/test.py b/app/service/design_batch/test.py new file mode 100644 index 0000000..6b94bc6 --- /dev/null +++ b/app/service/design_batch/test.py @@ -0,0 +1,162 @@ +from app.service.design_batch.design_batch_celery import batch_design + +if __name__ == '__main__': + data = { + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 200, + 241 + ], + "hand_point_right": [ + 223, + 297 + ], + "waistband_left": [ + 112, + 241 + ], + "hand_point_left": [ + 92, + 305 + ], + "shoulder_left": [ + 99, + 116 + ], + "shoulder_right": [ + 215, + 116 + ] + }, + "layer_order": True, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "businessId": 270372, + "color": "30 28 28", + "image_id": 69780, + "offset": [ + 0, + 0 + ], + "path": "aida-sys-image/images/female/trousers/0825000630.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 10, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "businessId": 270373, + "color": "30 28 28", + "image_id": 98243, + "offset": [ + 0, + 0 + ], + "path": "aida-sys-image/images/female/blouse/0902003811.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 11, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "businessId": 270374, + "color": "172 68 68", + "image_id": 98244, + "offset": [ + 0, + 0 + ], + "path": "aida-sys-image/images/female/outwear/0825000410.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "priority": 12, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png", + "image_id": 96090, + "type": "Body" + } + ] + } + ], + "process_id": "83" + } + task_id = 1 + json_name = "test.json" + batch_design.delay(data['objects'], task_id, json_name) diff --git a/app/service/design_batch/utils/MQ.py b/app/service/design_batch/utils/MQ.py new file mode 100644 index 0000000..50e98c2 --- /dev/null +++ b/app/service/design_batch/utils/MQ.py @@ -0,0 +1,17 @@ +import json + +import pika + + +def publish_status(task_id, progress, result): + connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213')) + channel = connection.channel() + channel.queue_declare(queue='DesignBatch', durable=True) + message = {'task_id': task_id, 'progress': progress, "result": result} + channel.basic_publish(exchange='', + routing_key='DesignBatch', + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode=2, + )) + connection.close() diff --git a/app/service/design/core/__init__.py b/app/service/design_batch/utils/__init__.py similarity index 100% rename from app/service/design/core/__init__.py rename to app/service/design_batch/utils/__init__.py diff --git a/app/service/design_batch/utils/conversion_image.py b/app/service/design_batch/utils/conversion_image.py new file mode 100644 index 0000000..11e39ae --- /dev/null +++ b/app/service/design_batch/utils/conversion_image.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :conversion_image.py +@Author :周成融 +@Date :2023/8/21 10:40:29 +@detail : +""" +import numpy as np + + +# def rgb_to_rgba(rgb_size, rgb_image, mask): +# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8) +# # 创建四通道的结果图像 +# rgba_image = np.dstack((rgb_image, alpha_channel)) +# alpha_channel = np.where(mask > 0, 255, 0) +# # 更新RGBA图像的透明度通道 +# rgba_image[:, :, 3] = alpha_channel +# return rgba_image + +def rgb_to_rgba(rgb_image, mask): + # 创建全透明的alpha通道 + alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8) + # 合并RGB图像和alpha通道 + rgba_image = np.dstack((rgb_image, alpha_channel)) + return rgba_image + + +if __name__ == '__main__': + image = open("") diff --git a/app/service/design_batch/utils/design_ensemble.py b/app/service/design_batch/utils/design_ensemble.py new file mode 100644 index 0000000..f4f6a34 --- /dev/null +++ b/app/service/design_batch/utils/design_ensemble.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :design_ensemble.py +@Author :周成融 +@Date :2023/8/16 19:36:21 +@detail :发起请求 获取推理结果 +""" +import logging + +import cv2 +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient + +from app.core.config import * + +""" + keypoint + 预处理 推理 后处理 +""" + + +def keypoint_preprocess(img_path): + img = mmcv.imread(img_path) + img_scale = (256, 256) + h, w = img.shape[:2] + img = cv2.resize(img, img_scale) + w_scale = img_scale[0] / w + h_scale = img_scale[1] / h + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, (w_scale, h_scale) + + +# @ RunTime +# 推理 +def get_keypoint_result(image, site): + keypoint_result = None + try: + image, scale_factor = keypoint_preprocess(image) + client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + transformed_img = image.astype(np.float32) + inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)] + results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs) + inference_output = torch.from_numpy(results.as_numpy(f'output')) + keypoint_result = keypoint_postprocess(inference_output, scale_factor) + except Exception as e: + logging.warning(f"get_keypoint_result : {e}") + return keypoint_result + + +def keypoint_postprocess(output, scale_factor): + max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2) + max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2) + segment_result = max_coords.numpy() + scale_factor = [1 / x for x in scale_factor[::-1]] + scale_matrix = np.diag(scale_factor) + nan = np.isinf(scale_matrix) + scale_matrix[nan] = 0 + return np.ceil(np.dot(segment_result, scale_matrix) * 4) + + +""" + seg + 预处理 推理 后处理 +""" + + +# KNet +def seg_preprocess(img_path): + img = mmcv.imread(img_path) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + # 如果图片size任意一边 大于 1024, 则会resize 成1024 + if ori_shape != (img_scale_w, img_scale_h): + # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 + img = cv2.resize(img, (img_scale_h, img_scale_w)) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + +# @ RunTime +def get_seg_result(image_id, image): + image, ori_shape = seg_preprocess(image) + client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") + transformed_img = image.astype(np.float32) + # 输入集 + inputs = [ + httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + # 输出集 + outputs = [ + httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True), + ] + results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs) + # 推理 + # 取结果 + inference_output1 = results.as_numpy(SEGMENTATION['output']) + seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape) + return seg_result + + +# no cache +def seg_postprocess(image_id, output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + +def key_point_show(image_path, key_point_result=None): + img = cv2.imread(image_path) + points_list = key_point_result + point_size = 1 + point_color = (0, 0, 255) # BGR + thickness = 4 # 可以为 0 、4、8 + for point in points_list: + cv2.circle(img, point[::-1], point_size, point_color, thickness) + cv2.imshow("0", img) + cv2.waitKey(0) + + +if __name__ == '__main__': + image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png") + a = get_keypoint_result(image, "up") + new_list = [] + print(list) + for i in a[0]: + new_list.append((int(i[0]), int(i[1]))) + key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list) + # a = get_seg_result(1, image) + print(a) diff --git a/app/service/design_batch/utils/organize.py b/app/service/design_batch/utils/organize.py new file mode 100644 index 0000000..8190de0 --- /dev/null +++ b/app/service/design_batch/utils/organize.py @@ -0,0 +1,77 @@ +import cv2 + +from app.core.config import PRIORITY_DICT + + +def organize_body(layer): + body_layer = dict(priority=0, + name=layer["name"].lower(), + image=layer['body_image'], + image_url=layer['body_path'], + mask_image=None, + mask_url=None, + sacle=1, + # mask=layer['body_mask'], + position=(0, 0)) + return body_layer + + +def organize_clothing(layer): + # 起始坐标 + start_point = calculate_start_point(layer['keypoint'], layer['scale'], layer['clothes_keypoint'], layer['body_point_test'], layer["offset"], layer["resize_scale"]) + # 前片数据 + front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None), + name=f'{layer["name"].lower()}_front', + image=layer["front_image"], + # mask_image=layer['front_mask_image'], + image_url=layer['front_image_url'], + mask_url=layer['mask_url'], + sacle=layer['scale'], + clothes_keypoint=layer['clothes_keypoint'], + position=start_point, + resize_scale=layer["resize_scale"], + mask=cv2.resize(layer['mask'], layer["front_image"].size), + gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", + pattern_image_url=layer['pattern_image_url'], + pattern_image=layer['pattern_image'] + + ) + # 后片数据 + back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), + name=f'{layer["name"].lower()}_back', + image=layer["back_image"], + # mask_image=layer['back_mask_image'], + image_url=layer['back_image_url'], + mask_url=layer['mask_url'], + sacle=layer['scale'], + clothes_keypoint=layer['clothes_keypoint'], + position=start_point, + resize_scale=layer["resize_scale"], + mask=cv2.resize(layer['mask'], layer["front_image"].size), + gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", + pattern_image_url=layer['pattern_image_url'], + ) + return front_layer, back_layer + + +def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale): + """ + Align left + Args: + keypoint_type: string, "waistband" | "shoulder" | "ear_point" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + + Returns: + start_point: tuple (x', y') + x' = y_body - y1 * scale + offset + y' = x_body - x1 * scale + offset + + """ + side_indicator = f'{keypoint_type}_left' + start_point = ( + int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y + int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x + ) + return start_point diff --git a/app/service/design_batch/utils/progress.py b/app/service/design_batch/utils/progress.py new file mode 100644 index 0000000..0f2c9cf --- /dev/null +++ b/app/service/design_batch/utils/progress.py @@ -0,0 +1,30 @@ +import logging + +from app.service.design_fast.utils.redis_utils import Redis + +logger = logging.getLogger(__name__) + + +def update_progress(process_id, total): + # logger.info(f"{process_id} , {total}") + r = Redis() + progress = r.read(key=process_id) + if progress and total != 1: + if int(progress) <= 100: + r.write(key=process_id, value=int(progress) + int(100 / total)) + else: + r.write(key=process_id, value=99) + return progress + elif total == 1: + r.write(key=process_id, value=100) + return progress + else: + r.write(key=process_id, value=int(100 / total)) + return progress + + +def final_progress(process_id): + r = Redis() + progress = r.read(key=process_id) + r.write(key=process_id, value=100) + return progress diff --git a/app/service/design_batch/utils/redis_utils.py b/app/service/design_batch/utils/redis_utils.py new file mode 100644 index 0000000..012fbe0 --- /dev/null +++ b/app/service/design_batch/utils/redis_utils.py @@ -0,0 +1,99 @@ +import redis + +from app.core.config import REDIS_HOST, REDIS_PORT + + +class Redis(object): + """ + redis数据库操作 + """ + + @staticmethod + def _get_r(): + host = REDIS_HOST + port = REDIS_PORT + db = 0 + r = redis.StrictRedis(host, port, db) + return r + + @classmethod + def write(cls, key, value, expire=None): + """ + 写入键值对 + """ + # 判断是否有过期时间,没有就设置默认值 + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.set(key, value, ex=expire_in_seconds) + + @classmethod + def read(cls, key): + """ + 读取键值对内容 + """ + r = cls._get_r() + value = r.get(key) + return value.decode('utf-8') if value else value + + @classmethod + def hset(cls, name, key, value): + """ + 写入hash表 + """ + r = cls._get_r() + r.hset(name, key, value) + + @classmethod + def hget(cls, name, key): + """ + 读取指定hash表的键值 + """ + r = cls._get_r() + value = r.hget(name, key) + return value.decode('utf-8') if value else value + + @classmethod + def hgetall(cls, name): + """ + 获取指定hash表所有的值 + """ + r = cls._get_r() + return r.hgetall(name) + + @classmethod + def delete(cls, *names): + """ + 删除一个或者多个 + """ + r = cls._get_r() + r.delete(*names) + + @classmethod + def hdel(cls, name, key): + """ + 删除指定hash表的键值 + """ + r = cls._get_r() + r.hdel(name, key) + + @classmethod + def expire(cls, name, expire=None): + """ + 设置过期时间 + """ + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.expire(name, expire_in_seconds) + + +if __name__ == '__main__': + redis_client = Redis() + # print(redis_client.write(key="1230", value=0)) + redis_client.write(key="1230", value=10) + # print(redis_client.read(key="1230")) diff --git a/app/service/design_batch/utils/save_json.py b/app/service/design_batch/utils/save_json.py new file mode 100644 index 0000000..9acd916 --- /dev/null +++ b/app/service/design_batch/utils/save_json.py @@ -0,0 +1,13 @@ +import json +import logging + +logger = logging.getLogger() + + +def oss_upload_json(oss_client, json_data, object_name): + try: + with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file: + json.dump(json_data, file, indent=4) + oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}") + except Exception as e: + logger.warning(str(e)) diff --git a/app/service/design_batch/utils/synthesis_item.py b/app/service/design_batch/utils/synthesis_item.py new file mode 100644 index 0000000..272ab23 --- /dev/null +++ b/app/service/design_batch/utils/synthesis_item.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :synthesis_item.py +@Author :周成融 +@Date :2023/8/26 14:13:04 +@detail : +""" +import io +import logging + +import cv2 +import numpy as np +from PIL import Image + +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.oss_client import oss_upload_image + + +def positioning(all_mask_shape, mask_shape, offset): + all_start = 0 + all_end = 0 + mask_start = 0 + mask_end = 0 + if offset == 0: + all_start = 0 + all_end = min(all_mask_shape, mask_shape) + + mask_start = 0 + mask_end = min(all_mask_shape, mask_shape) + elif offset > 0: + all_start = min(offset, all_mask_shape) + all_end = min(offset + mask_shape, all_mask_shape) + + mask_start = 0 + mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape) + elif offset < 0: + if abs(offset) > mask_shape: + all_start = 0 + all_end = 0 + else: + all_start = 0 + if mask_shape - abs(offset) > all_mask_shape: + all_end = min(mask_shape - abs(offset), all_mask_shape) + else: + all_end = mask_shape - abs(offset) + + if abs(offset) > mask_shape: + mask_start = mask_shape + mask_end = mask_shape + else: + mask_start = abs(offset) + if mask_shape - abs(offset) >= all_mask_shape: + mask_end = all_mask_shape + abs(offset) + else: + mask_end = mask_shape + return all_start, all_end, mask_start, mask_end + + +# @RunTime +def synthesis(data, size, basic_info): + # 创建底图 + base_image = Image.new('RGBA', size, (0, 0, 0, 0)) + try: + all_mask_shape = (size[1], size[0]) + body_mask = None + for d in data: + if d['name'] == 'body' or d['name'] == 'mannequin': + # 创建一个新的宽高透明图像, 把模特贴上去获取mask + transparent_image = Image.new("RGBA", size, (0, 0, 0, 0)) + transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position + body_mask = np.array(transparent_image.split()[3]) + + # 根据新的坐标获取新的肩点 + left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])] + right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])] + body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255 + _, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY) + top_outer_mask = np.array(binary_body_mask) + bottom_outer_mask = np.array(binary_body_mask) + + top = True + bottom = True + i = len(data) + while i: + i -= 1 + if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]: + top = False + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['adaptive_position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) + background = np.zeros_like(top_outer_mask) + background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] + top_outer_mask = background + top_outer_mask + elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]: + bottom = False + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['adaptive_position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) + background = np.zeros_like(top_outer_mask) + background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] + bottom_outer_mask = background + bottom_outer_mask + elif bottom is False and top is False: + break + + all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask) + + for layer in data: + if layer['image'] is not None: + if layer['name'] != "body": + test_image = Image.new('RGBA', size, (0, 0, 0, 0)) + test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image']) + mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8) + mask_alpha = Image.fromarray(mask_data) + cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha) + base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00 + else: + base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image']) + + result_image = base_image + + image_data = io.BytesIO() + result_image.save(image_data, format='PNG') + image_data.seek(0) + + # oss upload + image_bytes = image_data.read() + bucket_name = "aida-results" + object_name = f'result_{generate_uuid()}.png' + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return f"{bucket_name}/{object_name}" + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + + # object_name = f'result_{generate_uuid()}.png' + # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + # object_url = f"aida-results/{object_name}" + # if response['ResponseMetadata']['HTTPStatusCode'] == 200: + # return object_url + # else: + # return "" + + except Exception as e: + logging.warning(f"synthesis runtime exception : {e}") + + +def synthesis_single(front_image, back_image): + result_image = None + if front_image: + result_image = front_image + if back_image: + result_image.paste(back_image, (0, 0), back_image) + + # with io.BytesIO() as output: + # result_image.save(output, format='PNG') + # data = output.getvalue() + # object_name = f'result_{generate_uuid()}.png' + # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + # object_url = f"aida-results/{object_name}" + # if response['ResponseMetadata']['HTTPStatusCode'] == 200: + # return object_url + # else: + # return "" + image_data = io.BytesIO() + result_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # oss upload + bucket_name = 'aida-results' + object_name = f'result_{generate_uuid()}.png' + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return f"{bucket_name}/{object_name}" + + +def update_base_size_priority(layers, size): + # 计算透明背景图片的宽度 + min_x = min(info['position'][1] for info in layers) + x_list = [] + for info in layers: + if info['image'] is not None: + x_list.append(info['position'][1] + info['image'].width) + max_x = max(x_list) + new_width = max_x - min_x + new_height = 700 + # 更新坐标 + for info in layers: + info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x) + return layers, (new_width, new_height) diff --git a/app/service/design/utils/upload_image.py b/app/service/design_batch/utils/upload_image.py similarity index 51% rename from app/service/design/utils/upload_image.py rename to app/service/design_batch/utils/upload_image.py index 3571816..2c79f9f 100644 --- a/app/service/design/utils/upload_image.py +++ b/app/service/design_batch/utils/upload_image.py @@ -13,11 +13,11 @@ import logging import cv2 from app.core.config import * -from app.service.utils.oss_client import oss_upload_image +from app.service.utils.new_oss_client import oss_upload_image # @RunTime -def upload_png_mask(front_image, object_name, mask=None): +def upload_png_mask(minio_client, front_image, object_name, mask=None): try: mask_url = None if mask is not None: @@ -25,20 +25,14 @@ def upload_png_mask(front_image, object_name, mask=None): # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - # image_bytes = io.BytesIO() - # image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - # image_bytes.seek(0) - # mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" - # oss upload #################### - req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) + req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" image_data = io.BytesIO() front_image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - # image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) + req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" return front_image, image_url, mask_url except Exception as e: diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py new file mode 100644 index 0000000..ac1f79c --- /dev/null +++ b/app/service/design_fast/design_generate.py @@ -0,0 +1,1465 @@ +import logging +import threading +import time + +from minio import Minio + +from app.core.config import * +from app.service.design_fast.item import BodyItem, TopItem, BottomItem +from app.service.design_fast.utils.organize import organize_body, organize_clothing +from app.service.design_fast.utils.progress import final_progress, update_progress +from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority +from app.service.utils.decorator import RunTime + +id_lock = threading.Lock() + +logger = logging.getLogger() + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +def process_item(item, basic): + # 处理project中单个item + if item['type'] == "Body": + body_server = BodyItem(data=item, basic=basic, minio_client=minio_client) + item_data = body_server.process() + elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']: + top_server = TopItem(data=item, basic=basic, minio_client=minio_client) + item_data = top_server.process() + else: + bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client) + item_data = bottom_server.process() + return item_data + + +def process_layer(item, layers): + # item处理结束后 对图层数据组装 + if item['name'] == "mannequin": + body_layer = organize_body(item) + layers.append(body_layer) + return item['body_image'].size + else: + front_layer, back_layer = organize_clothing(item) + layers.append(front_layer) + layers.append(back_layer) + + +@RunTime +def design_generate(request_data): + objects_data = request_data.dict()['objects'] + process_id = request_data.dict()['process_id'] + object_response = {} + threads = [] + active_threads = 0 + lock = threading.Lock() + total = len(objects_data) + + def process_object(step, object): + nonlocal active_threads + basic = object['basic'] + items_response = {'layers': []} + if basic['single_overall'] == "overall": + item_results = [] + for item in object['items']: + item_results.append(process_item(item, basic)) + layers = [] + body_size = None + for item in item_results: + body_size = process_layer(item, layers) + layers = sorted(layers, key=lambda s: s.get("priority", float('inf'))) + + layers, new_size = update_base_size_priority(layers, body_size) + + for lay in layers: + items_response['layers'].append({ + 'image_category': "body" if lay['name'] == 'mannequin' else lay['name'], + 'position': lay['position'], + 'priority': lay.get("priority", None), + 'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None, + 'image_size': lay['image'] if lay['image'] is None else lay['image'].size, + 'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "", + 'mask_url': lay['mask_url'], + 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, + 'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None, + }) + items_response['synthesis_url'] = synthesis(layers, new_size, basic) + else: + item_result = process_item(object['items'][0], basic) + items_response['layers'].append({ + 'image_category': f"{item_result['name']}_front", + 'image_size': item_result['back_image'].size if item_result['back_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item_result['front_image_url'], + 'mask_url': item_result['mask_url'], + "gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "", + 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, + }) + items_response['layers'].append({ + 'image_category': f"{item_result['name']}_back", + 'image_size': item_result['front_image'].size if item_result['front_image'] else None, + 'position': None, + 'priority': 0, + 'image_url': item_result['back_image_url'], + 'mask_url': item_result['mask_url'], + "gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "", + 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, + }) + items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image']) + update_progress(process_id, total) + + with lock: + object_response[step] = items_response + active_threads -= 1 + + for step, object in enumerate(objects_data): + t = threading.Thread(target=process_object, args=(step, object)) + threads.append(t) + t.start() + with lock: + active_threads += 1 + + for t in threads: + t.join() + final_progress(process_id) + return object_response + + +if __name__ == '__main__': + object_data = { + "objects": [ + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": False, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "28 26 26", + "icon": "none", + "image_id": 98419, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/dress/0825000526.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Dress" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": False, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "28 26 26", + "icon": "none", + "image_id": 98420, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/903000127.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 69140, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902001100.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 81604, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/outwear_p5_729.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": False, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "28 26 26", + "icon": "none", + "image_id": 63964, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0825001572.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 98421, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/blouse_506.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 98422, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0628001244.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": False, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "28 26 26", + "icon": "none", + "image_id": 79927, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0825000378.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 67473, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0825001350.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 80046, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/0628001443.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": False, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "28 26 26", + "icon": "none", + "image_id": 84148, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0628000751.jpeg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 97321, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902000222.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 90718, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0825000314.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": False, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "28 26 26", + "icon": "none", + "image_id": 86403, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/skirt/0902000231.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Skirt" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 87135, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902001315.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 87428, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0902000566.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": False, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "28 26 26", + "icon": "none", + "image_id": 98423, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/dress/0916001596.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Dress" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + }, + { + "basic": { + "body_point_test": { + "waistband_right": [ + 203, + 249 + ], + "hand_point_right": [ + 229, + 343 + ], + "waistband_left": [ + 119, + 248 + ], + "hand_point_left": [ + 97, + 343 + ], + "shoulder_left": [ + 108, + 107 + ], + "shoulder_right": [ + 212, + 107 + ] + }, + "layer_order": False, + "scale_bag": 0.7, + "scale_earrings": 0.16, + "self_template": True, + "single_overall": "overall", + "switch_category": "" + }, + "items": [ + { + "color": "28 26 26", + "icon": "none", + "image_id": 86345, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/outwear/0825000695.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Outwear" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 78743, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/blouse/0902001412.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Blouse" + }, + { + "color": "28 26 26", + "icon": "none", + "image_id": 68988, + "offset": [ + 1, + 1 + ], + "path": "aida-sys-image/images/female/trousers/0825000403.jpg", + "print": { + "element": { + "element_angle_list": [], + "element_path_list": [], + "element_scale_list": [], + "location": [] + }, + "overall": { + "location": [ + [ + 0.0, + 0.0 + ] + ], + "print_angle_list": [ + 0.0, + 0.0 + ], + "print_path_list": [], + "print_scale_list": [ + 0.0, + 0.0 + ] + }, + "single": { + "location": [], + "print_angle_list": [], + "print_path_list": [], + "print_scale_list": [] + } + }, + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Trousers" + }, + { + "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png", + "image_id": 67277, + "offset": [ + 1, + 1 + ], + "resize_scale": [ + 1.0, + 1.0 + ], + "type": "Body" + } + ] + } + ], + "process_id": "123" + } + start_time = time.time() + X = design_generate(object_data) + print(time.time() - start_time) + print(X) diff --git a/app/service/design_fast/item.py b/app/service/design_fast/item.py new file mode 100644 index 0000000..e10320d --- /dev/null +++ b/app/service/design_fast/item.py @@ -0,0 +1,61 @@ +from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection + + +class BaseItem: + def __init__(self, data, basic): + self.result = data.copy() + self.result['name'] = data['type'].lower() + self.result.pop("type") + self.result.update(basic) + + +class TopItem(BaseItem): + def __init__(self, data, basic, minio_client): + super().__init__(data, basic) + self.top_pipeline = [ + LoadImage(minio_client), + KeyPoint(), + Segmentation(minio_client), + Color(minio_client), + PrintPainting(minio_client), + Scaling(), + Split(minio_client) + ] + + def process(self): + for item in self.top_pipeline: + self.result = item(self.result) + return self.result + + +class BottomItem(BaseItem): + def __init__(self, data, basic, minio_client): + super().__init__(data, basic) + self.bottom_pipeline = [ + LoadImage(minio_client), + KeyPoint(), + ContourDetection(), + # Segmentation(), + Color(minio_client), + PrintPainting(minio_client), + Scaling(), + Split(minio_client) + ] + + def process(self): + for item in self.bottom_pipeline: + self.result = item(self.result) + return self.result + + +class BodyItem(BaseItem): + def __init__(self, data, basic, minio_client): + super().__init__(data, basic) + self.top_pipeline = [ + LoadBodyImage(minio_client), + ] + + def process(self): + for item in self.top_pipeline: + self.result = item(self.result) + return self.result diff --git a/app/service/design_fast/pipeline/__init__.py b/app/service/design_fast/pipeline/__init__.py new file mode 100644 index 0000000..ec55933 --- /dev/null +++ b/app/service/design_fast/pipeline/__init__.py @@ -0,0 +1,20 @@ +from .color import Color +from .contour_detection import ContourDetection +from .keypoint import KeyPoint +from .keypoint import KeyPoint +from .loading import LoadImage, LoadBodyImage +from .print_painting import PrintPainting +from .scale import Scaling +from .segmentation import Segmentation +from .split import Split + +__all__ = [ + 'LoadBodyImage', 'LoadImage', + 'KeyPoint', + 'ContourDetection', + 'Segmentation', + 'Color', + 'PrintPainting', + 'Scaling', + 'Split' +] diff --git a/app/service/design_fast/pipeline/color.py b/app/service/design_fast/pipeline/color.py new file mode 100644 index 0000000..546c671 --- /dev/null +++ b/app/service/design_fast/pipeline/color.py @@ -0,0 +1,62 @@ +import logging + +import cv2 +import numpy as np + +from app.service.utils.new_oss_client import oss_get_image + +logger = logging.getLogger() + + +class Color: + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + dim_image_h, dim_image_w = result['image'].shape[0:2] + if "gradient" in result.keys() and result['gradient'] != "": + bucket_name = result['gradient'].split('/')[0] + object_name = result['gradient'][result['gradient'].find('/') + 1:] + pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name) + resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + else: + pattern = self.get_pattern(result['color']) + resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255) + result['pattern_image'] = get_image_fir.astype(np.uint8) + result['final_image'] = result['pattern_image'] + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + result['alpha'] = 100 / 255.0 + return result + + def get_gradient(self, bucket_name, object_name): + # 获取渐变色图案 + image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2") + if image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + return image + + @staticmethod + def crop_image(image, image_size_h, image_size_w): + x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6) + y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6) + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :] + return image + + @staticmethod + def get_pattern(single_color): + if single_color is None: + raise False + R, G, B = single_color.split(' ') + pattern = np.zeros([1, 1, 3], np.uint8) + pattern[0, 0, 0] = int(B) + pattern[0, 0, 1] = int(G) + pattern[0, 0, 2] = int(R) + return pattern diff --git a/app/service/design_fast/pipeline/contour_detection.py b/app/service/design_fast/pipeline/contour_detection.py new file mode 100644 index 0000000..2b76c0b --- /dev/null +++ b/app/service/design_fast/pipeline/contour_detection.py @@ -0,0 +1,37 @@ +import cv2 +import numpy as np + + +class ContourDetection: + def __call__(self, result): + Contour = self.get_contours(result['image']) + Mask = np.zeros(result['image'].shape[:2], np.uint8) + if len(Contour): + Max_contour = Contour[0] + Epsilon = 0.001 * cv2.arcLength(Max_contour, True) + Approx = cv2.approxPolyDP(Max_contour, Epsilon, True) + cv2.drawContours(Mask, [Approx], -1, 255, -1) + else: + Mask = np.ones(result['image'].shape[:2], np.uint8) * 255 + # TODO 修复部分图片出现透明的情况 下版本上线 + # img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) + # ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY) + # Mask = cv2.bitwise_not(Mask) + if result['pre_mask'] is None: + result['mask'] = Mask + else: + result['mask'] = cv2.bitwise_and(Mask, result['pre_mask']) + result['front_mask'] = result['mask'] + result['back_mask'] = result['mask'] + return result + + @staticmethod + def get_contours(image): + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + Edge = cv2.Canny(gray, 10, 150) + kernel = np.ones((5, 5), np.uint8) + Edge = cv2.dilate(Edge, kernel=kernel, iterations=1) + Edge = cv2.erode(Edge, kernel=kernel, iterations=1) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contour = sorted(Contour, key=cv2.contourArea, reverse=True) + return Contour diff --git a/app/service/design_fast/pipeline/keypoint.py b/app/service/design_fast/pipeline/keypoint.py new file mode 100644 index 0000000..73d7586 --- /dev/null +++ b/app/service/design_fast/pipeline/keypoint.py @@ -0,0 +1,116 @@ +import logging + +import numpy as np +from pymilvus import MilvusClient + +from app.core.config import * +from app.service.design_fast.utils.design_ensemble import get_keypoint_result +from app.service.utils.decorator import ClassCallRunTime, RunTime + +logger = logging.getLogger(__name__) + + +class KeyPoint: + name = "KeyPoint" + + @classmethod + def get_name(cls): + return cls.name + + @ClassCallRunTime + def __call__(self, result): + if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新 + # result['clothes_keypoint'] = self.infer_keypoint_result(result) + site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + # keypoint_cache = search_keypoint_cache(result["image_id"], site) + # keypoint_cache = self.keypoint_cache(result, site) + keypoint_cache = False + # 取消向量查询 直接过模型推理 + if keypoint_cache is False: + keypoint_infer_result, site = self.infer_keypoint_result(result) + result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site) + else: + result['clothes_keypoint'] = keypoint_cache + return result + + @staticmethod + def infer_keypoint_result(result): + site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down' + keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果 + return keypoint_infer_result, site + + @staticmethod + def save_keypoint_cache(keypoint_id, cache, site): + if site == "down": + zeros = np.zeros(20, dtype=int) + result = np.concatenate([zeros, cache.flatten()]) + else: + zeros = np.zeros(4, dtype=int) + result = np.concatenate([cache.flatten(), zeros]) + # 取消向量保存 直接拿结果 + data = [ + {"keypoint_id": keypoint_id, + "keypoint_site": site, + "keypoint_vector": result.tolist() + } + ] + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data) + client.close() + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logger.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + @staticmethod + def update_keypoint_cache(keypoint_id, infer_result, search_result, site): + if site == "up": + # 需要的是up 即推理出来的是up 那么查询的就是down + result = np.concatenate([infer_result.flatten(), search_result[-4:]]) + else: + # 需要的是down 即推理出来的是down 那么查询的就是up + result = np.concatenate([search_result[:20], infer_result.flatten()]) + data = [ + {"keypoint_id": keypoint_id, + "keypoint_site": "all", + "keypoint_vector": result.tolist() + } + ] + + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + client.upsert( + collection_name=MILVUS_TABLE_KEYPOINT, + data=data + ) + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + except Exception as e: + logger.info(f"save keypoint cache milvus error : {e}") + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) + + @RunTime + def keypoint_cache(self, result, site): + try: + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + keypoint_id = result['image_id'] + res = client.query( + collection_name=MILVUS_TABLE_KEYPOINT, + # ids=[keypoint_id], + filter=f"keypoint_id == {keypoint_id}", + output_fields=['keypoint_vector', 'keypoint_site'] + ) + if len(res) == 0: + # 没有结果 直接推理拿结果 并保存 + keypoint_infer_result, site = self.infer_keypoint_result(result) + return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) + elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: + # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果 + return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) + elif res[0]["keypoint_site"] != site: + # 需要的类型和查询到的不一致,则更新类型为all + keypoint_infer_result, site = self.infer_keypoint_result(result) + return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) + except Exception as e: + logger.info(f"search keypoint cache milvus error {e}") + return False diff --git a/app/service/design_fast/pipeline/loading.py b/app/service/design_fast/pipeline/loading.py new file mode 100644 index 0000000..0ce0dfa --- /dev/null +++ b/app/service/design_fast/pipeline/loading.py @@ -0,0 +1,80 @@ +import io +import logging + +import cv2 +import numpy as np +from PIL import Image + +from app.service.utils.new_oss_client import oss_get_image + +logger = logging.getLogger() + + +class LoadBodyImage: + name = "LoadBodyImage" + + def __init__(self, minio_client): + self.minio_client = minio_client + + @classmethod + def get_name(cls): + return cls.name + + def __call__(self, result): + result["name"] = "mannequin" + result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL") + return result + + +class LoadImage: + name = "LoadImage" + + def __init__(self, minio_client): + self.minio_client = minio_client + + @classmethod + def get_name(cls): + return cls.name + + def __call__(self, result): + result['image'], result['pre_mask'] = self.read_image(result['path']) + result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) + result['keypoint'] = self.get_keypoint(result['name']) + result['img_shape'] = result['image'].shape + result['ori_shape'] = result['image'].shape + return result + + def read_image(self, image_path): + image_mask = None + image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2") + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + if image.shape[2] == 4: # 如果是四通道 mask + image_mask = image[:, :, 3] + image = image[:, :, :3] + + if image.shape[:2] <= (50, 50): + # 计算新尺寸 + new_size = (image.shape[1] * 2, image.shape[0] * 2) + # 调整大小 + image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR) + return image, image_mask + + @staticmethod + def get_keypoint(name): + if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops': + keypoint = 'shoulder' + elif name == 'trousers' or name == 'skirt' or name == 'bottoms': + keypoint = 'waistband' + elif name == 'bag': + keypoint = 'hand_point' + elif name == 'shoes': + keypoint = 'toe' + elif name == 'hairstyle': + keypoint = 'head_point' + elif name == 'earring': + keypoint = 'ear_point' + else: + raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " + f"bag, shoes, hairstyle, earring.") + return keypoint diff --git a/app/service/design_fast/pipeline/print_painting.py b/app/service/design_fast/pipeline/print_painting.py new file mode 100644 index 0000000..6fe40d8 --- /dev/null +++ b/app/service/design_fast/pipeline/print_painting.py @@ -0,0 +1,524 @@ +import random + +import cv2 +import numpy as np +from PIL import Image + +from app.service.utils.new_oss_client import oss_get_image + + +class PrintPainting: + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + single_print = result['print']['single'] + overall_print = result['print']['overall'] + element_print = result['print']['element'] + result['single_image'] = None + result['print_image'] = None + if overall_print['print_path_list']: + painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} + result['print_image'] = result['pattern_image'] + if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True) + painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True) + painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True) + + # resize 到sketch大小 + painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h']) + else: + painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False) + result['print_image'] = self.printpaint(result, painting_dict, print_=True) + result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image'] + + if single_print['print_path_list']: + print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(single_print['print_path_list'])): + image, image_mode = self.read_image(single_print['print_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i]) + + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask)) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + + if element_print['element_path_list']: + print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) + for i in range(len(element_print['element_path_list'])): + image, image_mode = self.read_image(element_print['element_path_list'][i]) + if image_mode == "RGBA": + new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i])) + + mask = image.split()[3] + resized_source = image.resize(new_size) + resized_source_mask = mask.resize(new_size) + + rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i]) + rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i]) + + source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB)) + source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB)) + + source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source) + source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask) + + print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) + mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) + else: + mask = self.get_mask_inv(image) + mask = np.expand_dims(mask, axis=2) + mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + mask = cv2.bitwise_not(mask) + # 旋转后的坐标需要重新算 + rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) + x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1]) + + image_x = print_background.shape[1] + image_y = print_background.shape[0] + print_x = rotate_image.shape[1] + print_y = rotate_image.shape[0] + + # 有bug + # if x + print_x > image_x: + # rotate_image = rotate_image[:, :x + print_x - image_x] + # rotate_mask = rotate_mask[:, :x + print_x - image_x] + # # + # if y + print_y > image_y: + # rotate_image = rotate_image[:y + print_y - image_y] + # rotate_mask = rotate_mask[:y + print_y - image_y] + + # 不能是并行 + # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题 + # 先挪 再判断 最后裁剪 + + # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0 + if x <= 0: + rotate_image = rotate_image[:, -x:] + rotate_mask = rotate_mask[:, -x:] + start_x = x = 0 + else: + start_x = x + + if y <= 0: + rotate_image = rotate_image[-y:, :] + rotate_mask = rotate_mask[-y:, :] + start_y = y = 0 + else: + start_y = y + + # ------------------ + # 如果print-size大于image-size 则需要裁剪print + + if x + print_x > image_x: + rotate_image = rotate_image[:, :image_x - x] + rotate_mask = rotate_mask[:, :image_x - x] + + if y + print_y > image_y: + rotate_image = rotate_image[:image_y - y, :] + rotate_mask = rotate_mask[:image_y - y, :] + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask) + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image) + + # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask + # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) + print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) + + # gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY) + # print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image) + + print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) + img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) + # TODO element 丢失信息 + three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)]) + img_bg = cv2.bitwise_and(result['final_image'], three_channel_image) + # mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + # gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + # img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + result['final_image'] = cv2.add(img_bg, img_fg) + canvas = np.full_like(result['final_image'], 255) + temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) + tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8) + temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) + tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) + result['single_image'] = cv2.add(tmp1, tmp2) + return result + + @staticmethod + def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x): + temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8) + temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image + img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY) + ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY) + mask_inv = cv2.bitwise_not(mask_) + img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv) + img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_) + print_background = img1_bg + img2_fg + return print_background + + def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False): + if print_trigger: + print_ = self.get_print(print_dict) + painting_dict['Trigger'] = not is_single + painting_dict['location'] = print_['location'] + single_mask_inv_print = self.get_mask_inv(print_['image']) + dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w']) + dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) + if not is_single: + self.random_seed = random.randint(0, 1000) + # 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪 + if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) + else: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) + else: + painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) + painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location']) + painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern + return painting_dict + + def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False): + tile = None + if not trigger: + tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) + else: + resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA) + if len(pattern.shape) == 2: + tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4)) + if len(pattern.shape) == 3: + tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1)) + tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape) + return tile + + def get_mask_inv(self, print_): + if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255: + bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0] + print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2] + bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True) + bg_a_high, bg_a_low = self.get_low_high_lab(bg_a) + bg_b_high, bg_b_low = self.get_low_high_lab(bg_b) + lower = np.array([bg_L_low, bg_a_low, bg_b_low]) + upper = np.array([bg_L_high, bg_a_high, bg_b_high]) + mask_inv = cv2.inRange(print_tile, lower, upper) + return mask_inv + else: + # bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0] + # print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + # bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2] + # bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True) + # bg_a_high, bg_a_low = self.get_low_high_lab(bg_a) + # bg_b_high, bg_b_low = self.get_low_high_lab(bg_b) + # lower = np.array([bg_L_low, bg_a_low, bg_b_low]) + # upper = np.array([bg_L_high, bg_a_high, bg_b_high]) + + # print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB) + # mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY) + + # mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY) + mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8) + return mask_inv + + @staticmethod + def printpaint(result, painting_dict, print_=False): + + if print_ and painting_dict['Trigger']: + print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print'])) + img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask) + else: + print_mask = result['mask'] + img_fg = result['final_image'] + if print_ and not painting_dict['Trigger']: + index_ = None + try: + index_ = len(painting_dict['location']) + except: + assert f'there must be parameter of location if choose IfSingle' + + for i in range(index_): + start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0]) + + length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0]) + length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1]) + + change_region = img_fg[start_h: length_h, start_w: length_w, :] + # problem in change_mask + change_mask = print_mask[start_h: length_h, start_w: length_w] + # get real part into change mask + _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY) + mask = cv2.bitwise_not(painting_dict['mask_inv_print']) + img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region + + clothes_mask_print = cv2.bitwise_not(print_mask) + + img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print) + mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2) + gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) + img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) + print_image = cv2.add(img_bg, img_fg) + return print_image + + def get_print(self, print_dict): + if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3: + print_dict['scale'] = 0.3 + else: + print_dict['scale'] = print_dict['print_scale_list'][0] + + bucket_name = print_dict['print_path_list'][0].split("/", 1)[0] + object_name = print_dict['print_path_list'][0].split("/", 1)[1] + image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL") + # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑 + if image.mode == "RGBA": + new_background = Image.new('RGB', image.size, (255, 255, 255)) + new_background.paste(image, mask=image.split()[3]) + image = new_background + print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + return print_dict + + def crop_image(self, image, image_size_h, image_size_w, location, print_shape): + print_w = print_shape[1] + print_h = print_shape[0] + + random.seed(self.random_seed) + # logging.info(f'overall print location : {location}') + # x_offset = random.randint(0, image.shape[0] - image_size_h) + # y_offset = random.randint(0, image.shape[1] - image_size_w) + + # 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量 + x_offset = print_w - int(location[0][1] % print_w) + y_offset = print_w - int(location[0][0] % print_h) + + # y_offset = int(location[0][0]) + # x_offset = int(location[0][1]) + + if len(image.shape) == 2: + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w] + elif len(image.shape) == 3: + image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :] + return image + + @staticmethod + def get_low_high_lab(Lab_value, L=False): + if L: + high = Lab_value + 30 if Lab_value + 30 < 255 else 255 + low = Lab_value - 30 if Lab_value - 30 > 0 else 0 + else: + high = Lab_value + 30 if Lab_value + 30 < 255 else 255 + low = Lab_value - 30 if Lab_value - 30 > 0 else 0 + return high, low + + @staticmethod + def img_rotate(image, angel, scale): + """顺时针旋转图像任意角度 + + Args: + image (np.array): [原始图像] + angel (float): [逆时针旋转的角度] + + Returns: + [array]: [旋转后的图像] + """ + + h, w = image.shape[:2] + center = (w // 2, h // 2) + # if type(angel) is not int: + # angel = 0 + M = cv2.getRotationMatrix2D(center, -angel, scale) + # 调整旋转后的图像长宽 + rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0])))) + rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0])))) + M[0, 2] += (rotated_w - w) // 2 + M[1, 2] += (rotated_h - h) // 2 + # 旋转图像 + rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h)) + + return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2) + # return rotated_img, (0, 0) + + @staticmethod + def rotate_crop_image(img, angle, crop): + """ + angle: 旋转的角度 + crop: 是否需要进行裁剪,布尔向量 + """ + crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w] + w, h = img.shape[:2] + # 旋转角度的周期是360° + angle %= 360 + # 计算仿射变换矩阵 + M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + # 得到旋转后的图像 + img_rotated = cv2.warpAffine(img, M_rotation, (w, h)) + + # 如果需要去除黑边 + if crop: + # 裁剪角度的等效周期是180° + angle_crop = angle % 180 + if angle > 90: + angle_crop = 180 - angle_crop + # 转化角度为弧度 + theta = angle_crop * np.pi / 180 + # 计算高宽比 + hw_ratio = float(h) / float(w) + # 计算裁剪边长系数的分子项 + tan_theta = np.tan(theta) + numerator = np.cos(theta) + np.sin(theta) * np.tan(theta) + + # 计算分母中和高宽比相关的项 + r = hw_ratio if h > w else 1 / hw_ratio + # 计算分母项 + denominator = r * tan_theta + 1 + # 最终的边长系数 + crop_mult = numerator / denominator + + # 得到裁剪区域 + w_crop = int(crop_mult * w) + h_crop = int(crop_mult * h) + x0 = int((w - w_crop) / 2) + y0 = int((h - h_crop) / 2) + + img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop) + + return img_rotated + + def read_image(self, image_url): + image = oss_get_image(oss_client=self.minio_client, bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2") + if image.shape[2] == 4: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image = Image.fromarray(image_rgb) + image_mode = "RGBA" + else: + image_mode = "RGB" + return image, image_mode + + @staticmethod + def resize_and_crop(img, target_width, target_height): + # 获取原始图像的尺寸 + original_height, original_width = img.shape[:2] + + # 计算目标尺寸的宽高比 + target_ratio = target_width / target_height + + # 计算原始图像的宽高比 + original_ratio = original_width / original_height + + # 调整尺寸 + if original_ratio > target_ratio: + # 原始图像更宽,按高度resize,然后裁剪宽度 + new_height = target_height + new_width = int(original_width * (target_height / original_height)) + resized_img = cv2.resize(img, (new_width, new_height)) + # 裁剪宽度 + start_x = (new_width - target_width) // 2 + cropped_img = resized_img[:, start_x:start_x + target_width] + else: + # 原始图像更高,按宽度resize,然后裁剪高度 + new_width = target_width + new_height = int(original_height * (target_width / original_width)) + resized_img = cv2.resize(img, (new_width, new_height)) + # 裁剪高度 + start_y = (new_height - target_height) // 2 + cropped_img = resized_img[start_y:start_y + target_height, :] + + return cropped_img diff --git a/app/service/design_fast/pipeline/scale.py b/app/service/design_fast/pipeline/scale.py new file mode 100644 index 0000000..732fcd8 --- /dev/null +++ b/app/service/design_fast/pipeline/scale.py @@ -0,0 +1,49 @@ +import math + +import cv2 + + +class Scaling: + def __call__(self, result): + if result['keypoint'] in ['waistband', 'shoulder', 'head_point']: + # milvus_db_keypoint_cache + distance_clo = math.sqrt( + (int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2 + + + (int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2 + ) + + distance_bdy = math.sqrt( + (int(result['body_point_test'][result['keypoint'] + '_left'][0]) + - + int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1 + ) + + if distance_clo == 0: + result['scale'] = 1 + else: + result['scale'] = distance_bdy / distance_clo + elif result['keypoint'] == 'toe': + distance_bdy = math.sqrt( + (int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2 + + + (int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2 + ) + + Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0) + Edge = cv2.Canny(Blur, 10, 200) + Edge = cv2.dilate(Edge, None) + Edge = cv2.erode(Edge, None) + Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + Contours = sorted(Contour, key=cv2.contourArea, reverse=True) + + Max_contour = Contours[0] + x, y, w, h = cv2.boundingRect(Max_contour) + width = w + distance_clo = width + result['scale'] = distance_bdy / distance_clo + elif result['keypoint'] == 'hand_point': + result['scale'] = result['scale_bag'] + elif result['keypoint'] == 'ear_point': + result['scale'] = result['scale_earrings'] + return result diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py new file mode 100644 index 0000000..ebf02b4 --- /dev/null +++ b/app/service/design_fast/pipeline/segmentation.py @@ -0,0 +1,85 @@ +import logging +import os + +import cv2 +import numpy as np + +from app.core.config import SEG_CACHE_PATH +from app.service.design_fast.utils.design_ensemble import get_seg_result +from app.service.utils.decorator import ClassCallRunTime +from app.service.utils.new_oss_client import oss_get_image + +logger = logging.getLogger() + + +class Segmentation: + def __init__(self, minio_client): + self.minio_client = minio_client + + @ClassCallRunTime + def __call__(self, result): + if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "": + seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2") + seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST) + # 转换颜色空间为 RGB(OpenCV 默认是 BGR) + image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB) + + r, g, b = cv2.split(image_rgb) + red_mask = r > g + green_mask = g > r + + # 创建红色和绿色掩码 + result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255 + result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255 + result['mask'] = result['front_mask'] + result['back_mask'] + else: + # preview 过模型 不缓存 + if "preview_submit" in result.keys() and result['preview_submit'] == "preview": + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + # submit 过模型 缓存 + elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + self.save_seg_result(seg_result, result['image_id']) + # null 正常流程 加载本地缓存 无缓存则过模型 + else: + # 本地查询seg 缓存是否存在 + _, seg_result = self.load_seg_result(result["image_id"]) + # 判断缓存和实际图片size是否相同 + if not _ or result["image"].shape[:2] != seg_result.shape: + # 推理获得seg 结果 + seg_result = get_seg_result(result["image_id"], result['image'])[0] + self.save_seg_result(seg_result, result['image_id']) + result['seg_result'] = seg_result + + # 处理前片后片 + temp_front = seg_result == 1.0 + result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) + temp_back = seg_result == 2.0 + result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) + result['mask'] = result['front_mask'] + result['back_mask'] + return result + + @staticmethod + def save_seg_result(seg_result, image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" + try: + np.save(file_path, seg_result) + logger.info(f"保存成功 :{os.path.abspath(file_path)}") + except Exception as e: + logger.error(f"保存失败: {e}") + + @staticmethod + def load_seg_result(image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" + logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + try: + seg_result = np.load(file_path) + return True, seg_result + except FileNotFoundError: + logger.warning("文件不存在") + return False, None + except Exception as e: + logger.error(f"加载失败: {e}") + return False, None diff --git a/app/service/design_fast/pipeline/split.py b/app/service/design_fast/pipeline/split.py new file mode 100644 index 0000000..737b50e --- /dev/null +++ b/app/service/design_fast/pipeline/split.py @@ -0,0 +1,74 @@ +import io +import logging + +import cv2 +import numpy as np +from PIL import Image +from cv2 import cvtColor, COLOR_BGR2RGBA + +from app.core.config import AIDA_CLOTHING +from app.service.design_fast.utils.conversion_image import rgb_to_rgba +from app.service.design_fast.utils.upload_image import upload_png_mask +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.new_oss_client import oss_upload_image + + +class Split(object): + def __init__(self, minio_client): + self.minio_client = minio_client + + def __call__(self, result): + try: + + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): + front_mask = result['front_mask'] + back_mask = result['back_mask'] + rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) + new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1])) + rgba_image = cv2.resize(rgba_image, new_size) + result_front_image = np.zeros_like(rgba_image) + front_mask = cv2.resize(front_mask, new_size) + result_front_image[front_mask != 0] = rgba_image[front_mask != 0] + result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) + result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None) + + height, width = front_mask.shape + mask_image = np.zeros((height, width, 3)) + mask_image[front_mask != 0] = [0, 0, 255] + + if result["name"] in ('blouse', 'dress', 'outwear', 'tops'): + result_back_image = np.zeros_like(rgba_image) + back_mask = cv2.resize(back_mask, new_size) + result_back_image[back_mask != 0] = rgba_image[back_mask != 0] + result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) + result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None) + mask_image[back_mask != 0] = [0, 255, 0] + + rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask) + mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA)) + image_data = io.BytesIO() + mask_pil.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) + result['mask_url'] = req.bucket_name + "/" + req.object_name + else: + rbga_mask = rgb_to_rgba(mask_image, front_mask) + mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA)) + image_data = io.BytesIO() + mask_pil.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) + result['mask_url'] = req.bucket_name + "/" + req.object_name + result['back_image'] = None + result["back_image_url"] = None + # result["back_mask_url"] = None + # result['back_mask_image'] = None + # 创建中间图层 + result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask']) + result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA)) + result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}') + return result + except Exception as e: + logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") diff --git a/app/service/design/utils/__init__.py b/app/service/design_fast/utils/__init__.py similarity index 100% rename from app/service/design/utils/__init__.py rename to app/service/design_fast/utils/__init__.py diff --git a/app/service/design_fast/utils/conversion_image.py b/app/service/design_fast/utils/conversion_image.py new file mode 100644 index 0000000..11e39ae --- /dev/null +++ b/app/service/design_fast/utils/conversion_image.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :conversion_image.py +@Author :周成融 +@Date :2023/8/21 10:40:29 +@detail : +""" +import numpy as np + + +# def rgb_to_rgba(rgb_size, rgb_image, mask): +# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8) +# # 创建四通道的结果图像 +# rgba_image = np.dstack((rgb_image, alpha_channel)) +# alpha_channel = np.where(mask > 0, 255, 0) +# # 更新RGBA图像的透明度通道 +# rgba_image[:, :, 3] = alpha_channel +# return rgba_image + +def rgb_to_rgba(rgb_image, mask): + # 创建全透明的alpha通道 + alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8) + # 合并RGB图像和alpha通道 + rgba_image = np.dstack((rgb_image, alpha_channel)) + return rgba_image + + +if __name__ == '__main__': + image = open("") diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py new file mode 100644 index 0000000..f4f6a34 --- /dev/null +++ b/app/service/design_fast/utils/design_ensemble.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :design_ensemble.py +@Author :周成融 +@Date :2023/8/16 19:36:21 +@detail :发起请求 获取推理结果 +""" +import logging + +import cv2 +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient + +from app.core.config import * + +""" + keypoint + 预处理 推理 后处理 +""" + + +def keypoint_preprocess(img_path): + img = mmcv.imread(img_path) + img_scale = (256, 256) + h, w = img.shape[:2] + img = cv2.resize(img, img_scale) + w_scale = img_scale[0] / w + h_scale = img_scale[1] / h + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, (w_scale, h_scale) + + +# @ RunTime +# 推理 +def get_keypoint_result(image, site): + keypoint_result = None + try: + image, scale_factor = keypoint_preprocess(image) + client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + transformed_img = image.astype(np.float32) + inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)] + results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs) + inference_output = torch.from_numpy(results.as_numpy(f'output')) + keypoint_result = keypoint_postprocess(inference_output, scale_factor) + except Exception as e: + logging.warning(f"get_keypoint_result : {e}") + return keypoint_result + + +def keypoint_postprocess(output, scale_factor): + max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2) + max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2) + segment_result = max_coords.numpy() + scale_factor = [1 / x for x in scale_factor[::-1]] + scale_matrix = np.diag(scale_factor) + nan = np.isinf(scale_matrix) + scale_matrix[nan] = 0 + return np.ceil(np.dot(segment_result, scale_matrix) * 4) + + +""" + seg + 预处理 推理 后处理 +""" + + +# KNet +def seg_preprocess(img_path): + img = mmcv.imread(img_path) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + # 如果图片size任意一边 大于 1024, 则会resize 成1024 + if ori_shape != (img_scale_w, img_scale_h): + # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 + img = cv2.resize(img, (img_scale_h, img_scale_w)) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + +# @ RunTime +def get_seg_result(image_id, image): + image, ori_shape = seg_preprocess(image) + client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") + transformed_img = image.astype(np.float32) + # 输入集 + inputs = [ + httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + # 输出集 + outputs = [ + httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True), + ] + results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs) + # 推理 + # 取结果 + inference_output1 = results.as_numpy(SEGMENTATION['output']) + seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape) + return seg_result + + +# no cache +def seg_postprocess(image_id, output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + +def key_point_show(image_path, key_point_result=None): + img = cv2.imread(image_path) + points_list = key_point_result + point_size = 1 + point_color = (0, 0, 255) # BGR + thickness = 4 # 可以为 0 、4、8 + for point in points_list: + cv2.circle(img, point[::-1], point_size, point_color, thickness) + cv2.imshow("0", img) + cv2.waitKey(0) + + +if __name__ == '__main__': + image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png") + a = get_keypoint_result(image, "up") + new_list = [] + print(list) + for i in a[0]: + new_list.append((int(i[0]), int(i[1]))) + key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list) + # a = get_seg_result(1, image) + print(a) diff --git a/app/service/design_fast/utils/organize.py b/app/service/design_fast/utils/organize.py new file mode 100644 index 0000000..8190de0 --- /dev/null +++ b/app/service/design_fast/utils/organize.py @@ -0,0 +1,77 @@ +import cv2 + +from app.core.config import PRIORITY_DICT + + +def organize_body(layer): + body_layer = dict(priority=0, + name=layer["name"].lower(), + image=layer['body_image'], + image_url=layer['body_path'], + mask_image=None, + mask_url=None, + sacle=1, + # mask=layer['body_mask'], + position=(0, 0)) + return body_layer + + +def organize_clothing(layer): + # 起始坐标 + start_point = calculate_start_point(layer['keypoint'], layer['scale'], layer['clothes_keypoint'], layer['body_point_test'], layer["offset"], layer["resize_scale"]) + # 前片数据 + front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None), + name=f'{layer["name"].lower()}_front', + image=layer["front_image"], + # mask_image=layer['front_mask_image'], + image_url=layer['front_image_url'], + mask_url=layer['mask_url'], + sacle=layer['scale'], + clothes_keypoint=layer['clothes_keypoint'], + position=start_point, + resize_scale=layer["resize_scale"], + mask=cv2.resize(layer['mask'], layer["front_image"].size), + gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", + pattern_image_url=layer['pattern_image_url'], + pattern_image=layer['pattern_image'] + + ) + # 后片数据 + back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), + name=f'{layer["name"].lower()}_back', + image=layer["back_image"], + # mask_image=layer['back_mask_image'], + image_url=layer['back_image_url'], + mask_url=layer['mask_url'], + sacle=layer['scale'], + clothes_keypoint=layer['clothes_keypoint'], + position=start_point, + resize_scale=layer["resize_scale"], + mask=cv2.resize(layer['mask'], layer["front_image"].size), + gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", + pattern_image_url=layer['pattern_image_url'], + ) + return front_layer, back_layer + + +def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale): + """ + Align left + Args: + keypoint_type: string, "waistband" | "shoulder" | "ear_point" + scale: float + clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]} + body_point: dict, containing keypoint data of body figure + + Returns: + start_point: tuple (x', y') + x' = y_body - y1 * scale + offset + y' = x_body - x1 * scale + offset + + """ + side_indicator = f'{keypoint_type}_left' + start_point = ( + int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y + int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x + ) + return start_point diff --git a/app/service/design_fast/utils/progress.py b/app/service/design_fast/utils/progress.py new file mode 100644 index 0000000..0f2c9cf --- /dev/null +++ b/app/service/design_fast/utils/progress.py @@ -0,0 +1,30 @@ +import logging + +from app.service.design_fast.utils.redis_utils import Redis + +logger = logging.getLogger(__name__) + + +def update_progress(process_id, total): + # logger.info(f"{process_id} , {total}") + r = Redis() + progress = r.read(key=process_id) + if progress and total != 1: + if int(progress) <= 100: + r.write(key=process_id, value=int(progress) + int(100 / total)) + else: + r.write(key=process_id, value=99) + return progress + elif total == 1: + r.write(key=process_id, value=100) + return progress + else: + r.write(key=process_id, value=int(100 / total)) + return progress + + +def final_progress(process_id): + r = Redis() + progress = r.read(key=process_id) + r.write(key=process_id, value=100) + return progress diff --git a/app/service/design_fast/utils/redis_utils.py b/app/service/design_fast/utils/redis_utils.py new file mode 100644 index 0000000..012fbe0 --- /dev/null +++ b/app/service/design_fast/utils/redis_utils.py @@ -0,0 +1,99 @@ +import redis + +from app.core.config import REDIS_HOST, REDIS_PORT + + +class Redis(object): + """ + redis数据库操作 + """ + + @staticmethod + def _get_r(): + host = REDIS_HOST + port = REDIS_PORT + db = 0 + r = redis.StrictRedis(host, port, db) + return r + + @classmethod + def write(cls, key, value, expire=None): + """ + 写入键值对 + """ + # 判断是否有过期时间,没有就设置默认值 + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.set(key, value, ex=expire_in_seconds) + + @classmethod + def read(cls, key): + """ + 读取键值对内容 + """ + r = cls._get_r() + value = r.get(key) + return value.decode('utf-8') if value else value + + @classmethod + def hset(cls, name, key, value): + """ + 写入hash表 + """ + r = cls._get_r() + r.hset(name, key, value) + + @classmethod + def hget(cls, name, key): + """ + 读取指定hash表的键值 + """ + r = cls._get_r() + value = r.hget(name, key) + return value.decode('utf-8') if value else value + + @classmethod + def hgetall(cls, name): + """ + 获取指定hash表所有的值 + """ + r = cls._get_r() + return r.hgetall(name) + + @classmethod + def delete(cls, *names): + """ + 删除一个或者多个 + """ + r = cls._get_r() + r.delete(*names) + + @classmethod + def hdel(cls, name, key): + """ + 删除指定hash表的键值 + """ + r = cls._get_r() + r.hdel(name, key) + + @classmethod + def expire(cls, name, expire=None): + """ + 设置过期时间 + """ + if expire: + expire_in_seconds = expire + else: + expire_in_seconds = 100 + r = cls._get_r() + r.expire(name, expire_in_seconds) + + +if __name__ == '__main__': + redis_client = Redis() + # print(redis_client.write(key="1230", value=0)) + redis_client.write(key="1230", value=10) + # print(redis_client.read(key="1230")) diff --git a/app/service/design_fast/utils/synthesis_item.py b/app/service/design_fast/utils/synthesis_item.py new file mode 100644 index 0000000..08bf4ec --- /dev/null +++ b/app/service/design_fast/utils/synthesis_item.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :synthesis_item.py +@Author :周成融 +@Date :2023/8/26 14:13:04 +@detail : +""" +import io +import logging + +import cv2 +import numpy as np +from PIL import Image + +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.oss_client import oss_upload_image + + +def positioning(all_mask_shape, mask_shape, offset): + all_start = 0 + all_end = 0 + mask_start = 0 + mask_end = 0 + if offset == 0: + all_start = 0 + all_end = min(all_mask_shape, mask_shape) + + mask_start = 0 + mask_end = min(all_mask_shape, mask_shape) + elif offset > 0: + all_start = min(offset, all_mask_shape) + all_end = min(offset + mask_shape, all_mask_shape) + + mask_start = 0 + mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape) + elif offset < 0: + if abs(offset) > mask_shape: + all_start = 0 + all_end = 0 + else: + all_start = 0 + if mask_shape - abs(offset) > all_mask_shape: + all_end = min(mask_shape - abs(offset), all_mask_shape) + else: + all_end = mask_shape - abs(offset) + + if abs(offset) > mask_shape: + mask_start = mask_shape + mask_end = mask_shape + else: + mask_start = abs(offset) + if mask_shape - abs(offset) >= all_mask_shape: + mask_end = all_mask_shape + abs(offset) + else: + mask_end = mask_shape + return all_start, all_end, mask_start, mask_end + + +# @RunTime +def synthesis(data, size, basic_info): + # 创建底图 + base_image = Image.new('RGBA', size, (0, 0, 0, 0)) + try: + all_mask_shape = (size[1], size[0]) + body_mask = None + for d in data: + if d['name'] == 'body' or d['name'] == 'mannequin': + # 创建一个新的宽高透明图像, 把模特贴上去获取mask + transparent_image = Image.new("RGBA", size, (0, 0, 0, 0)) + transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position + body_mask = np.array(transparent_image.split()[3]) + + # 根据新的坐标获取新的肩点 + left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])] + right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])] + body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255 + _, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY) + top_outer_mask = np.array(binary_body_mask) + bottom_outer_mask = np.array(binary_body_mask) + + top = True + bottom = True + i = len(data) + while i: + i -= 1 + if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]: + top = False + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['adaptive_position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) + background = np.zeros_like(top_outer_mask) + background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] + top_outer_mask = background + top_outer_mask + elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]: + bottom = False + mask_shape = data[i]['mask'].shape + y_offset, x_offset = data[i]['adaptive_position'] + # 初始化叠加区域的起始和结束位置 + all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset) + all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset) + # 将叠加区域赋值为相应的像素值 + _, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY) + background = np.zeros_like(top_outer_mask) + background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] + bottom_outer_mask = background + bottom_outer_mask + elif bottom is False and top is False: + break + + all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask) + + for layer in data: + if layer['image'] is not None: + if layer['name'] != "body": + test_image = Image.new('RGBA', size, (0, 0, 0, 0)) + test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image']) + mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8) + mask_alpha = Image.fromarray(mask_data) + cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha) + base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00 + else: + base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image']) + + result_image = base_image + + image_data = io.BytesIO() + result_image.save(image_data, format='PNG') + image_data.seek(0) + + # oss upload + image_bytes = image_data.read() + bucket_name = "aida-results" + object_name = f'result_{generate_uuid()}.png' + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return f"{bucket_name}/{object_name}" + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + + # object_name = f'result_{generate_uuid()}.png' + # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + # object_url = f"aida-results/{object_name}" + # if response['ResponseMetadata']['HTTPStatusCode'] == 200: + # return object_url + # else: + # return "" + + except Exception as e: + logging.warning(f"synthesis runtime exception : {e}") + + +def synthesis_single(front_image, back_image): + result_image = None + if front_image: + result_image = front_image + if back_image: + result_image.paste(back_image, (0, 0), back_image) + + # with io.BytesIO() as output: + # result_image.save(output, format='PNG') + # data = output.getvalue() + # object_name = f'result_{generate_uuid()}.png' + # response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png') + # object_url = f"aida-results/{object_name}" + # if response['ResponseMetadata']['HTTPStatusCode'] == 200: + # return object_url + # else: + # return "" + image_data = io.BytesIO() + result_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + # return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" + # oss upload + bucket_name = 'aida-results' + object_name = f'result_{generate_uuid()}.png' + req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) + return f"{bucket_name}/{object_name}" + + +def update_base_size_priority(layers, size): + # 计算透明背景图片的宽度 + min_x = min(info['position'][1] for info in layers) + x_list = [] + new_height = 700 + for info in layers: + if info['image'] is not None: + x_list.append(info['position'][1] + info['image'].width) + if info['name'] == 'mannequin': + new_height = info['image'].height + max_x = max(x_list) + new_width = max_x - min_x + # 更新坐标 + for info in layers: + info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x) + return layers, (new_width, new_height) diff --git a/app/service/design_fast/utils/upload_image.py b/app/service/design_fast/utils/upload_image.py new file mode 100644 index 0000000..2c79f9f --- /dev/null +++ b/app/service/design_fast/utils/upload_image.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :upload_image.py +@Author :周成融 +@Date :2023/8/28 13:49:20 +@detail : +""" +import io +import logging + +import cv2 + +from app.core.config import * +from app.service.utils.new_oss_client import oss_upload_image + + +# @RunTime +def upload_png_mask(minio_client, front_image, object_name, mask=None): + try: + mask_url = None + if mask is not None: + mask_inverted = cv2.bitwise_not(mask) + # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 + rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) + rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] + req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) + mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" + + image_data = io.BytesIO() + front_image.save(image_data, format='PNG') + image_data.seek(0) + image_bytes = image_data.read() + req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) + image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" + return front_image, image_url, mask_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index dd2a543..16ca870 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -5,13 +5,16 @@ import cv2 import numpy as np import torch import tritonclient.grpc as grpcclient +from pymilvus import MilvusClient from urllib3.exceptions import ResponseError from app.core.config import * from app.schemas.pre_processing import DesignPreProcessingModel -from app.service.design.utils.design_ensemble import get_keypoint_result +from app.service.design_fast.utils.design_ensemble import get_seg_result, get_keypoint_result from app.service.utils.oss_client import oss_get_image, oss_upload_image +logger = logging.getLogger() + class DesignPreprocessing: # def __init__(self): @@ -20,19 +23,19 @@ class DesignPreprocessing: # @ RunTime def pipeline(self, image_list): sketches_list = self.read_image(image_list) - logging.info("read image success") + # logging.info("read image success") bounding_box_sketches_list = self.bounding_box(sketches_list) - logging.info("bounding box image success") + # logging.info("bounding box image success") # super_resolution_list = self.super_resolution(bounding_box_sketches_list) # logging.info("super_resolution_list image success") infer_sketches_list = self.infer_image(bounding_box_sketches_list) - logging.info("infer image success") + # logging.info("infer image success") result = self.composing_image(infer_sketches_list) - logging.info("Replenish white edge image success") + # logging.info("Replenish white edge image success") for d in result: if 'image_obj' in d: @@ -59,6 +62,7 @@ class DesignPreprocessing: def bounding_box(self, image_list): for item in image_list: image = item['image_obj'] + height, width = image.shape[:2] # 使用Canny边缘检测来检测物体的轮廓 edges = cv2.Canny(image, 50, 150) # 查找轮廓 @@ -82,16 +86,25 @@ class DesignPreprocessing: if len(contours) > 0: cropped_image = image[y_min:y_max, x_min:x_max] item['obj'] = cropped_image # 新shape图像 - # 取消直接覆盖,新增size判断 - # try: - # # 覆盖到minio - # image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes() - # self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", ) - # print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") - # except ResponseError as err: - # print(f"Error: {err}") else: item['obj'] = image + + padding_top = max(20 - y_min, 0) + padding_bottom = max(20 - (height - y_max), 0) + padding_left = max(20 - x_min, 0) + padding_right = max(20 - (width - x_max), 0) + + # 添加padding + padded_image = cv2.copyMakeBorder( + image, + padding_top, + padding_bottom, + padding_left, + padding_right, + cv2.BORDER_CONSTANT, + value=(255, 255, 255) + ) + item['obj'] = padded_image return image_list def super_resolution(self, image_list): @@ -99,7 +112,7 @@ class DesignPreprocessing: # 判断 两边是否同时都小于512 因为此处做四倍超分 if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512: # 如果任意一边小于256则超分 - if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256: + if item['obj'].shape[0] <= 200 or item['obj'].shape[1] <= 200: # 超分 img = item['obj'].astype(np.float32) / 255. sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1)) @@ -124,13 +137,14 @@ class DesignPreprocessing: bucket_name = item['image_url'].split("/", 1)[0] object_name = item['image_url'].split("/", 1)[1] oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) - print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") + logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") except ResponseError as err: - print(f"Error: {err}") + logging.warning(f"Error: {err}") return image_list # @ RunTime def infer_image(self, image_list): + seg_result = None for sketch in image_list: # 小写 image_category = sketch['image_category'].lower() @@ -138,6 +152,15 @@ class DesignPreprocessing: sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down' # 推理得到keypoint sketch['keypoint_result'] = self.keypoint_cache(sketch) + if sketch['site'] == 'up': + _, seg_cache = self.load_seg_result(sketch['image_id']) + if not _: + # 推理获得seg 结果 + seg_result = get_seg_result(sketch["image_id"], sketch['obj'])[0] + self.save_seg_result(seg_result, sketch['image_id']) + logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_result.shape}") + else: + logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_cache.shape}") if IF_DEBUG_SHOW: debug_show_image = sketch['obj'].copy() @@ -149,6 +172,7 @@ class DesignPreprocessing: points_list.append((int(i[1]), int(i[0]))) for point in points_list: cv2.circle(debug_show_image, point, point_size, point_color, thickness) + cv2.imshow("seg_result", seg_result) cv2.imshow("", debug_show_image) cv2.waitKey(0) # # 关键点在上部则推理seg @@ -236,58 +260,37 @@ class DesignPreprocessing: return image_list @staticmethod - def select_seg_result(image_id, image_obj): + def load_seg_result(image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: - # 如果shape不匹配 返回false - result = np.load(f"seg_result/{image_id}.npy").astype(np.int64) - if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]: - return result - else: - return False - except FileNotFoundError as e: - logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") - return False + seg_result = np.load(file_path) + return True, seg_result + except FileNotFoundError: + logging.info("文件不存在") + return False, None + except Exception as e: + logging.warning(f"加载失败: {e}") + return False, None @staticmethod - def search_seg_result(image_id, ori_shape): + def save_seg_result(seg_result, image_id): + file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: - # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) - # collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection. - # collection.load() - # start_time = time.time() - # res = collection.query( - # expr=f"seg_id == {image_id}", - # offset=0, - # limit=10, - # output_fields=["seg_cache"], - # ) - # logging.info(f"search seg cache time : {time.time() - start_time}") - - # if len(res): - # vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224)) - # array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False) - # array_2d_exact = array_2d_exact.squeeze().numpy() - # return array_2d_exact - # else: - return False + np.save(file_path, seg_result) + logging.info(f"保存成功,{os.path.abspath(file_path)}") except Exception as e: - logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") - return False + logging.warning(f"保存失败: {e}") def keypoint_cache(self, sketch): try: - # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) - # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. - # collection.load() - start_time = time.time() - # res = collection.query( - # expr=f"keypoint_id == {sketch['image_id']}", - # offset=0, - # limit=1, - # output_fields=["keypoint_cache", "keypoint_site"], - # ) - res = [] - logging.info(f"search keypoint time : {time.time() - start_time}") + client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS) + keypoint_id = sketch['image_id'] + res = client.query( + collection_name=MILVUS_TABLE_KEYPOINT, + # ids=[keypoint_id], + filter=f"keypoint_id == {keypoint_id}", + output_fields=['keypoint_vector', 'keypoint_site'] + ) if len(res) == 0: # 没有结果 直接推理拿结果 并保存 keypoint_infer_result = self.infer_keypoint_result(sketch) @@ -348,7 +351,7 @@ class DesignPreprocessing: ] try: # connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) - start_time = time.time() + # start_time = time.time() # collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. # mr = collection.upsert(data) # logging.info(f"save keypoint time : {time.time() - start_time}") @@ -362,9 +365,9 @@ if __name__ == '__main__': data = { "sketches": [ { - "image_category": "dress", - "image_id": "107903", - "image_url": "aida-sys-image/images/female/dress/0628000000.jpg" + "image_category": "blouse", + "image_id": "123123123", + "image_url": "test/0628000198.jpg" } ] } diff --git a/app/service/image2sketch/checkpoints/download_checkpoints.py b/app/service/image2sketch/checkpoints/download_checkpoints.py new file mode 100644 index 0000000..03cc2c6 --- /dev/null +++ b/app/service/image2sketch/checkpoints/download_checkpoints.py @@ -0,0 +1,45 @@ +import os + +from minio import Minio +from minio.error import S3Error + +MINIO_URL = "www.minio.aida.com.hk:12024" +MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' +MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' +MINIO_SECURE = True +# 配置MinIO客户端 +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +# 下载函数 +def download_folder(bucket_name, folder_name, local_dir): + try: + # 确保本地目录存在 + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + # 遍历MinIO中的文件 + objects = minio_client.list_objects(bucket_name, prefix=folder_name, recursive=True) + for obj in objects: + # 构造本地文件路径 + local_file_path = os.path.join(local_dir, obj.object_name[len(folder_name):]) + local_file_dir = os.path.dirname(local_file_path) + + # 确保本地目录存在 + if not os.path.exists(local_file_dir): + os.makedirs(local_file_dir) + + # 下载文件 + minio_client.fget_object(bucket_name, obj.object_name, local_file_path) + print(f"Downloaded {obj.object_name} to {local_file_path}") + + except S3Error as e: + print(f"Error occurred: {e}") + + +# 使用示例 +bucket_name = "test" # 替换成你的bucket名称 +folder_name = "checkpoints/" # 权重文件夹的路径 +local_dir = "app/service/image2sketch/checkpoints" # 替换成你希望保存到的本地目录 + +download_folder(bucket_name, folder_name, local_dir) diff --git a/app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg b/app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg new file mode 100644 index 0000000..3a66b7f Binary files /dev/null and b/app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg differ diff --git a/app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg b/app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg new file mode 100644 index 0000000..0347322 Binary files /dev/null and b/app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg differ diff --git a/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png b/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png new file mode 100644 index 0000000..8d8bcf4 Binary files /dev/null and b/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png differ diff --git a/app/service/image2sketch/infer.py b/app/service/image2sketch/infer.py new file mode 100644 index 0000000..8ec241f --- /dev/null +++ b/app/service/image2sketch/infer.py @@ -0,0 +1,89 @@ +import os + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image + +from .models import create_model + + +def tensor2im(input_image, imtype=np.uint8): + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def save_image(image_numpy, image_path, w, h, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + image_pil = image_pil.resize((w, h)) + image_pil.save(image_path) + + +def save_img(image_tensor, w, h, filename): + image_pil = tensor2im(image_tensor) + + save_image(image_pil, filename, w, h, aspect_ratio=1.0) + print("Image saved as {}".format(filename)) + + +def load_img(filepath): + img = Image.open(filepath).convert('L') + # print(img.size) + width = img.size[0] + height = img.size[1] + # img = img.resize((512, 512), Image.BICUBIC) + return img, width, height + + +if __name__ == '__main__': + img_A = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testA/real_Dress_732caedc416a0cbfedd0e6528040eac7.jpg_Img.jpg" + img_B = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testC/style_3.png" + from opt import Config + + opt = Config() # get test options + # hard-code some parameters for test + opt.num_threads = 0 # test code only supports num_threads = 0 + opt.batch_size = 1 # test code only supports batch_size = 1 + opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + opt.no_flip = True # no flip; comment this line if results on flipped images are needed. + opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. + device = torch.device("cuda:0") + model = create_model(opt) # create a model given opt.model and other options + model.setup(opt) + transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + transform = transforms.Compose(transform_list) + if opt.eval: + model.eval() + data = {} + print(os.getcwd()) + B = reference, _, _ = load_img(r"/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png") + style_img = transform(reference) + data['B'] = style_img + data['B'] = data['B'].unsqueeze(0).to(device) + A = Image.open(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testA\real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg") + width = A.size[0] + height = A.size[1] + # data['A'] = A.resize((512, 512)) + data['A'] = transform(A) + data['A'] = data['A'].unsqueeze(0).to(device) + model.set_input(data) + model.test() # run inference + visuals = model.get_current_visuals() # get image results + save_img(visuals['content_output'].cpu(), width, height, "result/result.jpg") diff --git a/app/service/image2sketch/models/__init__.py b/app/service/image2sketch/models/__init__.py new file mode 100644 index 0000000..809105c --- /dev/null +++ b/app/service/image2sketch/models/__init__.py @@ -0,0 +1,49 @@ +import importlib + +from app.service.image2sketch.models import unpaired_model as modellib +from .base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + # model_filename = "." + model_name + "_model" + # modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from .models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/app/service/image2sketch/models/base_model.py b/app/service/image2sketch/models/base_model.py new file mode 100644 index 0000000..6de961b --- /dev/null +++ b/app/service/image2sketch/models/base_model.py @@ -0,0 +1,230 @@ +import os +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix) + self.print_networks(opt.verbose) + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + old_lr = self.optimizers[0].param_groups[0]['lr'] + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate %.7f -> %.7f' % (old_lr, lr)) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + print('loading the model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/app/service/image2sketch/models/layer.py b/app/service/image2sketch/models/layer.py new file mode 100644 index 0000000..df96a35 --- /dev/null +++ b/app/service/image2sketch/models/layer.py @@ -0,0 +1,354 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CNR2d(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]): + super().__init__() + + if bias == []: + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] + + if norm != []: + layers += [Norm2d(nch_out, norm)] + + if relu != []: + layers += [ReLU(relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + self.cbr = nn.Sequential(*layers) + + def forward(self, x): + return self.cbr(x) + + +class DECNR2d(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]): + super().__init__() + + if bias == []: + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)] + + if norm != []: + layers += [Norm2d(nch_out, norm)] + + if relu != []: + layers += [ReLU(relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + self.decbr = nn.Sequential(*layers) + + def forward(self, x): + return self.decbr(x) + + +class ResBlock(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): + super().__init__() + + if bias == []: + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + + # 1st conv + layers += [Padding(padding, padding_mode=padding_mode)] + layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + # 2nd conv + layers += [Padding(padding, padding_mode=padding_mode)] + layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] + + self.resblk = nn.Sequential(*layers) + + def forward(self, x): + return x + self.resblk(x) + + +class ResBlock_cat(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): + super().__init__() + + if bias == []: + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + + # 1st conv + layers += [Padding(padding, padding_mode=padding_mode)] + layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + # 2nd conv + layers += [Padding(padding, padding_mode=padding_mode)] + layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] + + self.resblk = nn.Sequential(*layers) + + def forward(self,x,y): + output = x + self.resblk(torch.cat([x,y],dim=1)) + return output + +class LinearBlock(nn.Module): + def __init__(self, input_dim, output_dim, norm='none', activation='relu'): + super(LinearBlock, self).__init__() + use_bias = True + # initialize fully connected layer + if norm == 'sn': + self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) + else: + self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm1d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm1d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + def forward(self, x): + out = self.fc(x) + if self.norm: + out = self.norm(out) + if self.activation: + out = self.activation(out) + return out + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): + + super(MLP, self).__init__() + self.model = [] + self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] + for i in range(n_blk - 2): + self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] + self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + +class CNR1d(nn.Module): + def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]): + super().__init__() + + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + layers += [nn.Linear(nch_in, nch_out, bias=bias)] + + if norm != []: + layers += [Norm2d(nch_out, norm)] + + if relu != []: + layers += [ReLU(relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + self.cbr = nn.Sequential(*layers) + + def forward(self, x): + return self.cbr(x) + + +class Conv2d(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True): + super(Conv2d, self).__init__() + self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) + + def forward(self, x): + return self.conv(x) + + +class Deconv2d(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True): + super(Deconv2d, self).__init__() + self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias) + + # layers = [nn.Upsample(scale_factor=2, mode='bilinear'), + # nn.ReflectionPad2d(1), + # nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)] + # + # self.deconv = nn.Sequential(*layers) + + def forward(self, x): + return self.deconv(x) + + +class Linear(nn.Module): + def __init__(self, nch_in, nch_out): + super(Linear, self).__init__() + self.linear = nn.Linear(nch_in, nch_out) + + def forward(self, x): + return self.linear(x) + + +class Norm2d(nn.Module): + def __init__(self, nch, norm_mode): + super(Norm2d, self).__init__() + if norm_mode == 'bnorm': + self.norm = nn.BatchNorm2d(nch) + elif norm_mode == 'inorm': + self.norm = nn.InstanceNorm2d(nch) + + def forward(self, x): + return self.norm(x) + + +class ReLU(nn.Module): + def __init__(self, relu): + super(ReLU, self).__init__() + if relu > 0: + self.relu = nn.LeakyReLU(relu, True) + elif relu == 0: + self.relu = nn.ReLU(True) + + def forward(self, x): + return self.relu(x) + + +class Padding(nn.Module): + def __init__(self, padding, padding_mode='zeros', value=0): + super(Padding, self).__init__() + if padding_mode == 'reflection': + self. padding = nn.ReflectionPad2d(padding) + elif padding_mode == 'replication': + self.padding = nn.ReplicationPad2d(padding) + elif padding_mode == 'constant': + self.padding = nn.ConstantPad2d(padding, value) + elif padding_mode == 'zeros': + self.padding = nn.ZeroPad2d(padding) + + def forward(self, x): + return self.padding(x) + + +class Pooling2d(nn.Module): + def __init__(self, nch=[], pool=2, type='avg'): + super().__init__() + + if type == 'avg': + self.pooling = nn.AvgPool2d(pool) + elif type == 'max': + self.pooling = nn.MaxPool2d(pool) + elif type == 'conv': + self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool) + + def forward(self, x): + return self.pooling(x) + + +class UnPooling2d(nn.Module): + def __init__(self, nch=[], pool=2, type='nearest'): + super().__init__() + + if type == 'nearest': + self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True) + elif type == 'bilinear': + self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True) + elif type == 'conv': + self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool) + + def forward(self, x): + return self.unpooling(x) + + +class Concat(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2): + diffy = x2.size()[2] - x1.size()[2] + diffx = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, + diffy // 2, diffy - diffy // 2]) + + return torch.cat([x2, x1], dim=1) + + +class TV1dLoss(nn.Module): + def __init__(self): + super(TV1dLoss, self).__init__() + + def forward(self, input): + # loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ + # torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) + loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:])) + + return loss + + +class TV2dLoss(nn.Module): + def __init__(self): + super(TV2dLoss, self).__init__() + + def forward(self, input): + loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ + torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) + return loss + + +class SSIM2dLoss(nn.Module): + def __init__(self): + super(SSIM2dLoss, self).__init__() + + def forward(self, input, targer): + loss = 0 + return loss + diff --git a/app/service/image2sketch/models/networks.py b/app/service/image2sketch/models/networks.py new file mode 100644 index 0000000..fc341c2 --- /dev/null +++ b/app/service/image2sketch/models/networks.py @@ -0,0 +1,734 @@ +import functools + +from torch.nn import init +from torch.optim import lr_scheduler + +from .layer import * + + +############################################################################### +# Helper Functions +############################################################################### + + +class Identity(nn.Module): + def forward(self, x): + return x + + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + def norm_layer(x): + return Identity() + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) + return lr_l + + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert (torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netG == 'ref_unpair_cbam_cat': + net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_cbam_cat') + elif netG == 'ref_unpair_recon': + net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_recon') + elif netG == 'triplet': + net = triplet(input_nc, output_nc, ngf, norm='inorm') + + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % netG) + return init_net(net, init_type, init_gain, gpu_ids) + + +class AdaIN(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + eps = 1e-5 + mean_x = torch.mean(x, dim=[2, 3]) + mean_y = torch.mean(y, dim=[2, 3]) + + std_x = torch.std(x, dim=[2, 3]) + std_y = torch.std(y, dim=[2, 3]) + + mean_x = mean_x.unsqueeze(-1).unsqueeze(-1) + mean_y = mean_y.unsqueeze(-1).unsqueeze(-1) + + std_x = std_x.unsqueeze(-1).unsqueeze(-1) + eps + std_y = std_y.unsqueeze(-1).unsqueeze(-1) + eps + + out = (x - mean_x) / std_x * std_y + mean_y + + return out + + +class HED(nn.Module): + def __init__(self): + super(HED, self).__init__() + + self.moduleVggOne = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleVggTwo = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleVggThr = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleVggFou = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleVggFiv = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0) + self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0) + self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0) + self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) + self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) + + self.moduleCombine = nn.Sequential( + nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0), + nn.Sigmoid() + ) + + def forward(self, tensorInput): + tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793 + tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762 + tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434 + tensorInput = torch.cat([tensorBlue, tensorGreen, tensorRed], 1) + + tensorVggOne = self.moduleVggOne(tensorInput) + tensorVggTwo = self.moduleVggTwo(tensorVggOne) + tensorVggThr = self.moduleVggThr(tensorVggTwo) + tensorVggFou = self.moduleVggFou(tensorVggThr) + tensorVggFiv = self.moduleVggFiv(tensorVggFou) + + tensorScoreOne = self.moduleScoreOne(tensorVggOne) + tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo) + tensorScoreThr = self.moduleScoreThr(tensorVggThr) + tensorScoreFou = self.moduleScoreFou(tensorVggFou) + tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv) + + tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + + return self.moduleCombine(torch.cat([tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv], 1)) + # return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreOne, tensorScoreTwo ], 1)) + + # return torch.sigmoid(tensorScoreOne),torch.sigmoid(tensorScoreTwo),torch.sigmoid(tensorScoreThr),torch.sigmoid(tensorScoreFou),torch.sigmoid(tensorScoreFiv),self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1)) + # return torch.sigmoid(tensorScoreTwo) + + +def define_HED(init_weights_, gpu_ids_=[]): + net = HED() + + if len(gpu_ids_) > 0: + assert (torch.cuda.is_available()) + net.to(gpu_ids_[0]) + net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs + + if not init_weights_ == None: + device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu') + print('Loading model from: %s' % init_weights_) + state_dict = torch.load(init_weights_, map_location=str(device)) + if isinstance(net, torch.nn.DataParallel): + net.module.load_state_dict(state_dict) + else: + net.load_state_dict(state_dict) + print('load the weights successfully') + + return net + + +def define_styletps(init_weights_, gpu_ids_=[], shape=False): + net = None + if shape == False: + net = triplet() + if len(gpu_ids_) > 0: + assert (torch.cuda.is_available()) + net.to(gpu_ids_[0]) + net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs + + if not init_weights_ == None: + device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu') + print('Loading model from: %s' % init_weights_) + state_dict = torch.load(init_weights_, map_location=str(device)) + if isinstance(net, torch.nn.DataParallel): + net.module.load_state_dict(state_dict) + else: + net.load_state_dict(state_dict) + print('load the weights successfully') + + return net + + +class triplet(nn.Module): + def __init__(self): # mnblk=4 + super(triplet, self).__init__() + + # self.channels = nch_in + self.nch_in = 1 + self.nch_out = 1 + self.nch_ker = 64 + self.norm = 'bnorm' + # self.nblk = nblk + + if self.norm == 'bnorm': + self.bias = False + else: + self.bias = True + + self.conv0 = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0) + self.conv1 = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.conv2 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + + self.final_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(256, 128) + + def forward(self, x, y, z): + + x = self.conv0(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.final_pool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + + y = self.conv0(y) + y = self.conv1(y) + y = self.conv2(y) + y = self.final_pool(y) + y = torch.flatten(y, 1) + y = self.linear(y) + + z = self.conv0(z) + z = self.conv1(z) + z = self.conv2(z) + z = self.final_pool(z) + z = torch.flatten(z, 1) + z = self.linear(z) + + return x, y, z + + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): + super(MLP, self).__init__() + self.model = [] + self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] + for i in range(n_blk - 2): + self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] + self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + + +class ref_unpair(nn.Module): + def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm', nblk=4, status='ref_unpair'): + super(ref_unpair, self).__init__() + + nch_ker = 64 + # self.channels = nch_in + self.nch_in = nch_in + self.nchs_in = 1 + self.status = status + + if self.status == 'ref_unpair_recon': + self.nch_out = 3 + self.nch_in = 1 + else: + self.nch_out = 1 + + self.nch_ker = nch_ker + self.norm = norm + self.nblk = nblk + self.dec0 = [] + + if status == 'ref_unpair_cbam_cat': + self.cbam_c = CBAM(nch_ker * 8, 16, 3, cbam_status="channel") + self.cbam_s = CBAM(nch_ker * 8, 16, 3, cbam_status="spatial") + + self.enc1_s = CNR2d(self.nchs_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0) + self.enc2_s = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.enc3_s = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.enc4_s = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + + if norm == 'bnorm': + self.bias = False + else: + self.bias = True + + self.enc1_c = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0) + self.enc2_c = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.enc3_c = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.enc4_c = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + + if status == 'ref_unpair_cbam_cat': + self.res_cat1 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection') + self.res_cat2 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection') + self.res_cat3 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection') + self.res_cat4 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection') + + if self.nblk and status != 'ref_unpair_cbam_cat': + res = [] + for i in range(self.nblk): + res += [ResBlock(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')] + self.res1 = nn.Sequential(*res) + + # self.dec0 += [DECNR2d(16 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)] + self.dec0 += [DECNR2d(8 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)] + self.dec0 += [DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)] + self.dec0 += [DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)] + self.dec0 += [DECNR2d(1 * self.nch_ker, 1 * self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)] + self.dec0 += [nn.Conv2d(1 * self.nch_ker, self.nch_out, kernel_size=3, stride=1, padding=1)] + + self.dec = nn.Sequential(*self.dec0) + + def forward(self, content, style): + + content_cs = self.enc1_c(content) + content_cs = self.enc2_c(content_cs) + content_cs = self.enc3_c(content_cs) + content_cs = self.enc4_c(content_cs) + # content_cs = self.enc5_c(content_cs) + + if self.status == 'ref_unpair_cbam_cat': + cbam_content_cs = self.cbam_s(content_cs) + sp_content_cs = content_cs + cbam_content_cs + + style_cs = self.enc1_s(style) + style_cs = self.enc2_s(style_cs) + style_cs = self.enc3_s(style_cs) + style_cs = self.enc4_s(style_cs) + + cbam_style_cs = self.cbam_c(style_cs) + ch_style_cs = style_cs + cbam_style_cs + + content_output = self.adaptive_instance_normalization(content_cs, style_cs) + cbam_content_output = self.adaptive_instance_normalization(sp_content_cs, ch_style_cs) + + content_output = self.res_cat1(content_output, cbam_content_output) + content_output = self.res_cat2(content_output, cbam_content_output) + content_output = self.res_cat3(content_output, cbam_content_output) + content_output = self.res_cat4(content_output, cbam_content_output) + + + else: + content_output = content_cs + + if self.nblk and self.status != 'ref_unpair_cbam_cat': + content_cs = self.res1(content_output) + + content_output = self.dec(content_output) + + content_output = torch.tanh(content_output) + + return content_output + + def calc_mean_std(self, feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + + def adaptive_instance_normalization(self, content_feat, style_feat): + assert (content_feat.size()[:2] == style_feat.size()[:2]) + size = content_feat.size() + style_mean, style_std = self.calc_mean_std(style_feat) + content_mean, content_std = self.calc_mean_std(content_feat) + + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netD == 'basic': # default PatchGAN classifier + net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) + elif netD == 'n_layers': # more options + net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) + elif netD == 'pixel': # classify if each pixel is real or fake + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) + return init_net(net, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## +class GANLoss(nn.Module): + """Define different GAN objectives. + + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """ Initialize the GANLoss class. + + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss + + +def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): + if lambda_gp > 0.0: + if type == 'real': # either use real images, fake images, or a linear interpolation of two. + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = torch.rand(real_data.shape[0], 1, device=device) + alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + interpolatesv.requires_grad_(True) + disc_interpolates = netD(interpolatesv) + gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True) + gradients = gradients[0].view(real_data.size(0), -1) # flat the data + gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps + return gradient_penalty, gradients + else: + return 0.0, None + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class PixelDiscriminator(nn.Module): + """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" + + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + """Construct a 1x1 PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + """ + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + """Standard forward.""" + return self.net(input) + + +class CBAM(nn.Module): + def __init__(self, n_channels_in, reduction_ratio, kernel_size, cbam_status): + super(CBAM, self).__init__() + self.n_channels_in = n_channels_in + self.reduction_ratio = reduction_ratio + self.kernel_size = kernel_size + self.channel_attention = ChannelAttention_nopara(n_channels_in, reduction_ratio) + self.spatial_attention = SpatialAttention_nopara(kernel_size) + self.status = cbam_status + + def forward(self, x): + ## We don't use cbam in this version + if self.status == "cbam": + chan_att = self.channel_attention(x) + fp = chan_att * x + spat_att = self.spatial_attention(fp) + fpp = spat_att * fp + + if self.status == "spatial": + spat_att = self.spatial_attention(x) # * s_para_1d + fpp = spat_att * x + if self.status == "channel": + chan_att = self.channel_attention(x) # * c_para_1d + fpp = chan_att * x + + return fpp # ,c_wgt,s_wgt + + +class SpatialAttention_nopara(nn.Module): + def __init__(self, kernel_size): + super(SpatialAttention_nopara, self).__init__() + self.kernel_size = kernel_size + assert kernel_size % 2 == 1, "Odd kernel size required" + self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=int((kernel_size - 1) / 2)) + + def forward(self, x): + max_pool = self.agg_channel(x, "max") + avg_pool = self.agg_channel(x, "avg") + pool = torch.cat([max_pool, avg_pool], dim=1) + conv = self.conv(pool) + conv = conv.repeat(1, x.size()[1], 1, 1) + att = torch.sigmoid(conv) + return att + + def agg_channel(self, x, pool="max"): + b, c, h, w = x.size() + x = x.view(b, c, h * w) + x = x.permute(0, 2, 1) + if pool == "max": + x = F.max_pool1d(x, c) + elif pool == "avg": + x = F.avg_pool1d(x, c) + x = x.permute(0, 2, 1) + x = x.view(b, 1, h, w) + return x + + +class ChannelAttention_nopara(nn.Module): + def __init__(self, n_channels_in, reduction_ratio): + super(ChannelAttention_nopara, self).__init__() + self.n_channels_in = n_channels_in + self.reduction_ratio = reduction_ratio + self.middle_layer_size = int(self.n_channels_in / float(self.reduction_ratio)) + self.bottleneck = nn.Sequential( + nn.Linear(self.n_channels_in, self.middle_layer_size), + nn.ReLU(), + nn.Linear(self.middle_layer_size, self.n_channels_in) + ) + + def forward(self, x): + kernel = (x.size()[2], x.size()[3]) + avg_pool = F.avg_pool2d(x, kernel) + max_pool = F.max_pool2d(x, kernel) + avg_pool = avg_pool.view(avg_pool.size()[0], -1) + max_pool = max_pool.view(max_pool.size()[0], -1) + avg_pool_bck = self.bottleneck(avg_pool) + max_pool_bck = self.bottleneck(max_pool) + pool_sum = avg_pool_bck + max_pool_bck + sig_pool = torch.sigmoid(pool_sum) + sig_pool = sig_pool.unsqueeze(2).unsqueeze(3) + # out = sig_pool.repeat(1,1,kernel[0], kernel[1]) + + return sig_pool diff --git a/app/service/image2sketch/models/perceptual.py b/app/service/image2sketch/models/perceptual.py new file mode 100644 index 0000000..666fab8 --- /dev/null +++ b/app/service/image2sketch/models/perceptual.py @@ -0,0 +1,86 @@ +import torch +import torchvision + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, resize=True): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) + for bl in blocks: + for p in bl: + p.requires_grad = False + self.blocks = torch.nn.ModuleList(blocks) + self.transform = torch.nn.functional.interpolate + self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) + self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) + self.resize = resize + + def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): + if input.shape[1] != 3: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + input = (input-self.mean) / self.std + target = (target-self.mean) / self.std + if self.resize: + input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) + target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) + loss = 0.0 + x = input + y = target + for i, block in enumerate(self.blocks): + x = block(x) + y = block(y) + if i in feature_layers: + loss += torch.nn.functional.l1_loss(x, y) + if i in style_layers: + act_x = x.reshape(x.shape[0], x.shape[1], -1) + act_y = y.reshape(y.shape[0], y.shape[1], -1) + gram_x = act_x @ act_x.permute(0, 2, 1) + gram_y = act_y @ act_y.permute(0, 2, 1) + loss += torch.nn.functional.l1_loss(gram_x, gram_y) + return loss + +class VGGstyleLoss(torch.nn.Module): + def __init__(self, resize=True): + super(VGGstyleLoss, self).__init__() + blocks = [] + blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) + for bl in blocks: + for p in bl: + p.requires_grad = False + self.blocks = torch.nn.ModuleList(blocks) + self.transform = torch.nn.functional.interpolate + self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) + self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) + self.resize = resize + + def forward(self, input, target, feature_layers=[0,1,2,3], style_layers=[]): + if input.shape[1] != 3: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + input = (input-self.mean) / self.std + target = (target-self.mean) / self.std + if self.resize: + input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) + target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) + loss = 0.0 + x = input + y = target + for i, block in enumerate(self.blocks): + x = block(x) + y = block(y) + if i in feature_layers: + loss += torch.nn.functional.l1_loss(x, y) + if i in style_layers: + act_x = x.reshape(x.shape[0], x.shape[1], -1) + act_y = y.reshape(y.shape[0], y.shape[1], -1) + gram_x = act_x @ act_x.permute(0, 2, 1) + gram_y = act_y @ act_y.permute(0, 2, 1) + loss += torch.nn.functional.l1_loss(gram_x, gram_y) + return loss diff --git a/app/service/image2sketch/models/template_model.py b/app/service/image2sketch/models/template_model.py new file mode 100644 index 0000000..45c68b2 --- /dev/null +++ b/app/service/image2sketch/models/template_model.py @@ -0,0 +1,82 @@ +import torch +from .base_model import BaseModel +from . import networks + + +class TemplateModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new model-specific options and rewrite default values for existing options. + + Parameters: + parser -- the option parser + is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. + if is_train: + parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. + + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. + self.loss_names = ['loss_G'] + # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. + self.visual_names = ['data_A', 'data_B', 'output'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. + # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. + self.model_names = ['G'] + # define networks; you can use opt.isTrain to specify different behaviors for training and test. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) + if self.isTrain: # only defined during training time + # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. + # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) + self.criterionLoss = torch.nn.L1Loss() + # define and initialize optimizers. You can define one optimizer for each network. + # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [self.optimizer] + + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B + self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A + self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B + self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths + + def forward(self): + """Run forward pass. This will be called by both functions and .""" + self.output = self.netG(self.data_A) # generate output image given the input data_A + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # caculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression + self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.forward() # first call forward to calculate intermediate results + self.optimizer.zero_grad() # clear network G's existing gradients + self.backward() # calculate gradients for network G + self.optimizer.step() # update gradients for network G diff --git a/app/service/image2sketch/models/test_model.py b/app/service/image2sketch/models/test_model.py new file mode 100644 index 0000000..2f70821 --- /dev/null +++ b/app/service/image2sketch/models/test_model.py @@ -0,0 +1,45 @@ +from .base_model import BaseModel +from . import networks + + +class TestModel(BaseModel): + """ This TesteModel can be used to generate CycleGAN results for only one direction. + This model will automatically set '--dataset_mode single', which only loads the images from one collection. + + See the test instruction for more details. + """ + @staticmethod + def modify_commandline_options(parser, is_train=True): + assert not is_train, 'TestModel cannot be used during training time' + parser.set_defaults(dataset_mode='single') + parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.') + + return parser + + def __init__(self, opt): + assert(not opt.isTrain) + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + self.loss_names = [] + # specify the images you want to save/display. The training/test scripts will call + self.visual_names = ['real', 'fake'] + # specify the models you want to save to the disk. The training/test scripts will call and + self.model_names = ['G' + opt.model_suffix] # only generator is needed. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, + opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + # assigns the model to self.netG_[suffix] so that it can be loaded + # please see + setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. + + def set_input(self, input): + self.real = input['A'].to(self.device) + self.image_paths = input['A_paths'] + + def forward(self): + """Run forward pass.""" + self.fake = self.netG(self.real) # G(real) + + def optimize_parameters(self): + """No optimization for test model.""" + pass diff --git a/app/service/image2sketch/models/triplet_model.py b/app/service/image2sketch/models/triplet_model.py new file mode 100644 index 0000000..a667d49 --- /dev/null +++ b/app/service/image2sketch/models/triplet_model.py @@ -0,0 +1,68 @@ +import torch +from .base_model import BaseModel +from . import networks +from util.image_pool import ImagePool + + +class TripletModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=True): + parser.set_defaults(norm='batch', netG='triplet', dataset_mode='triplet') + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla') + parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') + + return parser + + def __init__(self, opt): + + BaseModel.__init__(self, opt) + + self.loss_names = ['G_triplet'] + self.visual_names = ['x','y'] + + if self.isTrain: + self.model_names = ['G'] + else: + self.model_names = ['G'] + self.netG = networks.define_G(1, 1, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + + if self.isTrain: + self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images + self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images + + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + + self.triplet = torch.nn.TripletMarginLoss(margin=3.0) + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + + def set_input(self, input): + AtoB = self.opt.direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) + self.real_C = input['C'].to(self.device) + + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + + + def forward(self): + self.x,self.y,self.z = self.netG(self.real_A,self.real_B,self.real_C) + + + def backward_G(self): + self.loss_G_triplet_1 = self.triplet(self.x,self.y,self.z) + self.loss_G_triplet = self.loss_G_triplet_1 + + self.loss_G = self.loss_G_triplet + self.loss_G.backward() + + def optimize_parameters(self): + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() diff --git a/app/service/image2sketch/models/unpaired_model.py b/app/service/image2sketch/models/unpaired_model.py new file mode 100644 index 0000000..9c043ca --- /dev/null +++ b/app/service/image2sketch/models/unpaired_model.py @@ -0,0 +1,144 @@ +import torch + +from . import networks +from .base_model import BaseModel +from .perceptual import VGGPerceptualLoss +from ..util.image_pool import ImagePool + + +class UnpairedModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=True): + parser.set_defaults(norm='batch', netG='ref_unpair_cbam_cat', netG2='ref_unpair_recon', dataset_mode='unaligned') + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla') + parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') + + return parser + + def __init__(self, opt): + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + self.loss_names = ['G_GAN', 'G_L1_1', 'G_Rec', 'G_line', 'D_real', 'D_fake'] + self.visual_names = ['real_A', 'content_output', 'real_B'] + + if self.isTrain: + self.model_names = ['G_A', 'G_B', 'D'] + else: # during test time, only load G + self.model_names = ['G_A', 'G_B'] + # define networks (both generator and discriminator) + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG2, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + self.netD = networks.define_D(1, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + self.styletps = networks.define_styletps(init_weights_='./checkpoints/contrastive_pretrained.pth', gpu_ids_=self.gpu_ids, shape=False) + self.HED = networks.define_HED(init_weights_='./checkpoints/network-bsds500.pytorch', gpu_ids_=self.gpu_ids) + + if self.isTrain: # define discriminators + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images + self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1_1 = torch.nn.L1Loss() + self.criterionL1_2 = torch.nn.L1Loss() + self.criterionL1_3 = torch.nn.L1Loss() + self.per_loss_1 = VGGPerceptualLoss().to(self.device) + self.per_loss_2 = VGGPerceptualLoss().to(self.device) + self.per_loss_3 = VGGPerceptualLoss().to(self.device) + + self.optimizer_GA = torch.optim.Adam(self.netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_GB = torch.optim.Adam(self.netG_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_GA) + self.optimizers.append(self.optimizer_GB) + + self.optimizers.append(self.optimizer_D) + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + + The option 'direction' can be used to swap images in domain A and domain B. + """ + AtoB = self.opt.direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) + # self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + """Run forward pass; called by both functions and .""" + self.content_output = self.netG_A(self.real_A, self.real_B) + self.rec_output = self.netG_B(self.content_output, self.content_output) + + def update_process(self, epoch, total_epoch): + self.epoch_count = epoch + self.epoch_count_total = total_epoch + + def backward_D(self): + """Calculate GAN loss for the discriminator + + Parameters: + netD (network) -- the discriminator D + real (tensor array) -- real images + fake (tensor array) -- images generated by a generator + + Return the discriminator loss. + We also call loss_D.backward() to calculate the gradients. + """ + # Real + pred_real = self.netD(self.real_B) + self.loss_D_real = self.criterionGAN(pred_real, True) + # Fake + pred_fake = self.netD(self.content_output.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Combined loss and calculate gradients + loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5 + loss_D.backward() + return loss_D + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + + pred_fake = self.netD(self.content_output) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + + self.content_output_line = self.HED(self.real_A) + self.rec_output_line = self.HED(self.rec_output) + self.t1, self.t2, _ = self.styletps(self.content_output, self.real_B, self.real_B) + + decay_lambda = 5 - ((self.epoch_count * 4.5) / self.epoch_count_total) + self.loss_G_L1_1 = self.criterionL1_1(self.t1, self.t2) * 10 + self.loss_G_Rec = self.per_loss_2(self.real_A, self.rec_output) * decay_lambda + self.loss_G_line = self.per_loss_3(self.content_output_line, self.rec_output_line) * decay_lambda + + self.loss_G = self.loss_G_GAN + self.loss_G_L1_1 + self.loss_G_Rec + self.loss_G_line + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for backward_D_unsuper + self.optimizer_D.step() # update D's weights + # update G + self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G + self.optimizer_GA.zero_grad() # set G's gradients to zero + self.optimizer_GB.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer_GA.step() # udpate G's weights + self.optimizer_GB.step() # udpate G's weights diff --git a/app/service/image2sketch/opt.py b/app/service/image2sketch/opt.py new file mode 100644 index 0000000..eb453fb --- /dev/null +++ b/app/service/image2sketch/opt.py @@ -0,0 +1,57 @@ +from app.core.config import DEBUG + + +class Config: + def __init__(self): + # 基本参数 + self.dataroot = "app/service/image2sketch/datasets/ref_unpair" + self.name = 'semi_unpair' + self.gpu_ids = [0] + # 模型参数 + self.model = 'unpaired' + self.input_nc = 3 + self.output_nc = 3 + self.ngf = 64 + self.ndf = 64 + self.netD = 'basic' + self.netG = 'ref_unpair_cbam_cat' + self.netG2 = 'ref_unpair_recon' + self.n_layers_D = 3 + self.norm = 'instance' + self.init_type = 'normal' + self.init_gain = 0.02 + self.no_dropout = False # 对应 `--no_dropout` + # 数据集参数 + self.dataset_mode = 'single' + self.direction = 'AtoB' + self.serial_batches = True # 对应 `--serial_batches` + self.num_threads = 4 + self.batch_size = 4 + self.load_size = 512 + self.crop_size = 512 + self.max_dataset_size = float("inf") + self.preprocess = 'resize_and_crop' + self.no_flip = False # 对应 `--no_flip` + self.display_winsize = 256 + # 额外参数 + self.epoch = '100' + self.load_iter = 0 + self.verbose = False # 对应 `--verbose` + self.suffix = '' + self.isTrain = False + self.results_dir = 'service/image2sketch/results' + self.aspect_ratio = 1.0 + self.phase = 'test' + self.eval = False + self.num_test = 1000 + self.morm = 'batch' + if DEBUG: + self.style_image1 = "service/image2sketch/datasets/ref_unpair/testC/style_1.jpg" + self.style_image2 = "service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg" + self.style_image3 = "service/image2sketch/datasets/ref_unpair/testC/style_3.png" + self.checkpoints_dir = 'service/image2sketch/checkpoints/' + else: + self.checkpoints_dir = 'app/service/image2sketch/checkpoints/' + self.style_image1 = "app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg" + self.style_image2 = "app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg" + self.style_image3 = "app/service/image2sketch/datasets/ref_unpair/testC/style_3.png" diff --git a/app/service/image2sketch/server.py b/app/service/image2sketch/server.py new file mode 100644 index 0000000..3094eea --- /dev/null +++ b/app/service/image2sketch/server.py @@ -0,0 +1,88 @@ +import logging + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image + +from app.schemas.image2sketch import Image2SketchModel +from app.service.image2sketch.infer import tensor2im +from app.service.image2sketch.models import create_model +from app.service.image2sketch.opt import Config +from app.service.utils.oss_client import oss_get_image, oss_upload_image + +logger = logging.getLogger() + + +def tensor2im(input_image, imtype=np.uint8): + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +class Image2SketchServer: + def __init__(self, request_data): + self.image_url = request_data.image_url + self.style_image_url = request_data.style_image_url + self.sketch_bucket = request_data.sketch_bucket + self.sketch_name = request_data.sketch_name + self.opt = Config() + self.opt.num_threads = 0 # test code only supports num_threads = 0 + self.opt.batch_size = 1 # test code only supports batch_size = 1 + self.opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + self.opt.no_flip = True # no flip; comment this line if results on flipped images are needed. + self.opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. + self.data = {} + device = torch.device("cuda:0") + self.model = create_model(self.opt) + self.model.setup(self.opt) + transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + transform = transforms.Compose(transform_list) + if request_data.default_style == "1": + style_img = Image.open(self.opt.style_image1).convert('L') + elif request_data.default_style == "2": + style_img = Image.open(self.opt.style_image2).convert('L') + elif request_data.default_style == "3": + style_img = Image.open(self.opt.style_image3).convert('L') + else: + style_img = oss_get_image(bucket=self.style_image_url.split('/')[0], object_name=self.style_image_url[self.style_image_url.find('/') + 1:], data_type="PIL") + style_img = style_img.convert('L') + style_img = transform(style_img) + self.data['B'] = style_img + self.data['B'] = self.data['B'].unsqueeze(0).to(device) + A, self.width, self.height = self.get_image(self.image_url) + self.data['A'] = transform(A) + self.data['A'] = self.data['A'].unsqueeze(0).to(device) + + def get_result(self): + self.model.set_input(self.data) + self.model.test() # run inference + visuals = self.model.get_current_visuals() # get image results + image_numpy = tensor2im(visuals['content_output'].cpu()) + image_bytes = cv2.imencode(".jpg", image_numpy)[1].tobytes() + req = oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes) + return f"{req.bucket_name}/{req.object_name}" + + def get_image(self, image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + image = image.convert('RGB') + width = image.size[0] + height = image.size[1] + return image, width, height + + +if __name__ == '__main__': + data = Image2SketchModel(image_url="test/real_Dress_790b2c6e370644e134df7abdfe7e54d9.jpg_Img.jpg", sketch_bucket="test", sketch_name="test123.jpg") + server = Image2SketchServer(data) + sketch_url = server.get_result() + print(sketch_url) diff --git a/app/service/image2sketch/util/__init__.py b/app/service/image2sketch/util/__init__.py new file mode 100644 index 0000000..ae36f63 --- /dev/null +++ b/app/service/image2sketch/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/app/service/image2sketch/util/get_data.py b/app/service/image2sketch/util/get_data.py new file mode 100644 index 0000000..97edc3c --- /dev/null +++ b/app/service/image2sketch/util/get_data.py @@ -0,0 +1,110 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """A Python script for downloading CycleGAN or pix2pix datasets. + + Parameters: + technique (str) -- One of: 'cyclegan' or 'pix2pix'. + verbose (bool) -- If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' + and 'scripts/download_cyclegan_model.sh'. + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Parameters: + save_path (str) -- A directory to save the data to. + dataset (str) -- (optional). A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full (str) -- the absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/app/service/image2sketch/util/html.py b/app/service/image2sketch/util/html.py new file mode 100644 index 0000000..cc3262a --- /dev/null +++ b/app/service/image2sketch/util/html.py @@ -0,0 +1,86 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + """save the current content to the HMTL file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/app/service/image2sketch/util/image_pool.py b/app/service/image2sketch/util/image_pool.py new file mode 100644 index 0000000..6d086f8 --- /dev/null +++ b/app/service/image2sketch/util/image_pool.py @@ -0,0 +1,54 @@ +import random +import torch + + +class ImagePool(): + """This class implements an image buffer that stores previously generated images. + + This buffer enables us to update discriminators using a history of generated images + rather than the ones produced by the latest generators. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_imgs = 0 + self.images = [] + + def query(self, images): + """Return an image from the pool. + + Parameters: + images: the latest generated images from the generator + + Returns images from the buffer. + + By 50/100, the buffer will return input images. + By 50/100, the buffer will return images previously stored in the buffer, + and insert the current images to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_images.append(image) + return_images = torch.cat(return_images, 0) # collect all the images and return + return return_images diff --git a/app/service/image2sketch/util/util.py b/app/service/image2sketch/util/util.py new file mode 100644 index 0000000..b050c13 --- /dev/null +++ b/app/service/image2sketch/util/util.py @@ -0,0 +1,103 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + if aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) diff --git a/app/service/image2sketch/util/visualizer.py b/app/service/image2sketch/util/visualizer.py new file mode 100644 index 0000000..239c5ee --- /dev/null +++ b/app/service/image2sketch/util/visualizer.py @@ -0,0 +1,223 @@ +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE + + +if sys.version_info[0] == 2: + VisdomExceptionBase = Exception +else: + VisdomExceptionBase = ConnectionError + + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(im, save_path, aspect_ratio=aspect_ratio) + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: connect to a visdom server + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.port = opt.display_port + self.saved = False + ''' + if self.display_id > 0: # connect to a visdom server given and + import visdom + self.ncols = opt.display_ncols + self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) + if not self.vis.check_connection(): + self.create_visdom_connections() + ''' + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + ''' + def create_visdom_connections(self): + """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ + cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port + print('\n\nCould not connect to Visdom server. \n Trying to start a server....') + print('Command: %s' % cmd) + Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) + + def display_current_results(self, visuals, epoch, save_result): + """Display current results on visdom; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + if self.display_id > 0: # show images in the browser using visdom + ncols = self.ncols + if ncols > 0: # show all the images in one visdom panel + ncols = min(ncols, len(visuals)) + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) # create a table css + # create a table of images. + title = self.name + label_html = '' + label_html_row = '' + images = [] + idx = 0 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + try: + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win=self.display_id + 2, + opts=dict(title=title + ' labels')) + except VisdomExceptionBase: + self.create_visdom_connections() + + else: # show each image in a separate visdom panel; + idx = 1 + try: + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + except VisdomExceptionBase: + self.create_visdom_connections() + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + ''' + def plot_current_losses(self, epoch, counter_ratio, losses): + """display the current losses on visdom display: dictionary of error labels and values + + Parameters: + epoch (int) -- current epoch + counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + """ + if not hasattr(self, 'plot_data'): + self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) + ''' + try: + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + except VisdomExceptionBase: + self.create_visdom_connections() + ''' + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message diff --git a/app/service/image2sketch_2/download_checkpoints.py b/app/service/image2sketch_2/download_checkpoints.py new file mode 100644 index 0000000..9048c34 --- /dev/null +++ b/app/service/image2sketch_2/download_checkpoints.py @@ -0,0 +1,45 @@ +import os + +from minio import Minio +from minio.error import S3Error + +MINIO_URL = "www.minio.aida.com.hk:12024" +MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' +MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' +MINIO_SECURE = True +# 配置MinIO客户端 +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +# 下载函数 +def download_folder(bucket_name, folder_name, local_dir): + try: + # 确保本地目录存在 + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + # 遍历MinIO中的文件 + objects = minio_client.list_objects(bucket_name, prefix=folder_name, recursive=True) + for obj in objects: + # 构造本地文件路径 + local_file_path = os.path.join(local_dir, obj.object_name[len(folder_name):]) + local_file_dir = os.path.dirname(local_file_path) + + # 确保本地目录存在 + if not os.path.exists(local_file_dir): + os.makedirs(local_file_dir) + + # 下载文件 + minio_client.fget_object(bucket_name, obj.object_name, local_file_path) + print(f"Downloaded {obj.object_name} to {local_file_path}") + + except S3Error as e: + print(f"Error occurred: {e}") + + +# 使用示例 +bucket_name = "test" # 替换成你的bucket名称 +folder_name = "checkpoints/lineart/" # 权重文件夹的路径 +local_dir = "app/service/image2sketch_2" # 替换成你希望保存到的本地目录 + +download_folder(bucket_name, folder_name, local_dir) diff --git a/app/service/image2sketch_2/server.py b/app/service/image2sketch_2/server.py new file mode 100644 index 0000000..41c0278 --- /dev/null +++ b/app/service/image2sketch_2/server.py @@ -0,0 +1,142 @@ +import cv2 +import numpy +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms as transforms +from PIL import Image + +from app.service.utils.oss_client import oss_get_image, oss_upload_image + +norm_layer = nn.InstanceNorm2d + +weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)] +kernel = np.ones((3, 3), np.uint8) + + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) + + +class Generator(nn.Module): + def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): + super(Generator, self).__init__() + + # Initial convolution block + model0 = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True)] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features * 2 + for _ in range(2): + model1 += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True)] + in_features = out_features + out_features = in_features * 2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features // 2 + for _ in range(2): + model3 += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True)] + in_features = out_features + out_features = in_features // 2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [nn.ReflectionPad2d(3), + nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +model1 = Generator(3, 1, 3) +model1.load_state_dict(torch.load('app/service/image2sketch_2/model.pth', map_location=torch.device('cpu'))) +model1.eval() + + +def predict(input_img, width): + transform = transforms.Compose([transforms.Resize(width, Image.BICUBIC), transforms.ToTensor()]) + input_img = transform(input_img) + input_img = torch.unsqueeze(input_img, 0) + + with torch.no_grad(): + drawing = model1(input_img)[0].detach() + + drawing = transforms.ToPILImage()(drawing) + + # 转ndarray + drawing = numpy.array(drawing) + return drawing + + +def get_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + image = image.convert('RGB') + width = image.size[0] + height = image.size[1] + return image, width, height + + +def processing_pipeline(image_url, thickness, sketch_bucket, sketch_name): + thickness = int(thickness) + # 提取sketch + image, width, height = get_image(image_url) + sketch_image = predict(image, width) + + # 设定线条粗细 + if thickness != 0: + dilated = cv2.erode(sketch_image, kernel, iterations=1) + # 将原图与膨胀后的图像进行混合,使用不同的权重 + sketch_image = cv2.addWeighted(sketch_image, weights[thickness][0], dilated, weights[thickness][1], 0) + + # 上传minio + image_bytes = cv2.imencode(".jpg", sketch_image)[1].tobytes() + req = oss_upload_image(bucket=sketch_bucket, object_name=sketch_name, image_bytes=image_bytes) + return f"{req.bucket_name}/{req.object_name}" + + +if __name__ == '__main__': + result_url = processing_pipeline("aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png", 1, "test", "test123.jpg") + print(result_url) diff --git a/app/service/lineart/service.py b/app/service/lineart/service.py new file mode 100644 index 0000000..d822dfa --- /dev/null +++ b/app/service/lineart/service.py @@ -0,0 +1,99 @@ +import logging + +import cv2 +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient + +from app.core.config import DESIGN_MODEL_URL +from app.schemas.image2sketch import Image2SketchModel +from app.service.utils.oss_client import oss_get_image, oss_upload_image + +logger = logging.getLogger() + + +class LineArtService: + def __init__(self, request_item): + self.line_style = int(request_item.default_style) + self.image_url = request_item.image_url + self.sketch_bucket = request_item.sketch_bucket + self.sketch_name = request_item.sketch_name + self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)] + + def get_result(self): + client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + input_image = self.get_image() + input_img, ori_shape = self.line_art_preprocess(input_image) + transformed_img = input_img.astype(np.float32) + + inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)] + results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs) + inference_output1 = results.as_numpy("output__0") + line_art_result = self.line_art_postprocess(inference_output1, ori_shape) + + line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8) + if self.line_style != 0: + logger.info(self.line_style) + kernel = np.ones((3, 3), np.uint8) + dilated = cv2.erode(line_art_result, kernel, iterations=1) + # 将原图与膨胀后的图像进行混合,使用不同的权重 + line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0) + # cv2.imshow("", line_art_result) + # cv2.waitKey(0) + return self.put_image(line_art_result) + + def get_image(self): + image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") + # 将其转换为彩色图像 + if len(image.shape) == 3 and image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + elif len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + return image + + def put_image(self, image): + try: + image_bytes = cv2.imencode('.jpg', image)[1].tobytes() + oss_upload_image(bucket=self.sketch_bucket, object_name=f"{self.sketch_name}.jpg", image_bytes=image_bytes) + return f"{self.sketch_bucket}/{self.sketch_name}.jpg" + except Exception as e: + logger.warning(e) + + @staticmethod + def line_art_preprocess(image): + img = mmcv.imread(image) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + # 如果图片size任意一边 大于 1024, 则会resize 成1024 + if ori_shape != (img_scale_w, img_scale_h): + # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 + img = cv2.resize(img, (img_scale_h, img_scale_w)) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + @staticmethod + def line_art_postprocess(output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + +if __name__ == '__main__': + request_item = Image2SketchModel( + image_url="aida-collection-element/87/Sketchboard/555a443f-fd6b-4cd7-8147-b92d55513af0.png", + default_style="4", + sketch_bucket="test", + sketch_name="test123" + ) + service = LineArtService(request_item) + result_url = service.get_result() + print(result_url) diff --git a/app/service/utils/decorator.py b/app/service/utils/decorator.py index 294b54b..3e86182 100644 --- a/app/service/utils/decorator.py +++ b/app/service/utils/decorator.py @@ -1,5 +1,5 @@ -import time import logging +import time def RunTime(func): @@ -7,8 +7,22 @@ def RunTime(func): t1 = time.time() res = func(*args, **kwargs) t2 = time.time() - if t2 - t1 > 0.05: - logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + # if t2 - t1 > 0.05: + # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") return res return wrapper + + +def ClassCallRunTime(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + class_name = args[0].__class__.__name__ # 获取类名 + print(f"class name: {class_name} , run time is : {execution_time} s") + return result + + return wrapper diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py new file mode 100644 index 0000000..95a0fbf --- /dev/null +++ b/app/service/utils/new_oss_client.py @@ -0,0 +1,94 @@ +import io +import logging +from io import BytesIO + +import cv2 +import numpy as np +import urllib3 +from PIL import Image +from minio import Minio + +from app.core.config import * +from app.service.utils.decorator import RunTime + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +# 自定义 Retry 类 +class CustomRetry(urllib3.Retry): + def increment(self, method=None, url=None, response=None, error=None, **kwargs): + # 调用父类的 increment 方法 + new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs) + # 打印重试信息 + logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}") + return new_retry + + +logger = logging.getLogger() +timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒 +http_client = urllib3.PoolManager( + num_pools=10, # 设置连接池大小 + maxsize=10, + timeout=timeout, + cert_reqs='CERT_REQUIRED', # 需要证书验证 + retries=CustomRetry( + total=5, + backoff_factor=0.2, + status_forcelist=[500, 502, 503, 504], + ), +) + + +# 获取图片 +@RunTime +def oss_get_image(oss_client, bucket, object_name, data_type): + # cv2 默认全通道读取 + image_object = None + try: + image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) + if data_type == "cv2": + image_bytes = image_data.read() + image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 + image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED) + if image_object.dtype == np.uint16: + image_object = (image_object / 256).astype('uint8') + else: + data_bytes = BytesIO(image_data.read()) + image_object = Image.open(data_bytes) + except Exception as e: + logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}") + return image_object + + +@RunTime +def oss_upload_image(oss_client, bucket, object_name, image_bytes): + req = None + try: + req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') + except Exception as e: + logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") + return req + + +if __name__ == '__main__': + # url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png" + # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" + # url = "aida-sys-image/images/female/outwear/0628000054.jpg" + # url = "aida-users/89/product_image/string-89.png" + # url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png" + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" + # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" + # url = "aida-users/89/single_logo/123-89.png" + url = "aida-users/31/sketchboard/female/dress/6edcbf92-7da9-4809-a0a8-a4b4f06dec1e0628000041.jpg" + # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" + read_type = "cv2" + if read_type == "cv2": + img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + cv2.imshow("", img) + cv2.waitKey(0) + else: + img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + img.show() diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index 370cd7c..0bd9853 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -1,16 +1,38 @@ import io import logging from io import BytesIO - -import boto3 import cv2 import numpy as np +import urllib3 from PIL import Image from minio import Minio from app.core.config import * + +# 自定义 Retry 类 +class CustomRetry(urllib3.Retry): + def increment(self, method=None, url=None, response=None, error=None, **kwargs): + # 调用父类的 increment 方法 + new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs) + # 打印重试信息 + logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}") + return new_retry + + logger = logging.getLogger() +timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒 +http_client = urllib3.PoolManager( + num_pools=10, # 设置连接池大小 + maxsize=10, + timeout=timeout, + cert_reqs='CERT_REQUIRED', # 需要证书验证 + retries=CustomRetry( + total=5, + backoff_factor=0.2, + status_forcelist=[500, 502, 503, 504], + ), +) # 获取图片 @@ -18,12 +40,8 @@ def oss_get_image(bucket, object_name, data_type): # cv2 默认全通道读取 image_object = None try: - if OSS == "minio": - oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) - else: - oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) - image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body'] + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE, http_client=http_client) + image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) if data_type == "cv2": image_bytes = image_data.read() image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 @@ -41,12 +59,8 @@ def oss_get_image(bucket, object_name, data_type): def oss_upload_image(bucket, object_name, image_bytes): req = None try: - if OSS == "minio": - oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') - else: - oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) - req = oss_client.put_object(Bucket=bucket, Key=object_name, Body=io.BytesIO(image_bytes), ContentType='image/png') + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') except Exception as e: logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") return req @@ -64,8 +78,8 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - # url = "aida-users/89/product_image/string-89.png" - url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" + url = "aida-results/result_e2673d92-8d25-11ef-be24-0826ae3ad6b3.png" + # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "cv2" if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) diff --git a/requirements.txt b/requirements.txt index 1bfec2b..6c9e38f 100644 Binary files a/requirements.txt and b/requirements.txt differ