Merge branch 'refs/heads/develop'

# Conflicts:
#	app/service/design/service.py
#	app/service/design_pre_processing/service.py
#	app/service/utils/oss_client.py
#	requirements.txt
This commit is contained in:
zhouchengrong
2024-10-20 09:36:38 +08:00
95 changed files with 8936 additions and 1062 deletions

6
.gitignore vendored
View File

@@ -120,10 +120,10 @@ dmypy.json
#runtime produce
test
seg_cache
logs
seg_result/
seg_result
*.png
uwsgi
*.yaml
*.yml
@@ -133,5 +133,7 @@ Dockerfile
app/logs
app/logs/*
*.log
*.jpg
/qodana.yaml
.pth
.pytorch
*.png

59
app/api/api_brighten.py Normal file
View File

@@ -0,0 +1,59 @@
import io
import json
import logging
import time
from PIL import ImageEnhance
from fastapi import APIRouter, HTTPException
from app.schemas.brighten import BrightenModel
from app.schemas.response_template import ResponseModel
from app.service.utils.oss_client import oss_get_image, oss_upload_image
router = APIRouter()
logger = logging.getLogger()
def increase_brightness(img, factor):
enhancer = ImageEnhance.Brightness(img)
bright_img = enhancer.enhance(factor)
return bright_img
@router.post("/brighten")
async def brighten(request_item: BrightenModel):
"""
创建一个具有以下参数的请求体:
- **image_url**: 提亮图片url
- **brighten_value**: 提高亮度的比重 亮度因子 1.0 表示原始亮度1.5 表示增加 50% 的亮度
示例参数:
{
"image_url": "aida-users/89/relight_image/3850e17b-3efd-4597-90ef-2a7bcd1a1a0b-0-89.png",
"brighten_value": 1.5
}
"""
try:
start_time = time.time()
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict())}")
image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
new_image = increase_brightness(image, request_item.brighten_value)
image_data = io.BytesIO()
new_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
brighten_url = f"{req.bucket_name}/{req.object_name}"
logger.info(f"run time is : {time.time() - start_time}")
except Exception as e:
logger.warning(f"brighten Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=brighten_url)
if __name__ == '__main__':
request_item = BrightenModel(image_url="aida-users/89/relight_image/3850e17b-3efd-4597-90ef-2a7bcd1a1a0b-0-89.png",
brighten_value=1.5)
image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
new_image = increase_brightness(image, request_item.brighten_value)
new_image.show()

View File

@@ -1,13 +1,15 @@
import json
import logging
import os
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel
from app.schemas.response_template import ResponseModel
from app.service.design.model_process_service import model_transpose
from app.service.design.service import generate
from app.service.design.utils.redis_utils import Redis
from app.service.design_batch.service import start_design_batch_generate
from app.service.design_fast.design_generate import design_generate
from app.service.design_fast.utils.redis_utils import Redis
router = APIRouter()
logger = logging.getLogger()
@@ -24,28 +26,28 @@ def design(request_data: DesignModel):
"basic": {
"body_point_test": {
"waistband_right": [
203,
249
200,
241
],
"hand_point_right": [
229,
343
223,
297
],
"waistband_left": [
119,
248
112,
241
],
"hand_point_left": [
97,
343
92,
305
],
"shoulder_left": [
108,
107
99,
116
],
"shoulder_right": [
212,
107
215,
116
]
},
"layer_order": true,
@@ -57,65 +59,33 @@ def design(request_data: DesignModel):
},
"items": [
{
"businessId": 255303,
"color": "139 148 156",
"image_id": 95159,
"businessId": 270372,
"color": "30 28 28",
"image_id": 69780,
"offset": [
0,
0
],
"path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png",
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
"seg_mask_url": "test/result.png",
"print": {
"single": {
"location": [
[
200.0,
200.0
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png"
],
"print_scale_list": [
1.0
]
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [
[
512.0,
512.0
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png"
],
"print_scale_list": [
1.0
]
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"element": {
"element_angle_list": [
0.0
],
"element_path_list": [
"aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png"
],
"element_scale_list": [
0.2731036750637755
],
"location": [
[
228.63694825464364,
406.4843844199667
]
]
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 10,
@@ -123,22 +93,101 @@ def design(request_data: DesignModel):
1.0,
1.0
],
"type": "Dress"
"type": "Trousers"
},
{
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"businessId": 270373,
"color": "30 28 28",
"image_id": 98243,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
"seg_mask_url": "test/result.png",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 11,
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"businessId": 270374,
"color": "172 68 68",
"image_id": 98244,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
"seg_mask_url": "test/result.png",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 12,
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
"image_id": 96090,
"type": "Body"
}
]
}
],
"process_id": "89"
"process_id": "83"
}
"""
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
# data = generate(request_data=request_data)
# logger.info(f"design response @@@@@@:{json.dumps(data)}")
#
try:
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = generate(request_data=request_data)
data = design_generate(request_data=request_data)
logger.info(f"design response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
@@ -193,3 +242,36 @@ def model_process(request_data: ModelProgressModel):
logger.warning(f"model_process Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
# ##############################################################
@router.post("/design_batch_generate")
async def design(file: UploadFile = File(...),
tasks_id: str = Form(...),
user_id: str = Form(...),
file_name: str = Form(...),
total: int = Form(...)
):
dbg_config = DBGConfigModel(
tasks_id=tasks_id,
user_id=user_id,
file_name=file_name,
total=total
)
contents = await file.read()
file_name = file.filename
await save_request_file(contents, file_name)
return await start_design_batch_generate(dbg_config, contents)
async def save_request_file(contents, file_name):
# 创建保存文件的目录(如果不存在)
save_dir = os.path.join(os.getcwd(), "design_batch", "request_data")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 处理文件
file_path = os.path.join(save_dir, file_name)
with open(file_path, "wb") as f:
f.write(contents)

View File

@@ -0,0 +1,38 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.image2sketch import Image2SketchModel
from app.schemas.response_template import ResponseModel
from app.service.lineart.service import LineArtService
router = APIRouter()
logger = logging.getLogger()
@router.post("/image2sketch")
def image2sketch(request_item: Image2SketchModel):
"""
创建一个具有以下参数的请求体:
- **image_url**: 提取图片url
- **default_style**: 原始、 1、2、3、4、5
- **sketch_bucket**: sketch保存的bucket
- **sketch_name**: sketch保存的object name
示例参数:
{
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
"default_style": 0,
"sketch_bucket": "test",
"sketch_name": "image2sketch/area_fill_img.png"
}
"""
try:
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = LineArtService(request_item)
result_url = service.get_result()
except Exception as e:
logger.warning(f"image2sketch Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=result_url)

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter
from app.api import api_test
from app.api import api_super_resolution
from app.api import api_generate_image
from app.api import api_attribute_retrieve
from app.api import api_design
from app.api import api_brighten
from app.api import api_chat_robot
from app.api import api_prompt_generation
from app.api import api_design
from app.api import api_design_pre_processing
from app.api import api_generate_image
from app.api import api_image2sketch
from app.api import api_prompt_generation
from app.api import api_super_resolution
from app.api import api_test
router = APIRouter()
@@ -20,3 +21,5 @@ router.include_router(api_design.router, tags=['design'], prefix="/api")
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")

View File

@@ -24,11 +24,11 @@ DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
# FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
SEG_CACHE_PATH = "../seg_cache/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
# FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
SEG_CACHE_PATH = "/seg_cache/"
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = "-dev" # 开发环境
@@ -64,7 +64,7 @@ RABBITMQ_PARAMS = {
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置

View File

@@ -0,0 +1,90 @@
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
201,
242
],
"hand_point_right": [
222,
312
],
"waistband_left": [
114,
243
],
"hand_point_left": [
94,
310
],
"shoulder_left": [
102,
116
],
"shoulder_right": [
211,
115
]
},
"layer_order": true,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 264931,
"color": "145 220 232",
"image_id": 96844,
"offset": [
0,
0
],
"path": "aida-users/87/sketch/2aa7aad5-74bb-41fa-9cdf-f06611b3e89a-2-87.png",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"type": "Dress"
},
{
"body_path": "aida-sys-image/models/female/79805ec3-3f01-466d-91e0-36028d079699.png",
"image_id": 95444,
"type": "Body"
}
]
}
],
"process_id": "87",
"tasks_id": ,
}
//用 openai jsonl
//

6
app/schemas/brighten.py Normal file
View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class BrightenModel(BaseModel):
image_url: str
brighten_value: float

View File

@@ -1,50 +1,6 @@
from pydantic import BaseModel
# class BodyPointModel(BaseModel):
# waistband_right: list[int]
# hand_point_right: list[int]
# waistband_left: list[int]
# hand_point_left: list[int]
# shoulder_left: list[int]
# shoulder_right: list[int]
#
#
# class BasicModel(BaseModel):
# body_point: BodyPointModel
# layer_order: bool
# scale_bag: float
# scale_earrings: float
# self_template: bool
# single_overall: str
# switch_category: str
# body_path: str
#
#
# class PrintModel(BaseModel):
# if_single: bool
# print_path_list: list[str]
#
#
# class ItemModel(BaseModel):
# color: str
# image_id: str
# offset: list[int]
# path: str
# print: PrintModel
# resize_scale: float
# type: str
#
#
# class CollocationModel(BaseModel):
# basic: BasicModel
# item: list[ItemModel]
#
#
# class DesignModel(BaseModel):
# object: list[CollocationModel]
# process_id: str
class DesignModel(BaseModel):
objects: list[dict]
process_id: str
@@ -56,3 +12,10 @@ class DesignProgressModel(BaseModel):
class ModelProgressModel(BaseModel):
model_path: str
class DBGConfigModel(BaseModel):
tasks_id: str
user_id: str
file_name: str
total: int

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class Image2SketchModel(BaseModel):
image_url: str
default_style: str
sketch_bucket: str
sketch_name: str

View File

@@ -1,771 +0,0 @@
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 493827,
"color": "127 61 21",
"elementId": 493827,
"icon": "none",
"image_id": 110201,
"offset": [
1,
1
],
"path": "aida-users/31/sketch/62302527-2910-4740-808d-2cb8221daa34-3-31.png",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Dress"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "27 25 23",
"icon": "none",
"image_id": 110202,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/skirt/0916000602.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Skirt"
},
{
"businessId": 493825,
"color": "229 214 200",
"elementId": 493825,
"icon": "none",
"image_id": 107101,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"businessId": 493824,
"color": "76 124 124",
"elementId": 493824,
"icon": "none",
"image_id": 104522,
"offset": [
1,
1
],
"path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "229 214 200",
"icon": "none",
"image_id": 110203,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0825001576.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"color": "76 124 124",
"icon": "none",
"image_id": 96071,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/skirt/903000097.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Skirt"
},
{
"color": "209 125 29",
"icon": "none",
"image_id": 93798,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/outwear/outwear_p4_561.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 493824,
"color": "209 125 29",
"elementId": 493824,
"icon": "none",
"image_id": 104522,
"offset": [
1,
1
],
"path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"color": "118 123 115",
"icon": "none",
"image_id": 110204,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0902000457.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"color": "118 123 115",
"icon": "none",
"image_id": 79259,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/826000094.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "127 61 21",
"icon": "none",
"image_id": 96038,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/dress/0902003549.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Dress"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 493822,
"color": "127 61 21",
"elementId": 493822,
"icon": "none",
"image_id": 62309,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/trousers/c37c2ea6-8955-4b40-8339-c737e672ca3d.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
},
{
"businessId": 493825,
"color": "118 123 115",
"elementId": 493825,
"icon": "none",
"image_id": 107101,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 493826,
"color": "127 61 21",
"elementId": 493826,
"icon": "none",
"image_id": 107105,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/Skirt/58710352-6301-450d-b69a-fb2922b5429a.png",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Skirt"
},
{
"color": "118 123 115",
"icon": "none",
"image_id": 79114,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/903000169.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"color": "229 214 200",
"icon": "none",
"image_id": 90573,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/outwear/0628000541.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "229 214 200",
"icon": "none",
"image_id": 110205,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/0916000217.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
},
{
"businessId": 493825,
"color": "209 125 29",
"elementId": 493825,
"icon": "none",
"image_id": 107101,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
}
],
"process_id": "6878547032381675"
}

View File

@@ -10,6 +10,7 @@ class Bottom(Clothing):
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
# dict(type='Segmentation'),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
dict(type='Scaling'),

View File

@@ -30,14 +30,15 @@ class Clothing(object):
image=self.result["front_image"],
# mask_image=self.result['front_mask_image'],
image_url=self.result['front_image_url'],
mask_url=self.result['front_mask_url'],
mask_url=self.result['mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url']
pattern_image_url=self.result['pattern_image_url'],
pattern_image=self.result['pattern_image']
)
layer.insert(front_layer)
@@ -47,14 +48,14 @@ class Clothing(object):
image=self.result["back_image"],
# mask_image=self.result['back_mask_image'],
image_url=self.result['back_image_url'],
mask_url=self.result['back_mask_url'],
mask_url=self.result['mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url']
pattern_image_url=self.result['pattern_image_url'],
)
layer.insert(back_layer)

View File

@@ -43,7 +43,8 @@ class ContourDetection(object):
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
result['front_mask'] = result['mask']
result['back_mask'] = result['mask']
return result
@staticmethod

View File

@@ -5,6 +5,7 @@ import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.utils.decorator import RunTime, ClassCallRunTime
from ..builder import PIPELINES
from ...utils.design_ensemble import get_keypoint_result
@@ -27,7 +28,7 @@ class KeypointDetection(object):
# self.client.close()
# print(f"client close time : {time.time() - start_time}")
# @ RunTime
# @ClassCallRunTime
def __call__(self, result):
# logging.info("KeypointDetection run ")
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新

View File

@@ -12,6 +12,7 @@ class LoadImageFromFile(object):
self.print_dict = print_dict
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# @ClassCallRunTime
def __call__(self, result):
result['image'], result['pre_mask'] = self.read_image(self.path)
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
@@ -45,15 +46,18 @@ class LoadImageFromFile(object):
@staticmethod
def read_image(image_path):
image_mask = None
# file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
if image.shape[:2] <= (50, 50):
# 计算新尺寸
new_size = (image.shape[1] * 2, image.shape[0] * 2)
# 调整大小
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
return image, image_mask

View File

@@ -1,3 +1,4 @@
import logging
import random
import cv2
@@ -7,13 +8,15 @@ from PIL import Image
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
logger = logging.getLogger()
@PIPELINES.register_module()
class Painting(object):
def __init__(self, painting_flag=True):
self.painting_flag = painting_flag
# @ RunTime
# @ClassCallRunTime
def __call__(self, result):
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
dim_image_h, dim_image_w = result['image'].shape[0:2]
@@ -86,7 +89,7 @@ class PrintPainting(object):
def __init__(self, print_flag=True):
self.print_flag = print_flag
# @ RunTime
# @ClassCallRunTime
def __call__(self, result):
single_print = result['print']['single']
overall_print = result['print']['overall']
@@ -236,7 +239,6 @@ class PrintPainting(object):
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
print(1)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)

View File

@@ -2,6 +2,7 @@ import math
import cv2
from app.service.utils.decorator import ClassCallRunTime
from ..builder import PIPELINES
@@ -10,7 +11,7 @@ class Scaling(object):
def __init__(self):
pass
# @ RunTime
# @ClassCallRunTime
def __call__(self, result):
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
# milvus_db_keypoint_cache

View File

@@ -1,14 +1,71 @@
import logging
import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
from ...utils.design_ensemble import get_seg_result
logger = logging.getLogger()
@PIPELINES.register_module()
class Segmentation(object):
def __init__(self, device='cpu', show=False, debug=None):
self.show = show
self.device = device
self.debug = debug
@ClassCallRunTime
def __call__(self, result):
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
red_mask = r > g
green_mask = g > r
# 创建红色和绿色掩码
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
result['seg_result'] = seg_result
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
# 处理前片后片
temp_front = seg_result == 1.0
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")
return False, None

View File

@@ -1,3 +1,4 @@
import io
import logging
import cv2
@@ -5,7 +6,9 @@ import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
from ..builder import PIPELINES
from ...utils.conversion_image import rgb_to_rgba
from ...utils.upload_image import upload_png_mask
@@ -17,32 +20,14 @@ class Split(object):
Split image into front and back layer according to the segmentation result
"""
# @ClassCallRunTime
# KNet
def __call__(self, result):
try:
if 'mask' not in result.keys():
raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.')
if 'seg_result' not in result.keys(): # 没过seg模型
result['front_mask'] = result['mask'].copy()
result['back_mask'] = np.zeros_like(result['mask'])
else:
temp_front = result['seg_result'] == 1
result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8))
temp_back = result['seg_result'] == 2
result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8))
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
if len(result['front_mask'].shape) > 2:
front_mask = result['front_mask'][0]
else:
front_mask = result['front_mask']
if len(result['back_mask'].shape) > 2:
back_mask = result['back_mask'][0]
else:
back_mask = result['back_mask']
# rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], front_mask + back_mask)
front_mask = result['front_mask']
back_mask = result['back_mask']
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
rgba_image = cv2.resize(rgba_image, new_size)
@@ -50,23 +35,45 @@ class Split(object):
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask)
result['front_image'], result["front_image_url"], _ = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask)
result['back_image'], result["back_image_url"], _ = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
result["back_mask_url"] = None
result['back_mask_image'] = None
# 创建中间图层
# result["back_mask_url"] = None
# result['back_mask_image'] = None
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
_, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -9,8 +9,8 @@ class Top(Clothing):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']),
# dict(type='ContourDetection'),
dict(type='Segmentation'),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),

View File

@@ -1,4 +1,7 @@
import concurrent.futures
import io
import cv2
from app.core.config import PRIORITY_DICT
from app.service.design.core.layer import Layer
@@ -6,6 +9,7 @@ from app.service.design.items import build_item
from app.service.design.utils.redis_utils import Redis
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
from app.service.utils.decorator import RunTime
from app.service.utils.oss_client import oss_upload_image
def process_item(item, layers):
@@ -23,7 +27,7 @@ def update_progress(process_id, total):
if int(progress) <= 100:
r.write(key=process_id, value=int(progress) + int(100 / total))
else:
r.write(key=process_id, value=100)
r.write(key=process_id, value=99)
return progress
elif total == 1:
r.write(key=process_id, value=100)
@@ -43,6 +47,7 @@ def final_progress(process_id):
@RunTime
def generate(request_data):
return_response = {}
return_png_mask = []
request_data = request_data.dict()
assert "process_id" in request_data.keys(), "Need process_id parameters"
@@ -55,14 +60,15 @@ def generate(request_data):
# 获取处理结果
for future in concurrent.futures.as_completed(futures):
obj = futures[future]
result = future.result()
return_response[obj] = result
return_response[obj] = future.result()[0]
return_png_mask.extend(future.result()[1])
# upload_results = process_images(return_png_mask)
final_progress(process_id)
return return_response
def process_object(cfg, process_id, total):
uploaded_images = []
basic_info = cfg.get('basic')
items_response = {
'layers': []
@@ -83,8 +89,17 @@ def process_object(cfg, process_id, total):
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
else:
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
# 上传所有图片
# for layer in layers:
# if 'image' in layer.keys() and layer['image'] is not None:
# uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url'], 'image_type': 'image'})
# if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None:
# uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url'], 'image_type': 'pattern_image'})
# if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None:
# uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url'], 'image_type': 'mask'})
layers, new_size = update_base_size_priority(layers, body_size)
# 合成
items_response['synthesis_url'] = synthesis(layers, body_size)
items_response['synthesis_url'] = synthesis(layers, new_size, basic_info)
for lay in layers:
items_response['layers'].append({
@@ -114,9 +129,10 @@ def process_object(cfg, process_id, total):
'position': None,
'priority': 0,
'image_url': item.result['front_image_url'],
'mask_url': item.result['front_mask_url'],
'mask_url': item.result['mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
})
items_response['layers'].append({
'image_category': f"{item.result['name']}_back",
@@ -124,11 +140,58 @@ def process_object(cfg, process_id, total):
'position': None,
'priority': 0,
'image_url': item.result['back_image_url'],
'mask_url': item.result['back_mask_url'],
'mask_url': item.result['mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
})
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
break
update_progress(process_id, total)
return items_response
return items_response, uploaded_images
@RunTime
def process_images(images):
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(upload_images, images))
# results = []
# for image in images:
# results.append(upload_images(image))
return results
# @RunTime
def upload_images(image_obj):
bucket_name = image_obj['image_url'].split("/", 1)[0]
object_name = image_obj['image_url'].split("/", 1)[1]
if image_obj['image_type'] == 'image' or image_obj['image_type'] == 'pattern_image':
image_data = io.BytesIO()
image_obj['image_obj'].save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return image_obj['image_url']
else:
mask_inverted = cv2.bitwise_not(image_obj['image_obj'])
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1])
return image_obj['image_url']
def update_base_size_priority(layers, size):
# 计算透明背景图片的宽度
min_x = min(info['position'][1] for info in layers)
x_list = []
for info in layers:
if info['image'] is not None:
x_list.append(info['position'][1] + info['image'].width)
max_x = max(x_list)
new_width = max_x - min_x
new_height = 700
# 更新坐标
for info in layers:
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
return layers, (new_width, new_height)

View File

@@ -59,14 +59,26 @@ def positioning(all_mask_shape, mask_shape, offset):
# @RunTime
def synthesis(data, size):
def synthesis(data, size, basic_info):
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
top_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
bottom_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
body_mask = None
for d in data:
if d['name'] == 'body':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
@@ -76,21 +88,27 @@ def synthesis(data, size):
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['position']
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
top_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front"]:
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['position']
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
bottom_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif bottom is False and top is False:
break
@@ -100,13 +118,13 @@ def synthesis(data, size):
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
# mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
# mask_alpha = Image.fromarray(mask_data)
# cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), test_image)
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
result_image = base_image

View File

@@ -0,0 +1,126 @@
import logging
import threading
from celery import Celery
from minio import Minio
from app.core.config import *
from app.service.design_batch.item import BodyItem, TopItem, BottomItem
from app.service.design_batch.utils.MQ import publish_status
from app.service.design_batch.utils.organize import organize_body, organize_clothing
from app.service.design_batch.utils.save_json import oss_upload_json
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
id_lock = threading.Lock()
celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://')
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def process_item(item, basic):
# 处理project中单个item
if item['type'] == "Body":
body_server = BodyItem(data=item, basic=basic, minio_client=minio_client)
item_data = body_server.process()
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
item_data = top_server.process()
else:
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
return item_data
def process_layer(item, layers):
# item处理结束后 对图层数据组装
if item['name'] == "mannequin":
body_layer = organize_body(item)
layers.append(body_layer)
return item['body_image'].size
else:
front_layer, back_layer = organize_clothing(item)
layers.append(front_layer)
layers.append(back_layer)
@celery_app.task
def batch_design(objects_data, tasks_id, json_name):
object_response = []
threads = []
active_threads = 0
lock = threading.Lock()
def process_object(step, object):
nonlocal active_threads
basic = object['basic']
items_response = {'layers': []}
if basic['single_overall'] == "overall":
item_results = []
for item in object['items']:
item_results.append(process_item(item, basic))
layers = []
body_size = None
for item in item_results:
body_size = process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
layers, new_size = update_base_size_priority(layers, body_size)
for lay in layers:
items_response['layers'].append({
'image_category': lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
})
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
else:
item_result = process_item(object['items'][0], basic)
items_response['layers'].append({
'image_category': f"{item_result['name']}_front",
'image_size': item_result['back_image'].size if item_result['back_image'] else None,
'position': None,
'priority': 0,
'image_url': item_result['front_image_url'],
'mask_url': item_result['mask_url'],
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
})
items_response['layers'].append({
'image_category': f"{item_result['name']}_back",
'image_size': item_result['front_image'].size if item_result['front_image'] else None,
'position': None,
'priority': 0,
'image_url': item_result['back_image_url'],
'mask_url': item_result['mask_url'],
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
})
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
with lock:
object_response.append(items_response)
publish_status(tasks_id, step + 1, items_response)
active_threads -= 1
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object))
threads.append(t)
t.start()
with lock:
active_threads += 1
for t in threads:
t.join()
oss_upload_json(minio_client, object_response, json_name)
publish_status(tasks_id, "ok", json_name)
return object_response

View File

@@ -0,0 +1,61 @@
from app.service.design_batch.pipeline import *
class BaseItem:
def __init__(self, data, basic):
self.result = data.copy()
self.result['name'] = data['type'].lower()
self.result.pop("type")
self.result.update(basic)
class TopItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.top_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.top_pipeline:
self.result = item(self.result)
return self.result
class BottomItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.bottom_pipeline = [
LoadImage(minio_client),
KeyPoint(),
ContourDetection(),
# Segmentation(),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.bottom_pipeline:
self.result = item(self.result)
return self.result
class BodyItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.top_pipeline = [
LoadBodyImage(minio_client),
]
def process(self):
for item in self.top_pipeline:
self.result = item(self.result)
return self.result

View File

@@ -0,0 +1,20 @@
from .color import Color
from .contour_detection import ContourDetection
from .keypoint import KeyPoint
from .keypoint import KeyPoint
from .loading import LoadImage, LoadBodyImage
from .print_painting import PrintPainting
from .scale import Scaling
from .segmentation import Segmentation
from .split import Split
__all__ = [
'LoadBodyImage', 'LoadImage',
'KeyPoint',
'ContourDetection',
'Segmentation',
'Color',
'PrintPainting',
'Scaling',
'Split'
]

View File

@@ -0,0 +1,62 @@
import logging
import cv2
import numpy as np
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class Color:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
dim_image_h, dim_image_w = result['image'].shape[0:2]
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
result['alpha'] = 100 / 255.0
return result
def get_gradient(self, bucket_name, object_name):
# 获取渐变色图案
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
return image
@staticmethod
def crop_image(image, image_size_h, image_size_w):
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
return pattern

View File

@@ -0,0 +1,37 @@
import cv2
import numpy as np
class ContourDetection:
def __call__(self, result):
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
# TODO 修复部分图片出现透明的情况 下版本上线
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
# Mask = cv2.bitwise_not(Mask)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
result['front_mask'] = result['mask']
result['back_mask'] = result['mask']
return result
@staticmethod
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour

View File

@@ -0,0 +1,114 @@
import logging
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.design_batch.utils.design_ensemble import get_keypoint_result
logger = logging.getLogger(__name__)
class KeyPoint:
name = "KeyPoint"
@classmethod
def get_name(cls):
return cls.name
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
result['clothes_keypoint'] = keypoint_cache
return result
@staticmethod
def infer_keypoint_result(result):
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
return keypoint_infer_result, site
@staticmethod
def save_keypoint_cache(keypoint_id, cache, site):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([cache.flatten(), zeros])
# 取消向量保存 直接拿结果
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": site,
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
if site == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
)
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e:
logger.info(f"search keypoint cache milvus error {e}")
return False

View File

@@ -0,0 +1,77 @@
import logging
import cv2
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class LoadBodyImage:
name = "LoadBodyImage"
def __init__(self, minio_client):
self.minio_client = minio_client
@classmethod
def get_name(cls):
return cls.name
def __call__(self, result):
result["name"] = "mannequin"
result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL")
return result
class LoadImage:
name = "LoadImage"
def __init__(self, minio_client):
self.minio_client = minio_client
@classmethod
def get_name(cls):
return cls.name
def __call__(self, result):
result['image'], result['pre_mask'] = self.read_image(result['path'])
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['keypoint'] = self.get_keypoint(result['name'])
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
return result
def read_image(self, image_path):
image_mask = None
image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
if image.shape[:2] <= (50, 50):
# 计算新尺寸
new_size = (image.shape[1] * 2, image.shape[0] * 2)
# 调整大小
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
return image, image_mask
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
keypoint = 'waistband'
elif name == 'bag':
keypoint = 'hand_point'
elif name == 'shoes':
keypoint = 'toe'
elif name == 'hairstyle':
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")
return keypoint

View File

@@ -0,0 +1,524 @@
import random
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image
class PrintPainting:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
result['single_image'] = None
result['print_image'] = None
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
if single_print['print_path_list']:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if element_print['element_path_list']:
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])):
image, image_mode = self.read_image(element_print['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask_)
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
print_background = img1_bg + img2_fg
return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
if print_trigger:
print_ = self.get_print(print_dict)
painting_dict['Trigger'] = not is_single
painting_dict['location'] = print_['location']
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not is_single:
self.random_seed = random.randint(0, 1000)
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
tile = None
if not trigger:
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
else:
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
if len(pattern.shape) == 2:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
if len(pattern.shape) == 3:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
return tile
def get_mask_inv(self, print_):
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
mask_inv = cv2.inRange(print_tile, lower, upper)
return mask_inv
else:
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
return mask_inv
@staticmethod
def printpaint(result, painting_dict, print_=False):
if print_ and painting_dict['Trigger']:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else:
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:
assert f'there must be parameter of location if choose IfSingle'
for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
print_image = cv2.add(img_bg, img_fg)
return print_image
def get_print(self, print_dict):
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
random.seed(self.random_seed)
# logging.info(f'overall print location : {location}')
# x_offset = random.randint(0, image.shape[0] - image_size_h)
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
x_offset = print_w - int(location[0][1] % print_w)
y_offset = print_w - int(location[0][0] % print_h)
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_low_high_lab(Lab_value, L=False):
if L:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
else:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
return high, low
@staticmethod
def img_rotate(image, angel, scale):
"""顺时针旋转图像任意角度
Args:
image (np.array): [原始图像]
angel (float): [逆时针旋转的角度]
Returns:
[array]: [旋转后的图像]
"""
h, w = image.shape[:2]
center = (w // 2, h // 2)
# if type(angel) is not int:
# angel = 0
M = cv2.getRotationMatrix2D(center, -angel, scale)
# 调整旋转后的图像长宽
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
M[0, 2] += (rotated_w - w) // 2
M[1, 2] += (rotated_h - h) // 2
# 旋转图像
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
# return rotated_img, (0, 0)
@staticmethod
def rotate_crop_image(img, angle, crop):
"""
angle: 旋转的角度
crop: 是否需要进行裁剪,布尔向量
"""
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
w, h = img.shape[:2]
# 旋转角度的周期是360°
angle %= 360
# 计算仿射变换矩阵
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
# 得到旋转后的图像
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
# 如果需要去除黑边
if crop:
# 裁剪角度的等效周期是180°
angle_crop = angle % 180
if angle > 90:
angle_crop = 180 - angle_crop
# 转化角度为弧度
theta = angle_crop * np.pi / 180
# 计算高宽比
hw_ratio = float(h) / float(w)
# 计算裁剪边长系数的分子项
tan_theta = np.tan(theta)
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
# 计算分母中和高宽比相关的项
r = hw_ratio if h > w else 1 / hw_ratio
# 计算分母项
denominator = r * tan_theta + 1
# 最终的边长系数
crop_mult = numerator / denominator
# 得到裁剪区域
w_crop = int(crop_mult * w)
h_crop = int(crop_mult * h)
x0 = int((w - w_crop) / 2)
y0 = int((h - h_crop) / 2)
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
return img_rotated
def read_image(self, image_url):
image = oss_get_image(oss_client=self.minio_client, bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
if image.shape[2] == 4:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image = Image.fromarray(image_rgb)
image_mode = "RGBA"
else:
image_mode = "RGB"
return image, image_mode
@staticmethod
def resize_and_crop(img, target_width, target_height):
# 获取原始图像的尺寸
original_height, original_width = img.shape[:2]
# 计算目标尺寸的宽高比
target_ratio = target_width / target_height
# 计算原始图像的宽高比
original_ratio = original_width / original_height
# 调整尺寸
if original_ratio > target_ratio:
# 原始图像更宽按高度resize然后裁剪宽度
new_height = target_height
new_width = int(original_width * (target_height / original_height))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪宽度
start_x = (new_width - target_width) // 2
cropped_img = resized_img[:, start_x:start_x + target_width]
else:
# 原始图像更高按宽度resize然后裁剪高度
new_width = target_width
new_height = int(original_height * (target_width / original_width))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪高度
start_y = (new_height - target_height) // 2
cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img

View File

@@ -0,0 +1,49 @@
import math
import cv2
class Scaling:
def __call__(self, result):
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
# milvus_db_keypoint_cache
distance_clo = math.sqrt(
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
+
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2
)
distance_bdy = math.sqrt(
(int(result['body_point_test'][result['keypoint'] + '_left'][0])
-
int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1
)
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'toe':
distance_bdy = math.sqrt(
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
+
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
)
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
Edge = cv2.Canny(Blur, 10, 200)
Edge = cv2.dilate(Edge, None)
Edge = cv2.erode(Edge, None)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
Max_contour = Contours[0]
x, y, w, h = cv2.boundingRect(Max_contour)
width = w
distance_clo = width
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'hand_point':
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
return result

View File

@@ -0,0 +1,70 @@
import logging
import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.design_batch.utils.design_ensemble import get_seg_result
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class Segmentation:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
red_mask = r > g
green_mask = g > r
# 创建红色和绿色掩码
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
result['seg_result'] = seg_result
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
# 处理前片后片
temp_front = seg_result == 1.0
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"seg_cache/{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"seg_cache/{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")
return False, None

View File

@@ -0,0 +1,74 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.design_batch.utils.conversion_image import rgb_to_rgba
from app.service.design_batch.utils.upload_image import upload_png_mask
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
class Split(object):
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
front_mask = result['front_mask']
back_mask = result['back_mask']
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
rgba_image = cv2.resize(rgba_image, new_size)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
# result["back_mask_url"] = None
# result['back_mask_image'] = None
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -0,0 +1,11 @@
import json
from app.service.design_batch.design_batch_celery import batch_design
from app.service.design_batch.utils.MQ import publish_status
async def start_design_batch_generate(data, file):
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id)
print(generate_clothes_task)
publish_status(data.tasks_id, "0/100", "")
return {"task_id": data.tasks_id}

View File

@@ -0,0 +1,162 @@
from app.service.design_batch.design_batch_celery import batch_design
if __name__ == '__main__':
data = {
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
200,
241
],
"hand_point_right": [
223,
297
],
"waistband_left": [
112,
241
],
"hand_point_left": [
92,
305
],
"shoulder_left": [
99,
116
],
"shoulder_right": [
215,
116
]
},
"layer_order": True,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": True,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 270372,
"color": "30 28 28",
"image_id": 69780,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
},
{
"businessId": 270373,
"color": "30 28 28",
"image_id": 98243,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 11,
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"businessId": 270374,
"color": "172 68 68",
"image_id": 98244,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 12,
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
"image_id": 96090,
"type": "Body"
}
]
}
],
"process_id": "83"
}
task_id = 1
json_name = "test.json"
batch_design.delay(data['objects'], task_id, json_name)

View File

@@ -0,0 +1,17 @@
import json
import pika
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213'))
channel = connection.channel()
channel.queue_declare(queue='DesignBatch', durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key='DesignBatch',
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()

View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File conversion_image.py
@Author :周成融
@Date 2023/8/21 10:40:29
@detail
"""
import numpy as np
# def rgb_to_rgba(rgb_size, rgb_image, mask):
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
# # 创建四通道的结果图像
# rgba_image = np.dstack((rgb_image, alpha_channel))
# alpha_channel = np.where(mask > 0, 255, 0)
# # 更新RGBA图像的透明度通道
# rgba_image[:, :, 3] = alpha_channel
# return rgba_image
def rgb_to_rgba(rgb_image, mask):
# 创建全透明的alpha通道
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
# 合并RGB图像和alpha通道
rgba_image = np.dstack((rgb_image, alpha_channel))
return rgba_image
if __name__ == '__main__':
image = open("")

View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File design_ensemble.py
@Author :周成融
@Date 2023/8/16 19:36:21
@detail :发起请求 获取推理结果
"""
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import *
"""
keypoint
预处理 推理 后处理
"""
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
# @ RunTime
# 推理
def get_keypoint_result(image, site):
keypoint_result = None
try:
image, scale_factor = keypoint_preprocess(image)
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
inference_output = torch.from_numpy(results.as_numpy(f'output'))
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
except Exception as e:
logging.warning(f"get_keypoint_result : {e}")
return keypoint_result
def keypoint_postprocess(output, scale_factor):
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
segment_result = max_coords.numpy()
scale_factor = [1 / x for x in scale_factor[::-1]]
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
"""
seg
预处理 推理 后处理
"""
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# @ RunTime
def get_seg_result(image_id, image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
def key_point_show(image_path, key_point_result=None):
img = cv2.imread(image_path)
points_list = key_point_result
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for point in points_list:
cv2.circle(img, point[::-1], point_size, point_color, thickness)
cv2.imshow("0", img)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
# a = get_seg_result(1, image)
print(a)

View File

@@ -0,0 +1,77 @@
import cv2
from app.core.config import PRIORITY_DICT
def organize_body(layer):
body_layer = dict(priority=0,
name=layer["name"].lower(),
image=layer['body_image'],
image_url=layer['body_path'],
mask_image=None,
mask_url=None,
sacle=1,
# mask=layer['body_mask'],
position=(0, 0))
return body_layer
def organize_clothing(layer):
# 起始坐标
start_point = calculate_start_point(layer['keypoint'], layer['scale'], layer['clothes_keypoint'], layer['body_point_test'], layer["offset"], layer["resize_scale"])
# 前片数据
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
name=f'{layer["name"].lower()}_front',
image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'],
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
pattern_image=layer['pattern_image']
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
name=f'{layer["name"].lower()}_back',
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'],
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
)
return front_layer, back_layer
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
"""
Align left
Args:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale + offset
y' = x_body - x1 * scale + offset
"""
side_indicator = f'{keypoint_type}_left'
start_point = (
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
)
return start_point

View File

@@ -0,0 +1,30 @@
import logging
from app.service.design_fast.utils.redis_utils import Redis
logger = logging.getLogger(__name__)
def update_progress(process_id, total):
# logger.info(f"{process_id} , {total}")
r = Redis()
progress = r.read(key=process_id)
if progress and total != 1:
if int(progress) <= 100:
r.write(key=process_id, value=int(progress) + int(100 / total))
else:
r.write(key=process_id, value=99)
return progress
elif total == 1:
r.write(key=process_id, value=100)
return progress
else:
r.write(key=process_id, value=int(100 / total))
return progress
def final_progress(process_id):
r = Redis()
progress = r.read(key=process_id)
r.write(key=process_id, value=100)
return progress

View File

@@ -0,0 +1,99 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -0,0 +1,13 @@
import json
import logging
logger = logging.getLogger()
def oss_upload_json(oss_client, json_data, object_name):
try:
with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file:
json.dump(json_data, file, indent=4)
oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}")
except Exception as e:
logger.warning(str(e))

View File

@@ -0,0 +1,197 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File synthesis_item.py
@Author :周成融
@Date 2023/8/26 14:13:04
@detail
"""
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
def positioning(all_mask_shape, mask_shape, offset):
all_start = 0
all_end = 0
mask_start = 0
mask_end = 0
if offset == 0:
all_start = 0
all_end = min(all_mask_shape, mask_shape)
mask_start = 0
mask_end = min(all_mask_shape, mask_shape)
elif offset > 0:
all_start = min(offset, all_mask_shape)
all_end = min(offset + mask_shape, all_mask_shape)
mask_start = 0
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
elif offset < 0:
if abs(offset) > mask_shape:
all_start = 0
all_end = 0
else:
all_start = 0
if mask_shape - abs(offset) > all_mask_shape:
all_end = min(mask_shape - abs(offset), all_mask_shape)
else:
all_end = mask_shape - abs(offset)
if abs(offset) > mask_shape:
mask_start = mask_shape
mask_end = mask_shape
else:
mask_start = abs(offset)
if mask_shape - abs(offset) >= all_mask_shape:
mask_end = all_mask_shape + abs(offset)
else:
mask_end = mask_shape
return all_start, all_end, mask_start, mask_end
# @RunTime
def synthesis(data, size, basic_info):
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
body_mask = None
for d in data:
if d['name'] == 'body' or d['name'] == 'mannequin':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
result_image = base_image
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
result_image = front_image
if back_image:
result_image.paste(back_image, (0, 0), back_image)
# with io.BytesIO() as output:
# result_image.save(output, format='PNG')
# data = output.getvalue()
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# oss upload
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
def update_base_size_priority(layers, size):
# 计算透明背景图片的宽度
min_x = min(info['position'][1] for info in layers)
x_list = []
for info in layers:
if info['image'] is not None:
x_list.append(info['position'][1] + info['image'].width)
max_x = max(x_list)
new_width = max_x - min_x
new_height = 700
# 更新坐标
for info in layers:
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
return layers, (new_width, new_height)

View File

@@ -13,11 +13,11 @@ import logging
import cv2
from app.core.config import *
from app.service.utils.oss_client import oss_upload_image
from app.service.utils.new_oss_client import oss_upload_image
# @RunTime
def upload_png_mask(front_image, object_name, mask=None):
def upload_png_mask(minio_client, front_image, object_name, mask=None):
try:
mask_url = None
if mask is not None:
@@ -25,20 +25,14 @@ def upload_png_mask(front_image, object_name, mask=None):
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
# image_bytes = io.BytesIO()
# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
# image_bytes.seek(0)
# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}"
# oss upload ####################
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
image_data = io.BytesIO()
front_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
return front_image, image_url, mask_url
except Exception as e:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
class BaseItem:
def __init__(self, data, basic):
self.result = data.copy()
self.result['name'] = data['type'].lower()
self.result.pop("type")
self.result.update(basic)
class TopItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.top_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.top_pipeline:
self.result = item(self.result)
return self.result
class BottomItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.bottom_pipeline = [
LoadImage(minio_client),
KeyPoint(),
ContourDetection(),
# Segmentation(),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.bottom_pipeline:
self.result = item(self.result)
return self.result
class BodyItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.top_pipeline = [
LoadBodyImage(minio_client),
]
def process(self):
for item in self.top_pipeline:
self.result = item(self.result)
return self.result

View File

@@ -0,0 +1,20 @@
from .color import Color
from .contour_detection import ContourDetection
from .keypoint import KeyPoint
from .keypoint import KeyPoint
from .loading import LoadImage, LoadBodyImage
from .print_painting import PrintPainting
from .scale import Scaling
from .segmentation import Segmentation
from .split import Split
__all__ = [
'LoadBodyImage', 'LoadImage',
'KeyPoint',
'ContourDetection',
'Segmentation',
'Color',
'PrintPainting',
'Scaling',
'Split'
]

View File

@@ -0,0 +1,62 @@
import logging
import cv2
import numpy as np
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class Color:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
dim_image_h, dim_image_w = result['image'].shape[0:2]
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
result['alpha'] = 100 / 255.0
return result
def get_gradient(self, bucket_name, object_name):
# 获取渐变色图案
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
return image
@staticmethod
def crop_image(image, image_size_h, image_size_w):
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
return pattern

View File

@@ -0,0 +1,37 @@
import cv2
import numpy as np
class ContourDetection:
def __call__(self, result):
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
# TODO 修复部分图片出现透明的情况 下版本上线
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
# Mask = cv2.bitwise_not(Mask)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
result['front_mask'] = result['mask']
result['back_mask'] = result['mask']
return result
@staticmethod
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour

View File

@@ -0,0 +1,116 @@
import logging
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
from app.service.utils.decorator import ClassCallRunTime, RunTime
logger = logging.getLogger(__name__)
class KeyPoint:
name = "KeyPoint"
@classmethod
def get_name(cls):
return cls.name
@ClassCallRunTime
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
# keypoint_cache = self.keypoint_cache(result, site)
keypoint_cache = False
# 取消向量查询 直接过模型推理
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
result['clothes_keypoint'] = keypoint_cache
return result
@staticmethod
def infer_keypoint_result(result):
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
return keypoint_infer_result, site
@staticmethod
def save_keypoint_cache(keypoint_id, cache, site):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([cache.flatten(), zeros])
# 取消向量保存 直接拿结果
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": site,
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
if site == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
)
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e:
logger.info(f"search keypoint cache milvus error {e}")
return False

View File

@@ -0,0 +1,80 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class LoadBodyImage:
name = "LoadBodyImage"
def __init__(self, minio_client):
self.minio_client = minio_client
@classmethod
def get_name(cls):
return cls.name
def __call__(self, result):
result["name"] = "mannequin"
result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL")
return result
class LoadImage:
name = "LoadImage"
def __init__(self, minio_client):
self.minio_client = minio_client
@classmethod
def get_name(cls):
return cls.name
def __call__(self, result):
result['image'], result['pre_mask'] = self.read_image(result['path'])
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['keypoint'] = self.get_keypoint(result['name'])
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
return result
def read_image(self, image_path):
image_mask = None
image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
if image.shape[:2] <= (50, 50):
# 计算新尺寸
new_size = (image.shape[1] * 2, image.shape[0] * 2)
# 调整大小
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
return image, image_mask
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
keypoint = 'waistband'
elif name == 'bag':
keypoint = 'hand_point'
elif name == 'shoes':
keypoint = 'toe'
elif name == 'hairstyle':
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")
return keypoint

View File

@@ -0,0 +1,524 @@
import random
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image
class PrintPainting:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
result['single_image'] = None
result['print_image'] = None
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
if single_print['print_path_list']:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if element_print['element_path_list']:
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])):
image, image_mode = self.read_image(element_print['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask_)
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
print_background = img1_bg + img2_fg
return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
if print_trigger:
print_ = self.get_print(print_dict)
painting_dict['Trigger'] = not is_single
painting_dict['location'] = print_['location']
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not is_single:
self.random_seed = random.randint(0, 1000)
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
tile = None
if not trigger:
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
else:
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
if len(pattern.shape) == 2:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
if len(pattern.shape) == 3:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
return tile
def get_mask_inv(self, print_):
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
mask_inv = cv2.inRange(print_tile, lower, upper)
return mask_inv
else:
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
return mask_inv
@staticmethod
def printpaint(result, painting_dict, print_=False):
if print_ and painting_dict['Trigger']:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else:
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:
assert f'there must be parameter of location if choose IfSingle'
for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
print_image = cv2.add(img_bg, img_fg)
return print_image
def get_print(self, print_dict):
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
random.seed(self.random_seed)
# logging.info(f'overall print location : {location}')
# x_offset = random.randint(0, image.shape[0] - image_size_h)
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
x_offset = print_w - int(location[0][1] % print_w)
y_offset = print_w - int(location[0][0] % print_h)
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_low_high_lab(Lab_value, L=False):
if L:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
else:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
return high, low
@staticmethod
def img_rotate(image, angel, scale):
"""顺时针旋转图像任意角度
Args:
image (np.array): [原始图像]
angel (float): [逆时针旋转的角度]
Returns:
[array]: [旋转后的图像]
"""
h, w = image.shape[:2]
center = (w // 2, h // 2)
# if type(angel) is not int:
# angel = 0
M = cv2.getRotationMatrix2D(center, -angel, scale)
# 调整旋转后的图像长宽
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
M[0, 2] += (rotated_w - w) // 2
M[1, 2] += (rotated_h - h) // 2
# 旋转图像
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
# return rotated_img, (0, 0)
@staticmethod
def rotate_crop_image(img, angle, crop):
"""
angle: 旋转的角度
crop: 是否需要进行裁剪,布尔向量
"""
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
w, h = img.shape[:2]
# 旋转角度的周期是360°
angle %= 360
# 计算仿射变换矩阵
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
# 得到旋转后的图像
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
# 如果需要去除黑边
if crop:
# 裁剪角度的等效周期是180°
angle_crop = angle % 180
if angle > 90:
angle_crop = 180 - angle_crop
# 转化角度为弧度
theta = angle_crop * np.pi / 180
# 计算高宽比
hw_ratio = float(h) / float(w)
# 计算裁剪边长系数的分子项
tan_theta = np.tan(theta)
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
# 计算分母中和高宽比相关的项
r = hw_ratio if h > w else 1 / hw_ratio
# 计算分母项
denominator = r * tan_theta + 1
# 最终的边长系数
crop_mult = numerator / denominator
# 得到裁剪区域
w_crop = int(crop_mult * w)
h_crop = int(crop_mult * h)
x0 = int((w - w_crop) / 2)
y0 = int((h - h_crop) / 2)
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
return img_rotated
def read_image(self, image_url):
image = oss_get_image(oss_client=self.minio_client, bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
if image.shape[2] == 4:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image = Image.fromarray(image_rgb)
image_mode = "RGBA"
else:
image_mode = "RGB"
return image, image_mode
@staticmethod
def resize_and_crop(img, target_width, target_height):
# 获取原始图像的尺寸
original_height, original_width = img.shape[:2]
# 计算目标尺寸的宽高比
target_ratio = target_width / target_height
# 计算原始图像的宽高比
original_ratio = original_width / original_height
# 调整尺寸
if original_ratio > target_ratio:
# 原始图像更宽按高度resize然后裁剪宽度
new_height = target_height
new_width = int(original_width * (target_height / original_height))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪宽度
start_x = (new_width - target_width) // 2
cropped_img = resized_img[:, start_x:start_x + target_width]
else:
# 原始图像更高按宽度resize然后裁剪高度
new_width = target_width
new_height = int(original_height * (target_width / original_width))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪高度
start_y = (new_height - target_height) // 2
cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img

View File

@@ -0,0 +1,49 @@
import math
import cv2
class Scaling:
def __call__(self, result):
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
# milvus_db_keypoint_cache
distance_clo = math.sqrt(
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
+
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2
)
distance_bdy = math.sqrt(
(int(result['body_point_test'][result['keypoint'] + '_left'][0])
-
int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1
)
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'toe':
distance_bdy = math.sqrt(
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
+
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
)
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
Edge = cv2.Canny(Blur, 10, 200)
Edge = cv2.dilate(Edge, None)
Edge = cv2.erode(Edge, None)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
Max_contour = Contours[0]
x, y, w, h = cv2.boundingRect(Max_contour)
width = w
distance_clo = width
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'hand_point':
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
return result

View File

@@ -0,0 +1,85 @@
import logging
import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class Segmentation:
def __init__(self, minio_client):
self.minio_client = minio_client
@ClassCallRunTime
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
red_mask = r > g
green_mask = g > r
# 创建红色和绿色掩码
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
# 处理前片后片
temp_front = seg_result == 1.0
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")
return False, None

View File

@@ -0,0 +1,74 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.upload_image import upload_png_mask
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
class Split(object):
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
front_mask = result['front_mask']
back_mask = result['back_mask']
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
rgba_image = cv2.resize(rgba_image, new_size)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
# result["back_mask_url"] = None
# result['back_mask_image'] = None
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File conversion_image.py
@Author :周成融
@Date 2023/8/21 10:40:29
@detail
"""
import numpy as np
# def rgb_to_rgba(rgb_size, rgb_image, mask):
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
# # 创建四通道的结果图像
# rgba_image = np.dstack((rgb_image, alpha_channel))
# alpha_channel = np.where(mask > 0, 255, 0)
# # 更新RGBA图像的透明度通道
# rgba_image[:, :, 3] = alpha_channel
# return rgba_image
def rgb_to_rgba(rgb_image, mask):
# 创建全透明的alpha通道
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
# 合并RGB图像和alpha通道
rgba_image = np.dstack((rgb_image, alpha_channel))
return rgba_image
if __name__ == '__main__':
image = open("")

View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File design_ensemble.py
@Author :周成融
@Date 2023/8/16 19:36:21
@detail :发起请求 获取推理结果
"""
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import *
"""
keypoint
预处理 推理 后处理
"""
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
# @ RunTime
# 推理
def get_keypoint_result(image, site):
keypoint_result = None
try:
image, scale_factor = keypoint_preprocess(image)
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
inference_output = torch.from_numpy(results.as_numpy(f'output'))
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
except Exception as e:
logging.warning(f"get_keypoint_result : {e}")
return keypoint_result
def keypoint_postprocess(output, scale_factor):
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
segment_result = max_coords.numpy()
scale_factor = [1 / x for x in scale_factor[::-1]]
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
"""
seg
预处理 推理 后处理
"""
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# @ RunTime
def get_seg_result(image_id, image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
def key_point_show(image_path, key_point_result=None):
img = cv2.imread(image_path)
points_list = key_point_result
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for point in points_list:
cv2.circle(img, point[::-1], point_size, point_color, thickness)
cv2.imshow("0", img)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
# a = get_seg_result(1, image)
print(a)

View File

@@ -0,0 +1,77 @@
import cv2
from app.core.config import PRIORITY_DICT
def organize_body(layer):
body_layer = dict(priority=0,
name=layer["name"].lower(),
image=layer['body_image'],
image_url=layer['body_path'],
mask_image=None,
mask_url=None,
sacle=1,
# mask=layer['body_mask'],
position=(0, 0))
return body_layer
def organize_clothing(layer):
# 起始坐标
start_point = calculate_start_point(layer['keypoint'], layer['scale'], layer['clothes_keypoint'], layer['body_point_test'], layer["offset"], layer["resize_scale"])
# 前片数据
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
name=f'{layer["name"].lower()}_front',
image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'],
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
pattern_image=layer['pattern_image']
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
name=f'{layer["name"].lower()}_back',
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'],
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
)
return front_layer, back_layer
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
"""
Align left
Args:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale + offset
y' = x_body - x1 * scale + offset
"""
side_indicator = f'{keypoint_type}_left'
start_point = (
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
)
return start_point

View File

@@ -0,0 +1,30 @@
import logging
from app.service.design_fast.utils.redis_utils import Redis
logger = logging.getLogger(__name__)
def update_progress(process_id, total):
# logger.info(f"{process_id} , {total}")
r = Redis()
progress = r.read(key=process_id)
if progress and total != 1:
if int(progress) <= 100:
r.write(key=process_id, value=int(progress) + int(100 / total))
else:
r.write(key=process_id, value=99)
return progress
elif total == 1:
r.write(key=process_id, value=100)
return progress
else:
r.write(key=process_id, value=int(100 / total))
return progress
def final_progress(process_id):
r = Redis()
progress = r.read(key=process_id)
r.write(key=process_id, value=100)
return progress

View File

@@ -0,0 +1,99 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -0,0 +1,199 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File synthesis_item.py
@Author :周成融
@Date 2023/8/26 14:13:04
@detail
"""
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
def positioning(all_mask_shape, mask_shape, offset):
all_start = 0
all_end = 0
mask_start = 0
mask_end = 0
if offset == 0:
all_start = 0
all_end = min(all_mask_shape, mask_shape)
mask_start = 0
mask_end = min(all_mask_shape, mask_shape)
elif offset > 0:
all_start = min(offset, all_mask_shape)
all_end = min(offset + mask_shape, all_mask_shape)
mask_start = 0
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
elif offset < 0:
if abs(offset) > mask_shape:
all_start = 0
all_end = 0
else:
all_start = 0
if mask_shape - abs(offset) > all_mask_shape:
all_end = min(mask_shape - abs(offset), all_mask_shape)
else:
all_end = mask_shape - abs(offset)
if abs(offset) > mask_shape:
mask_start = mask_shape
mask_end = mask_shape
else:
mask_start = abs(offset)
if mask_shape - abs(offset) >= all_mask_shape:
mask_end = all_mask_shape + abs(offset)
else:
mask_end = mask_shape
return all_start, all_end, mask_start, mask_end
# @RunTime
def synthesis(data, size, basic_info):
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
body_mask = None
for d in data:
if d['name'] == 'body' or d['name'] == 'mannequin':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
result_image = base_image
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
result_image = front_image
if back_image:
result_image.paste(back_image, (0, 0), back_image)
# with io.BytesIO() as output:
# result_image.save(output, format='PNG')
# data = output.getvalue()
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# oss upload
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
def update_base_size_priority(layers, size):
# 计算透明背景图片的宽度
min_x = min(info['position'][1] for info in layers)
x_list = []
new_height = 700
for info in layers:
if info['image'] is not None:
x_list.append(info['position'][1] + info['image'].width)
if info['name'] == 'mannequin':
new_height = info['image'].height
max_x = max(x_list)
new_width = max_x - min_x
# 更新坐标
for info in layers:
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
return layers, (new_width, new_height)

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File upload_image.py
@Author :周成融
@Date 2023/8/28 13:49:20
@detail
"""
import io
import logging
import cv2
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
# @RunTime
def upload_png_mask(minio_client, front_image, object_name, mask=None):
try:
mask_url = None
if mask is not None:
mask_inverted = cv2.bitwise_not(mask)
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
image_data = io.BytesIO()
front_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
return front_image, image_url, mask_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -5,13 +5,16 @@ import cv2
import numpy as np
import torch
import tritonclient.grpc as grpcclient
from pymilvus import MilvusClient
from urllib3.exceptions import ResponseError
from app.core.config import *
from app.schemas.pre_processing import DesignPreProcessingModel
from app.service.design.utils.design_ensemble import get_keypoint_result
from app.service.design_fast.utils.design_ensemble import get_seg_result, get_keypoint_result
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
class DesignPreprocessing:
# def __init__(self):
@@ -20,19 +23,19 @@ class DesignPreprocessing:
# @ RunTime
def pipeline(self, image_list):
sketches_list = self.read_image(image_list)
logging.info("read image success")
# logging.info("read image success")
bounding_box_sketches_list = self.bounding_box(sketches_list)
logging.info("bounding box image success")
# logging.info("bounding box image success")
# super_resolution_list = self.super_resolution(bounding_box_sketches_list)
# logging.info("super_resolution_list image success")
infer_sketches_list = self.infer_image(bounding_box_sketches_list)
logging.info("infer image success")
# logging.info("infer image success")
result = self.composing_image(infer_sketches_list)
logging.info("Replenish white edge image success")
# logging.info("Replenish white edge image success")
for d in result:
if 'image_obj' in d:
@@ -59,6 +62,7 @@ class DesignPreprocessing:
def bounding_box(self, image_list):
for item in image_list:
image = item['image_obj']
height, width = image.shape[:2]
# 使用Canny边缘检测来检测物体的轮廓
edges = cv2.Canny(image, 50, 150)
# 查找轮廓
@@ -82,16 +86,25 @@ class DesignPreprocessing:
if len(contours) > 0:
cropped_image = image[y_min:y_max, x_min:x_max]
item['obj'] = cropped_image # 新shape图像
# 取消直接覆盖新增size判断
# try:
# # 覆盖到minio
# image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes()
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
# print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
# except ResponseError as err:
# print(f"Error: {err}")
else:
item['obj'] = image
padding_top = max(20 - y_min, 0)
padding_bottom = max(20 - (height - y_max), 0)
padding_left = max(20 - x_min, 0)
padding_right = max(20 - (width - x_max), 0)
# 添加padding
padded_image = cv2.copyMakeBorder(
image,
padding_top,
padding_bottom,
padding_left,
padding_right,
cv2.BORDER_CONSTANT,
value=(255, 255, 255)
)
item['obj'] = padded_image
return image_list
def super_resolution(self, image_list):
@@ -99,7 +112,7 @@ class DesignPreprocessing:
# 判断 两边是否同时都小于512 因为此处做四倍超分
if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512:
# 如果任意一边小于256则超分
if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256:
if item['obj'].shape[0] <= 200 or item['obj'].shape[1] <= 200:
# 超分
img = item['obj'].astype(np.float32) / 255.
sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1))
@@ -124,13 +137,14 @@ class DesignPreprocessing:
bucket_name = item['image_url'].split("/", 1)[0]
object_name = item['image_url'].split("/", 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
except ResponseError as err:
print(f"Error: {err}")
logging.warning(f"Error: {err}")
return image_list
# @ RunTime
def infer_image(self, image_list):
seg_result = None
for sketch in image_list:
# 小写
image_category = sketch['image_category'].lower()
@@ -138,6 +152,15 @@ class DesignPreprocessing:
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 推理得到keypoint
sketch['keypoint_result'] = self.keypoint_cache(sketch)
if sketch['site'] == 'up':
_, seg_cache = self.load_seg_result(sketch['image_id'])
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(sketch["image_id"], sketch['obj'])[0]
self.save_seg_result(seg_result, sketch['image_id'])
logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_result.shape}")
else:
logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_cache.shape}")
if IF_DEBUG_SHOW:
debug_show_image = sketch['obj'].copy()
@@ -149,6 +172,7 @@ class DesignPreprocessing:
points_list.append((int(i[1]), int(i[0])))
for point in points_list:
cv2.circle(debug_show_image, point, point_size, point_color, thickness)
cv2.imshow("seg_result", seg_result)
cv2.imshow("", debug_show_image)
cv2.waitKey(0)
# # 关键点在上部则推理seg
@@ -236,58 +260,37 @@ class DesignPreprocessing:
return image_list
@staticmethod
def select_seg_result(image_id, image_obj):
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
# 如果shape不匹配 返回false
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
return result
else:
return False
except FileNotFoundError as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logging.info("文件不存在")
return False, None
except Exception as e:
logging.warning(f"加载失败: {e}")
return False, None
@staticmethod
def search_seg_result(image_id, ori_shape):
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
# collection.load()
# start_time = time.time()
# res = collection.query(
# expr=f"seg_id == {image_id}",
# offset=0,
# limit=10,
# output_fields=["seg_cache"],
# )
# logging.info(f"search seg cache time {time.time() - start_time}")
# if len(res):
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
# array_2d_exact = array_2d_exact.squeeze().numpy()
# return array_2d_exact
# else:
return False
np.save(file_path, seg_result)
logging.info(f"保存成功,{os.path.abspath(file_path)}")
except Exception as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
logging.warning(f"保存失败: {e}")
def keypoint_cache(self, sketch):
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# collection.load()
start_time = time.time()
# res = collection.query(
# expr=f"keypoint_id == {sketch['image_id']}",
# offset=0,
# limit=1,
# output_fields=["keypoint_cache", "keypoint_site"],
# )
res = []
logging.info(f"search keypoint time : {time.time() - start_time}")
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = sketch['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch)
@@ -348,7 +351,7 @@ class DesignPreprocessing:
]
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
@@ -362,9 +365,9 @@ if __name__ == '__main__':
data = {
"sketches": [
{
"image_category": "dress",
"image_id": "107903",
"image_url": "aida-sys-image/images/female/dress/0628000000.jpg"
"image_category": "blouse",
"image_id": "123123123",
"image_url": "test/0628000198.jpg"
}
]
}

View File

@@ -0,0 +1,45 @@
import os
from minio import Minio
from minio.error import S3Error
MINIO_URL = "www.minio.aida.com.hk:12024"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# 配置MinIO客户端
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# 下载函数
def download_folder(bucket_name, folder_name, local_dir):
try:
# 确保本地目录存在
if not os.path.exists(local_dir):
os.makedirs(local_dir)
# 遍历MinIO中的文件
objects = minio_client.list_objects(bucket_name, prefix=folder_name, recursive=True)
for obj in objects:
# 构造本地文件路径
local_file_path = os.path.join(local_dir, obj.object_name[len(folder_name):])
local_file_dir = os.path.dirname(local_file_path)
# 确保本地目录存在
if not os.path.exists(local_file_dir):
os.makedirs(local_file_dir)
# 下载文件
minio_client.fget_object(bucket_name, obj.object_name, local_file_path)
print(f"Downloaded {obj.object_name} to {local_file_path}")
except S3Error as e:
print(f"Error occurred: {e}")
# 使用示例
bucket_name = "test" # 替换成你的bucket名称
folder_name = "checkpoints/" # 权重文件夹的路径
local_dir = "app/service/image2sketch/checkpoints" # 替换成你希望保存到的本地目录
download_folder(bucket_name, folder_name, local_dir)

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 376 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

View File

@@ -0,0 +1,89 @@
import os
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from .models import create_model
def tensor2im(input_image, imtype=np.uint8):
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def save_image(image_numpy, image_path, w, h, aspect_ratio=1.0):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
image_pil = Image.fromarray(image_numpy)
image_pil = image_pil.resize((w, h))
image_pil.save(image_path)
def save_img(image_tensor, w, h, filename):
image_pil = tensor2im(image_tensor)
save_image(image_pil, filename, w, h, aspect_ratio=1.0)
print("Image saved as {}".format(filename))
def load_img(filepath):
img = Image.open(filepath).convert('L')
# print(img.size)
width = img.size[0]
height = img.size[1]
# img = img.resize((512, 512), Image.BICUBIC)
return img, width, height
if __name__ == '__main__':
img_A = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testA/real_Dress_732caedc416a0cbfedd0e6528040eac7.jpg_Img.jpg"
img_B = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testC/style_3.png"
from opt import Config
opt = Config() # get test options
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 0
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
device = torch.device("cuda:0")
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt)
transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
transform = transforms.Compose(transform_list)
if opt.eval:
model.eval()
data = {}
print(os.getcwd())
B = reference, _, _ = load_img(r"/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png")
style_img = transform(reference)
data['B'] = style_img
data['B'] = data['B'].unsqueeze(0).to(device)
A = Image.open(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testA\real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg")
width = A.size[0]
height = A.size[1]
# data['A'] = A.resize((512, 512))
data['A'] = transform(A)
data['A'] = data['A'].unsqueeze(0).to(device)
model.set_input(data)
model.test() # run inference
visuals = model.get_current_visuals() # get image results
save_img(visuals['content_output'].cpu(), width, height, "result/result.jpg")

View File

@@ -0,0 +1,49 @@
import importlib
from app.service.image2sketch.models import unpaired_model as modellib
from .base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
# model_filename = "." + model_name + "_model"
# modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from .models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance

View File

@@ -0,0 +1,230 @@
import os
import torch
from collections import OrderedDict
from abc import ABC, abstractmethod
from . import networks
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
self.load_networks(load_suffix)
self.print_networks(opt.verbose)
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
old_lr = self.optimizers[0].param_groups[0]['lr']
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate %.7f -> %.7f' % (old_lr, lr))
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def save_networks(self, epoch):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad

View File

@@ -0,0 +1,354 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNR2d(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]):
super().__init__()
if bias == []:
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
if norm != []:
layers += [Norm2d(nch_out, norm)]
if relu != []:
layers += [ReLU(relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
self.cbr = nn.Sequential(*layers)
def forward(self, x):
return self.cbr(x)
class DECNR2d(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]):
super().__init__()
if bias == []:
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)]
if norm != []:
layers += [Norm2d(nch_out, norm)]
if relu != []:
layers += [ReLU(relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
self.decbr = nn.Sequential(*layers)
def forward(self, x):
return self.decbr(x)
class ResBlock(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
super().__init__()
if bias == []:
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
# 1st conv
layers += [Padding(padding, padding_mode=padding_mode)]
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
# 2nd conv
layers += [Padding(padding, padding_mode=padding_mode)]
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
self.resblk = nn.Sequential(*layers)
def forward(self, x):
return x + self.resblk(x)
class ResBlock_cat(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
super().__init__()
if bias == []:
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
# 1st conv
layers += [Padding(padding, padding_mode=padding_mode)]
layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
# 2nd conv
layers += [Padding(padding, padding_mode=padding_mode)]
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
self.resblk = nn.Sequential(*layers)
def forward(self,x,y):
output = x + self.resblk(torch.cat([x,y],dim=1))
return output
class LinearBlock(nn.Module):
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
super(LinearBlock, self).__init__()
use_bias = True
# initialize fully connected layer
if norm == 'sn':
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
else:
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm1d(norm_dim)
elif norm == 'in':
self.norm = nn.InstanceNorm1d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
def forward(self, x):
out = self.fc(x)
if self.norm:
out = self.norm(out)
if self.activation:
out = self.activation(out)
return out
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
super(MLP, self).__init__()
self.model = []
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
for i in range(n_blk - 2):
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
self.model = nn.Sequential(*self.model)
def forward(self, x):
return self.model(x.view(x.size(0), -1))
class CNR1d(nn.Module):
def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]):
super().__init__()
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
layers += [nn.Linear(nch_in, nch_out, bias=bias)]
if norm != []:
layers += [Norm2d(nch_out, norm)]
if relu != []:
layers += [ReLU(relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
self.cbr = nn.Sequential(*layers)
def forward(self, x):
return self.cbr(x)
class Conv2d(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True):
super(Conv2d, self).__init__()
self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
def forward(self, x):
return self.conv(x)
class Deconv2d(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True):
super(Deconv2d, self).__init__()
self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)
# layers = [nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.ReflectionPad2d(1),
# nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)]
#
# self.deconv = nn.Sequential(*layers)
def forward(self, x):
return self.deconv(x)
class Linear(nn.Module):
def __init__(self, nch_in, nch_out):
super(Linear, self).__init__()
self.linear = nn.Linear(nch_in, nch_out)
def forward(self, x):
return self.linear(x)
class Norm2d(nn.Module):
def __init__(self, nch, norm_mode):
super(Norm2d, self).__init__()
if norm_mode == 'bnorm':
self.norm = nn.BatchNorm2d(nch)
elif norm_mode == 'inorm':
self.norm = nn.InstanceNorm2d(nch)
def forward(self, x):
return self.norm(x)
class ReLU(nn.Module):
def __init__(self, relu):
super(ReLU, self).__init__()
if relu > 0:
self.relu = nn.LeakyReLU(relu, True)
elif relu == 0:
self.relu = nn.ReLU(True)
def forward(self, x):
return self.relu(x)
class Padding(nn.Module):
def __init__(self, padding, padding_mode='zeros', value=0):
super(Padding, self).__init__()
if padding_mode == 'reflection':
self. padding = nn.ReflectionPad2d(padding)
elif padding_mode == 'replication':
self.padding = nn.ReplicationPad2d(padding)
elif padding_mode == 'constant':
self.padding = nn.ConstantPad2d(padding, value)
elif padding_mode == 'zeros':
self.padding = nn.ZeroPad2d(padding)
def forward(self, x):
return self.padding(x)
class Pooling2d(nn.Module):
def __init__(self, nch=[], pool=2, type='avg'):
super().__init__()
if type == 'avg':
self.pooling = nn.AvgPool2d(pool)
elif type == 'max':
self.pooling = nn.MaxPool2d(pool)
elif type == 'conv':
self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool)
def forward(self, x):
return self.pooling(x)
class UnPooling2d(nn.Module):
def __init__(self, nch=[], pool=2, type='nearest'):
super().__init__()
if type == 'nearest':
self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True)
elif type == 'bilinear':
self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True)
elif type == 'conv':
self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool)
def forward(self, x):
return self.unpooling(x)
class Concat(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
diffy = x2.size()[2] - x1.size()[2]
diffx = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2,
diffy // 2, diffy - diffy // 2])
return torch.cat([x2, x1], dim=1)
class TV1dLoss(nn.Module):
def __init__(self):
super(TV1dLoss, self).__init__()
def forward(self, input):
# loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
# torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:]))
return loss
class TV2dLoss(nn.Module):
def __init__(self):
super(TV2dLoss, self).__init__()
def forward(self, input):
loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
return loss
class SSIM2dLoss(nn.Module):
def __init__(self):
super(SSIM2dLoss, self).__init__()
def forward(self, input, targer):
loss = 0
return loss

View File

@@ -0,0 +1,734 @@
import functools
from torch.nn import init
from torch.optim import lr_scheduler
from .layer import *
###############################################################################
# Helper Functions
###############################################################################
class Identity(nn.Module):
def forward(self, x):
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
def norm_layer(x):
return Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
net = None
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'ref_unpair_cbam_cat':
net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_cbam_cat')
elif netG == 'ref_unpair_recon':
net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_recon')
elif netG == 'triplet':
net = triplet(input_nc, output_nc, ngf, norm='inorm')
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
class AdaIN(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
eps = 1e-5
mean_x = torch.mean(x, dim=[2, 3])
mean_y = torch.mean(y, dim=[2, 3])
std_x = torch.std(x, dim=[2, 3])
std_y = torch.std(y, dim=[2, 3])
mean_x = mean_x.unsqueeze(-1).unsqueeze(-1)
mean_y = mean_y.unsqueeze(-1).unsqueeze(-1)
std_x = std_x.unsqueeze(-1).unsqueeze(-1) + eps
std_y = std_y.unsqueeze(-1).unsqueeze(-1) + eps
out = (x - mean_x) / std_x * std_y + mean_y
return out
class HED(nn.Module):
def __init__(self):
super(HED, self).__init__()
self.moduleVggOne = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleVggTwo = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleVggThr = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleVggFou = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleVggFiv = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleCombine = nn.Sequential(
nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
nn.Sigmoid()
)
def forward(self, tensorInput):
tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793
tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762
tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434
tensorInput = torch.cat([tensorBlue, tensorGreen, tensorRed], 1)
tensorVggOne = self.moduleVggOne(tensorInput)
tensorVggTwo = self.moduleVggTwo(tensorVggOne)
tensorVggThr = self.moduleVggThr(tensorVggTwo)
tensorVggFou = self.moduleVggFou(tensorVggThr)
tensorVggFiv = self.moduleVggFiv(tensorVggFou)
tensorScoreOne = self.moduleScoreOne(tensorVggOne)
tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo)
tensorScoreThr = self.moduleScoreThr(tensorVggThr)
tensorScoreFou = self.moduleScoreFou(tensorVggFou)
tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv)
tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
return self.moduleCombine(torch.cat([tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv], 1))
# return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreOne, tensorScoreTwo ], 1))
# return torch.sigmoid(tensorScoreOne),torch.sigmoid(tensorScoreTwo),torch.sigmoid(tensorScoreThr),torch.sigmoid(tensorScoreFou),torch.sigmoid(tensorScoreFiv),self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1))
# return torch.sigmoid(tensorScoreTwo)
def define_HED(init_weights_, gpu_ids_=[]):
net = HED()
if len(gpu_ids_) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids_[0])
net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
if not init_weights_ == None:
device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
print('Loading model from: %s' % init_weights_)
state_dict = torch.load(init_weights_, map_location=str(device))
if isinstance(net, torch.nn.DataParallel):
net.module.load_state_dict(state_dict)
else:
net.load_state_dict(state_dict)
print('load the weights successfully')
return net
def define_styletps(init_weights_, gpu_ids_=[], shape=False):
net = None
if shape == False:
net = triplet()
if len(gpu_ids_) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids_[0])
net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
if not init_weights_ == None:
device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
print('Loading model from: %s' % init_weights_)
state_dict = torch.load(init_weights_, map_location=str(device))
if isinstance(net, torch.nn.DataParallel):
net.module.load_state_dict(state_dict)
else:
net.load_state_dict(state_dict)
print('load the weights successfully')
return net
class triplet(nn.Module):
def __init__(self): # mnblk=4
super(triplet, self).__init__()
# self.channels = nch_in
self.nch_in = 1
self.nch_out = 1
self.nch_ker = 64
self.norm = 'bnorm'
# self.nblk = nblk
if self.norm == 'bnorm':
self.bias = False
else:
self.bias = True
self.conv0 = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
self.conv1 = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
self.conv2 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
self.linear = nn.Linear(256, 128)
def forward(self, x, y, z):
x = self.conv0(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.final_pool(x)
x = torch.flatten(x, 1)
x = self.linear(x)
y = self.conv0(y)
y = self.conv1(y)
y = self.conv2(y)
y = self.final_pool(y)
y = torch.flatten(y, 1)
y = self.linear(y)
z = self.conv0(z)
z = self.conv1(z)
z = self.conv2(z)
z = self.final_pool(z)
z = torch.flatten(z, 1)
z = self.linear(z)
return x, y, z
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
super(MLP, self).__init__()
self.model = []
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
for i in range(n_blk - 2):
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
self.model = nn.Sequential(*self.model)
def forward(self, x):
return self.model(x.view(x.size(0), -1))
class ref_unpair(nn.Module):
def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm', nblk=4, status='ref_unpair'):
super(ref_unpair, self).__init__()
nch_ker = 64
# self.channels = nch_in
self.nch_in = nch_in
self.nchs_in = 1
self.status = status
if self.status == 'ref_unpair_recon':
self.nch_out = 3
self.nch_in = 1
else:
self.nch_out = 1
self.nch_ker = nch_ker
self.norm = norm
self.nblk = nblk
self.dec0 = []
if status == 'ref_unpair_cbam_cat':
self.cbam_c = CBAM(nch_ker * 8, 16, 3, cbam_status="channel")
self.cbam_s = CBAM(nch_ker * 8, 16, 3, cbam_status="spatial")
self.enc1_s = CNR2d(self.nchs_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
self.enc2_s = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
self.enc3_s = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
self.enc4_s = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
if norm == 'bnorm':
self.bias = False
else:
self.bias = True
self.enc1_c = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
self.enc2_c = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
self.enc3_c = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
self.enc4_c = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
if status == 'ref_unpair_cbam_cat':
self.res_cat1 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
self.res_cat2 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
self.res_cat3 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
self.res_cat4 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
if self.nblk and status != 'ref_unpair_cbam_cat':
res = []
for i in range(self.nblk):
res += [ResBlock(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')]
self.res1 = nn.Sequential(*res)
# self.dec0 += [DECNR2d(16 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
self.dec0 += [DECNR2d(8 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
self.dec0 += [DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
self.dec0 += [DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
self.dec0 += [DECNR2d(1 * self.nch_ker, 1 * self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)]
self.dec0 += [nn.Conv2d(1 * self.nch_ker, self.nch_out, kernel_size=3, stride=1, padding=1)]
self.dec = nn.Sequential(*self.dec0)
def forward(self, content, style):
content_cs = self.enc1_c(content)
content_cs = self.enc2_c(content_cs)
content_cs = self.enc3_c(content_cs)
content_cs = self.enc4_c(content_cs)
# content_cs = self.enc5_c(content_cs)
if self.status == 'ref_unpair_cbam_cat':
cbam_content_cs = self.cbam_s(content_cs)
sp_content_cs = content_cs + cbam_content_cs
style_cs = self.enc1_s(style)
style_cs = self.enc2_s(style_cs)
style_cs = self.enc3_s(style_cs)
style_cs = self.enc4_s(style_cs)
cbam_style_cs = self.cbam_c(style_cs)
ch_style_cs = style_cs + cbam_style_cs
content_output = self.adaptive_instance_normalization(content_cs, style_cs)
cbam_content_output = self.adaptive_instance_normalization(sp_content_cs, ch_style_cs)
content_output = self.res_cat1(content_output, cbam_content_output)
content_output = self.res_cat2(content_output, cbam_content_output)
content_output = self.res_cat3(content_output, cbam_content_output)
content_output = self.res_cat4(content_output, cbam_content_output)
else:
content_output = content_cs
if self.nblk and self.status != 'ref_unpair_cbam_cat':
content_cs = self.res1(content_output)
content_output = self.dec(content_output)
content_output = torch.tanh(content_output)
return content_output
def calc_mean_std(self, feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(self, content_feat, style_feat):
assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_mean, style_std = self.calc_mean_std(style_feat)
content_mean, content_std = self.calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
net = None
norm_layer = get_norm_layer(norm_type=norm)
if netD == 'basic': # default PatchGAN classifier
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
elif netD == 'n_layers': # more options
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
elif netD == 'pixel': # classify if each pixel is real or fake
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
return init_net(net, init_type, init_gain, gpu_ids)
##############################################################################
# Classes
##############################################################################
class GANLoss(nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_target_tensor(self, prediction, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real):
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
if lambda_gp > 0.0:
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = torch.rand(real_data.shape[0], 1, device=device)
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.requires_grad_(True)
disc_interpolates = netD(interpolatesv)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
class PixelDiscriminator(nn.Module):
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
"""Construct a 1x1 PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
"""
super(PixelDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
self.net = nn.Sequential(*self.net)
def forward(self, input):
"""Standard forward."""
return self.net(input)
class CBAM(nn.Module):
def __init__(self, n_channels_in, reduction_ratio, kernel_size, cbam_status):
super(CBAM, self).__init__()
self.n_channels_in = n_channels_in
self.reduction_ratio = reduction_ratio
self.kernel_size = kernel_size
self.channel_attention = ChannelAttention_nopara(n_channels_in, reduction_ratio)
self.spatial_attention = SpatialAttention_nopara(kernel_size)
self.status = cbam_status
def forward(self, x):
## We don't use cbam in this version
if self.status == "cbam":
chan_att = self.channel_attention(x)
fp = chan_att * x
spat_att = self.spatial_attention(fp)
fpp = spat_att * fp
if self.status == "spatial":
spat_att = self.spatial_attention(x) # * s_para_1d
fpp = spat_att * x
if self.status == "channel":
chan_att = self.channel_attention(x) # * c_para_1d
fpp = chan_att * x
return fpp # ,c_wgt,s_wgt
class SpatialAttention_nopara(nn.Module):
def __init__(self, kernel_size):
super(SpatialAttention_nopara, self).__init__()
self.kernel_size = kernel_size
assert kernel_size % 2 == 1, "Odd kernel size required"
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=int((kernel_size - 1) / 2))
def forward(self, x):
max_pool = self.agg_channel(x, "max")
avg_pool = self.agg_channel(x, "avg")
pool = torch.cat([max_pool, avg_pool], dim=1)
conv = self.conv(pool)
conv = conv.repeat(1, x.size()[1], 1, 1)
att = torch.sigmoid(conv)
return att
def agg_channel(self, x, pool="max"):
b, c, h, w = x.size()
x = x.view(b, c, h * w)
x = x.permute(0, 2, 1)
if pool == "max":
x = F.max_pool1d(x, c)
elif pool == "avg":
x = F.avg_pool1d(x, c)
x = x.permute(0, 2, 1)
x = x.view(b, 1, h, w)
return x
class ChannelAttention_nopara(nn.Module):
def __init__(self, n_channels_in, reduction_ratio):
super(ChannelAttention_nopara, self).__init__()
self.n_channels_in = n_channels_in
self.reduction_ratio = reduction_ratio
self.middle_layer_size = int(self.n_channels_in / float(self.reduction_ratio))
self.bottleneck = nn.Sequential(
nn.Linear(self.n_channels_in, self.middle_layer_size),
nn.ReLU(),
nn.Linear(self.middle_layer_size, self.n_channels_in)
)
def forward(self, x):
kernel = (x.size()[2], x.size()[3])
avg_pool = F.avg_pool2d(x, kernel)
max_pool = F.max_pool2d(x, kernel)
avg_pool = avg_pool.view(avg_pool.size()[0], -1)
max_pool = max_pool.view(max_pool.size()[0], -1)
avg_pool_bck = self.bottleneck(avg_pool)
max_pool_bck = self.bottleneck(max_pool)
pool_sum = avg_pool_bck + max_pool_bck
sig_pool = torch.sigmoid(pool_sum)
sig_pool = sig_pool.unsqueeze(2).unsqueeze(3)
# out = sig_pool.repeat(1,1,kernel[0], kernel[1])
return sig_pool

View File

@@ -0,0 +1,86 @@
import torch
import torchvision
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl:
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
self.resize = resize
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for i, block in enumerate(self.blocks):
x = block(x)
y = block(y)
if i in feature_layers:
loss += torch.nn.functional.l1_loss(x, y)
if i in style_layers:
act_x = x.reshape(x.shape[0], x.shape[1], -1)
act_y = y.reshape(y.shape[0], y.shape[1], -1)
gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
return loss
class VGGstyleLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGstyleLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl:
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
self.resize = resize
def forward(self, input, target, feature_layers=[0,1,2,3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for i, block in enumerate(self.blocks):
x = block(x)
y = block(y)
if i in feature_layers:
loss += torch.nn.functional.l1_loss(x, y)
if i in style_layers:
act_x = x.reshape(x.shape[0], x.shape[1], -1)
act_y = y.reshape(y.shape[0], y.shape[1], -1)
gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
return loss

View File

@@ -0,0 +1,82 @@
import torch
from .base_model import BaseModel
from . import networks
class TemplateModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new model-specific options and rewrite default values for existing options.
Parameters:
parser -- the option parser
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
if is_train:
parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
self.loss_names = ['loss_G']
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
self.visual_names = ['data_A', 'data_B', 'output']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
self.model_names = ['G']
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
if self.isTrain: # only defined during training time
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
self.criterionLoss = torch.nn.L1Loss()
# define and initialize optimizers. You can define one optimizer for each network.
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = [self.optimizer]
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
def forward(self):
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
self.output = self.netG(self.data_A) # generate output image given the input data_A
def backward(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
# calculate loss given the input and intermediate results
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
def optimize_parameters(self):
"""Update network weights; it will be called in every training iteration."""
self.forward() # first call forward to calculate intermediate results
self.optimizer.zero_grad() # clear network G's existing gradients
self.backward() # calculate gradients for network G
self.optimizer.step() # update gradients for network G

View File

@@ -0,0 +1,45 @@
from .base_model import BaseModel
from . import networks
class TestModel(BaseModel):
""" This TesteModel can be used to generate CycleGAN results for only one direction.
This model will automatically set '--dataset_mode single', which only loads the images from one collection.
See the test instruction for more details.
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
assert not is_train, 'TestModel cannot be used during training time'
parser.set_defaults(dataset_mode='single')
parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
return parser
def __init__(self, opt):
assert(not opt.isTrain)
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = []
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real', 'fake']
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
self.model_names = ['G' + opt.model_suffix] # only generator is needed.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
# assigns the model to self.netG_[suffix] so that it can be loaded
# please see <BaseModel.load_networks>
setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
def set_input(self, input):
self.real = input['A'].to(self.device)
self.image_paths = input['A_paths']
def forward(self):
"""Run forward pass."""
self.fake = self.netG(self.real) # G(real)
def optimize_parameters(self):
"""No optimization for test model."""
pass

View File

@@ -0,0 +1,68 @@
import torch
from .base_model import BaseModel
from . import networks
from util.image_pool import ImagePool
class TripletModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
parser.set_defaults(norm='batch', netG='triplet', dataset_mode='triplet')
if is_train:
parser.set_defaults(pool_size=0, gan_mode='vanilla')
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_triplet']
self.visual_names = ['x','y']
if self.isTrain:
self.model_names = ['G']
else:
self.model_names = ['G']
self.netG = networks.define_G(1, 1, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionL1 = torch.nn.L1Loss()
self.triplet = torch.nn.TripletMarginLoss(margin=3.0)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
def set_input(self, input):
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.real_C = input['C'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
self.x,self.y,self.z = self.netG(self.real_A,self.real_B,self.real_C)
def backward_G(self):
self.loss_G_triplet_1 = self.triplet(self.x,self.y,self.z)
self.loss_G_triplet = self.loss_G_triplet_1
self.loss_G = self.loss_G_triplet
self.loss_G.backward()
def optimize_parameters(self):
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()

View File

@@ -0,0 +1,144 @@
import torch
from . import networks
from .base_model import BaseModel
from .perceptual import VGGPerceptualLoss
from ..util.image_pool import ImagePool
class UnpairedModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
parser.set_defaults(norm='batch', netG='ref_unpair_cbam_cat', netG2='ref_unpair_recon', dataset_mode='unaligned')
if is_train:
parser.set_defaults(pool_size=0, gan_mode='vanilla')
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['G_GAN', 'G_L1_1', 'G_Rec', 'G_line', 'D_real', 'D_fake']
self.visual_names = ['real_A', 'content_output', 'real_B']
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D']
else: # during test time, only load G
self.model_names = ['G_A', 'G_B']
# define networks (both generator and discriminator)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG2, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
self.netD = networks.define_D(1, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
self.styletps = networks.define_styletps(init_weights_='./checkpoints/contrastive_pretrained.pth', gpu_ids_=self.gpu_ids, shape=False)
self.HED = networks.define_HED(init_weights_='./checkpoints/network-bsds500.pytorch', gpu_ids_=self.gpu_ids)
if self.isTrain: # define discriminators
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
# define loss functions
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionL1_1 = torch.nn.L1Loss()
self.criterionL1_2 = torch.nn.L1Loss()
self.criterionL1_3 = torch.nn.L1Loss()
self.per_loss_1 = VGGPerceptualLoss().to(self.device)
self.per_loss_2 = VGGPerceptualLoss().to(self.device)
self.per_loss_3 = VGGPerceptualLoss().to(self.device)
self.optimizer_GA = torch.optim.Adam(self.netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_GB = torch.optim.Adam(self.netG_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_GA)
self.optimizers.append(self.optimizer_GB)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap images in domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
# self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.content_output = self.netG_A(self.real_A, self.real_B)
self.rec_output = self.netG_B(self.content_output, self.content_output)
def update_process(self, epoch, total_epoch):
self.epoch_count = epoch
self.epoch_count_total = total_epoch
def backward_D(self):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
# Real
pred_real = self.netD(self.real_B)
self.loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = self.netD(self.content_output.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5
loss_D.backward()
return loss_D
def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
pred_fake = self.netD(self.content_output)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
self.content_output_line = self.HED(self.real_A)
self.rec_output_line = self.HED(self.rec_output)
self.t1, self.t2, _ = self.styletps(self.content_output, self.real_B, self.real_B)
decay_lambda = 5 - ((self.epoch_count * 4.5) / self.epoch_count_total)
self.loss_G_L1_1 = self.criterionL1_1(self.t1, self.t2) * 10
self.loss_G_Rec = self.per_loss_2(self.real_A, self.rec_output) * decay_lambda
self.loss_G_line = self.per_loss_3(self.content_output_line, self.rec_output_line) * decay_lambda
self.loss_G = self.loss_G_GAN + self.loss_G_L1_1 + self.loss_G_Rec + self.loss_G_line
self.loss_G.backward()
def optimize_parameters(self):
self.forward() # compute fake images: G(A)
# update D
self.set_requires_grad(self.netD, True) # enable backprop for D
self.optimizer_D.zero_grad() # set D's gradients to zero
self.backward_D() # calculate gradients for backward_D_unsuper
self.optimizer_D.step() # update D's weights
# update G
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
self.optimizer_GA.zero_grad() # set G's gradients to zero
self.optimizer_GB.zero_grad() # set G's gradients to zero
self.backward_G() # calculate graidents for G
self.optimizer_GA.step() # udpate G's weights
self.optimizer_GB.step() # udpate G's weights

View File

@@ -0,0 +1,57 @@
from app.core.config import DEBUG
class Config:
def __init__(self):
# 基本参数
self.dataroot = "app/service/image2sketch/datasets/ref_unpair"
self.name = 'semi_unpair'
self.gpu_ids = [0]
# 模型参数
self.model = 'unpaired'
self.input_nc = 3
self.output_nc = 3
self.ngf = 64
self.ndf = 64
self.netD = 'basic'
self.netG = 'ref_unpair_cbam_cat'
self.netG2 = 'ref_unpair_recon'
self.n_layers_D = 3
self.norm = 'instance'
self.init_type = 'normal'
self.init_gain = 0.02
self.no_dropout = False # 对应 `--no_dropout`
# 数据集参数
self.dataset_mode = 'single'
self.direction = 'AtoB'
self.serial_batches = True # 对应 `--serial_batches`
self.num_threads = 4
self.batch_size = 4
self.load_size = 512
self.crop_size = 512
self.max_dataset_size = float("inf")
self.preprocess = 'resize_and_crop'
self.no_flip = False # 对应 `--no_flip`
self.display_winsize = 256
# 额外参数
self.epoch = '100'
self.load_iter = 0
self.verbose = False # 对应 `--verbose`
self.suffix = ''
self.isTrain = False
self.results_dir = 'service/image2sketch/results'
self.aspect_ratio = 1.0
self.phase = 'test'
self.eval = False
self.num_test = 1000
self.morm = 'batch'
if DEBUG:
self.style_image1 = "service/image2sketch/datasets/ref_unpair/testC/style_1.jpg"
self.style_image2 = "service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg"
self.style_image3 = "service/image2sketch/datasets/ref_unpair/testC/style_3.png"
self.checkpoints_dir = 'service/image2sketch/checkpoints/'
else:
self.checkpoints_dir = 'app/service/image2sketch/checkpoints/'
self.style_image1 = "app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg"
self.style_image2 = "app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg"
self.style_image3 = "app/service/image2sketch/datasets/ref_unpair/testC/style_3.png"

View File

@@ -0,0 +1,88 @@
import logging
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from app.schemas.image2sketch import Image2SketchModel
from app.service.image2sketch.infer import tensor2im
from app.service.image2sketch.models import create_model
from app.service.image2sketch.opt import Config
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
def tensor2im(input_image, imtype=np.uint8):
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
class Image2SketchServer:
def __init__(self, request_data):
self.image_url = request_data.image_url
self.style_image_url = request_data.style_image_url
self.sketch_bucket = request_data.sketch_bucket
self.sketch_name = request_data.sketch_name
self.opt = Config()
self.opt.num_threads = 0 # test code only supports num_threads = 0
self.opt.batch_size = 1 # test code only supports batch_size = 1
self.opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
self.opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
self.opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
self.data = {}
device = torch.device("cuda:0")
self.model = create_model(self.opt)
self.model.setup(self.opt)
transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
transform = transforms.Compose(transform_list)
if request_data.default_style == "1":
style_img = Image.open(self.opt.style_image1).convert('L')
elif request_data.default_style == "2":
style_img = Image.open(self.opt.style_image2).convert('L')
elif request_data.default_style == "3":
style_img = Image.open(self.opt.style_image3).convert('L')
else:
style_img = oss_get_image(bucket=self.style_image_url.split('/')[0], object_name=self.style_image_url[self.style_image_url.find('/') + 1:], data_type="PIL")
style_img = style_img.convert('L')
style_img = transform(style_img)
self.data['B'] = style_img
self.data['B'] = self.data['B'].unsqueeze(0).to(device)
A, self.width, self.height = self.get_image(self.image_url)
self.data['A'] = transform(A)
self.data['A'] = self.data['A'].unsqueeze(0).to(device)
def get_result(self):
self.model.set_input(self.data)
self.model.test() # run inference
visuals = self.model.get_current_visuals() # get image results
image_numpy = tensor2im(visuals['content_output'].cpu())
image_bytes = cv2.imencode(".jpg", image_numpy)[1].tobytes()
req = oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes)
return f"{req.bucket_name}/{req.object_name}"
def get_image(self, image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
image = image.convert('RGB')
width = image.size[0]
height = image.size[1]
return image, width, height
if __name__ == '__main__':
data = Image2SketchModel(image_url="test/real_Dress_790b2c6e370644e134df7abdfe7e54d9.jpg_Img.jpg", sketch_bucket="test", sketch_name="test123.jpg")
server = Image2SketchServer(data)
sketch_url = server.get_result()
print(sketch_url)

View File

@@ -0,0 +1 @@
"""This package includes a miscellaneous collection of useful helper functions."""

View File

@@ -0,0 +1,110 @@
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename
class GetData(object):
"""A Python script for downloading CycleGAN or pix2pix datasets.
Parameters:
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
verbose (bool) -- If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan')
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
and 'scripts/download_cyclegan_model.sh'.
"""
def __init__(self, technique='cyclegan', verbose=True):
url_dict = {
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
}
self.url = url_dict.get(technique.lower())
self._verbose = verbose
def _print(self, text):
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
soup = BeautifulSoup(r.text, 'lxml')
options = [h.text for h in soup.find_all('a', href=True)
if h.text.endswith(('.zip', 'tar.gz'))]
return options
def _present_options(self):
r = requests.get(self.url)
options = self._get_options(r)
print('Options:\n')
for i, o in enumerate(options):
print("{0}: {1}".format(i, o))
choice = input("\nPlease enter the number of the "
"dataset above you wish to download:")
return options[int(choice)]
def _download_data(self, dataset_url, save_path):
if not isdir(save_path):
os.makedirs(save_path)
base = basename(dataset_url)
temp_save_path = join(save_path, base)
with open(temp_save_path, "wb") as f:
r = requests.get(dataset_url)
f.write(r.content)
if base.endswith('.tar.gz'):
obj = tarfile.open(temp_save_path)
elif base.endswith('.zip'):
obj = ZipFile(temp_save_path, 'r')
else:
raise ValueError("Unknown File Type: {0}.".format(base))
self._print("Unpacking Data...")
obj.extractall(save_path)
obj.close()
os.remove(temp_save_path)
def get(self, save_path, dataset=None):
"""
Download a dataset.
Parameters:
save_path (str) -- A directory to save the data to.
dataset (str) -- (optional). A specific dataset to download.
Note: this must include the file extension.
If None, options will be presented for you
to choose from.
Returns:
save_path_full (str) -- the absolute path to the downloaded data.
"""
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
save_path_full = join(save_path, selected_dataset.split('.')[0])
if isdir(save_path_full):
warn("\n'{0}' already exists. Voiding Download.".format(
save_path_full))
else:
self._print('Downloading Data...')
url = "{0}/{1}".format(self.url, selected_dataset)
self._download_data(url, save_path=save_path)
return abspath(save_path_full)

View File

@@ -0,0 +1,86 @@
import dominate
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
import os
class HTML:
"""This HTML class allows us to save images and write texts into a single HTML file.
It consists of functions such as <add_header> (add a text header to the HTML file),
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
"""
def __init__(self, web_dir, title, refresh=0):
"""Initialize the HTML classes
Parameters:
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
title (str) -- the webpage name
refresh (int) -- how often the website refresh itself; if 0; no refreshing
"""
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
self.doc = dominate.document(title=title)
if refresh > 0:
with self.doc.head:
meta(http_equiv="refresh", content=str(refresh))
def get_image_dir(self):
"""Return the directory that stores images"""
return self.img_dir
def add_header(self, text):
"""Insert a header to the HTML file
Parameters:
text (str) -- the header text
"""
with self.doc:
h3(text)
def add_images(self, ims, txts, links, width=400):
"""add images to the HTML file
Parameters:
ims (str list) -- a list of image paths
txts (str list) -- a list of image names shown on the website
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
"""
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
self.doc.add(self.t)
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
"""save the current content to the HMTL file"""
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__': # we show an example usage here.
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims, txts, links = [], [], []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()

View File

@@ -0,0 +1,54 @@
import random
import torch
class ImagePool():
"""This class implements an image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""Initialize the ImagePool class
Parameters:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_imgs = 0
self.images = []
def query(self, images):
"""Return an image from the pool.
Parameters:
images: the latest generated images from the generator
Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = torch.cat(return_images, 0) # collect all the images and return
return return_images

View File

@@ -0,0 +1,103 @@
"""This module contains simple helper functions """
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
def tensor2im(input_image, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def diagnose_network(net, name='network'):
"""Calculate and print the mean of average absolute(gradients)
Parameters:
net (torch network) -- Torch network
name (str) -- the name of the network
"""
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path, aspect_ratio=1.0):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
image_pil = Image.fromarray(image_numpy)
h, w, _ = image_numpy.shape
if aspect_ratio > 1.0:
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
if aspect_ratio < 1.0:
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
"""Print the mean, min, max, median, std, and size of a numpy array
Parameters:
val (bool) -- if print the values of the numpy array
shp (bool) -- if print the shape of the numpy array
"""
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
"""create empty directories if they don't exist
Parameters:
paths (str list) -- a list of directory paths
"""
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
"""create a single empty directory if it didn't exist
Parameters:
path (str) -- a single directory path
"""
if not os.path.exists(path):
os.makedirs(path)

View File

@@ -0,0 +1,223 @@
import numpy as np
import os
import sys
import ntpath
import time
from . import util, html
from subprocess import Popen, PIPE
if sys.version_info[0] == 2:
VisdomExceptionBase = Exception
else:
VisdomExceptionBase = ConnectionError
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
"""Save images to the disk.
Parameters:
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
image_path (str) -- the string is used to create image paths
aspect_ratio (float) -- the aspect ratio of saved images
width (int) -- the images will be resized to width x width
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
"""
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims, txts, links = [], [], []
for label, im_data in visuals.items():
im = util.tensor2im(im_data)
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=width)
class Visualizer():
"""This class includes several functions that can display/save images and print/save logging information.
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
"""
def __init__(self, opt):
"""Initialize the Visualizer class
Parameters:
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
Step 1: Cache the training/test options
Step 2: connect to a visdom server
Step 3: create an HTML object for saveing HTML filters
Step 4: create a logging file to store training losses
"""
self.opt = opt # cache the option
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.port = opt.display_port
self.saved = False
'''
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
import visdom
self.ncols = opt.display_ncols
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
if not self.vis.check_connection():
self.create_visdom_connections()
'''
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
# create a logging file to store training losses
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
"""Reset the self.saved status"""
self.saved = False
'''
def create_visdom_connections(self):
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
print('Command: %s' % cmd)
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
def display_current_results(self, visuals, epoch, save_result):
"""Display current results on visdom; save current results to an HTML file.
Parameters:
visuals (OrderedDict) - - dictionary of images to display or save
epoch (int) - - the current epoch
save_result (bool) - - if save the current results to an HTML file
"""
if self.display_id > 0: # show images in the browser using visdom
ncols = self.ncols
if ncols > 0: # show all the images in one visdom panel
ncols = min(ncols, len(visuals))
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
</style>""" % (w, h) # create a table css
# create a table of images.
title = self.name
label_html = ''
label_html_row = ''
images = []
idx = 0
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
label_html_row += '<td>%s</td>' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += '<td></td>'
idx += 1
if label_html_row != '':
label_html += '<tr>%s</tr>' % label_html_row
try:
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '<table>%s</table>' % label_html
self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
except VisdomExceptionBase:
self.create_visdom_connections()
else: # show each image in a separate visdom panel;
idx = 1
try:
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
except VisdomExceptionBase:
self.create_visdom_connections()
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
self.saved = True
# save images to the disk
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims, txts, links = [], [], []
for label, image_numpy in visuals.items():
image_numpy = util.tensor2im(image)
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
'''
def plot_current_losses(self, epoch, counter_ratio, losses):
"""display the current losses on visdom display: dictionary of error labels and values
Parameters:
epoch (int) -- current epoch
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
"""
if not hasattr(self, 'plot_data'):
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
'''
try:
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
except VisdomExceptionBase:
self.create_visdom_connections()
'''
# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
"""print current losses on console; also save the losses to the disk
Parameters:
epoch (int) -- current epoch
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
t_comp (float) -- computational time per data point (normalized by batch_size)
t_data (float) -- data loading time per data point (normalized by batch_size)
"""
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message

View File

@@ -0,0 +1,45 @@
import os
from minio import Minio
from minio.error import S3Error
MINIO_URL = "www.minio.aida.com.hk:12024"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# 配置MinIO客户端
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# 下载函数
def download_folder(bucket_name, folder_name, local_dir):
try:
# 确保本地目录存在
if not os.path.exists(local_dir):
os.makedirs(local_dir)
# 遍历MinIO中的文件
objects = minio_client.list_objects(bucket_name, prefix=folder_name, recursive=True)
for obj in objects:
# 构造本地文件路径
local_file_path = os.path.join(local_dir, obj.object_name[len(folder_name):])
local_file_dir = os.path.dirname(local_file_path)
# 确保本地目录存在
if not os.path.exists(local_file_dir):
os.makedirs(local_file_dir)
# 下载文件
minio_client.fget_object(bucket_name, obj.object_name, local_file_path)
print(f"Downloaded {obj.object_name} to {local_file_path}")
except S3Error as e:
print(f"Error occurred: {e}")
# 使用示例
bucket_name = "test" # 替换成你的bucket名称
folder_name = "checkpoints/lineart/" # 权重文件夹的路径
local_dir = "app/service/image2sketch_2" # 替换成你希望保存到的本地目录
download_folder(bucket_name, folder_name, local_dir)

View File

@@ -0,0 +1,142 @@
import cv2
import numpy
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from app.service.utils.oss_client import oss_get_image, oss_upload_image
norm_layer = nn.InstanceNorm2d
weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
kernel = np.ones((3, 3), np.uint8)
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
# Initial convolution block
model0 = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True)]
self.model0 = nn.Sequential(*model0)
# Downsampling
model1 = []
in_features = 64
out_features = in_features * 2
for _ in range(2):
model1 += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)]
in_features = out_features
out_features = in_features * 2
self.model1 = nn.Sequential(*model1)
model2 = []
# Residual blocks
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
# Upsampling
model3 = []
out_features = in_features // 2
for _ in range(2):
model3 += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)]
in_features = out_features
out_features = in_features // 2
self.model3 = nn.Sequential(*model3)
# Output layer
model4 = [nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
model1 = Generator(3, 1, 3)
model1.load_state_dict(torch.load('app/service/image2sketch_2/model.pth', map_location=torch.device('cpu')))
model1.eval()
def predict(input_img, width):
transform = transforms.Compose([transforms.Resize(width, Image.BICUBIC), transforms.ToTensor()])
input_img = transform(input_img)
input_img = torch.unsqueeze(input_img, 0)
with torch.no_grad():
drawing = model1(input_img)[0].detach()
drawing = transforms.ToPILImage()(drawing)
# 转ndarray
drawing = numpy.array(drawing)
return drawing
def get_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
image = image.convert('RGB')
width = image.size[0]
height = image.size[1]
return image, width, height
def processing_pipeline(image_url, thickness, sketch_bucket, sketch_name):
thickness = int(thickness)
# 提取sketch
image, width, height = get_image(image_url)
sketch_image = predict(image, width)
# 设定线条粗细
if thickness != 0:
dilated = cv2.erode(sketch_image, kernel, iterations=1)
# 将原图与膨胀后的图像进行混合,使用不同的权重
sketch_image = cv2.addWeighted(sketch_image, weights[thickness][0], dilated, weights[thickness][1], 0)
# 上传minio
image_bytes = cv2.imencode(".jpg", sketch_image)[1].tobytes()
req = oss_upload_image(bucket=sketch_bucket, object_name=sketch_name, image_bytes=image_bytes)
return f"{req.bucket_name}/{req.object_name}"
if __name__ == '__main__':
result_url = processing_pipeline("aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png", 1, "test", "test123.jpg")
print(result_url)

View File

@@ -0,0 +1,99 @@
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import DESIGN_MODEL_URL
from app.schemas.image2sketch import Image2SketchModel
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
class LineArtService:
def __init__(self, request_item):
self.line_style = int(request_item.default_style)
self.image_url = request_item.image_url
self.sketch_bucket = request_item.sketch_bucket
self.sketch_name = request_item.sketch_name
self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
def get_result(self):
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
input_image = self.get_image()
input_img, ori_shape = self.line_art_preprocess(input_image)
transformed_img = input_img.astype(np.float32)
inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)]
results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs)
inference_output1 = results.as_numpy("output__0")
line_art_result = self.line_art_postprocess(inference_output1, ori_shape)
line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8)
if self.line_style != 0:
logger.info(self.line_style)
kernel = np.ones((3, 3), np.uint8)
dilated = cv2.erode(line_art_result, kernel, iterations=1)
# 将原图与膨胀后的图像进行混合,使用不同的权重
line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0)
# cv2.imshow("", line_art_result)
# cv2.waitKey(0)
return self.put_image(line_art_result)
def get_image(self):
image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
# 将其转换为彩色图像
if len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
return image
def put_image(self, image):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
oss_upload_image(bucket=self.sketch_bucket, object_name=f"{self.sketch_name}.jpg", image_bytes=image_bytes)
return f"{self.sketch_bucket}/{self.sketch_name}.jpg"
except Exception as e:
logger.warning(e)
@staticmethod
def line_art_preprocess(image):
img = mmcv.imread(image)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@staticmethod
def line_art_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
if __name__ == '__main__':
request_item = Image2SketchModel(
image_url="aida-collection-element/87/Sketchboard/555a443f-fd6b-4cd7-8147-b92d55513af0.png",
default_style="4",
sketch_bucket="test",
sketch_name="test123"
)
service = LineArtService(request_item)
result_url = service.get_result()
print(result_url)

View File

@@ -1,5 +1,5 @@
import time
import logging
import time
def RunTime(func):
@@ -7,8 +7,22 @@ def RunTime(func):
t1 = time.time()
res = func(*args, **kwargs)
t2 = time.time()
if t2 - t1 > 0.05:
logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s")
# if t2 - t1 > 0.05:
# logging.info(f"function【{func.__name__}】,runtime{str(t2 - t1)}】s")
logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s")
return res
return wrapper
def ClassCallRunTime(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
class_name = args[0].__class__.__name__ # 获取类名
print(f"class name: {class_name} , run time is : {execution_time} s")
return result
return wrapper

View File

@@ -0,0 +1,94 @@
import io
import logging
from io import BytesIO
import cv2
import numpy as np
import urllib3
from PIL import Image
from minio import Minio
from app.core.config import *
from app.service.utils.decorator import RunTime
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# 自定义 Retry 类
class CustomRetry(urllib3.Retry):
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
# 调用父类的 increment 方法
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
# 打印重试信息
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
return new_retry
logger = logging.getLogger()
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
http_client = urllib3.PoolManager(
num_pools=10, # 设置连接池大小
maxsize=10,
timeout=timeout,
cert_reqs='CERT_REQUIRED', # 需要证书验证
retries=CustomRetry(
total=5,
backoff_factor=0.2,
status_forcelist=[500, 502, 503, 504],
),
)
# 获取图片
@RunTime
def oss_get_image(oss_client, bucket, object_name, data_type):
# cv2 默认全通道读取
image_object = None
try:
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
if data_type == "cv2":
image_bytes = image_data.read()
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
if image_object.dtype == np.uint16:
image_object = (image_object / 256).astype('uint8')
else:
data_bytes = BytesIO(image_data.read())
image_object = Image.open(data_bytes)
except Exception as e:
logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}")
return image_object
@RunTime
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
req = None
try:
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
except Exception as e:
logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}")
return req
if __name__ == '__main__':
# url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png"
# url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg"
# url = "aida-sys-image/images/female/outwear/0628000054.jpg"
# url = "aida-users/89/product_image/string-89.png"
# url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png"
# url = 'aida-users/89/relight_image/123-89.png'
# url = 'aida-users/89/relight_image/123-89.png'
# url = 'aida-users/89/relight_image/123-89.png'
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png"
url = "aida-users/31/sketchboard/female/dress/6edcbf92-7da9-4809-a0a8-a4b4f06dec1e0628000041.jpg"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "cv2"
if read_type == "cv2":
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
cv2.imshow("", img)
cv2.waitKey(0)
else:
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
img.show()

View File

@@ -1,16 +1,38 @@
import io
import logging
from io import BytesIO
import boto3
import cv2
import numpy as np
import urllib3
from PIL import Image
from minio import Minio
from app.core.config import *
# 自定义 Retry 类
class CustomRetry(urllib3.Retry):
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
# 调用父类的 increment 方法
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
# 打印重试信息
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
return new_retry
logger = logging.getLogger()
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
http_client = urllib3.PoolManager(
num_pools=10, # 设置连接池大小
maxsize=10,
timeout=timeout,
cert_reqs='CERT_REQUIRED', # 需要证书验证
retries=CustomRetry(
total=5,
backoff_factor=0.2,
status_forcelist=[500, 502, 503, 504],
),
)
# 获取图片
@@ -18,12 +40,8 @@ def oss_get_image(bucket, object_name, data_type):
# cv2 默认全通道读取
image_object = None
try:
if OSS == "minio":
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
else:
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body']
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE, http_client=http_client)
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
if data_type == "cv2":
image_bytes = image_data.read()
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
@@ -41,12 +59,8 @@ def oss_get_image(bucket, object_name, data_type):
def oss_upload_image(bucket, object_name, image_bytes):
req = None
try:
if OSS == "minio":
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
else:
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
req = oss_client.put_object(Bucket=bucket, Key=object_name, Body=io.BytesIO(image_bytes), ContentType='image/png')
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
except Exception as e:
logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}")
return req
@@ -64,8 +78,8 @@ if __name__ == '__main__':
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png"
# url = "aida-users/89/product_image/string-89.png"
url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
url = "aida-results/result_e2673d92-8d25-11ef-be24-0826ae3ad6b3.png"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "cv2"
if read_type == "cv2":
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)

Binary file not shown.