Merge branch 'refs/heads/develop'
# Conflicts: # app/service/design/service.py # app/service/design_pre_processing/service.py # app/service/utils/oss_client.py # requirements.txt
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -120,10 +120,10 @@ dmypy.json
|
||||
|
||||
#runtime produce
|
||||
test
|
||||
seg_cache
|
||||
logs
|
||||
seg_result/
|
||||
seg_result
|
||||
*.png
|
||||
uwsgi
|
||||
*.yaml
|
||||
*.yml
|
||||
@@ -133,5 +133,7 @@ Dockerfile
|
||||
app/logs
|
||||
app/logs/*
|
||||
*.log
|
||||
*.jpg
|
||||
/qodana.yaml
|
||||
.pth
|
||||
.pytorch
|
||||
*.png
|
||||
59
app/api/api_brighten.py
Normal file
59
app/api/api_brighten.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from PIL import ImageEnhance
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.brighten import BrightenModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def increase_brightness(img, factor):
|
||||
enhancer = ImageEnhance.Brightness(img)
|
||||
bright_img = enhancer.enhance(factor)
|
||||
return bright_img
|
||||
|
||||
|
||||
@router.post("/brighten")
|
||||
async def brighten(request_item: BrightenModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **image_url**: 提亮图片url
|
||||
- **brighten_value**: 提高亮度的比重 亮度因子 1.0 表示原始亮度,1.5 表示增加 50% 的亮度
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"image_url": "aida-users/89/relight_image/3850e17b-3efd-4597-90ef-2a7bcd1a1a0b-0-89.png",
|
||||
"brighten_value": 1.5
|
||||
}
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
|
||||
new_image = increase_brightness(image, request_item.brighten_value)
|
||||
image_data = io.BytesIO()
|
||||
new_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
|
||||
brighten_url = f"{req.bucket_name}/{req.object_name}"
|
||||
logger.info(f"run time is : {time.time() - start_time}")
|
||||
except Exception as e:
|
||||
logger.warning(f"brighten Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=brighten_url)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_item = BrightenModel(image_url="aida-users/89/relight_image/3850e17b-3efd-4597-90ef-2a7bcd1a1a0b-0-89.png",
|
||||
brighten_value=1.5)
|
||||
image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
|
||||
new_image = increase_brightness(image, request_item.brighten_value)
|
||||
new_image.show()
|
||||
@@ -1,13 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
|
||||
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel
|
||||
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.design.model_process_service import model_transpose
|
||||
from app.service.design.service import generate
|
||||
from app.service.design.utils.redis_utils import Redis
|
||||
from app.service.design_batch.service import start_design_batch_generate
|
||||
from app.service.design_fast.design_generate import design_generate
|
||||
from app.service.design_fast.utils.redis_utils import Redis
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
@@ -24,28 +26,28 @@ def design(request_data: DesignModel):
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
203,
|
||||
249
|
||||
200,
|
||||
241
|
||||
],
|
||||
"hand_point_right": [
|
||||
229,
|
||||
343
|
||||
223,
|
||||
297
|
||||
],
|
||||
"waistband_left": [
|
||||
119,
|
||||
248
|
||||
112,
|
||||
241
|
||||
],
|
||||
"hand_point_left": [
|
||||
97,
|
||||
343
|
||||
92,
|
||||
305
|
||||
],
|
||||
"shoulder_left": [
|
||||
108,
|
||||
107
|
||||
99,
|
||||
116
|
||||
],
|
||||
"shoulder_right": [
|
||||
212,
|
||||
107
|
||||
215,
|
||||
116
|
||||
]
|
||||
},
|
||||
"layer_order": true,
|
||||
@@ -57,65 +59,33 @@ def design(request_data: DesignModel):
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 255303,
|
||||
"color": "139 148 156",
|
||||
"image_id": 95159,
|
||||
"businessId": 270372,
|
||||
"color": "30 28 28",
|
||||
"image_id": 69780,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png",
|
||||
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
|
||||
"seg_mask_url": "test/result.png",
|
||||
"print": {
|
||||
"single": {
|
||||
"location": [
|
||||
[
|
||||
200.0,
|
||||
200.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [
|
||||
"aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png"
|
||||
],
|
||||
"print_scale_list": [
|
||||
1.0
|
||||
]
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [
|
||||
[
|
||||
512.0,
|
||||
512.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [
|
||||
"aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png"
|
||||
],
|
||||
"print_scale_list": [
|
||||
1.0
|
||||
]
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"element": {
|
||||
"element_angle_list": [
|
||||
0.0
|
||||
],
|
||||
"element_path_list": [
|
||||
"aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png"
|
||||
],
|
||||
"element_scale_list": [
|
||||
0.2731036750637755
|
||||
],
|
||||
"location": [
|
||||
[
|
||||
228.63694825464364,
|
||||
406.4843844199667
|
||||
]
|
||||
]
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 10,
|
||||
@@ -123,22 +93,101 @@ def design(request_data: DesignModel):
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Dress"
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||
"image_id": 67277,
|
||||
"businessId": 270373,
|
||||
"color": "30 28 28",
|
||||
"image_id": 98243,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
|
||||
"seg_mask_url": "test/result.png",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 11,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 270374,
|
||||
"color": "172 68 68",
|
||||
"image_id": 98244,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
|
||||
"seg_mask_url": "test/result.png",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 12,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
|
||||
"image_id": 96090,
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"process_id": "89"
|
||||
"process_id": "83"
|
||||
}
|
||||
"""
|
||||
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
# data = generate(request_data=request_data)
|
||||
# logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
#
|
||||
|
||||
try:
|
||||
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
data = generate(request_data=request_data)
|
||||
data = design_generate(request_data=request_data)
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
@@ -193,3 +242,36 @@ def model_process(request_data: ModelProgressModel):
|
||||
logger.warning(f"model_process Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
# ##############################################################
|
||||
|
||||
|
||||
@router.post("/design_batch_generate")
|
||||
async def design(file: UploadFile = File(...),
|
||||
tasks_id: str = Form(...),
|
||||
user_id: str = Form(...),
|
||||
file_name: str = Form(...),
|
||||
total: int = Form(...)
|
||||
):
|
||||
dbg_config = DBGConfigModel(
|
||||
tasks_id=tasks_id,
|
||||
user_id=user_id,
|
||||
file_name=file_name,
|
||||
total=total
|
||||
)
|
||||
contents = await file.read()
|
||||
file_name = file.filename
|
||||
await save_request_file(contents, file_name)
|
||||
return await start_design_batch_generate(dbg_config, contents)
|
||||
|
||||
|
||||
async def save_request_file(contents, file_name):
|
||||
# 创建保存文件的目录(如果不存在)
|
||||
save_dir = os.path.join(os.getcwd(), "design_batch", "request_data")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
# 处理文件
|
||||
file_path = os.path.join(save_dir, file_name)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
38
app/api/api_image2sketch.py
Normal file
38
app/api/api_image2sketch.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.image2sketch import Image2SketchModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.lineart.service import LineArtService
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/image2sketch")
|
||||
def image2sketch(request_item: Image2SketchModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **image_url**: 提取图片url
|
||||
- **default_style**: 原始、 1、2、3、4、5
|
||||
- **sketch_bucket**: sketch保存的bucket
|
||||
- **sketch_name**: sketch保存的object name
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
|
||||
"default_style": 0,
|
||||
"sketch_bucket": "test",
|
||||
"sketch_name": "image2sketch/area_fill_img.png"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = LineArtService(request_item)
|
||||
result_url = service.get_result()
|
||||
except Exception as e:
|
||||
logger.warning(f"image2sketch Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=result_url)
|
||||
@@ -1,14 +1,15 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import api_test
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_attribute_retrieve
|
||||
from app.api import api_design
|
||||
from app.api import api_brighten
|
||||
from app.api import api_chat_robot
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_design
|
||||
from app.api import api_design_pre_processing
|
||||
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_image2sketch
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_test
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -20,3 +21,5 @@ router.include_router(api_design.router, tags=['design'], prefix="/api")
|
||||
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
|
||||
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
|
||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||
|
||||
@@ -24,11 +24,11 @@ DEBUG = False
|
||||
if DEBUG:
|
||||
LOGS_PATH = "logs/"
|
||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||
# FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
|
||||
SEG_CACHE_PATH = "../seg_cache/"
|
||||
else:
|
||||
LOGS_PATH = "app/logs/"
|
||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||
# FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
|
||||
SEG_CACHE_PATH = "/seg_cache/"
|
||||
|
||||
# RABBITMQ_ENV = "" # 生产环境
|
||||
RABBITMQ_ENV = "-dev" # 开发环境
|
||||
@@ -64,7 +64,7 @@ RABBITMQ_PARAMS = {
|
||||
MILVUS_URL = "http://10.1.1.240:19530"
|
||||
MILVUS_TOKEN = "root:Milvus"
|
||||
MILVUS_ALIAS = "default"
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
|
||||
MILVUS_TABLE_SEG = "seg_cache"
|
||||
|
||||
# Mysql 配置
|
||||
|
||||
90
app/design_batch/request_data/requests_data.json
Normal file
90
app/design_batch/request_data/requests_data.json
Normal file
@@ -0,0 +1,90 @@
|
||||
{
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
201,
|
||||
242
|
||||
],
|
||||
"hand_point_right": [
|
||||
222,
|
||||
312
|
||||
],
|
||||
"waistband_left": [
|
||||
114,
|
||||
243
|
||||
],
|
||||
"hand_point_left": [
|
||||
94,
|
||||
310
|
||||
],
|
||||
"shoulder_left": [
|
||||
102,
|
||||
116
|
||||
],
|
||||
"shoulder_right": [
|
||||
211,
|
||||
115
|
||||
]
|
||||
},
|
||||
"layer_order": true,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 264931,
|
||||
"color": "145 220 232",
|
||||
"image_id": 96844,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-users/87/sketch/2aa7aad5-74bb-41fa-9cdf-f06611b3e89a-2-87.png",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 10,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Dress"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/79805ec3-3f01-466d-91e0-36028d079699.png",
|
||||
"image_id": 95444,
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
],
|
||||
"process_id": "87",
|
||||
"tasks_id": ,
|
||||
}
|
||||
|
||||
|
||||
//用 openai jsonl
|
||||
//
|
||||
6
app/schemas/brighten.py
Normal file
6
app/schemas/brighten.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BrightenModel(BaseModel):
|
||||
image_url: str
|
||||
brighten_value: float
|
||||
@@ -1,50 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# class BodyPointModel(BaseModel):
|
||||
# waistband_right: list[int]
|
||||
# hand_point_right: list[int]
|
||||
# waistband_left: list[int]
|
||||
# hand_point_left: list[int]
|
||||
# shoulder_left: list[int]
|
||||
# shoulder_right: list[int]
|
||||
#
|
||||
#
|
||||
# class BasicModel(BaseModel):
|
||||
# body_point: BodyPointModel
|
||||
# layer_order: bool
|
||||
# scale_bag: float
|
||||
# scale_earrings: float
|
||||
# self_template: bool
|
||||
# single_overall: str
|
||||
# switch_category: str
|
||||
# body_path: str
|
||||
#
|
||||
#
|
||||
# class PrintModel(BaseModel):
|
||||
# if_single: bool
|
||||
# print_path_list: list[str]
|
||||
#
|
||||
#
|
||||
# class ItemModel(BaseModel):
|
||||
# color: str
|
||||
# image_id: str
|
||||
# offset: list[int]
|
||||
# path: str
|
||||
# print: PrintModel
|
||||
# resize_scale: float
|
||||
# type: str
|
||||
#
|
||||
#
|
||||
# class CollocationModel(BaseModel):
|
||||
# basic: BasicModel
|
||||
# item: list[ItemModel]
|
||||
#
|
||||
#
|
||||
# class DesignModel(BaseModel):
|
||||
# object: list[CollocationModel]
|
||||
# process_id: str
|
||||
|
||||
class DesignModel(BaseModel):
|
||||
objects: list[dict]
|
||||
process_id: str
|
||||
@@ -56,3 +12,10 @@ class DesignProgressModel(BaseModel):
|
||||
|
||||
class ModelProgressModel(BaseModel):
|
||||
model_path: str
|
||||
|
||||
|
||||
class DBGConfigModel(BaseModel):
|
||||
tasks_id: str
|
||||
user_id: str
|
||||
file_name: str
|
||||
total: int
|
||||
|
||||
8
app/schemas/image2sketch.py
Normal file
8
app/schemas/image2sketch.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Image2SketchModel(BaseModel):
|
||||
image_url: str
|
||||
default_style: str
|
||||
sketch_bucket: str
|
||||
sketch_name: str
|
||||
@@ -1,771 +0,0 @@
|
||||
{
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 493827,
|
||||
"color": "127 61 21",
|
||||
"elementId": 493827,
|
||||
"icon": "none",
|
||||
"image_id": 110201,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketch/62302527-2910-4740-808d-2cb8221daa34-3-31.png",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Dress"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "27 25 23",
|
||||
"icon": "none",
|
||||
"image_id": 110202,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/skirt/0916000602.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Skirt"
|
||||
},
|
||||
{
|
||||
"businessId": 493825,
|
||||
"color": "229 214 200",
|
||||
"elementId": 493825,
|
||||
"icon": "none",
|
||||
"image_id": 107101,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 493824,
|
||||
"color": "76 124 124",
|
||||
"elementId": 493824,
|
||||
"icon": "none",
|
||||
"image_id": 104522,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "229 214 200",
|
||||
"icon": "none",
|
||||
"image_id": 110203,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0825001576.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"color": "76 124 124",
|
||||
"icon": "none",
|
||||
"image_id": 96071,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/skirt/903000097.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Skirt"
|
||||
},
|
||||
{
|
||||
"color": "209 125 29",
|
||||
"icon": "none",
|
||||
"image_id": 93798,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/outwear_p4_561.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 493824,
|
||||
"color": "209 125 29",
|
||||
"elementId": 493824,
|
||||
"icon": "none",
|
||||
"image_id": 104522,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"color": "118 123 115",
|
||||
"icon": "none",
|
||||
"image_id": 110204,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0902000457.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"color": "118 123 115",
|
||||
"icon": "none",
|
||||
"image_id": 79259,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/826000094.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "127 61 21",
|
||||
"icon": "none",
|
||||
"image_id": 96038,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/dress/0902003549.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Dress"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 493822,
|
||||
"color": "127 61 21",
|
||||
"elementId": 493822,
|
||||
"icon": "none",
|
||||
"image_id": 62309,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/trousers/c37c2ea6-8955-4b40-8339-c737e672ca3d.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"businessId": 493825,
|
||||
"color": "118 123 115",
|
||||
"elementId": 493825,
|
||||
"icon": "none",
|
||||
"image_id": 107101,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 493826,
|
||||
"color": "127 61 21",
|
||||
"elementId": 493826,
|
||||
"icon": "none",
|
||||
"image_id": 107105,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/Skirt/58710352-6301-450d-b69a-fb2922b5429a.png",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Skirt"
|
||||
},
|
||||
{
|
||||
"color": "118 123 115",
|
||||
"icon": "none",
|
||||
"image_id": 79114,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/903000169.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"color": "229 214 200",
|
||||
"icon": "none",
|
||||
"image_id": 90573,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/0628000541.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "229 214 200",
|
||||
"icon": "none",
|
||||
"image_id": 110205,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/0916000217.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"businessId": 493825,
|
||||
"color": "209 125 29",
|
||||
"elementId": 493825,
|
||||
"icon": "none",
|
||||
"image_id": 107101,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"process_id": "6878547032381675"
|
||||
}
|
||||
@@ -10,6 +10,7 @@ class Bottom(Clothing):
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
# dict(type='Segmentation'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
dict(type='Scaling'),
|
||||
|
||||
@@ -30,14 +30,15 @@ class Clothing(object):
|
||||
image=self.result["front_image"],
|
||||
# mask_image=self.result['front_mask_image'],
|
||||
image_url=self.result['front_image_url'],
|
||||
mask_url=self.result['front_mask_url'],
|
||||
mask_url=self.result['mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=self.result["resize_scale"],
|
||||
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
|
||||
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
|
||||
pattern_image_url=self.result['pattern_image_url']
|
||||
pattern_image_url=self.result['pattern_image_url'],
|
||||
pattern_image=self.result['pattern_image']
|
||||
|
||||
)
|
||||
layer.insert(front_layer)
|
||||
@@ -47,14 +48,14 @@ class Clothing(object):
|
||||
image=self.result["back_image"],
|
||||
# mask_image=self.result['back_mask_image'],
|
||||
image_url=self.result['back_image_url'],
|
||||
mask_url=self.result['back_mask_url'],
|
||||
mask_url=self.result['mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=self.result["resize_scale"],
|
||||
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
|
||||
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
|
||||
pattern_image_url=self.result['pattern_image_url']
|
||||
pattern_image_url=self.result['pattern_image_url'],
|
||||
)
|
||||
layer.insert(back_layer)
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ class ContourDetection(object):
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
|
||||
result['front_mask'] = result['mask']
|
||||
result['back_mask'] = result['mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,6 +5,7 @@ import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.decorator import RunTime, ClassCallRunTime
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.design_ensemble import get_keypoint_result
|
||||
|
||||
@@ -27,7 +28,7 @@ class KeypointDetection(object):
|
||||
# self.client.close()
|
||||
# print(f"client close time : {time.time() - start_time}")
|
||||
|
||||
# @ RunTime
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
# logging.info("KeypointDetection run ")
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
|
||||
@@ -12,6 +12,7 @@ class LoadImageFromFile(object):
|
||||
self.print_dict = print_dict
|
||||
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
result['image'], result['pre_mask'] = self.read_image(self.path)
|
||||
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
@@ -45,15 +46,18 @@ class LoadImageFromFile(object):
|
||||
@staticmethod
|
||||
def read_image(image_path):
|
||||
image_mask = None
|
||||
# file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data
|
||||
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
|
||||
|
||||
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
if image.shape[2] == 4: # 如果是四通道 mask
|
||||
image_mask = image[:, :, 3]
|
||||
image = image[:, :, :3]
|
||||
|
||||
if image.shape[:2] <= (50, 50):
|
||||
# 计算新尺寸
|
||||
new_size = (image.shape[1] * 2, image.shape[0] * 2)
|
||||
# 调整大小
|
||||
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
|
||||
return image, image_mask
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import cv2
|
||||
@@ -7,13 +8,15 @@ from PIL import Image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Painting(object):
|
||||
def __init__(self, painting_flag=True):
|
||||
self.painting_flag = painting_flag
|
||||
|
||||
# @ RunTime
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
|
||||
dim_image_h, dim_image_w = result['image'].shape[0:2]
|
||||
@@ -86,7 +89,7 @@ class PrintPainting(object):
|
||||
def __init__(self, print_flag=True):
|
||||
self.print_flag = print_flag
|
||||
|
||||
# @ RunTime
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
single_print = result['print']['single']
|
||||
overall_print = result['print']['overall']
|
||||
@@ -236,7 +239,6 @@ class PrintPainting(object):
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
print(1)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
|
||||
@@ -2,6 +2,7 @@ import math
|
||||
|
||||
import cv2
|
||||
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@@ -10,7 +11,7 @@ class Scaling(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# @ RunTime
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
|
||||
# milvus_db_keypoint_cache
|
||||
|
||||
@@ -1,14 +1,71 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.design_ensemble import get_seg_result
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Segmentation(object):
|
||||
def __init__(self, device='cpu', show=False, debug=None):
|
||||
self.show = show
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
@ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
|
||||
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
|
||||
seg_mask = oss_get_image(bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
|
||||
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
|
||||
# 转换颜色空间为 RGB(OpenCV 默认是 BGR)
|
||||
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
|
||||
|
||||
r, g, b = cv2.split(image_rgb)
|
||||
red_mask = r > g
|
||||
green_mask = g > r
|
||||
|
||||
# 创建红色和绿色掩码
|
||||
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
|
||||
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
else:
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
result['seg_result'] = seg_result
|
||||
if not _:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# 处理前片后片
|
||||
temp_front = seg_result == 1.0
|
||||
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = seg_result == 2.0
|
||||
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logger.warning("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"加载失败: {e}")
|
||||
return False, None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
@@ -5,7 +6,9 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.conversion_image import rgb_to_rgba
|
||||
from ...utils.upload_image import upload_png_mask
|
||||
@@ -17,32 +20,14 @@ class Split(object):
|
||||
Split image into front and back layer according to the segmentation result
|
||||
"""
|
||||
|
||||
# @ClassCallRunTime
|
||||
# KNet
|
||||
def __call__(self, result):
|
||||
try:
|
||||
if 'mask' not in result.keys():
|
||||
raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.')
|
||||
if 'seg_result' not in result.keys(): # 没过seg模型
|
||||
result['front_mask'] = result['mask'].copy()
|
||||
result['back_mask'] = np.zeros_like(result['mask'])
|
||||
else:
|
||||
temp_front = result['seg_result'] == 1
|
||||
result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = result['seg_result'] == 2
|
||||
result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8))
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
if len(result['front_mask'].shape) > 2:
|
||||
front_mask = result['front_mask'][0]
|
||||
else:
|
||||
front_mask = result['front_mask']
|
||||
|
||||
if len(result['back_mask'].shape) > 2:
|
||||
back_mask = result['back_mask'][0]
|
||||
else:
|
||||
back_mask = result['back_mask']
|
||||
|
||||
# rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], front_mask + back_mask)
|
||||
front_mask = result['front_mask']
|
||||
back_mask = result['back_mask']
|
||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
||||
rgba_image = cv2.resize(rgba_image, new_size)
|
||||
@@ -50,23 +35,45 @@ class Split(object):
|
||||
front_mask = cv2.resize(front_mask, new_size)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask)
|
||||
result['front_image'], result["front_image_url"], _ = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=None)
|
||||
|
||||
height, width = front_mask.shape
|
||||
mask_image = np.zeros((height, width, 3))
|
||||
mask_image[front_mask != 0] = [0, 0, 255]
|
||||
|
||||
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask)
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
else:
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
result['back_image'] = None
|
||||
result["back_image_url"] = None
|
||||
result["back_mask_url"] = None
|
||||
result['back_mask_image'] = None
|
||||
|
||||
# 创建中间图层
|
||||
# result["back_mask_url"] = None
|
||||
# result['back_mask_image'] = None
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
_, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
|
||||
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
||||
|
||||
@@ -9,8 +9,8 @@ class Top(Clothing):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']),
|
||||
# dict(type='ContourDetection'),
|
||||
dict(type='Segmentation'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
|
||||
import cv2
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
from app.service.design.core.layer import Layer
|
||||
@@ -6,6 +9,7 @@ from app.service.design.items import build_item
|
||||
from app.service.design.utils.redis_utils import Redis
|
||||
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
def process_item(item, layers):
|
||||
@@ -23,7 +27,7 @@ def update_progress(process_id, total):
|
||||
if int(progress) <= 100:
|
||||
r.write(key=process_id, value=int(progress) + int(100 / total))
|
||||
else:
|
||||
r.write(key=process_id, value=100)
|
||||
r.write(key=process_id, value=99)
|
||||
return progress
|
||||
elif total == 1:
|
||||
r.write(key=process_id, value=100)
|
||||
@@ -43,6 +47,7 @@ def final_progress(process_id):
|
||||
@RunTime
|
||||
def generate(request_data):
|
||||
return_response = {}
|
||||
return_png_mask = []
|
||||
request_data = request_data.dict()
|
||||
assert "process_id" in request_data.keys(), "Need process_id parameters"
|
||||
|
||||
@@ -55,14 +60,15 @@ def generate(request_data):
|
||||
# 获取处理结果
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
obj = futures[future]
|
||||
|
||||
result = future.result()
|
||||
return_response[obj] = result
|
||||
return_response[obj] = future.result()[0]
|
||||
return_png_mask.extend(future.result()[1])
|
||||
# upload_results = process_images(return_png_mask)
|
||||
final_progress(process_id)
|
||||
return return_response
|
||||
|
||||
|
||||
def process_object(cfg, process_id, total):
|
||||
uploaded_images = []
|
||||
basic_info = cfg.get('basic')
|
||||
items_response = {
|
||||
'layers': []
|
||||
@@ -83,8 +89,17 @@ def process_object(cfg, process_id, total):
|
||||
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
|
||||
else:
|
||||
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
|
||||
# 上传所有图片
|
||||
# for layer in layers:
|
||||
# if 'image' in layer.keys() and layer['image'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url'], 'image_type': 'image'})
|
||||
# if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url'], 'image_type': 'pattern_image'})
|
||||
# if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url'], 'image_type': 'mask'})
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
# 合成
|
||||
items_response['synthesis_url'] = synthesis(layers, body_size)
|
||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic_info)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
@@ -114,9 +129,10 @@ def process_object(cfg, process_id, total):
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item.result['front_image_url'],
|
||||
'mask_url': item.result['front_mask_url'],
|
||||
'mask_url': item.result['mask_url'],
|
||||
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
|
||||
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item.result['name']}_back",
|
||||
@@ -124,11 +140,58 @@ def process_object(cfg, process_id, total):
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item.result['back_image_url'],
|
||||
'mask_url': item.result['back_mask_url'],
|
||||
'mask_url': item.result['mask_url'],
|
||||
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
|
||||
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
|
||||
break
|
||||
update_progress(process_id, total)
|
||||
return items_response
|
||||
return items_response, uploaded_images
|
||||
|
||||
|
||||
@RunTime
|
||||
def process_images(images):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
results = list(executor.map(upload_images, images))
|
||||
# results = []
|
||||
# for image in images:
|
||||
# results.append(upload_images(image))
|
||||
return results
|
||||
|
||||
|
||||
# @RunTime
|
||||
def upload_images(image_obj):
|
||||
bucket_name = image_obj['image_url'].split("/", 1)[0]
|
||||
object_name = image_obj['image_url'].split("/", 1)[1]
|
||||
if image_obj['image_type'] == 'image' or image_obj['image_type'] == 'pattern_image':
|
||||
image_data = io.BytesIO()
|
||||
image_obj['image_obj'].save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return image_obj['image_url']
|
||||
else:
|
||||
mask_inverted = cv2.bitwise_not(image_obj['image_obj'])
|
||||
# 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明
|
||||
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
|
||||
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
return image_obj['image_url']
|
||||
|
||||
|
||||
def update_base_size_priority(layers, size):
|
||||
# 计算透明背景图片的宽度
|
||||
min_x = min(info['position'][1] for info in layers)
|
||||
x_list = []
|
||||
for info in layers:
|
||||
if info['image'] is not None:
|
||||
x_list.append(info['position'][1] + info['image'].width)
|
||||
max_x = max(x_list)
|
||||
new_width = max_x - min_x
|
||||
new_height = 700
|
||||
# 更新坐标
|
||||
for info in layers:
|
||||
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
||||
return layers, (new_width, new_height)
|
||||
|
||||
@@ -59,14 +59,26 @@ def positioning(all_mask_shape, mask_shape, offset):
|
||||
|
||||
|
||||
# @RunTime
|
||||
def synthesis(data, size):
|
||||
def synthesis(data, size, basic_info):
|
||||
# 创建底图
|
||||
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
try:
|
||||
|
||||
all_mask_shape = (size[1], size[0])
|
||||
top_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
|
||||
bottom_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
|
||||
body_mask = None
|
||||
for d in data:
|
||||
if d['name'] == 'body':
|
||||
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
|
||||
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
|
||||
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position
|
||||
body_mask = np.array(transparent_image.split()[3])
|
||||
|
||||
# 根据新的坐标获取新的肩点
|
||||
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
|
||||
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
|
||||
top_outer_mask = np.array(binary_body_mask)
|
||||
bottom_outer_mask = np.array(binary_body_mask)
|
||||
|
||||
top = True
|
||||
bottom = True
|
||||
@@ -76,21 +88,27 @@ def synthesis(data, size):
|
||||
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
||||
top = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['position']
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
top_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front"]:
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
top_outer_mask = background + top_outer_mask
|
||||
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
|
||||
bottom = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['position']
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
bottom_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
bottom_outer_mask = background + bottom_outer_mask
|
||||
elif bottom is False and top is False:
|
||||
break
|
||||
|
||||
@@ -100,13 +118,13 @@ def synthesis(data, size):
|
||||
if layer['image'] is not None:
|
||||
if layer['name'] != "body":
|
||||
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
test_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
|
||||
# mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||
# mask_alpha = Image.fromarray(mask_data)
|
||||
# cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||
base_image.paste(test_image, (0, 0), test_image)
|
||||
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||
mask_alpha = Image.fromarray(mask_data)
|
||||
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
|
||||
else:
|
||||
base_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
|
||||
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
|
||||
result_image = base_image
|
||||
|
||||
|
||||
126
app/service/design_batch/design_batch_celery.py
Normal file
126
app/service/design_batch/design_batch_celery.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.design_batch.item import BodyItem, TopItem, BottomItem
|
||||
from app.service.design_batch.utils.MQ import publish_status
|
||||
from app.service.design_batch.utils.organize import organize_body, organize_clothing
|
||||
from app.service.design_batch.utils.save_json import oss_upload_json
|
||||
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
||||
|
||||
id_lock = threading.Lock()
|
||||
celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://')
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
logger = logging.getLogger()
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
def process_item(item, basic):
|
||||
# 处理project中单个item
|
||||
if item['type'] == "Body":
|
||||
body_server = BodyItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = body_server.process()
|
||||
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
|
||||
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = top_server.process()
|
||||
else:
|
||||
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = bottom_server.process()
|
||||
return item_data
|
||||
|
||||
|
||||
def process_layer(item, layers):
|
||||
# item处理结束后 对图层数据组装
|
||||
if item['name'] == "mannequin":
|
||||
body_layer = organize_body(item)
|
||||
layers.append(body_layer)
|
||||
return item['body_image'].size
|
||||
else:
|
||||
front_layer, back_layer = organize_clothing(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def batch_design(objects_data, tasks_id, json_name):
|
||||
object_response = []
|
||||
threads = []
|
||||
active_threads = 0
|
||||
lock = threading.Lock()
|
||||
|
||||
def process_object(step, object):
|
||||
nonlocal active_threads
|
||||
basic = object['basic']
|
||||
items_response = {'layers': []}
|
||||
if basic['single_overall'] == "overall":
|
||||
item_results = []
|
||||
for item in object['items']:
|
||||
item_results.append(process_item(item, basic))
|
||||
layers = []
|
||||
body_size = None
|
||||
for item in item_results:
|
||||
body_size = process_layer(item, layers)
|
||||
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
||||
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
'image_category': lay['name'],
|
||||
'position': lay['position'],
|
||||
'priority': lay.get("priority", None),
|
||||
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
|
||||
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
|
||||
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
||||
'mask_url': lay['mask_url'],
|
||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||
else:
|
||||
item_result = process_item(object['items'][0], basic)
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item_result['name']}_front",
|
||||
'image_size': item_result['back_image'].size if item_result['back_image'] else None,
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item_result['front_image_url'],
|
||||
'mask_url': item_result['mask_url'],
|
||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
||||
})
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item_result['name']}_back",
|
||||
'image_size': item_result['front_image'].size if item_result['front_image'] else None,
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item_result['back_image_url'],
|
||||
'mask_url': item_result['mask_url'],
|
||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
||||
|
||||
with lock:
|
||||
object_response.append(items_response)
|
||||
publish_status(tasks_id, step + 1, items_response)
|
||||
active_threads -= 1
|
||||
|
||||
for step, object in enumerate(objects_data):
|
||||
t = threading.Thread(target=process_object, args=(step, object))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
with lock:
|
||||
active_threads += 1
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
oss_upload_json(minio_client, object_response, json_name)
|
||||
publish_status(tasks_id, "ok", json_name)
|
||||
return object_response
|
||||
61
app/service/design_batch/item.py
Normal file
61
app/service/design_batch/item.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from app.service.design_batch.pipeline import *
|
||||
|
||||
|
||||
class BaseItem:
|
||||
def __init__(self, data, basic):
|
||||
self.result = data.copy()
|
||||
self.result['name'] = data['type'].lower()
|
||||
self.result.pop("type")
|
||||
self.result.update(basic)
|
||||
|
||||
|
||||
class TopItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.top_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
Segmentation(minio_client),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.top_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
class BottomItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.bottom_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
ContourDetection(),
|
||||
# Segmentation(),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.bottom_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
class BodyItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.top_pipeline = [
|
||||
LoadBodyImage(minio_client),
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.top_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
20
app/service/design_batch/pipeline/__init__.py
Normal file
20
app/service/design_batch/pipeline/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .color import Color
|
||||
from .contour_detection import ContourDetection
|
||||
from .keypoint import KeyPoint
|
||||
from .keypoint import KeyPoint
|
||||
from .loading import LoadImage, LoadBodyImage
|
||||
from .print_painting import PrintPainting
|
||||
from .scale import Scaling
|
||||
from .segmentation import Segmentation
|
||||
from .split import Split
|
||||
|
||||
__all__ = [
|
||||
'LoadBodyImage', 'LoadImage',
|
||||
'KeyPoint',
|
||||
'ContourDetection',
|
||||
'Segmentation',
|
||||
'Color',
|
||||
'PrintPainting',
|
||||
'Scaling',
|
||||
'Split'
|
||||
]
|
||||
62
app/service/design_batch/pipeline/color.py
Normal file
62
app/service/design_batch/pipeline/color.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class Color:
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
dim_image_h, dim_image_w = result['image'].shape[0:2]
|
||||
if "gradient" in result.keys() and result['gradient'] != "":
|
||||
bucket_name = result['gradient'].split('/')[0]
|
||||
object_name = result['gradient'][result['gradient'].find('/') + 1:]
|
||||
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
pattern = self.get_pattern(result['color'])
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
result['alpha'] = 100 / 255.0
|
||||
return result
|
||||
|
||||
def get_gradient(self, bucket_name, object_name):
|
||||
# 获取渐变色图案
|
||||
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def crop_image(image, image_size_h, image_size_w):
|
||||
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
|
||||
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
return pattern
|
||||
37
app/service/design_batch/pipeline/contour_detection.py
Normal file
37
app/service/design_batch/pipeline/contour_detection.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ContourDetection:
|
||||
def __call__(self, result):
|
||||
Contour = self.get_contours(result['image'])
|
||||
Mask = np.zeros(result['image'].shape[:2], np.uint8)
|
||||
if len(Contour):
|
||||
Max_contour = Contour[0]
|
||||
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
|
||||
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
|
||||
cv2.drawContours(Mask, [Approx], -1, 255, -1)
|
||||
else:
|
||||
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
|
||||
# TODO 修复部分图片出现透明的情况 下版本上线
|
||||
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
|
||||
# Mask = cv2.bitwise_not(Mask)
|
||||
if result['pre_mask'] is None:
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
result['front_mask'] = result['mask']
|
||||
result['back_mask'] = result['mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_contours(image):
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
Edge = cv2.Canny(gray, 10, 150)
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
|
||||
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
return Contour
|
||||
114
app/service/design_batch/pipeline/keypoint.py
Normal file
114
app/service/design_batch/pipeline/keypoint.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.design_batch.utils.design_ensemble import get_keypoint_result
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyPoint:
|
||||
name = "KeyPoint"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
return cls.name
|
||||
|
||||
def __call__(self, result):
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
keypoint_cache = self.keypoint_cache(result, site)
|
||||
# 取消向量查询 直接过模型推理
|
||||
# keypoint_cache = False
|
||||
if keypoint_cache is False:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
else:
|
||||
result['clothes_keypoint'] = keypoint_cache
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def infer_keypoint_result(result):
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
|
||||
return keypoint_infer_result, site
|
||||
|
||||
@staticmethod
|
||||
def save_keypoint_cache(keypoint_id, cache, site):
|
||||
if site == "down":
|
||||
zeros = np.zeros(20, dtype=int)
|
||||
result = np.concatenate([zeros, cache.flatten()])
|
||||
else:
|
||||
zeros = np.zeros(4, dtype=int)
|
||||
result = np.concatenate([cache.flatten(), zeros])
|
||||
# 取消向量保存 直接拿结果
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": site,
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
client.close()
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logger.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
@staticmethod
|
||||
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
||||
if site == "up":
|
||||
# 需要的是up 即推理出来的是up 那么查询的就是down
|
||||
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||
else:
|
||||
# 需要的是down 即推理出来的是down 那么查询的就是up
|
||||
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": "all",
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
client.upsert(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
data=data
|
||||
)
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logger.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
# @ RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
keypoint_id = result['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
# ids=[keypoint_id],
|
||||
filter=f"keypoint_id == {keypoint_id}",
|
||||
output_fields=['keypoint_vector', 'keypoint_site']
|
||||
)
|
||||
if len(res) == 0:
|
||||
# 没有结果 直接推理拿结果 并保存
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
||||
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
||||
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||
elif res[0]["keypoint_site"] != site:
|
||||
# 需要的类型和查询到的不一致,则更新类型为all
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
||||
except Exception as e:
|
||||
logger.info(f"search keypoint cache milvus error {e}")
|
||||
return False
|
||||
77
app/service/design_batch/pipeline/loading.py
Normal file
77
app/service/design_batch/pipeline/loading.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class LoadBodyImage:
|
||||
name = "LoadBodyImage"
|
||||
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
return cls.name
|
||||
|
||||
def __call__(self, result):
|
||||
result["name"] = "mannequin"
|
||||
result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL")
|
||||
return result
|
||||
|
||||
|
||||
class LoadImage:
|
||||
name = "LoadImage"
|
||||
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
return cls.name
|
||||
|
||||
def __call__(self, result):
|
||||
result['image'], result['pre_mask'] = self.read_image(result['path'])
|
||||
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
result['keypoint'] = self.get_keypoint(result['name'])
|
||||
result['img_shape'] = result['image'].shape
|
||||
result['ori_shape'] = result['image'].shape
|
||||
return result
|
||||
|
||||
def read_image(self, image_path):
|
||||
image_mask = None
|
||||
image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
if image.shape[2] == 4: # 如果是四通道 mask
|
||||
image_mask = image[:, :, 3]
|
||||
image = image[:, :, :3]
|
||||
|
||||
if image.shape[:2] <= (50, 50):
|
||||
# 计算新尺寸
|
||||
new_size = (image.shape[1] * 2, image.shape[0] * 2)
|
||||
# 调整大小
|
||||
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
|
||||
return image, image_mask
|
||||
|
||||
@staticmethod
|
||||
def get_keypoint(name):
|
||||
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
|
||||
keypoint = 'shoulder'
|
||||
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
|
||||
keypoint = 'waistband'
|
||||
elif name == 'bag':
|
||||
keypoint = 'hand_point'
|
||||
elif name == 'shoes':
|
||||
keypoint = 'toe'
|
||||
elif name == 'hairstyle':
|
||||
keypoint = 'head_point'
|
||||
elif name == 'earring':
|
||||
keypoint = 'ear_point'
|
||||
else:
|
||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||
f"bag, shoes, hairstyle, earring.")
|
||||
return keypoint
|
||||
524
app/service/design_batch/pipeline/print_painting.py
Normal file
524
app/service/design_batch/pipeline/print_painting.py
Normal file
@@ -0,0 +1,524 @@
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
|
||||
class PrintPainting:
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
single_print = result['print']['single']
|
||||
overall_print = result['print']['overall']
|
||||
element_print = result['print']['element']
|
||||
result['single_image'] = None
|
||||
result['print_image'] = None
|
||||
if overall_print['print_path_list']:
|
||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||
result['print_image'] = result['pattern_image']
|
||||
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
|
||||
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||
|
||||
# resize 到sketch大小
|
||||
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
else:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
|
||||
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
|
||||
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
|
||||
|
||||
if single_print['print_path_list']:
|
||||
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(single_print['print_path_list'])):
|
||||
image, image_mode = self.read_image(single_print['print_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
|
||||
if element_print['element_path_list']:
|
||||
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(element_print['element_path_list'])):
|
||||
image, image_mode = self.read_image(element_print['element_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
# TODO element 丢失信息
|
||||
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
||||
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
|
||||
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
|
||||
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
|
||||
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
|
||||
mask_inv = cv2.bitwise_not(mask_)
|
||||
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
|
||||
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
|
||||
print_background = img1_bg + img2_fg
|
||||
return print_background
|
||||
|
||||
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
|
||||
if print_trigger:
|
||||
print_ = self.get_print(print_dict)
|
||||
painting_dict['Trigger'] = not is_single
|
||||
painting_dict['location'] = print_['location']
|
||||
single_mask_inv_print = self.get_mask_inv(print_['image'])
|
||||
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||
if not is_single:
|
||||
self.random_seed = random.randint(0, 1000)
|
||||
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
||||
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
|
||||
return painting_dict
|
||||
|
||||
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
||||
tile = None
|
||||
if not trigger:
|
||||
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
if len(pattern.shape) == 2:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
|
||||
if len(pattern.shape) == 3:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
|
||||
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
|
||||
return tile
|
||||
|
||||
def get_mask_inv(self, print_):
|
||||
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
|
||||
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
mask_inv = cv2.inRange(print_tile, lower, upper)
|
||||
return mask_inv
|
||||
else:
|
||||
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
|
||||
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
|
||||
return mask_inv
|
||||
|
||||
@staticmethod
|
||||
def printpaint(result, painting_dict, print_=False):
|
||||
|
||||
if print_ and painting_dict['Trigger']:
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
||||
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
||||
else:
|
||||
print_mask = result['mask']
|
||||
img_fg = result['final_image']
|
||||
if print_ and not painting_dict['Trigger']:
|
||||
index_ = None
|
||||
try:
|
||||
index_ = len(painting_dict['location'])
|
||||
except:
|
||||
assert f'there must be parameter of location if choose IfSingle'
|
||||
|
||||
for i in range(index_):
|
||||
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
||||
|
||||
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
||||
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
||||
|
||||
change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
||||
# problem in change_mask
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
print_image = cv2.add(img_bg, img_fg)
|
||||
return print_image
|
||||
|
||||
def get_print(self, print_dict):
|
||||
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
|
||||
print_dict['scale'] = 0.3
|
||||
else:
|
||||
print_dict['scale'] = print_dict['print_scale_list'][0]
|
||||
|
||||
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
|
||||
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
|
||||
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
|
||||
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
||||
if image.mode == "RGBA":
|
||||
new_background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
new_background.paste(image, mask=image.split()[3])
|
||||
image = new_background
|
||||
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||
return print_dict
|
||||
|
||||
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
|
||||
print_w = print_shape[1]
|
||||
print_h = print_shape[0]
|
||||
|
||||
random.seed(self.random_seed)
|
||||
# logging.info(f'overall print location : {location}')
|
||||
# x_offset = random.randint(0, image.shape[0] - image_size_h)
|
||||
# y_offset = random.randint(0, image.shape[1] - image_size_w)
|
||||
|
||||
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||
x_offset = print_w - int(location[0][1] % print_w)
|
||||
y_offset = print_w - int(location[0][0] % print_h)
|
||||
|
||||
# y_offset = int(location[0][0])
|
||||
# x_offset = int(location[0][1])
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||
elif len(image.shape) == 3:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_low_high_lab(Lab_value, L=False):
|
||||
if L:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
else:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
return high, low
|
||||
|
||||
@staticmethod
|
||||
def img_rotate(image, angel, scale):
|
||||
"""顺时针旋转图像任意角度
|
||||
|
||||
Args:
|
||||
image (np.array): [原始图像]
|
||||
angel (float): [逆时针旋转的角度]
|
||||
|
||||
Returns:
|
||||
[array]: [旋转后的图像]
|
||||
"""
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
# if type(angel) is not int:
|
||||
# angel = 0
|
||||
M = cv2.getRotationMatrix2D(center, -angel, scale)
|
||||
# 调整旋转后的图像长宽
|
||||
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
|
||||
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
|
||||
M[0, 2] += (rotated_w - w) // 2
|
||||
M[1, 2] += (rotated_h - h) // 2
|
||||
# 旋转图像
|
||||
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
|
||||
|
||||
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
|
||||
# return rotated_img, (0, 0)
|
||||
|
||||
@staticmethod
|
||||
def rotate_crop_image(img, angle, crop):
|
||||
"""
|
||||
angle: 旋转的角度
|
||||
crop: 是否需要进行裁剪,布尔向量
|
||||
"""
|
||||
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
|
||||
w, h = img.shape[:2]
|
||||
# 旋转角度的周期是360°
|
||||
angle %= 360
|
||||
# 计算仿射变换矩阵
|
||||
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
||||
# 得到旋转后的图像
|
||||
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
|
||||
|
||||
# 如果需要去除黑边
|
||||
if crop:
|
||||
# 裁剪角度的等效周期是180°
|
||||
angle_crop = angle % 180
|
||||
if angle > 90:
|
||||
angle_crop = 180 - angle_crop
|
||||
# 转化角度为弧度
|
||||
theta = angle_crop * np.pi / 180
|
||||
# 计算高宽比
|
||||
hw_ratio = float(h) / float(w)
|
||||
# 计算裁剪边长系数的分子项
|
||||
tan_theta = np.tan(theta)
|
||||
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
|
||||
|
||||
# 计算分母中和高宽比相关的项
|
||||
r = hw_ratio if h > w else 1 / hw_ratio
|
||||
# 计算分母项
|
||||
denominator = r * tan_theta + 1
|
||||
# 最终的边长系数
|
||||
crop_mult = numerator / denominator
|
||||
|
||||
# 得到裁剪区域
|
||||
w_crop = int(crop_mult * w)
|
||||
h_crop = int(crop_mult * h)
|
||||
x0 = int((w - w_crop) / 2)
|
||||
y0 = int((h - h_crop) / 2)
|
||||
|
||||
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
|
||||
|
||||
return img_rotated
|
||||
|
||||
def read_image(self, image_url):
|
||||
image = oss_get_image(oss_client=self.minio_client, bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
image = Image.fromarray(image_rgb)
|
||||
image_mode = "RGBA"
|
||||
else:
|
||||
image_mode = "RGB"
|
||||
return image, image_mode
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop(img, target_width, target_height):
|
||||
# 获取原始图像的尺寸
|
||||
original_height, original_width = img.shape[:2]
|
||||
|
||||
# 计算目标尺寸的宽高比
|
||||
target_ratio = target_width / target_height
|
||||
|
||||
# 计算原始图像的宽高比
|
||||
original_ratio = original_width / original_height
|
||||
|
||||
# 调整尺寸
|
||||
if original_ratio > target_ratio:
|
||||
# 原始图像更宽,按高度resize,然后裁剪宽度
|
||||
new_height = target_height
|
||||
new_width = int(original_width * (target_height / original_height))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪宽度
|
||||
start_x = (new_width - target_width) // 2
|
||||
cropped_img = resized_img[:, start_x:start_x + target_width]
|
||||
else:
|
||||
# 原始图像更高,按宽度resize,然后裁剪高度
|
||||
new_width = target_width
|
||||
new_height = int(original_height * (target_width / original_width))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪高度
|
||||
start_y = (new_height - target_height) // 2
|
||||
cropped_img = resized_img[start_y:start_y + target_height, :]
|
||||
|
||||
return cropped_img
|
||||
49
app/service/design_batch/pipeline/scale.py
Normal file
49
app/service/design_batch/pipeline/scale.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
class Scaling:
|
||||
def __call__(self, result):
|
||||
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
|
||||
# milvus_db_keypoint_cache
|
||||
distance_clo = math.sqrt(
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
|
||||
+
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2
|
||||
)
|
||||
|
||||
distance_bdy = math.sqrt(
|
||||
(int(result['body_point_test'][result['keypoint'] + '_left'][0])
|
||||
-
|
||||
int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1
|
||||
)
|
||||
|
||||
if distance_clo == 0:
|
||||
result['scale'] = 1
|
||||
else:
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'toe':
|
||||
distance_bdy = math.sqrt(
|
||||
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
|
||||
+
|
||||
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
|
||||
)
|
||||
|
||||
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
|
||||
Edge = cv2.Canny(Blur, 10, 200)
|
||||
Edge = cv2.dilate(Edge, None)
|
||||
Edge = cv2.erode(Edge, None)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
|
||||
Max_contour = Contours[0]
|
||||
x, y, w, h = cv2.boundingRect(Max_contour)
|
||||
width = w
|
||||
distance_clo = width
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'hand_point':
|
||||
result['scale'] = result['scale_bag']
|
||||
elif result['keypoint'] == 'ear_point':
|
||||
result['scale'] = result['scale_earrings']
|
||||
return result
|
||||
70
app/service/design_batch/pipeline/segmentation.py
Normal file
70
app/service/design_batch/pipeline/segmentation.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.service.design_batch.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class Segmentation:
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
|
||||
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
|
||||
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
|
||||
# 转换颜色空间为 RGB(OpenCV 默认是 BGR)
|
||||
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
|
||||
|
||||
r, g, b = cv2.split(image_rgb)
|
||||
red_mask = r > g
|
||||
green_mask = g > r
|
||||
|
||||
# 创建红色和绿色掩码
|
||||
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
|
||||
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
else:
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
result['seg_result'] = seg_result
|
||||
if not _:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# 处理前片后片
|
||||
temp_front = seg_result == 1.0
|
||||
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = seg_result == 2.0
|
||||
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"seg_cache/{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"seg_cache/{image_id}.npy"
|
||||
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logger.warning("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"加载失败: {e}")
|
||||
return False, None
|
||||
74
app/service/design_batch/pipeline/split.py
Normal file
74
app/service/design_batch/pipeline/split.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.design_batch.utils.conversion_image import rgb_to_rgba
|
||||
from app.service.design_batch.utils.upload_image import upload_png_mask
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
|
||||
class Split(object):
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
try:
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
front_mask = result['front_mask']
|
||||
back_mask = result['back_mask']
|
||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
||||
rgba_image = cv2.resize(rgba_image, new_size)
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
front_mask = cv2.resize(front_mask, new_size)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
|
||||
|
||||
height, width = front_mask.shape
|
||||
mask_image = np.zeros((height, width, 3))
|
||||
mask_image[front_mask != 0] = [0, 0, 255]
|
||||
|
||||
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
else:
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
result['back_image'] = None
|
||||
result["back_image_url"] = None
|
||||
# result["back_mask_url"] = None
|
||||
# result['back_mask_image'] = None
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
||||
11
app/service/design_batch/service.py
Normal file
11
app/service/design_batch/service.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import json
|
||||
|
||||
from app.service.design_batch.design_batch_celery import batch_design
|
||||
from app.service.design_batch.utils.MQ import publish_status
|
||||
|
||||
|
||||
async def start_design_batch_generate(data, file):
|
||||
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id)
|
||||
print(generate_clothes_task)
|
||||
publish_status(data.tasks_id, "0/100", "")
|
||||
return {"task_id": data.tasks_id}
|
||||
162
app/service/design_batch/test.py
Normal file
162
app/service/design_batch/test.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from app.service.design_batch.design_batch_celery import batch_design
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = {
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
200,
|
||||
241
|
||||
],
|
||||
"hand_point_right": [
|
||||
223,
|
||||
297
|
||||
],
|
||||
"waistband_left": [
|
||||
112,
|
||||
241
|
||||
],
|
||||
"hand_point_left": [
|
||||
92,
|
||||
305
|
||||
],
|
||||
"shoulder_left": [
|
||||
99,
|
||||
116
|
||||
],
|
||||
"shoulder_right": [
|
||||
215,
|
||||
116
|
||||
]
|
||||
},
|
||||
"layer_order": True,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": True,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 270372,
|
||||
"color": "30 28 28",
|
||||
"image_id": 69780,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 10,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"businessId": 270373,
|
||||
"color": "30 28 28",
|
||||
"image_id": 98243,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 11,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 270374,
|
||||
"color": "172 68 68",
|
||||
"image_id": 98244,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 12,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
|
||||
"image_id": 96090,
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"process_id": "83"
|
||||
}
|
||||
task_id = 1
|
||||
json_name = "test.json"
|
||||
batch_design.delay(data['objects'], task_id, json_name)
|
||||
17
app/service/design_batch/utils/MQ.py
Normal file
17
app/service/design_batch/utils/MQ.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import json
|
||||
|
||||
import pika
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213'))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue='DesignBatch', durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key='DesignBatch',
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
31
app/service/design_batch/utils/conversion_image.py
Normal file
31
app/service/design_batch/utils/conversion_image.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :conversion_image.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/21 10:40:29
|
||||
@detail :
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
# def rgb_to_rgba(rgb_size, rgb_image, mask):
|
||||
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
|
||||
# # 创建四通道的结果图像
|
||||
# rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
# alpha_channel = np.where(mask > 0, 255, 0)
|
||||
# # 更新RGBA图像的透明度通道
|
||||
# rgba_image[:, :, 3] = alpha_channel
|
||||
# return rgba_image
|
||||
|
||||
def rgb_to_rgba(rgb_image, mask):
|
||||
# 创建全透明的alpha通道
|
||||
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
|
||||
# 合并RGB图像和alpha通道
|
||||
rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
return rgba_image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = open("")
|
||||
143
app/service/design_batch/utils/design_ensemble.py
Normal file
143
app/service/design_batch/utils/design_ensemble.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :design_ensemble.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/16 19:36:21
|
||||
@detail :发起请求 获取推理结果
|
||||
"""
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
"""
|
||||
keypoint
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
def keypoint_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
img_scale = (256, 256)
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.resize(img, img_scale)
|
||||
w_scale = img_scale[0] / w
|
||||
h_scale = img_scale[1] / h
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, (w_scale, h_scale)
|
||||
|
||||
|
||||
# @ RunTime
|
||||
# 推理
|
||||
def get_keypoint_result(image, site):
|
||||
keypoint_result = None
|
||||
try:
|
||||
image, scale_factor = keypoint_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
transformed_img = image.astype(np.float32)
|
||||
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
|
||||
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output'))
|
||||
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
|
||||
except Exception as e:
|
||||
logging.warning(f"get_keypoint_result : {e}")
|
||||
return keypoint_result
|
||||
|
||||
|
||||
def keypoint_postprocess(output, scale_factor):
|
||||
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
|
||||
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
|
||||
segment_result = max_coords.numpy()
|
||||
scale_factor = [1 / x for x in scale_factor[::-1]]
|
||||
scale_matrix = np.diag(scale_factor)
|
||||
nan = np.isinf(scale_matrix)
|
||||
scale_matrix[nan] = 0
|
||||
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||||
|
||||
|
||||
"""
|
||||
seg
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
# KNet
|
||||
def seg_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale_w, img_scale_h = ori_shape
|
||||
if ori_shape[0] > 1024:
|
||||
img_scale_w = 1024
|
||||
if ori_shape[1] > 1024:
|
||||
img_scale_h = 1024
|
||||
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||||
if ori_shape != (img_scale_w, img_scale_h):
|
||||
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||||
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, ori_shape
|
||||
|
||||
|
||||
# @ RunTime
|
||||
def get_seg_result(image_id, image):
|
||||
image, ori_shape = seg_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||||
transformed_img = image.astype(np.float32)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||
]
|
||||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||||
return seg_result
|
||||
|
||||
|
||||
# no cache
|
||||
def seg_postprocess(image_id, output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
|
||||
def key_point_show(image_path, key_point_result=None):
|
||||
img = cv2.imread(image_path)
|
||||
points_list = key_point_result
|
||||
point_size = 1
|
||||
point_color = (0, 0, 255) # BGR
|
||||
thickness = 4 # 可以为 0 、4、8
|
||||
for point in points_list:
|
||||
cv2.circle(img, point[::-1], point_size, point_color, thickness)
|
||||
cv2.imshow("0", img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
|
||||
a = get_keypoint_result(image, "up")
|
||||
new_list = []
|
||||
print(list)
|
||||
for i in a[0]:
|
||||
new_list.append((int(i[0]), int(i[1])))
|
||||
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
|
||||
# a = get_seg_result(1, image)
|
||||
print(a)
|
||||
77
app/service/design_batch/utils/organize.py
Normal file
77
app/service/design_batch/utils/organize.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import cv2
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
|
||||
|
||||
def organize_body(layer):
|
||||
body_layer = dict(priority=0,
|
||||
name=layer["name"].lower(),
|
||||
image=layer['body_image'],
|
||||
image_url=layer['body_path'],
|
||||
mask_image=None,
|
||||
mask_url=None,
|
||||
sacle=1,
|
||||
# mask=layer['body_mask'],
|
||||
position=(0, 0))
|
||||
return body_layer
|
||||
|
||||
|
||||
def organize_clothing(layer):
|
||||
# 起始坐标
|
||||
start_point = calculate_start_point(layer['keypoint'], layer['scale'], layer['clothes_keypoint'], layer['body_point_test'], layer["offset"], layer["resize_scale"])
|
||||
# 前片数据
|
||||
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
|
||||
name=f'{layer["name"].lower()}_front',
|
||||
image=layer["front_image"],
|
||||
# mask_image=layer['front_mask_image'],
|
||||
image_url=layer['front_image_url'],
|
||||
mask_url=layer['mask_url'],
|
||||
sacle=layer['scale'],
|
||||
clothes_keypoint=layer['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=layer["resize_scale"],
|
||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||
pattern_image_url=layer['pattern_image_url'],
|
||||
pattern_image=layer['pattern_image']
|
||||
|
||||
)
|
||||
# 后片数据
|
||||
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
|
||||
name=f'{layer["name"].lower()}_back',
|
||||
image=layer["back_image"],
|
||||
# mask_image=layer['back_mask_image'],
|
||||
image_url=layer['back_image_url'],
|
||||
mask_url=layer['mask_url'],
|
||||
sacle=layer['scale'],
|
||||
clothes_keypoint=layer['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=layer["resize_scale"],
|
||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||
pattern_image_url=layer['pattern_image_url'],
|
||||
)
|
||||
return front_layer, back_layer
|
||||
|
||||
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
|
||||
"""
|
||||
Align left
|
||||
Args:
|
||||
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale + offset
|
||||
y' = x_body - x1 * scale + offset
|
||||
|
||||
"""
|
||||
side_indicator = f'{keypoint_type}_left'
|
||||
start_point = (
|
||||
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
|
||||
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
|
||||
)
|
||||
return start_point
|
||||
30
app/service/design_batch/utils/progress.py
Normal file
30
app/service/design_batch/utils/progress.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from app.service.design_fast.utils.redis_utils import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_progress(process_id, total):
|
||||
# logger.info(f"{process_id} , {total}")
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
if progress and total != 1:
|
||||
if int(progress) <= 100:
|
||||
r.write(key=process_id, value=int(progress) + int(100 / total))
|
||||
else:
|
||||
r.write(key=process_id, value=99)
|
||||
return progress
|
||||
elif total == 1:
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
else:
|
||||
r.write(key=process_id, value=int(100 / total))
|
||||
return progress
|
||||
|
||||
|
||||
def final_progress(process_id):
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
99
app/service/design_batch/utils/redis_utils.py
Normal file
99
app/service/design_batch/utils/redis_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
13
app/service/design_batch/utils/save_json.py
Normal file
13
app/service/design_batch/utils/save_json.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def oss_upload_json(oss_client, json_data, object_name):
|
||||
try:
|
||||
with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file:
|
||||
json.dump(json_data, file, indent=4)
|
||||
oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}")
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
197
app/service/design_batch/utils/synthesis_item.py
Normal file
197
app/service/design_batch/utils/synthesis_item.py
Normal file
@@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :synthesis_item.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/26 14:13:04
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
def positioning(all_mask_shape, mask_shape, offset):
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
mask_start = 0
|
||||
mask_end = 0
|
||||
if offset == 0:
|
||||
all_start = 0
|
||||
all_end = min(all_mask_shape, mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = min(all_mask_shape, mask_shape)
|
||||
elif offset > 0:
|
||||
all_start = min(offset, all_mask_shape)
|
||||
all_end = min(offset + mask_shape, all_mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
|
||||
elif offset < 0:
|
||||
if abs(offset) > mask_shape:
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
else:
|
||||
all_start = 0
|
||||
if mask_shape - abs(offset) > all_mask_shape:
|
||||
all_end = min(mask_shape - abs(offset), all_mask_shape)
|
||||
else:
|
||||
all_end = mask_shape - abs(offset)
|
||||
|
||||
if abs(offset) > mask_shape:
|
||||
mask_start = mask_shape
|
||||
mask_end = mask_shape
|
||||
else:
|
||||
mask_start = abs(offset)
|
||||
if mask_shape - abs(offset) >= all_mask_shape:
|
||||
mask_end = all_mask_shape + abs(offset)
|
||||
else:
|
||||
mask_end = mask_shape
|
||||
return all_start, all_end, mask_start, mask_end
|
||||
|
||||
|
||||
# @RunTime
|
||||
def synthesis(data, size, basic_info):
|
||||
# 创建底图
|
||||
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
try:
|
||||
all_mask_shape = (size[1], size[0])
|
||||
body_mask = None
|
||||
for d in data:
|
||||
if d['name'] == 'body' or d['name'] == 'mannequin':
|
||||
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
|
||||
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
|
||||
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position
|
||||
body_mask = np.array(transparent_image.split()[3])
|
||||
|
||||
# 根据新的坐标获取新的肩点
|
||||
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
|
||||
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
|
||||
top_outer_mask = np.array(binary_body_mask)
|
||||
bottom_outer_mask = np.array(binary_body_mask)
|
||||
|
||||
top = True
|
||||
bottom = True
|
||||
i = len(data)
|
||||
while i:
|
||||
i -= 1
|
||||
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
||||
top = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
top_outer_mask = background + top_outer_mask
|
||||
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
|
||||
bottom = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
bottom_outer_mask = background + bottom_outer_mask
|
||||
elif bottom is False and top is False:
|
||||
break
|
||||
|
||||
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
|
||||
|
||||
for layer in data:
|
||||
if layer['image'] is not None:
|
||||
if layer['name'] != "body":
|
||||
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||
mask_alpha = Image.fromarray(mask_data)
|
||||
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
|
||||
else:
|
||||
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
|
||||
result_image = base_image
|
||||
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
|
||||
# oss upload
|
||||
image_bytes = image_data.read()
|
||||
bucket_name = "aida-results"
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"synthesis runtime exception : {e}")
|
||||
|
||||
|
||||
def synthesis_single(front_image, back_image):
|
||||
result_image = None
|
||||
if front_image:
|
||||
result_image = front_image
|
||||
if back_image:
|
||||
result_image.paste(back_image, (0, 0), back_image)
|
||||
|
||||
# with io.BytesIO() as output:
|
||||
# result_image.save(output, format='PNG')
|
||||
# data = output.getvalue()
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
# oss upload
|
||||
bucket_name = 'aida-results'
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
|
||||
|
||||
def update_base_size_priority(layers, size):
|
||||
# 计算透明背景图片的宽度
|
||||
min_x = min(info['position'][1] for info in layers)
|
||||
x_list = []
|
||||
for info in layers:
|
||||
if info['image'] is not None:
|
||||
x_list.append(info['position'][1] + info['image'].width)
|
||||
max_x = max(x_list)
|
||||
new_width = max_x - min_x
|
||||
new_height = 700
|
||||
# 更新坐标
|
||||
for info in layers:
|
||||
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
||||
return layers, (new_width, new_height)
|
||||
@@ -13,11 +13,11 @@ import logging
|
||||
import cv2
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
|
||||
# @RunTime
|
||||
def upload_png_mask(front_image, object_name, mask=None):
|
||||
def upload_png_mask(minio_client, front_image, object_name, mask=None):
|
||||
try:
|
||||
mask_url = None
|
||||
if mask is not None:
|
||||
@@ -25,20 +25,14 @@ def upload_png_mask(front_image, object_name, mask=None):
|
||||
# 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明
|
||||
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
|
||||
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
|
||||
# image_bytes = io.BytesIO()
|
||||
# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
|
||||
# image_bytes.seek(0)
|
||||
# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}"
|
||||
# oss upload ####################
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
|
||||
|
||||
image_data = io.BytesIO()
|
||||
front_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
|
||||
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
|
||||
return front_image, image_url, mask_url
|
||||
except Exception as e:
|
||||
1465
app/service/design_fast/design_generate.py
Normal file
1465
app/service/design_fast/design_generate.py
Normal file
File diff suppressed because it is too large
Load Diff
61
app/service/design_fast/item.py
Normal file
61
app/service/design_fast/item.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
|
||||
|
||||
|
||||
class BaseItem:
|
||||
def __init__(self, data, basic):
|
||||
self.result = data.copy()
|
||||
self.result['name'] = data['type'].lower()
|
||||
self.result.pop("type")
|
||||
self.result.update(basic)
|
||||
|
||||
|
||||
class TopItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.top_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
Segmentation(minio_client),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.top_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
class BottomItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.bottom_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
ContourDetection(),
|
||||
# Segmentation(),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.bottom_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
class BodyItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.top_pipeline = [
|
||||
LoadBodyImage(minio_client),
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.top_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
20
app/service/design_fast/pipeline/__init__.py
Normal file
20
app/service/design_fast/pipeline/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .color import Color
|
||||
from .contour_detection import ContourDetection
|
||||
from .keypoint import KeyPoint
|
||||
from .keypoint import KeyPoint
|
||||
from .loading import LoadImage, LoadBodyImage
|
||||
from .print_painting import PrintPainting
|
||||
from .scale import Scaling
|
||||
from .segmentation import Segmentation
|
||||
from .split import Split
|
||||
|
||||
__all__ = [
|
||||
'LoadBodyImage', 'LoadImage',
|
||||
'KeyPoint',
|
||||
'ContourDetection',
|
||||
'Segmentation',
|
||||
'Color',
|
||||
'PrintPainting',
|
||||
'Scaling',
|
||||
'Split'
|
||||
]
|
||||
62
app/service/design_fast/pipeline/color.py
Normal file
62
app/service/design_fast/pipeline/color.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class Color:
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
dim_image_h, dim_image_w = result['image'].shape[0:2]
|
||||
if "gradient" in result.keys() and result['gradient'] != "":
|
||||
bucket_name = result['gradient'].split('/')[0]
|
||||
object_name = result['gradient'][result['gradient'].find('/') + 1:]
|
||||
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
pattern = self.get_pattern(result['color'])
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
result['alpha'] = 100 / 255.0
|
||||
return result
|
||||
|
||||
def get_gradient(self, bucket_name, object_name):
|
||||
# 获取渐变色图案
|
||||
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def crop_image(image, image_size_h, image_size_w):
|
||||
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
|
||||
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
return pattern
|
||||
37
app/service/design_fast/pipeline/contour_detection.py
Normal file
37
app/service/design_fast/pipeline/contour_detection.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ContourDetection:
|
||||
def __call__(self, result):
|
||||
Contour = self.get_contours(result['image'])
|
||||
Mask = np.zeros(result['image'].shape[:2], np.uint8)
|
||||
if len(Contour):
|
||||
Max_contour = Contour[0]
|
||||
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
|
||||
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
|
||||
cv2.drawContours(Mask, [Approx], -1, 255, -1)
|
||||
else:
|
||||
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
|
||||
# TODO 修复部分图片出现透明的情况 下版本上线
|
||||
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
|
||||
# Mask = cv2.bitwise_not(Mask)
|
||||
if result['pre_mask'] is None:
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
result['front_mask'] = result['mask']
|
||||
result['back_mask'] = result['mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_contours(image):
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
Edge = cv2.Canny(gray, 10, 150)
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
|
||||
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
return Contour
|
||||
116
app/service/design_fast/pipeline/keypoint.py
Normal file
116
app/service/design_fast/pipeline/keypoint.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.utils.decorator import ClassCallRunTime, RunTime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyPoint:
|
||||
name = "KeyPoint"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
return cls.name
|
||||
|
||||
@ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
# keypoint_cache = self.keypoint_cache(result, site)
|
||||
keypoint_cache = False
|
||||
# 取消向量查询 直接过模型推理
|
||||
if keypoint_cache is False:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
else:
|
||||
result['clothes_keypoint'] = keypoint_cache
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def infer_keypoint_result(result):
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
|
||||
return keypoint_infer_result, site
|
||||
|
||||
@staticmethod
|
||||
def save_keypoint_cache(keypoint_id, cache, site):
|
||||
if site == "down":
|
||||
zeros = np.zeros(20, dtype=int)
|
||||
result = np.concatenate([zeros, cache.flatten()])
|
||||
else:
|
||||
zeros = np.zeros(4, dtype=int)
|
||||
result = np.concatenate([cache.flatten(), zeros])
|
||||
# 取消向量保存 直接拿结果
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": site,
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
client.close()
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logger.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
@staticmethod
|
||||
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
||||
if site == "up":
|
||||
# 需要的是up 即推理出来的是up 那么查询的就是down
|
||||
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||
else:
|
||||
# 需要的是down 即推理出来的是down 那么查询的就是up
|
||||
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": "all",
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
client.upsert(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
data=data
|
||||
)
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logger.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
@RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
keypoint_id = result['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
# ids=[keypoint_id],
|
||||
filter=f"keypoint_id == {keypoint_id}",
|
||||
output_fields=['keypoint_vector', 'keypoint_site']
|
||||
)
|
||||
if len(res) == 0:
|
||||
# 没有结果 直接推理拿结果 并保存
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
||||
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
||||
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||
elif res[0]["keypoint_site"] != site:
|
||||
# 需要的类型和查询到的不一致,则更新类型为all
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
||||
except Exception as e:
|
||||
logger.info(f"search keypoint cache milvus error {e}")
|
||||
return False
|
||||
80
app/service/design_fast/pipeline/loading.py
Normal file
80
app/service/design_fast/pipeline/loading.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class LoadBodyImage:
|
||||
name = "LoadBodyImage"
|
||||
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
return cls.name
|
||||
|
||||
def __call__(self, result):
|
||||
result["name"] = "mannequin"
|
||||
result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL")
|
||||
return result
|
||||
|
||||
|
||||
class LoadImage:
|
||||
name = "LoadImage"
|
||||
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
return cls.name
|
||||
|
||||
def __call__(self, result):
|
||||
result['image'], result['pre_mask'] = self.read_image(result['path'])
|
||||
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
result['keypoint'] = self.get_keypoint(result['name'])
|
||||
result['img_shape'] = result['image'].shape
|
||||
result['ori_shape'] = result['image'].shape
|
||||
return result
|
||||
|
||||
def read_image(self, image_path):
|
||||
image_mask = None
|
||||
image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
if image.shape[2] == 4: # 如果是四通道 mask
|
||||
image_mask = image[:, :, 3]
|
||||
image = image[:, :, :3]
|
||||
|
||||
if image.shape[:2] <= (50, 50):
|
||||
# 计算新尺寸
|
||||
new_size = (image.shape[1] * 2, image.shape[0] * 2)
|
||||
# 调整大小
|
||||
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
|
||||
return image, image_mask
|
||||
|
||||
@staticmethod
|
||||
def get_keypoint(name):
|
||||
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
|
||||
keypoint = 'shoulder'
|
||||
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
|
||||
keypoint = 'waistband'
|
||||
elif name == 'bag':
|
||||
keypoint = 'hand_point'
|
||||
elif name == 'shoes':
|
||||
keypoint = 'toe'
|
||||
elif name == 'hairstyle':
|
||||
keypoint = 'head_point'
|
||||
elif name == 'earring':
|
||||
keypoint = 'ear_point'
|
||||
else:
|
||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||
f"bag, shoes, hairstyle, earring.")
|
||||
return keypoint
|
||||
524
app/service/design_fast/pipeline/print_painting.py
Normal file
524
app/service/design_fast/pipeline/print_painting.py
Normal file
@@ -0,0 +1,524 @@
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
|
||||
class PrintPainting:
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
single_print = result['print']['single']
|
||||
overall_print = result['print']['overall']
|
||||
element_print = result['print']['element']
|
||||
result['single_image'] = None
|
||||
result['print_image'] = None
|
||||
if overall_print['print_path_list']:
|
||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||
result['print_image'] = result['pattern_image']
|
||||
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
|
||||
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||
|
||||
# resize 到sketch大小
|
||||
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
else:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
|
||||
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
|
||||
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
|
||||
|
||||
if single_print['print_path_list']:
|
||||
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(single_print['print_path_list'])):
|
||||
image, image_mode = self.read_image(single_print['print_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
|
||||
if element_print['element_path_list']:
|
||||
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(element_print['element_path_list'])):
|
||||
image, image_mode = self.read_image(element_print['element_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
# TODO element 丢失信息
|
||||
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
||||
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
|
||||
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
|
||||
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
|
||||
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
|
||||
mask_inv = cv2.bitwise_not(mask_)
|
||||
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
|
||||
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
|
||||
print_background = img1_bg + img2_fg
|
||||
return print_background
|
||||
|
||||
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
|
||||
if print_trigger:
|
||||
print_ = self.get_print(print_dict)
|
||||
painting_dict['Trigger'] = not is_single
|
||||
painting_dict['location'] = print_['location']
|
||||
single_mask_inv_print = self.get_mask_inv(print_['image'])
|
||||
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||
if not is_single:
|
||||
self.random_seed = random.randint(0, 1000)
|
||||
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
||||
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
|
||||
return painting_dict
|
||||
|
||||
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
||||
tile = None
|
||||
if not trigger:
|
||||
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
if len(pattern.shape) == 2:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
|
||||
if len(pattern.shape) == 3:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
|
||||
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
|
||||
return tile
|
||||
|
||||
def get_mask_inv(self, print_):
|
||||
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
|
||||
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
mask_inv = cv2.inRange(print_tile, lower, upper)
|
||||
return mask_inv
|
||||
else:
|
||||
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
|
||||
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
|
||||
return mask_inv
|
||||
|
||||
@staticmethod
|
||||
def printpaint(result, painting_dict, print_=False):
|
||||
|
||||
if print_ and painting_dict['Trigger']:
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
||||
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
||||
else:
|
||||
print_mask = result['mask']
|
||||
img_fg = result['final_image']
|
||||
if print_ and not painting_dict['Trigger']:
|
||||
index_ = None
|
||||
try:
|
||||
index_ = len(painting_dict['location'])
|
||||
except:
|
||||
assert f'there must be parameter of location if choose IfSingle'
|
||||
|
||||
for i in range(index_):
|
||||
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
||||
|
||||
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
||||
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
||||
|
||||
change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
||||
# problem in change_mask
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
print_image = cv2.add(img_bg, img_fg)
|
||||
return print_image
|
||||
|
||||
def get_print(self, print_dict):
|
||||
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
|
||||
print_dict['scale'] = 0.3
|
||||
else:
|
||||
print_dict['scale'] = print_dict['print_scale_list'][0]
|
||||
|
||||
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
|
||||
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
|
||||
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
|
||||
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
||||
if image.mode == "RGBA":
|
||||
new_background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
new_background.paste(image, mask=image.split()[3])
|
||||
image = new_background
|
||||
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||
return print_dict
|
||||
|
||||
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
|
||||
print_w = print_shape[1]
|
||||
print_h = print_shape[0]
|
||||
|
||||
random.seed(self.random_seed)
|
||||
# logging.info(f'overall print location : {location}')
|
||||
# x_offset = random.randint(0, image.shape[0] - image_size_h)
|
||||
# y_offset = random.randint(0, image.shape[1] - image_size_w)
|
||||
|
||||
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||
x_offset = print_w - int(location[0][1] % print_w)
|
||||
y_offset = print_w - int(location[0][0] % print_h)
|
||||
|
||||
# y_offset = int(location[0][0])
|
||||
# x_offset = int(location[0][1])
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||
elif len(image.shape) == 3:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_low_high_lab(Lab_value, L=False):
|
||||
if L:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
else:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
return high, low
|
||||
|
||||
@staticmethod
|
||||
def img_rotate(image, angel, scale):
|
||||
"""顺时针旋转图像任意角度
|
||||
|
||||
Args:
|
||||
image (np.array): [原始图像]
|
||||
angel (float): [逆时针旋转的角度]
|
||||
|
||||
Returns:
|
||||
[array]: [旋转后的图像]
|
||||
"""
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
# if type(angel) is not int:
|
||||
# angel = 0
|
||||
M = cv2.getRotationMatrix2D(center, -angel, scale)
|
||||
# 调整旋转后的图像长宽
|
||||
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
|
||||
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
|
||||
M[0, 2] += (rotated_w - w) // 2
|
||||
M[1, 2] += (rotated_h - h) // 2
|
||||
# 旋转图像
|
||||
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
|
||||
|
||||
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
|
||||
# return rotated_img, (0, 0)
|
||||
|
||||
@staticmethod
|
||||
def rotate_crop_image(img, angle, crop):
|
||||
"""
|
||||
angle: 旋转的角度
|
||||
crop: 是否需要进行裁剪,布尔向量
|
||||
"""
|
||||
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
|
||||
w, h = img.shape[:2]
|
||||
# 旋转角度的周期是360°
|
||||
angle %= 360
|
||||
# 计算仿射变换矩阵
|
||||
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
||||
# 得到旋转后的图像
|
||||
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
|
||||
|
||||
# 如果需要去除黑边
|
||||
if crop:
|
||||
# 裁剪角度的等效周期是180°
|
||||
angle_crop = angle % 180
|
||||
if angle > 90:
|
||||
angle_crop = 180 - angle_crop
|
||||
# 转化角度为弧度
|
||||
theta = angle_crop * np.pi / 180
|
||||
# 计算高宽比
|
||||
hw_ratio = float(h) / float(w)
|
||||
# 计算裁剪边长系数的分子项
|
||||
tan_theta = np.tan(theta)
|
||||
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
|
||||
|
||||
# 计算分母中和高宽比相关的项
|
||||
r = hw_ratio if h > w else 1 / hw_ratio
|
||||
# 计算分母项
|
||||
denominator = r * tan_theta + 1
|
||||
# 最终的边长系数
|
||||
crop_mult = numerator / denominator
|
||||
|
||||
# 得到裁剪区域
|
||||
w_crop = int(crop_mult * w)
|
||||
h_crop = int(crop_mult * h)
|
||||
x0 = int((w - w_crop) / 2)
|
||||
y0 = int((h - h_crop) / 2)
|
||||
|
||||
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
|
||||
|
||||
return img_rotated
|
||||
|
||||
def read_image(self, image_url):
|
||||
image = oss_get_image(oss_client=self.minio_client, bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
image = Image.fromarray(image_rgb)
|
||||
image_mode = "RGBA"
|
||||
else:
|
||||
image_mode = "RGB"
|
||||
return image, image_mode
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop(img, target_width, target_height):
|
||||
# 获取原始图像的尺寸
|
||||
original_height, original_width = img.shape[:2]
|
||||
|
||||
# 计算目标尺寸的宽高比
|
||||
target_ratio = target_width / target_height
|
||||
|
||||
# 计算原始图像的宽高比
|
||||
original_ratio = original_width / original_height
|
||||
|
||||
# 调整尺寸
|
||||
if original_ratio > target_ratio:
|
||||
# 原始图像更宽,按高度resize,然后裁剪宽度
|
||||
new_height = target_height
|
||||
new_width = int(original_width * (target_height / original_height))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪宽度
|
||||
start_x = (new_width - target_width) // 2
|
||||
cropped_img = resized_img[:, start_x:start_x + target_width]
|
||||
else:
|
||||
# 原始图像更高,按宽度resize,然后裁剪高度
|
||||
new_width = target_width
|
||||
new_height = int(original_height * (target_width / original_width))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪高度
|
||||
start_y = (new_height - target_height) // 2
|
||||
cropped_img = resized_img[start_y:start_y + target_height, :]
|
||||
|
||||
return cropped_img
|
||||
49
app/service/design_fast/pipeline/scale.py
Normal file
49
app/service/design_fast/pipeline/scale.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
class Scaling:
|
||||
def __call__(self, result):
|
||||
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
|
||||
# milvus_db_keypoint_cache
|
||||
distance_clo = math.sqrt(
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
|
||||
+
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2
|
||||
)
|
||||
|
||||
distance_bdy = math.sqrt(
|
||||
(int(result['body_point_test'][result['keypoint'] + '_left'][0])
|
||||
-
|
||||
int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1
|
||||
)
|
||||
|
||||
if distance_clo == 0:
|
||||
result['scale'] = 1
|
||||
else:
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'toe':
|
||||
distance_bdy = math.sqrt(
|
||||
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
|
||||
+
|
||||
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
|
||||
)
|
||||
|
||||
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
|
||||
Edge = cv2.Canny(Blur, 10, 200)
|
||||
Edge = cv2.dilate(Edge, None)
|
||||
Edge = cv2.erode(Edge, None)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
|
||||
Max_contour = Contours[0]
|
||||
x, y, w, h = cv2.boundingRect(Max_contour)
|
||||
width = w
|
||||
distance_clo = width
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'hand_point':
|
||||
result['scale'] = result['scale_bag']
|
||||
elif result['keypoint'] == 'ear_point':
|
||||
result['scale'] = result['scale_earrings']
|
||||
return result
|
||||
85
app/service/design_fast/pipeline/segmentation.py
Normal file
85
app/service/design_fast/pipeline/segmentation.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class Segmentation:
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
@ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
|
||||
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
|
||||
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
|
||||
# 转换颜色空间为 RGB(OpenCV 默认是 BGR)
|
||||
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
|
||||
|
||||
r, g, b = cv2.split(image_rgb)
|
||||
red_mask = r > g
|
||||
green_mask = g > r
|
||||
|
||||
# 创建红色和绿色掩码
|
||||
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
|
||||
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
else:
|
||||
# preview 过模型 不缓存
|
||||
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
# submit 过模型 缓存
|
||||
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# null 正常流程 加载本地缓存 无缓存则过模型
|
||||
else:
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
# 判断缓存和实际图片size是否相同
|
||||
if not _ or result["image"].shape[:2] != seg_result.shape:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
result['seg_result'] = seg_result
|
||||
|
||||
# 处理前片后片
|
||||
temp_front = seg_result == 1.0
|
||||
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = seg_result == 2.0
|
||||
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logger.warning("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"加载失败: {e}")
|
||||
return False, None
|
||||
74
app/service/design_fast/pipeline/split.py
Normal file
74
app/service/design_fast/pipeline/split.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
|
||||
from app.service.design_fast.utils.upload_image import upload_png_mask
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
|
||||
class Split(object):
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
try:
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
front_mask = result['front_mask']
|
||||
back_mask = result['back_mask']
|
||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
||||
rgba_image = cv2.resize(rgba_image, new_size)
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
front_mask = cv2.resize(front_mask, new_size)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
|
||||
|
||||
height, width = front_mask.shape
|
||||
mask_image = np.zeros((height, width, 3))
|
||||
mask_image[front_mask != 0] = [0, 0, 255]
|
||||
|
||||
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
else:
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
result['back_image'] = None
|
||||
result["back_image_url"] = None
|
||||
# result["back_mask_url"] = None
|
||||
# result['back_mask_image'] = None
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
||||
31
app/service/design_fast/utils/conversion_image.py
Normal file
31
app/service/design_fast/utils/conversion_image.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :conversion_image.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/21 10:40:29
|
||||
@detail :
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
# def rgb_to_rgba(rgb_size, rgb_image, mask):
|
||||
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
|
||||
# # 创建四通道的结果图像
|
||||
# rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
# alpha_channel = np.where(mask > 0, 255, 0)
|
||||
# # 更新RGBA图像的透明度通道
|
||||
# rgba_image[:, :, 3] = alpha_channel
|
||||
# return rgba_image
|
||||
|
||||
def rgb_to_rgba(rgb_image, mask):
|
||||
# 创建全透明的alpha通道
|
||||
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
|
||||
# 合并RGB图像和alpha通道
|
||||
rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
return rgba_image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = open("")
|
||||
143
app/service/design_fast/utils/design_ensemble.py
Normal file
143
app/service/design_fast/utils/design_ensemble.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :design_ensemble.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/16 19:36:21
|
||||
@detail :发起请求 获取推理结果
|
||||
"""
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
"""
|
||||
keypoint
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
def keypoint_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
img_scale = (256, 256)
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.resize(img, img_scale)
|
||||
w_scale = img_scale[0] / w
|
||||
h_scale = img_scale[1] / h
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, (w_scale, h_scale)
|
||||
|
||||
|
||||
# @ RunTime
|
||||
# 推理
|
||||
def get_keypoint_result(image, site):
|
||||
keypoint_result = None
|
||||
try:
|
||||
image, scale_factor = keypoint_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
transformed_img = image.astype(np.float32)
|
||||
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
|
||||
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output'))
|
||||
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
|
||||
except Exception as e:
|
||||
logging.warning(f"get_keypoint_result : {e}")
|
||||
return keypoint_result
|
||||
|
||||
|
||||
def keypoint_postprocess(output, scale_factor):
|
||||
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
|
||||
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
|
||||
segment_result = max_coords.numpy()
|
||||
scale_factor = [1 / x for x in scale_factor[::-1]]
|
||||
scale_matrix = np.diag(scale_factor)
|
||||
nan = np.isinf(scale_matrix)
|
||||
scale_matrix[nan] = 0
|
||||
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||||
|
||||
|
||||
"""
|
||||
seg
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
# KNet
|
||||
def seg_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale_w, img_scale_h = ori_shape
|
||||
if ori_shape[0] > 1024:
|
||||
img_scale_w = 1024
|
||||
if ori_shape[1] > 1024:
|
||||
img_scale_h = 1024
|
||||
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||||
if ori_shape != (img_scale_w, img_scale_h):
|
||||
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||||
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, ori_shape
|
||||
|
||||
|
||||
# @ RunTime
|
||||
def get_seg_result(image_id, image):
|
||||
image, ori_shape = seg_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||||
transformed_img = image.astype(np.float32)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||
]
|
||||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||||
return seg_result
|
||||
|
||||
|
||||
# no cache
|
||||
def seg_postprocess(image_id, output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
|
||||
def key_point_show(image_path, key_point_result=None):
|
||||
img = cv2.imread(image_path)
|
||||
points_list = key_point_result
|
||||
point_size = 1
|
||||
point_color = (0, 0, 255) # BGR
|
||||
thickness = 4 # 可以为 0 、4、8
|
||||
for point in points_list:
|
||||
cv2.circle(img, point[::-1], point_size, point_color, thickness)
|
||||
cv2.imshow("0", img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
|
||||
a = get_keypoint_result(image, "up")
|
||||
new_list = []
|
||||
print(list)
|
||||
for i in a[0]:
|
||||
new_list.append((int(i[0]), int(i[1])))
|
||||
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
|
||||
# a = get_seg_result(1, image)
|
||||
print(a)
|
||||
77
app/service/design_fast/utils/organize.py
Normal file
77
app/service/design_fast/utils/organize.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import cv2
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
|
||||
|
||||
def organize_body(layer):
|
||||
body_layer = dict(priority=0,
|
||||
name=layer["name"].lower(),
|
||||
image=layer['body_image'],
|
||||
image_url=layer['body_path'],
|
||||
mask_image=None,
|
||||
mask_url=None,
|
||||
sacle=1,
|
||||
# mask=layer['body_mask'],
|
||||
position=(0, 0))
|
||||
return body_layer
|
||||
|
||||
|
||||
def organize_clothing(layer):
|
||||
# 起始坐标
|
||||
start_point = calculate_start_point(layer['keypoint'], layer['scale'], layer['clothes_keypoint'], layer['body_point_test'], layer["offset"], layer["resize_scale"])
|
||||
# 前片数据
|
||||
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
|
||||
name=f'{layer["name"].lower()}_front',
|
||||
image=layer["front_image"],
|
||||
# mask_image=layer['front_mask_image'],
|
||||
image_url=layer['front_image_url'],
|
||||
mask_url=layer['mask_url'],
|
||||
sacle=layer['scale'],
|
||||
clothes_keypoint=layer['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=layer["resize_scale"],
|
||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||
pattern_image_url=layer['pattern_image_url'],
|
||||
pattern_image=layer['pattern_image']
|
||||
|
||||
)
|
||||
# 后片数据
|
||||
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
|
||||
name=f'{layer["name"].lower()}_back',
|
||||
image=layer["back_image"],
|
||||
# mask_image=layer['back_mask_image'],
|
||||
image_url=layer['back_image_url'],
|
||||
mask_url=layer['mask_url'],
|
||||
sacle=layer['scale'],
|
||||
clothes_keypoint=layer['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=layer["resize_scale"],
|
||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||
pattern_image_url=layer['pattern_image_url'],
|
||||
)
|
||||
return front_layer, back_layer
|
||||
|
||||
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
|
||||
"""
|
||||
Align left
|
||||
Args:
|
||||
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale + offset
|
||||
y' = x_body - x1 * scale + offset
|
||||
|
||||
"""
|
||||
side_indicator = f'{keypoint_type}_left'
|
||||
start_point = (
|
||||
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
|
||||
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
|
||||
)
|
||||
return start_point
|
||||
30
app/service/design_fast/utils/progress.py
Normal file
30
app/service/design_fast/utils/progress.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from app.service.design_fast.utils.redis_utils import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_progress(process_id, total):
|
||||
# logger.info(f"{process_id} , {total}")
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
if progress and total != 1:
|
||||
if int(progress) <= 100:
|
||||
r.write(key=process_id, value=int(progress) + int(100 / total))
|
||||
else:
|
||||
r.write(key=process_id, value=99)
|
||||
return progress
|
||||
elif total == 1:
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
else:
|
||||
r.write(key=process_id, value=int(100 / total))
|
||||
return progress
|
||||
|
||||
|
||||
def final_progress(process_id):
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
99
app/service/design_fast/utils/redis_utils.py
Normal file
99
app/service/design_fast/utils/redis_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
199
app/service/design_fast/utils/synthesis_item.py
Normal file
199
app/service/design_fast/utils/synthesis_item.py
Normal file
@@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :synthesis_item.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/26 14:13:04
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
def positioning(all_mask_shape, mask_shape, offset):
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
mask_start = 0
|
||||
mask_end = 0
|
||||
if offset == 0:
|
||||
all_start = 0
|
||||
all_end = min(all_mask_shape, mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = min(all_mask_shape, mask_shape)
|
||||
elif offset > 0:
|
||||
all_start = min(offset, all_mask_shape)
|
||||
all_end = min(offset + mask_shape, all_mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
|
||||
elif offset < 0:
|
||||
if abs(offset) > mask_shape:
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
else:
|
||||
all_start = 0
|
||||
if mask_shape - abs(offset) > all_mask_shape:
|
||||
all_end = min(mask_shape - abs(offset), all_mask_shape)
|
||||
else:
|
||||
all_end = mask_shape - abs(offset)
|
||||
|
||||
if abs(offset) > mask_shape:
|
||||
mask_start = mask_shape
|
||||
mask_end = mask_shape
|
||||
else:
|
||||
mask_start = abs(offset)
|
||||
if mask_shape - abs(offset) >= all_mask_shape:
|
||||
mask_end = all_mask_shape + abs(offset)
|
||||
else:
|
||||
mask_end = mask_shape
|
||||
return all_start, all_end, mask_start, mask_end
|
||||
|
||||
|
||||
# @RunTime
|
||||
def synthesis(data, size, basic_info):
|
||||
# 创建底图
|
||||
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
try:
|
||||
all_mask_shape = (size[1], size[0])
|
||||
body_mask = None
|
||||
for d in data:
|
||||
if d['name'] == 'body' or d['name'] == 'mannequin':
|
||||
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
|
||||
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
|
||||
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position
|
||||
body_mask = np.array(transparent_image.split()[3])
|
||||
|
||||
# 根据新的坐标获取新的肩点
|
||||
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
|
||||
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
|
||||
top_outer_mask = np.array(binary_body_mask)
|
||||
bottom_outer_mask = np.array(binary_body_mask)
|
||||
|
||||
top = True
|
||||
bottom = True
|
||||
i = len(data)
|
||||
while i:
|
||||
i -= 1
|
||||
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
||||
top = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
top_outer_mask = background + top_outer_mask
|
||||
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
|
||||
bottom = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
bottom_outer_mask = background + bottom_outer_mask
|
||||
elif bottom is False and top is False:
|
||||
break
|
||||
|
||||
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
|
||||
|
||||
for layer in data:
|
||||
if layer['image'] is not None:
|
||||
if layer['name'] != "body":
|
||||
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||
mask_alpha = Image.fromarray(mask_data)
|
||||
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
|
||||
else:
|
||||
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
|
||||
result_image = base_image
|
||||
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
|
||||
# oss upload
|
||||
image_bytes = image_data.read()
|
||||
bucket_name = "aida-results"
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"synthesis runtime exception : {e}")
|
||||
|
||||
|
||||
def synthesis_single(front_image, back_image):
|
||||
result_image = None
|
||||
if front_image:
|
||||
result_image = front_image
|
||||
if back_image:
|
||||
result_image.paste(back_image, (0, 0), back_image)
|
||||
|
||||
# with io.BytesIO() as output:
|
||||
# result_image.save(output, format='PNG')
|
||||
# data = output.getvalue()
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
# oss upload
|
||||
bucket_name = 'aida-results'
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
|
||||
|
||||
def update_base_size_priority(layers, size):
|
||||
# 计算透明背景图片的宽度
|
||||
min_x = min(info['position'][1] for info in layers)
|
||||
x_list = []
|
||||
new_height = 700
|
||||
for info in layers:
|
||||
if info['image'] is not None:
|
||||
x_list.append(info['position'][1] + info['image'].width)
|
||||
if info['name'] == 'mannequin':
|
||||
new_height = info['image'].height
|
||||
max_x = max(x_list)
|
||||
new_width = max_x - min_x
|
||||
# 更新坐标
|
||||
for info in layers:
|
||||
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
||||
return layers, (new_width, new_height)
|
||||
39
app/service/design_fast/utils/upload_image.py
Normal file
39
app/service/design_fast/utils/upload_image.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :upload_image.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/28 13:49:20
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
|
||||
# @RunTime
|
||||
def upload_png_mask(minio_client, front_image, object_name, mask=None):
|
||||
try:
|
||||
mask_url = None
|
||||
if mask is not None:
|
||||
mask_inverted = cv2.bitwise_not(mask)
|
||||
# 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明
|
||||
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
|
||||
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
|
||||
|
||||
image_data = io.BytesIO()
|
||||
front_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
|
||||
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
|
||||
return front_image, image_url, mask_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
@@ -5,13 +5,16 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tritonclient.grpc as grpcclient
|
||||
from pymilvus import MilvusClient
|
||||
from urllib3.exceptions import ResponseError
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.pre_processing import DesignPreProcessingModel
|
||||
from app.service.design.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result, get_keypoint_result
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class DesignPreprocessing:
|
||||
# def __init__(self):
|
||||
@@ -20,19 +23,19 @@ class DesignPreprocessing:
|
||||
# @ RunTime
|
||||
def pipeline(self, image_list):
|
||||
sketches_list = self.read_image(image_list)
|
||||
logging.info("read image success")
|
||||
# logging.info("read image success")
|
||||
|
||||
bounding_box_sketches_list = self.bounding_box(sketches_list)
|
||||
logging.info("bounding box image success")
|
||||
# logging.info("bounding box image success")
|
||||
|
||||
# super_resolution_list = self.super_resolution(bounding_box_sketches_list)
|
||||
# logging.info("super_resolution_list image success")
|
||||
|
||||
infer_sketches_list = self.infer_image(bounding_box_sketches_list)
|
||||
logging.info("infer image success")
|
||||
# logging.info("infer image success")
|
||||
|
||||
result = self.composing_image(infer_sketches_list)
|
||||
logging.info("Replenish white edge image success")
|
||||
# logging.info("Replenish white edge image success")
|
||||
|
||||
for d in result:
|
||||
if 'image_obj' in d:
|
||||
@@ -59,6 +62,7 @@ class DesignPreprocessing:
|
||||
def bounding_box(self, image_list):
|
||||
for item in image_list:
|
||||
image = item['image_obj']
|
||||
height, width = image.shape[:2]
|
||||
# 使用Canny边缘检测来检测物体的轮廓
|
||||
edges = cv2.Canny(image, 50, 150)
|
||||
# 查找轮廓
|
||||
@@ -82,16 +86,25 @@ class DesignPreprocessing:
|
||||
if len(contours) > 0:
|
||||
cropped_image = image[y_min:y_max, x_min:x_max]
|
||||
item['obj'] = cropped_image # 新shape图像
|
||||
# 取消直接覆盖,新增size判断
|
||||
# try:
|
||||
# # 覆盖到minio
|
||||
# image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes()
|
||||
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
|
||||
# print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||
# except ResponseError as err:
|
||||
# print(f"Error: {err}")
|
||||
else:
|
||||
item['obj'] = image
|
||||
|
||||
padding_top = max(20 - y_min, 0)
|
||||
padding_bottom = max(20 - (height - y_max), 0)
|
||||
padding_left = max(20 - x_min, 0)
|
||||
padding_right = max(20 - (width - x_max), 0)
|
||||
|
||||
# 添加padding
|
||||
padded_image = cv2.copyMakeBorder(
|
||||
image,
|
||||
padding_top,
|
||||
padding_bottom,
|
||||
padding_left,
|
||||
padding_right,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=(255, 255, 255)
|
||||
)
|
||||
item['obj'] = padded_image
|
||||
return image_list
|
||||
|
||||
def super_resolution(self, image_list):
|
||||
@@ -99,7 +112,7 @@ class DesignPreprocessing:
|
||||
# 判断 两边是否同时都小于512 因为此处做四倍超分
|
||||
if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512:
|
||||
# 如果任意一边小于256则超分
|
||||
if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256:
|
||||
if item['obj'].shape[0] <= 200 or item['obj'].shape[1] <= 200:
|
||||
# 超分
|
||||
img = item['obj'].astype(np.float32) / 255.
|
||||
sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
@@ -124,13 +137,14 @@ class DesignPreprocessing:
|
||||
bucket_name = item['image_url'].split("/", 1)[0]
|
||||
object_name = item['image_url'].split("/", 1)[1]
|
||||
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||
logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||
except ResponseError as err:
|
||||
print(f"Error: {err}")
|
||||
logging.warning(f"Error: {err}")
|
||||
return image_list
|
||||
|
||||
# @ RunTime
|
||||
def infer_image(self, image_list):
|
||||
seg_result = None
|
||||
for sketch in image_list:
|
||||
# 小写
|
||||
image_category = sketch['image_category'].lower()
|
||||
@@ -138,6 +152,15 @@ class DesignPreprocessing:
|
||||
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# 推理得到keypoint
|
||||
sketch['keypoint_result'] = self.keypoint_cache(sketch)
|
||||
if sketch['site'] == 'up':
|
||||
_, seg_cache = self.load_seg_result(sketch['image_id'])
|
||||
if not _:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(sketch["image_id"], sketch['obj'])[0]
|
||||
self.save_seg_result(seg_result, sketch['image_id'])
|
||||
logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_result.shape}")
|
||||
else:
|
||||
logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_cache.shape}")
|
||||
|
||||
if IF_DEBUG_SHOW:
|
||||
debug_show_image = sketch['obj'].copy()
|
||||
@@ -149,6 +172,7 @@ class DesignPreprocessing:
|
||||
points_list.append((int(i[1]), int(i[0])))
|
||||
for point in points_list:
|
||||
cv2.circle(debug_show_image, point, point_size, point_color, thickness)
|
||||
cv2.imshow("seg_result", seg_result)
|
||||
cv2.imshow("", debug_show_image)
|
||||
cv2.waitKey(0)
|
||||
# # 关键点在上部则推理seg
|
||||
@@ -236,58 +260,37 @@ class DesignPreprocessing:
|
||||
return image_list
|
||||
|
||||
@staticmethod
|
||||
def select_seg_result(image_id, image_obj):
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
# 如果shape不匹配 返回false
|
||||
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
|
||||
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
|
||||
return result
|
||||
else:
|
||||
return False
|
||||
except FileNotFoundError as e:
|
||||
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
||||
return False
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logging.info("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logging.warning(f"加载失败: {e}")
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def search_seg_result(image_id, ori_shape):
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
|
||||
# collection.load()
|
||||
# start_time = time.time()
|
||||
# res = collection.query(
|
||||
# expr=f"seg_id == {image_id}",
|
||||
# offset=0,
|
||||
# limit=10,
|
||||
# output_fields=["seg_cache"],
|
||||
# )
|
||||
# logging.info(f"search seg cache time : {time.time() - start_time}")
|
||||
|
||||
# if len(res):
|
||||
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
|
||||
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
|
||||
# array_2d_exact = array_2d_exact.squeeze().numpy()
|
||||
# return array_2d_exact
|
||||
# else:
|
||||
return False
|
||||
np.save(file_path, seg_result)
|
||||
logging.info(f"保存成功,{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
||||
return False
|
||||
logging.warning(f"保存失败: {e}")
|
||||
|
||||
def keypoint_cache(self, sketch):
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# collection.load()
|
||||
start_time = time.time()
|
||||
# res = collection.query(
|
||||
# expr=f"keypoint_id == {sketch['image_id']}",
|
||||
# offset=0,
|
||||
# limit=1,
|
||||
# output_fields=["keypoint_cache", "keypoint_site"],
|
||||
# )
|
||||
res = []
|
||||
logging.info(f"search keypoint time : {time.time() - start_time}")
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
keypoint_id = sketch['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
# ids=[keypoint_id],
|
||||
filter=f"keypoint_id == {keypoint_id}",
|
||||
output_fields=['keypoint_vector', 'keypoint_site']
|
||||
)
|
||||
if len(res) == 0:
|
||||
# 没有结果 直接推理拿结果 并保存
|
||||
keypoint_infer_result = self.infer_keypoint_result(sketch)
|
||||
@@ -348,7 +351,7 @@ class DesignPreprocessing:
|
||||
]
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
start_time = time.time()
|
||||
# start_time = time.time()
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# mr = collection.upsert(data)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
@@ -362,9 +365,9 @@ if __name__ == '__main__':
|
||||
data = {
|
||||
"sketches": [
|
||||
{
|
||||
"image_category": "dress",
|
||||
"image_id": "107903",
|
||||
"image_url": "aida-sys-image/images/female/dress/0628000000.jpg"
|
||||
"image_category": "blouse",
|
||||
"image_id": "123123123",
|
||||
"image_url": "test/0628000198.jpg"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
45
app/service/image2sketch/checkpoints/download_checkpoints.py
Normal file
45
app/service/image2sketch/checkpoints/download_checkpoints.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
MINIO_URL = "www.minio.aida.com.hk:12024"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
# 配置MinIO客户端
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
# 下载函数
|
||||
def download_folder(bucket_name, folder_name, local_dir):
|
||||
try:
|
||||
# 确保本地目录存在
|
||||
if not os.path.exists(local_dir):
|
||||
os.makedirs(local_dir)
|
||||
|
||||
# 遍历MinIO中的文件
|
||||
objects = minio_client.list_objects(bucket_name, prefix=folder_name, recursive=True)
|
||||
for obj in objects:
|
||||
# 构造本地文件路径
|
||||
local_file_path = os.path.join(local_dir, obj.object_name[len(folder_name):])
|
||||
local_file_dir = os.path.dirname(local_file_path)
|
||||
|
||||
# 确保本地目录存在
|
||||
if not os.path.exists(local_file_dir):
|
||||
os.makedirs(local_file_dir)
|
||||
|
||||
# 下载文件
|
||||
minio_client.fget_object(bucket_name, obj.object_name, local_file_path)
|
||||
print(f"Downloaded {obj.object_name} to {local_file_path}")
|
||||
|
||||
except S3Error as e:
|
||||
print(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# 使用示例
|
||||
bucket_name = "test" # 替换成你的bucket名称
|
||||
folder_name = "checkpoints/" # 权重文件夹的路径
|
||||
local_dir = "app/service/image2sketch/checkpoints" # 替换成你希望保存到的本地目录
|
||||
|
||||
download_folder(bucket_name, folder_name, local_dir)
|
||||
BIN
app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg
Normal file
BIN
app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 101 KiB |
BIN
app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg
Normal file
BIN
app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 376 KiB |
BIN
app/service/image2sketch/datasets/ref_unpair/testC/style_3.png
Normal file
BIN
app/service/image2sketch/datasets/ref_unpair/testC/style_3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 57 KiB |
89
app/service/image2sketch/infer.py
Normal file
89
app/service/image2sketch/infer.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
from .models import create_model
|
||||
|
||||
|
||||
def tensor2im(input_image, imtype=np.uint8):
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
||||
image_tensor = input_image.data
|
||||
else:
|
||||
return input_image
|
||||
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
||||
if image_numpy.shape[0] == 1: # grayscale to RGB
|
||||
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
||||
else: # if it is a numpy array, do nothing
|
||||
image_numpy = input_image
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
|
||||
def save_image(image_numpy, image_path, w, h, aspect_ratio=1.0):
|
||||
"""Save a numpy image to the disk
|
||||
|
||||
Parameters:
|
||||
image_numpy (numpy array) -- input numpy array
|
||||
image_path (str) -- the path of the image
|
||||
"""
|
||||
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil = image_pil.resize((w, h))
|
||||
image_pil.save(image_path)
|
||||
|
||||
|
||||
def save_img(image_tensor, w, h, filename):
|
||||
image_pil = tensor2im(image_tensor)
|
||||
|
||||
save_image(image_pil, filename, w, h, aspect_ratio=1.0)
|
||||
print("Image saved as {}".format(filename))
|
||||
|
||||
|
||||
def load_img(filepath):
|
||||
img = Image.open(filepath).convert('L')
|
||||
# print(img.size)
|
||||
width = img.size[0]
|
||||
height = img.size[1]
|
||||
# img = img.resize((512, 512), Image.BICUBIC)
|
||||
return img, width, height
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img_A = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testA/real_Dress_732caedc416a0cbfedd0e6528040eac7.jpg_Img.jpg"
|
||||
img_B = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testC/style_3.png"
|
||||
from opt import Config
|
||||
|
||||
opt = Config() # get test options
|
||||
# hard-code some parameters for test
|
||||
opt.num_threads = 0 # test code only supports num_threads = 0
|
||||
opt.batch_size = 1 # test code only supports batch_size = 1
|
||||
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
||||
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
|
||||
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
|
||||
device = torch.device("cuda:0")
|
||||
model = create_model(opt) # create a model given opt.model and other options
|
||||
model.setup(opt)
|
||||
transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
|
||||
transform = transforms.Compose(transform_list)
|
||||
if opt.eval:
|
||||
model.eval()
|
||||
data = {}
|
||||
print(os.getcwd())
|
||||
B = reference, _, _ = load_img(r"/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png")
|
||||
style_img = transform(reference)
|
||||
data['B'] = style_img
|
||||
data['B'] = data['B'].unsqueeze(0).to(device)
|
||||
A = Image.open(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testA\real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg")
|
||||
width = A.size[0]
|
||||
height = A.size[1]
|
||||
# data['A'] = A.resize((512, 512))
|
||||
data['A'] = transform(A)
|
||||
data['A'] = data['A'].unsqueeze(0).to(device)
|
||||
model.set_input(data)
|
||||
model.test() # run inference
|
||||
visuals = model.get_current_visuals() # get image results
|
||||
save_img(visuals['content_output'].cpu(), width, height, "result/result.jpg")
|
||||
49
app/service/image2sketch/models/__init__.py
Normal file
49
app/service/image2sketch/models/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import importlib
|
||||
|
||||
from app.service.image2sketch.models import unpaired_model as modellib
|
||||
from .base_model import BaseModel
|
||||
|
||||
|
||||
def find_model_using_name(model_name):
|
||||
"""Import the module "models/[model_name]_model.py".
|
||||
|
||||
In the file, the class called DatasetNameModel() will
|
||||
be instantiated. It has to be a subclass of BaseModel,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
# model_filename = "." + model_name + "_model"
|
||||
# modellib = importlib.import_module(model_filename)
|
||||
model = None
|
||||
target_model_name = model_name.replace('_', '') + 'model'
|
||||
for name, cls in modellib.__dict__.items():
|
||||
if name.lower() == target_model_name.lower() \
|
||||
and issubclass(cls, BaseModel):
|
||||
model = cls
|
||||
|
||||
if model is None:
|
||||
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
||||
exit(0)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_option_setter(model_name):
|
||||
"""Return the static method <modify_commandline_options> of the model class."""
|
||||
model_class = find_model_using_name(model_name)
|
||||
return model_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
"""Create a model given the option.
|
||||
|
||||
This function warps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from .models import create_model
|
||||
>>> model = create_model(opt)
|
||||
"""
|
||||
model = find_model_using_name(opt.model)
|
||||
instance = model(opt)
|
||||
print("model [%s] was created" % type(instance).__name__)
|
||||
return instance
|
||||
230
app/service/image2sketch/models/base_model.py
Normal file
230
app/service/image2sketch/models/base_model.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from . import networks
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""This class is an abstract base class (ABC) for models.
|
||||
To create a subclass, you need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the BaseModel class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
|
||||
When creating your custom class, you need to implement your own initialization.
|
||||
In this function, you should first call <BaseModel.__init__(self, opt)>
|
||||
Then, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): define networks used in our training.
|
||||
-- self.visual_names (str list): specify the images that you want to display and save.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||||
"""
|
||||
self.opt = opt
|
||||
self.gpu_ids = opt.gpu_ids
|
||||
self.isTrain = opt.isTrain
|
||||
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
||||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
||||
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.loss_names = []
|
||||
self.model_names = []
|
||||
self.visual_names = []
|
||||
self.optimizers = []
|
||||
self.image_paths = []
|
||||
self.metric = 0 # used for learning rate policy 'plateau'
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new model-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): includes the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize_parameters(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
pass
|
||||
|
||||
def setup(self, opt):
|
||||
"""Load and print networks; create schedulers
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
if self.isTrain:
|
||||
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
||||
if not self.isTrain or opt.continue_train:
|
||||
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
||||
self.load_networks(load_suffix)
|
||||
self.print_networks(opt.verbose)
|
||||
|
||||
def eval(self):
|
||||
"""Make models eval mode during test time"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
net.eval()
|
||||
|
||||
def test(self):
|
||||
"""Forward function used in test time.
|
||||
|
||||
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
||||
It also calls <compute_visuals> to produce additional visualization results
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.forward()
|
||||
self.compute_visuals()
|
||||
|
||||
def compute_visuals(self):
|
||||
"""Calculate additional output images for visdom and HTML visualization"""
|
||||
pass
|
||||
|
||||
def get_image_paths(self):
|
||||
""" Return image paths that are used to load current data"""
|
||||
return self.image_paths
|
||||
|
||||
def update_learning_rate(self):
|
||||
"""Update learning rates for all the networks; called at the end of every epoch"""
|
||||
old_lr = self.optimizers[0].param_groups[0]['lr']
|
||||
for scheduler in self.schedulers:
|
||||
if self.opt.lr_policy == 'plateau':
|
||||
scheduler.step(self.metric)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
lr = self.optimizers[0].param_groups[0]['lr']
|
||||
print('learning rate %.7f -> %.7f' % (old_lr, lr))
|
||||
|
||||
def get_current_visuals(self):
|
||||
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
||||
visual_ret = OrderedDict()
|
||||
for name in self.visual_names:
|
||||
if isinstance(name, str):
|
||||
visual_ret[name] = getattr(self, name)
|
||||
return visual_ret
|
||||
|
||||
def get_current_losses(self):
|
||||
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
||||
errors_ret = OrderedDict()
|
||||
for name in self.loss_names:
|
||||
if isinstance(name, str):
|
||||
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
||||
return errors_ret
|
||||
|
||||
def save_networks(self, epoch):
|
||||
"""Save all the networks to the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
save_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
|
||||
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
||||
torch.save(net.module.cpu().state_dict(), save_path)
|
||||
net.cuda(self.gpu_ids[0])
|
||||
else:
|
||||
torch.save(net.cpu().state_dict(), save_path)
|
||||
|
||||
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
||||
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
||||
key = keys[i]
|
||||
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'running_mean' or key == 'running_var'):
|
||||
if getattr(module, key) is None:
|
||||
state_dict.pop('.'.join(keys))
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'num_batches_tracked'):
|
||||
state_dict.pop('.'.join(keys))
|
||||
else:
|
||||
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
||||
|
||||
def load_networks(self, epoch):
|
||||
"""Load all the networks from the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
load_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
load_path = os.path.join(self.save_dir, load_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
net = net.module
|
||||
print('loading the model from %s' % load_path)
|
||||
# if you are using PyTorch newer than 0.4 (e.g., built from
|
||||
# GitHub source), you can remove str() on self.device
|
||||
state_dict = torch.load(load_path, map_location=str(self.device))
|
||||
if hasattr(state_dict, '_metadata'):
|
||||
del state_dict._metadata
|
||||
|
||||
# patch InstanceNorm checkpoints prior to 0.4
|
||||
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
||||
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
def print_networks(self, verbose):
|
||||
"""Print the total number of parameters in the network and (if verbose) network architecture
|
||||
|
||||
Parameters:
|
||||
verbose (bool) -- if verbose: print the network architecture
|
||||
"""
|
||||
print('---------- Networks initialized -------------')
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
if verbose:
|
||||
print(net)
|
||||
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
||||
print('-----------------------------------------------')
|
||||
|
||||
def set_requires_grad(self, nets, requires_grad=False):
|
||||
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
||||
Parameters:
|
||||
nets (network list) -- a list of networks
|
||||
requires_grad (bool) -- whether the networks require gradients or not
|
||||
"""
|
||||
if not isinstance(nets, list):
|
||||
nets = [nets]
|
||||
for net in nets:
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
354
app/service/image2sketch/models/layer.py
Normal file
354
app/service/image2sketch/models/layer.py
Normal file
@@ -0,0 +1,354 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CNR2d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]):
|
||||
super().__init__()
|
||||
|
||||
if bias == []:
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
|
||||
|
||||
if norm != []:
|
||||
layers += [Norm2d(nch_out, norm)]
|
||||
|
||||
if relu != []:
|
||||
layers += [ReLU(relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
self.cbr = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbr(x)
|
||||
|
||||
|
||||
class DECNR2d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]):
|
||||
super().__init__()
|
||||
|
||||
if bias == []:
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)]
|
||||
|
||||
if norm != []:
|
||||
layers += [Norm2d(nch_out, norm)]
|
||||
|
||||
if relu != []:
|
||||
layers += [ReLU(relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
self.decbr = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.decbr(x)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
|
||||
super().__init__()
|
||||
|
||||
if bias == []:
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
|
||||
# 1st conv
|
||||
layers += [Padding(padding, padding_mode=padding_mode)]
|
||||
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
# 2nd conv
|
||||
layers += [Padding(padding, padding_mode=padding_mode)]
|
||||
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
|
||||
|
||||
self.resblk = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.resblk(x)
|
||||
|
||||
|
||||
class ResBlock_cat(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
|
||||
super().__init__()
|
||||
|
||||
if bias == []:
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
|
||||
# 1st conv
|
||||
layers += [Padding(padding, padding_mode=padding_mode)]
|
||||
layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
# 2nd conv
|
||||
layers += [Padding(padding, padding_mode=padding_mode)]
|
||||
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
|
||||
|
||||
self.resblk = nn.Sequential(*layers)
|
||||
|
||||
def forward(self,x,y):
|
||||
output = x + self.resblk(torch.cat([x,y],dim=1))
|
||||
return output
|
||||
|
||||
class LinearBlock(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
||||
super(LinearBlock, self).__init__()
|
||||
use_bias = True
|
||||
# initialize fully connected layer
|
||||
if norm == 'sn':
|
||||
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
|
||||
else:
|
||||
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
||||
|
||||
# initialize normalization
|
||||
norm_dim = output_dim
|
||||
if norm == 'bn':
|
||||
self.norm = nn.BatchNorm1d(norm_dim)
|
||||
elif norm == 'in':
|
||||
self.norm = nn.InstanceNorm1d(norm_dim)
|
||||
elif norm == 'ln':
|
||||
self.norm = LayerNorm(norm_dim)
|
||||
elif norm == 'none' or norm == 'sn':
|
||||
self.norm = None
|
||||
else:
|
||||
assert 0, "Unsupported normalization: {}".format(norm)
|
||||
|
||||
# initialize activation
|
||||
if activation == 'relu':
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
elif activation == 'lrelu':
|
||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||||
elif activation == 'prelu':
|
||||
self.activation = nn.PReLU()
|
||||
elif activation == 'selu':
|
||||
self.activation = nn.SELU(inplace=True)
|
||||
elif activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
elif activation == 'none':
|
||||
self.activation = None
|
||||
else:
|
||||
assert 0, "Unsupported activation: {}".format(activation)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc(x)
|
||||
if self.norm:
|
||||
out = self.norm(out)
|
||||
if self.activation:
|
||||
out = self.activation(out)
|
||||
return out
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
|
||||
|
||||
super(MLP, self).__init__()
|
||||
self.model = []
|
||||
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
|
||||
for i in range(n_blk - 2):
|
||||
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
|
||||
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x.view(x.size(0), -1))
|
||||
|
||||
class CNR1d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]):
|
||||
super().__init__()
|
||||
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
layers += [nn.Linear(nch_in, nch_out, bias=bias)]
|
||||
|
||||
if norm != []:
|
||||
layers += [Norm2d(nch_out, norm)]
|
||||
|
||||
if relu != []:
|
||||
layers += [ReLU(relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
self.cbr = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbr(x)
|
||||
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True):
|
||||
super(Conv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Deconv2d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True):
|
||||
super(Deconv2d, self).__init__()
|
||||
self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)
|
||||
|
||||
# layers = [nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
# nn.ReflectionPad2d(1),
|
||||
# nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)]
|
||||
#
|
||||
# self.deconv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.deconv(x)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, nch_in, nch_out):
|
||||
super(Linear, self).__init__()
|
||||
self.linear = nn.Linear(nch_in, nch_out)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class Norm2d(nn.Module):
|
||||
def __init__(self, nch, norm_mode):
|
||||
super(Norm2d, self).__init__()
|
||||
if norm_mode == 'bnorm':
|
||||
self.norm = nn.BatchNorm2d(nch)
|
||||
elif norm_mode == 'inorm':
|
||||
self.norm = nn.InstanceNorm2d(nch)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class ReLU(nn.Module):
|
||||
def __init__(self, relu):
|
||||
super(ReLU, self).__init__()
|
||||
if relu > 0:
|
||||
self.relu = nn.LeakyReLU(relu, True)
|
||||
elif relu == 0:
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class Padding(nn.Module):
|
||||
def __init__(self, padding, padding_mode='zeros', value=0):
|
||||
super(Padding, self).__init__()
|
||||
if padding_mode == 'reflection':
|
||||
self. padding = nn.ReflectionPad2d(padding)
|
||||
elif padding_mode == 'replication':
|
||||
self.padding = nn.ReplicationPad2d(padding)
|
||||
elif padding_mode == 'constant':
|
||||
self.padding = nn.ConstantPad2d(padding, value)
|
||||
elif padding_mode == 'zeros':
|
||||
self.padding = nn.ZeroPad2d(padding)
|
||||
|
||||
def forward(self, x):
|
||||
return self.padding(x)
|
||||
|
||||
|
||||
class Pooling2d(nn.Module):
|
||||
def __init__(self, nch=[], pool=2, type='avg'):
|
||||
super().__init__()
|
||||
|
||||
if type == 'avg':
|
||||
self.pooling = nn.AvgPool2d(pool)
|
||||
elif type == 'max':
|
||||
self.pooling = nn.MaxPool2d(pool)
|
||||
elif type == 'conv':
|
||||
self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pooling(x)
|
||||
|
||||
|
||||
class UnPooling2d(nn.Module):
|
||||
def __init__(self, nch=[], pool=2, type='nearest'):
|
||||
super().__init__()
|
||||
|
||||
if type == 'nearest':
|
||||
self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True)
|
||||
elif type == 'bilinear':
|
||||
self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True)
|
||||
elif type == 'conv':
|
||||
self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool)
|
||||
|
||||
def forward(self, x):
|
||||
return self.unpooling(x)
|
||||
|
||||
|
||||
class Concat(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x1, x2):
|
||||
diffy = x2.size()[2] - x1.size()[2]
|
||||
diffx = x2.size()[3] - x1.size()[3]
|
||||
|
||||
x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2,
|
||||
diffy // 2, diffy - diffy // 2])
|
||||
|
||||
return torch.cat([x2, x1], dim=1)
|
||||
|
||||
|
||||
class TV1dLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(TV1dLoss, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
# loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
|
||||
# torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
|
||||
loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:]))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class TV2dLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(TV2dLoss, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
|
||||
torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
|
||||
return loss
|
||||
|
||||
|
||||
class SSIM2dLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(SSIM2dLoss, self).__init__()
|
||||
|
||||
def forward(self, input, targer):
|
||||
loss = 0
|
||||
return loss
|
||||
|
||||
734
app/service/image2sketch/models/networks.py
Normal file
734
app/service/image2sketch/models/networks.py
Normal file
@@ -0,0 +1,734 @@
|
||||
import functools
|
||||
|
||||
from torch.nn import init
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from .layer import *
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Helper Functions
|
||||
###############################################################################
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def get_norm_layer(norm_type='instance'):
|
||||
"""Return a normalization layer
|
||||
|
||||
Parameters:
|
||||
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
||||
|
||||
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
||||
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
||||
"""
|
||||
if norm_type == 'batch':
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||
elif norm_type == 'instance':
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x):
|
||||
return Identity()
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
||||
return norm_layer
|
||||
|
||||
|
||||
def get_scheduler(optimizer, opt):
|
||||
"""Return a learning rate scheduler
|
||||
|
||||
Parameters:
|
||||
optimizer -- the optimizer of the network
|
||||
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
||||
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
||||
|
||||
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
||||
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
||||
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
||||
See https://pytorch.org/docs/stable/optim.html for more details.
|
||||
"""
|
||||
if opt.lr_policy == 'linear':
|
||||
def lambda_rule(epoch):
|
||||
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
||||
return lr_l
|
||||
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
||||
elif opt.lr_policy == 'step':
|
||||
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
||||
elif opt.lr_policy == 'plateau':
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
||||
elif opt.lr_policy == 'cosine':
|
||||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
||||
else:
|
||||
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
||||
return scheduler
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (network) -- network to be initialized
|
||||
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||
|
||||
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
||||
work better for some applications. Feel free to try yourself.
|
||||
"""
|
||||
|
||||
def init_func(m): # define the initialization function
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
||||
if init_type == 'normal':
|
||||
init.normal_(m.weight.data, 0.0, init_gain)
|
||||
elif init_type == 'xavier':
|
||||
init.xavier_normal_(m.weight.data, gain=init_gain)
|
||||
elif init_type == 'kaiming':
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
init.orthogonal_(m.weight.data, gain=init_gain)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
||||
init.normal_(m.weight.data, 1.0, init_gain)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
print('initialize network with %s' % init_type)
|
||||
net.apply(init_func) # apply the initialization function <init_func>
|
||||
|
||||
|
||||
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
||||
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
||||
Parameters:
|
||||
net (network) -- the network to be initialized
|
||||
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
||||
|
||||
Return an initialized network.
|
||||
"""
|
||||
if len(gpu_ids) > 0:
|
||||
assert (torch.cuda.is_available())
|
||||
net.to(gpu_ids[0])
|
||||
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
||||
init_weights(net, init_type, init_gain=init_gain)
|
||||
return net
|
||||
|
||||
|
||||
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
||||
net = None
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
|
||||
if netG == 'ref_unpair_cbam_cat':
|
||||
net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_cbam_cat')
|
||||
elif netG == 'ref_unpair_recon':
|
||||
net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_recon')
|
||||
elif netG == 'triplet':
|
||||
net = triplet(input_nc, output_nc, ngf, norm='inorm')
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
||||
return init_net(net, init_type, init_gain, gpu_ids)
|
||||
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
eps = 1e-5
|
||||
mean_x = torch.mean(x, dim=[2, 3])
|
||||
mean_y = torch.mean(y, dim=[2, 3])
|
||||
|
||||
std_x = torch.std(x, dim=[2, 3])
|
||||
std_y = torch.std(y, dim=[2, 3])
|
||||
|
||||
mean_x = mean_x.unsqueeze(-1).unsqueeze(-1)
|
||||
mean_y = mean_y.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
std_x = std_x.unsqueeze(-1).unsqueeze(-1) + eps
|
||||
std_y = std_y.unsqueeze(-1).unsqueeze(-1) + eps
|
||||
|
||||
out = (x - mean_x) / std_x * std_y + mean_y
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HED(nn.Module):
|
||||
def __init__(self):
|
||||
super(HED, self).__init__()
|
||||
|
||||
self.moduleVggOne = nn.Sequential(
|
||||
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleVggTwo = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleVggThr = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleVggFou = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleVggFiv = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.moduleCombine = nn.Sequential(
|
||||
nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, tensorInput):
|
||||
tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793
|
||||
tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762
|
||||
tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434
|
||||
tensorInput = torch.cat([tensorBlue, tensorGreen, tensorRed], 1)
|
||||
|
||||
tensorVggOne = self.moduleVggOne(tensorInput)
|
||||
tensorVggTwo = self.moduleVggTwo(tensorVggOne)
|
||||
tensorVggThr = self.moduleVggThr(tensorVggTwo)
|
||||
tensorVggFou = self.moduleVggFou(tensorVggThr)
|
||||
tensorVggFiv = self.moduleVggFiv(tensorVggFou)
|
||||
|
||||
tensorScoreOne = self.moduleScoreOne(tensorVggOne)
|
||||
tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo)
|
||||
tensorScoreThr = self.moduleScoreThr(tensorVggThr)
|
||||
tensorScoreFou = self.moduleScoreFou(tensorVggFou)
|
||||
tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv)
|
||||
|
||||
tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
|
||||
return self.moduleCombine(torch.cat([tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv], 1))
|
||||
# return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreOne, tensorScoreTwo ], 1))
|
||||
|
||||
# return torch.sigmoid(tensorScoreOne),torch.sigmoid(tensorScoreTwo),torch.sigmoid(tensorScoreThr),torch.sigmoid(tensorScoreFou),torch.sigmoid(tensorScoreFiv),self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1))
|
||||
# return torch.sigmoid(tensorScoreTwo)
|
||||
|
||||
|
||||
def define_HED(init_weights_, gpu_ids_=[]):
|
||||
net = HED()
|
||||
|
||||
if len(gpu_ids_) > 0:
|
||||
assert (torch.cuda.is_available())
|
||||
net.to(gpu_ids_[0])
|
||||
net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
|
||||
|
||||
if not init_weights_ == None:
|
||||
device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
|
||||
print('Loading model from: %s' % init_weights_)
|
||||
state_dict = torch.load(init_weights_, map_location=str(device))
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
net.module.load_state_dict(state_dict)
|
||||
else:
|
||||
net.load_state_dict(state_dict)
|
||||
print('load the weights successfully')
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def define_styletps(init_weights_, gpu_ids_=[], shape=False):
|
||||
net = None
|
||||
if shape == False:
|
||||
net = triplet()
|
||||
if len(gpu_ids_) > 0:
|
||||
assert (torch.cuda.is_available())
|
||||
net.to(gpu_ids_[0])
|
||||
net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
|
||||
|
||||
if not init_weights_ == None:
|
||||
device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
|
||||
print('Loading model from: %s' % init_weights_)
|
||||
state_dict = torch.load(init_weights_, map_location=str(device))
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
net.module.load_state_dict(state_dict)
|
||||
else:
|
||||
net.load_state_dict(state_dict)
|
||||
print('load the weights successfully')
|
||||
|
||||
return net
|
||||
|
||||
|
||||
class triplet(nn.Module):
|
||||
def __init__(self): # mnblk=4
|
||||
super(triplet, self).__init__()
|
||||
|
||||
# self.channels = nch_in
|
||||
self.nch_in = 1
|
||||
self.nch_out = 1
|
||||
self.nch_ker = 64
|
||||
self.norm = 'bnorm'
|
||||
# self.nblk = nblk
|
||||
|
||||
if self.norm == 'bnorm':
|
||||
self.bias = False
|
||||
else:
|
||||
self.bias = True
|
||||
|
||||
self.conv0 = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
|
||||
self.conv1 = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.conv2 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
|
||||
self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.linear = nn.Linear(256, 128)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
|
||||
x = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.final_pool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.linear(x)
|
||||
|
||||
y = self.conv0(y)
|
||||
y = self.conv1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.final_pool(y)
|
||||
y = torch.flatten(y, 1)
|
||||
y = self.linear(y)
|
||||
|
||||
z = self.conv0(z)
|
||||
z = self.conv1(z)
|
||||
z = self.conv2(z)
|
||||
z = self.final_pool(z)
|
||||
z = torch.flatten(z, 1)
|
||||
z = self.linear(z)
|
||||
|
||||
return x, y, z
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
|
||||
super(MLP, self).__init__()
|
||||
self.model = []
|
||||
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
|
||||
for i in range(n_blk - 2):
|
||||
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
|
||||
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x.view(x.size(0), -1))
|
||||
|
||||
|
||||
class ref_unpair(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm', nblk=4, status='ref_unpair'):
|
||||
super(ref_unpair, self).__init__()
|
||||
|
||||
nch_ker = 64
|
||||
# self.channels = nch_in
|
||||
self.nch_in = nch_in
|
||||
self.nchs_in = 1
|
||||
self.status = status
|
||||
|
||||
if self.status == 'ref_unpair_recon':
|
||||
self.nch_out = 3
|
||||
self.nch_in = 1
|
||||
else:
|
||||
self.nch_out = 1
|
||||
|
||||
self.nch_ker = nch_ker
|
||||
self.norm = norm
|
||||
self.nblk = nblk
|
||||
self.dec0 = []
|
||||
|
||||
if status == 'ref_unpair_cbam_cat':
|
||||
self.cbam_c = CBAM(nch_ker * 8, 16, 3, cbam_status="channel")
|
||||
self.cbam_s = CBAM(nch_ker * 8, 16, 3, cbam_status="spatial")
|
||||
|
||||
self.enc1_s = CNR2d(self.nchs_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
|
||||
self.enc2_s = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.enc3_s = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.enc4_s = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
|
||||
if norm == 'bnorm':
|
||||
self.bias = False
|
||||
else:
|
||||
self.bias = True
|
||||
|
||||
self.enc1_c = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
|
||||
self.enc2_c = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.enc3_c = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.enc4_c = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
|
||||
if status == 'ref_unpair_cbam_cat':
|
||||
self.res_cat1 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
|
||||
self.res_cat2 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
|
||||
self.res_cat3 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
|
||||
self.res_cat4 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
|
||||
|
||||
if self.nblk and status != 'ref_unpair_cbam_cat':
|
||||
res = []
|
||||
for i in range(self.nblk):
|
||||
res += [ResBlock(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')]
|
||||
self.res1 = nn.Sequential(*res)
|
||||
|
||||
# self.dec0 += [DECNR2d(16 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [DECNR2d(8 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [DECNR2d(1 * self.nch_ker, 1 * self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [nn.Conv2d(1 * self.nch_ker, self.nch_out, kernel_size=3, stride=1, padding=1)]
|
||||
|
||||
self.dec = nn.Sequential(*self.dec0)
|
||||
|
||||
def forward(self, content, style):
|
||||
|
||||
content_cs = self.enc1_c(content)
|
||||
content_cs = self.enc2_c(content_cs)
|
||||
content_cs = self.enc3_c(content_cs)
|
||||
content_cs = self.enc4_c(content_cs)
|
||||
# content_cs = self.enc5_c(content_cs)
|
||||
|
||||
if self.status == 'ref_unpair_cbam_cat':
|
||||
cbam_content_cs = self.cbam_s(content_cs)
|
||||
sp_content_cs = content_cs + cbam_content_cs
|
||||
|
||||
style_cs = self.enc1_s(style)
|
||||
style_cs = self.enc2_s(style_cs)
|
||||
style_cs = self.enc3_s(style_cs)
|
||||
style_cs = self.enc4_s(style_cs)
|
||||
|
||||
cbam_style_cs = self.cbam_c(style_cs)
|
||||
ch_style_cs = style_cs + cbam_style_cs
|
||||
|
||||
content_output = self.adaptive_instance_normalization(content_cs, style_cs)
|
||||
cbam_content_output = self.adaptive_instance_normalization(sp_content_cs, ch_style_cs)
|
||||
|
||||
content_output = self.res_cat1(content_output, cbam_content_output)
|
||||
content_output = self.res_cat2(content_output, cbam_content_output)
|
||||
content_output = self.res_cat3(content_output, cbam_content_output)
|
||||
content_output = self.res_cat4(content_output, cbam_content_output)
|
||||
|
||||
|
||||
else:
|
||||
content_output = content_cs
|
||||
|
||||
if self.nblk and self.status != 'ref_unpair_cbam_cat':
|
||||
content_cs = self.res1(content_output)
|
||||
|
||||
content_output = self.dec(content_output)
|
||||
|
||||
content_output = torch.tanh(content_output)
|
||||
|
||||
return content_output
|
||||
|
||||
def calc_mean_std(self, feat, eps=1e-5):
|
||||
# eps is a small value added to the variance to avoid divide-by-zero.
|
||||
size = feat.size()
|
||||
assert (len(size) == 4)
|
||||
N, C = size[:2]
|
||||
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
||||
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
def adaptive_instance_normalization(self, content_feat, style_feat):
|
||||
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = self.calc_mean_std(style_feat)
|
||||
content_mean, content_std = self.calc_mean_std(content_feat)
|
||||
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
|
||||
net = None
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
|
||||
if netD == 'basic': # default PatchGAN classifier
|
||||
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
|
||||
elif netD == 'n_layers': # more options
|
||||
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
|
||||
elif netD == 'pixel': # classify if each pixel is real or fake
|
||||
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
||||
return init_net(net, init_type, init_gain, gpu_ids)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Classes
|
||||
##############################################################################
|
||||
class GANLoss(nn.Module):
|
||||
"""Define different GAN objectives.
|
||||
|
||||
The GANLoss class abstracts away the need to create the target label tensor
|
||||
that has the same size as the input.
|
||||
"""
|
||||
|
||||
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
||||
""" Initialize the GANLoss class.
|
||||
|
||||
Parameters:
|
||||
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
||||
target_real_label (bool) - - label for a real image
|
||||
target_fake_label (bool) - - label of a fake image
|
||||
|
||||
Note: Do not use sigmoid as the last layer of Discriminator.
|
||||
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
||||
"""
|
||||
super(GANLoss, self).__init__()
|
||||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||
self.gan_mode = gan_mode
|
||||
if gan_mode == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
elif gan_mode == 'vanilla':
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif gan_mode in ['wgangp']:
|
||||
self.loss = None
|
||||
else:
|
||||
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
||||
|
||||
def get_target_tensor(self, prediction, target_is_real):
|
||||
if target_is_real:
|
||||
target_tensor = self.real_label
|
||||
else:
|
||||
target_tensor = self.fake_label
|
||||
return target_tensor.expand_as(prediction)
|
||||
|
||||
def __call__(self, prediction, target_is_real):
|
||||
if self.gan_mode in ['lsgan', 'vanilla']:
|
||||
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
||||
loss = self.loss(prediction, target_tensor)
|
||||
elif self.gan_mode == 'wgangp':
|
||||
if target_is_real:
|
||||
loss = -prediction.mean()
|
||||
else:
|
||||
loss = prediction.mean()
|
||||
return loss
|
||||
|
||||
|
||||
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
||||
if lambda_gp > 0.0:
|
||||
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
||||
interpolatesv = real_data
|
||||
elif type == 'fake':
|
||||
interpolatesv = fake_data
|
||||
elif type == 'mixed':
|
||||
alpha = torch.rand(real_data.shape[0], 1, device=device)
|
||||
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
||||
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
||||
else:
|
||||
raise NotImplementedError('{} not implemented'.format(type))
|
||||
interpolatesv.requires_grad_(True)
|
||||
disc_interpolates = netD(interpolatesv)
|
||||
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
||||
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)
|
||||
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
||||
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
||||
return gradient_penalty, gradients
|
||||
else:
|
||||
return 0.0, None
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
"""Defines a PatchGAN discriminator"""
|
||||
|
||||
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
||||
"""Construct a PatchGAN discriminator
|
||||
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
ndf (int) -- the number of filters in the last conv layer
|
||||
n_layers (int) -- the number of conv layers in the discriminator
|
||||
norm_layer -- normalization layer
|
||||
"""
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
kw = 4
|
||||
padw = 1
|
||||
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
||||
nf_mult = 1
|
||||
nf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2 ** n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2 ** n_layers, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
"""Standard forward."""
|
||||
return self.model(input)
|
||||
|
||||
|
||||
class PixelDiscriminator(nn.Module):
|
||||
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
||||
|
||||
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
||||
"""Construct a 1x1 PatchGAN discriminator
|
||||
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
ndf (int) -- the number of filters in the last conv layer
|
||||
norm_layer -- normalization layer
|
||||
"""
|
||||
super(PixelDiscriminator, self).__init__()
|
||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
self.net = [
|
||||
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
||||
norm_layer(ndf * 2),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
||||
|
||||
self.net = nn.Sequential(*self.net)
|
||||
|
||||
def forward(self, input):
|
||||
"""Standard forward."""
|
||||
return self.net(input)
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
def __init__(self, n_channels_in, reduction_ratio, kernel_size, cbam_status):
|
||||
super(CBAM, self).__init__()
|
||||
self.n_channels_in = n_channels_in
|
||||
self.reduction_ratio = reduction_ratio
|
||||
self.kernel_size = kernel_size
|
||||
self.channel_attention = ChannelAttention_nopara(n_channels_in, reduction_ratio)
|
||||
self.spatial_attention = SpatialAttention_nopara(kernel_size)
|
||||
self.status = cbam_status
|
||||
|
||||
def forward(self, x):
|
||||
## We don't use cbam in this version
|
||||
if self.status == "cbam":
|
||||
chan_att = self.channel_attention(x)
|
||||
fp = chan_att * x
|
||||
spat_att = self.spatial_attention(fp)
|
||||
fpp = spat_att * fp
|
||||
|
||||
if self.status == "spatial":
|
||||
spat_att = self.spatial_attention(x) # * s_para_1d
|
||||
fpp = spat_att * x
|
||||
if self.status == "channel":
|
||||
chan_att = self.channel_attention(x) # * c_para_1d
|
||||
fpp = chan_att * x
|
||||
|
||||
return fpp # ,c_wgt,s_wgt
|
||||
|
||||
|
||||
class SpatialAttention_nopara(nn.Module):
|
||||
def __init__(self, kernel_size):
|
||||
super(SpatialAttention_nopara, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
assert kernel_size % 2 == 1, "Odd kernel size required"
|
||||
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=int((kernel_size - 1) / 2))
|
||||
|
||||
def forward(self, x):
|
||||
max_pool = self.agg_channel(x, "max")
|
||||
avg_pool = self.agg_channel(x, "avg")
|
||||
pool = torch.cat([max_pool, avg_pool], dim=1)
|
||||
conv = self.conv(pool)
|
||||
conv = conv.repeat(1, x.size()[1], 1, 1)
|
||||
att = torch.sigmoid(conv)
|
||||
return att
|
||||
|
||||
def agg_channel(self, x, pool="max"):
|
||||
b, c, h, w = x.size()
|
||||
x = x.view(b, c, h * w)
|
||||
x = x.permute(0, 2, 1)
|
||||
if pool == "max":
|
||||
x = F.max_pool1d(x, c)
|
||||
elif pool == "avg":
|
||||
x = F.avg_pool1d(x, c)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = x.view(b, 1, h, w)
|
||||
return x
|
||||
|
||||
|
||||
class ChannelAttention_nopara(nn.Module):
|
||||
def __init__(self, n_channels_in, reduction_ratio):
|
||||
super(ChannelAttention_nopara, self).__init__()
|
||||
self.n_channels_in = n_channels_in
|
||||
self.reduction_ratio = reduction_ratio
|
||||
self.middle_layer_size = int(self.n_channels_in / float(self.reduction_ratio))
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(self.n_channels_in, self.middle_layer_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.middle_layer_size, self.n_channels_in)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
kernel = (x.size()[2], x.size()[3])
|
||||
avg_pool = F.avg_pool2d(x, kernel)
|
||||
max_pool = F.max_pool2d(x, kernel)
|
||||
avg_pool = avg_pool.view(avg_pool.size()[0], -1)
|
||||
max_pool = max_pool.view(max_pool.size()[0], -1)
|
||||
avg_pool_bck = self.bottleneck(avg_pool)
|
||||
max_pool_bck = self.bottleneck(max_pool)
|
||||
pool_sum = avg_pool_bck + max_pool_bck
|
||||
sig_pool = torch.sigmoid(pool_sum)
|
||||
sig_pool = sig_pool.unsqueeze(2).unsqueeze(3)
|
||||
# out = sig_pool.repeat(1,1,kernel[0], kernel[1])
|
||||
|
||||
return sig_pool
|
||||
86
app/service/image2sketch/models/perceptual.py
Normal file
86
app/service/image2sketch/models/perceptual.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
class VGGPerceptualLoss(torch.nn.Module):
|
||||
def __init__(self, resize=True):
|
||||
super(VGGPerceptualLoss, self).__init__()
|
||||
blocks = []
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
||||
for bl in blocks:
|
||||
for p in bl:
|
||||
p.requires_grad = False
|
||||
self.blocks = torch.nn.ModuleList(blocks)
|
||||
self.transform = torch.nn.functional.interpolate
|
||||
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
||||
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
||||
self.resize = resize
|
||||
|
||||
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
|
||||
if input.shape[1] != 3:
|
||||
input = input.repeat(1, 3, 1, 1)
|
||||
target = target.repeat(1, 3, 1, 1)
|
||||
input = (input-self.mean) / self.std
|
||||
target = (target-self.mean) / self.std
|
||||
if self.resize:
|
||||
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
||||
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
||||
loss = 0.0
|
||||
x = input
|
||||
y = target
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x)
|
||||
y = block(y)
|
||||
if i in feature_layers:
|
||||
loss += torch.nn.functional.l1_loss(x, y)
|
||||
if i in style_layers:
|
||||
act_x = x.reshape(x.shape[0], x.shape[1], -1)
|
||||
act_y = y.reshape(y.shape[0], y.shape[1], -1)
|
||||
gram_x = act_x @ act_x.permute(0, 2, 1)
|
||||
gram_y = act_y @ act_y.permute(0, 2, 1)
|
||||
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
|
||||
return loss
|
||||
|
||||
class VGGstyleLoss(torch.nn.Module):
|
||||
def __init__(self, resize=True):
|
||||
super(VGGstyleLoss, self).__init__()
|
||||
blocks = []
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
||||
for bl in blocks:
|
||||
for p in bl:
|
||||
p.requires_grad = False
|
||||
self.blocks = torch.nn.ModuleList(blocks)
|
||||
self.transform = torch.nn.functional.interpolate
|
||||
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
||||
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
||||
self.resize = resize
|
||||
|
||||
def forward(self, input, target, feature_layers=[0,1,2,3], style_layers=[]):
|
||||
if input.shape[1] != 3:
|
||||
input = input.repeat(1, 3, 1, 1)
|
||||
target = target.repeat(1, 3, 1, 1)
|
||||
input = (input-self.mean) / self.std
|
||||
target = (target-self.mean) / self.std
|
||||
if self.resize:
|
||||
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
||||
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
||||
loss = 0.0
|
||||
x = input
|
||||
y = target
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x)
|
||||
y = block(y)
|
||||
if i in feature_layers:
|
||||
loss += torch.nn.functional.l1_loss(x, y)
|
||||
if i in style_layers:
|
||||
act_x = x.reshape(x.shape[0], x.shape[1], -1)
|
||||
act_y = y.reshape(y.shape[0], y.shape[1], -1)
|
||||
gram_x = act_x @ act_x.permute(0, 2, 1)
|
||||
gram_y = act_y @ act_y.permute(0, 2, 1)
|
||||
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
|
||||
return loss
|
||||
82
app/service/image2sketch/models/template_model.py
Normal file
82
app/service/image2sketch/models/template_model.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
|
||||
class TemplateModel(BaseModel):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
"""Add new model-specific options and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- the option parser
|
||||
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
|
||||
if is_train:
|
||||
parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this model class.
|
||||
|
||||
Parameters:
|
||||
opt -- training/test options
|
||||
|
||||
A few things can be done here.
|
||||
- (required) call the initialization function of BaseModel
|
||||
- define loss function, visualization images, model names, and optimizers
|
||||
"""
|
||||
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
|
||||
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
|
||||
self.loss_names = ['loss_G']
|
||||
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
|
||||
self.visual_names = ['data_A', 'data_B', 'output']
|
||||
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
|
||||
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
|
||||
self.model_names = ['G']
|
||||
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
|
||||
if self.isTrain: # only defined during training time
|
||||
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
|
||||
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
|
||||
self.criterionLoss = torch.nn.L1Loss()
|
||||
# define and initialize optimizers. You can define one optimizer for each network.
|
||||
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||||
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers = [self.optimizer]
|
||||
|
||||
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input: a dictionary that contains the data itself and its metadata information.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
|
||||
self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
|
||||
self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
|
||||
self.output = self.netG(self.data_A) # generate output image given the input data_A
|
||||
|
||||
def backward(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
|
||||
# calculate loss given the input and intermediate results
|
||||
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
|
||||
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
|
||||
|
||||
def optimize_parameters(self):
|
||||
"""Update network weights; it will be called in every training iteration."""
|
||||
self.forward() # first call forward to calculate intermediate results
|
||||
self.optimizer.zero_grad() # clear network G's existing gradients
|
||||
self.backward() # calculate gradients for network G
|
||||
self.optimizer.step() # update gradients for network G
|
||||
45
app/service/image2sketch/models/test_model.py
Normal file
45
app/service/image2sketch/models/test_model.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
""" This TesteModel can be used to generate CycleGAN results for only one direction.
|
||||
This model will automatically set '--dataset_mode single', which only loads the images from one collection.
|
||||
|
||||
See the test instruction for more details.
|
||||
"""
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
assert not is_train, 'TestModel cannot be used during training time'
|
||||
parser.set_defaults(dataset_mode='single')
|
||||
parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
assert(not opt.isTrain)
|
||||
BaseModel.__init__(self, opt)
|
||||
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
||||
self.loss_names = []
|
||||
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
||||
self.visual_names = ['real', 'fake']
|
||||
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
||||
self.model_names = ['G' + opt.model_suffix] # only generator is needed.
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
|
||||
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
# assigns the model to self.netG_[suffix] so that it can be loaded
|
||||
# please see <BaseModel.load_networks>
|
||||
setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
|
||||
|
||||
def set_input(self, input):
|
||||
self.real = input['A'].to(self.device)
|
||||
self.image_paths = input['A_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass."""
|
||||
self.fake = self.netG(self.real) # G(real)
|
||||
|
||||
def optimize_parameters(self):
|
||||
"""No optimization for test model."""
|
||||
pass
|
||||
68
app/service/image2sketch/models/triplet_model.py
Normal file
68
app/service/image2sketch/models/triplet_model.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from util.image_pool import ImagePool
|
||||
|
||||
|
||||
class TripletModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
parser.set_defaults(norm='batch', netG='triplet', dataset_mode='triplet')
|
||||
if is_train:
|
||||
parser.set_defaults(pool_size=0, gan_mode='vanilla')
|
||||
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
self.loss_names = ['G_triplet']
|
||||
self.visual_names = ['x','y']
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G']
|
||||
else:
|
||||
self.model_names = ['G']
|
||||
self.netG = networks.define_G(1, 1, opt.ngf, opt.netG, opt.norm,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionL1 = torch.nn.L1Loss()
|
||||
|
||||
self.triplet = torch.nn.TripletMarginLoss(margin=3.0)
|
||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
def set_input(self, input):
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
self.real_C = input['C'].to(self.device)
|
||||
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
|
||||
|
||||
def forward(self):
|
||||
self.x,self.y,self.z = self.netG(self.real_A,self.real_B,self.real_C)
|
||||
|
||||
|
||||
def backward_G(self):
|
||||
self.loss_G_triplet_1 = self.triplet(self.x,self.y,self.z)
|
||||
self.loss_G_triplet = self.loss_G_triplet_1
|
||||
|
||||
self.loss_G = self.loss_G_triplet
|
||||
self.loss_G.backward()
|
||||
|
||||
def optimize_parameters(self):
|
||||
self.optimizer_G.zero_grad()
|
||||
self.backward_G()
|
||||
self.optimizer_G.step()
|
||||
144
app/service/image2sketch/models/unpaired_model.py
Normal file
144
app/service/image2sketch/models/unpaired_model.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import torch
|
||||
|
||||
from . import networks
|
||||
from .base_model import BaseModel
|
||||
from .perceptual import VGGPerceptualLoss
|
||||
from ..util.image_pool import ImagePool
|
||||
|
||||
|
||||
class UnpairedModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
parser.set_defaults(norm='batch', netG='ref_unpair_cbam_cat', netG2='ref_unpair_recon', dataset_mode='unaligned')
|
||||
if is_train:
|
||||
parser.set_defaults(pool_size=0, gan_mode='vanilla')
|
||||
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
||||
self.loss_names = ['G_GAN', 'G_L1_1', 'G_Rec', 'G_line', 'D_real', 'D_fake']
|
||||
self.visual_names = ['real_A', 'content_output', 'real_B']
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G_A', 'G_B', 'D']
|
||||
else: # during test time, only load G
|
||||
self.model_names = ['G_A', 'G_B']
|
||||
# define networks (both generator and discriminator)
|
||||
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG2, opt.norm,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
||||
self.netD = networks.define_D(1, opt.ndf, opt.netD,
|
||||
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
self.styletps = networks.define_styletps(init_weights_='./checkpoints/contrastive_pretrained.pth', gpu_ids_=self.gpu_ids, shape=False)
|
||||
self.HED = networks.define_HED(init_weights_='./checkpoints/network-bsds500.pytorch', gpu_ids_=self.gpu_ids)
|
||||
|
||||
if self.isTrain: # define discriminators
|
||||
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
|
||||
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
|
||||
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
if self.isTrain:
|
||||
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
# define loss functions
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionL1_1 = torch.nn.L1Loss()
|
||||
self.criterionL1_2 = torch.nn.L1Loss()
|
||||
self.criterionL1_3 = torch.nn.L1Loss()
|
||||
self.per_loss_1 = VGGPerceptualLoss().to(self.device)
|
||||
self.per_loss_2 = VGGPerceptualLoss().to(self.device)
|
||||
self.per_loss_3 = VGGPerceptualLoss().to(self.device)
|
||||
|
||||
self.optimizer_GA = torch.optim.Adam(self.netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizer_GB = torch.optim.Adam(self.netG_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers.append(self.optimizer_GA)
|
||||
self.optimizers.append(self.optimizer_GB)
|
||||
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
|
||||
The option 'direction' can be used to swap images in domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
# self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.content_output = self.netG_A(self.real_A, self.real_B)
|
||||
self.rec_output = self.netG_B(self.content_output, self.content_output)
|
||||
|
||||
def update_process(self, epoch, total_epoch):
|
||||
self.epoch_count = epoch
|
||||
self.epoch_count_total = total_epoch
|
||||
|
||||
def backward_D(self):
|
||||
"""Calculate GAN loss for the discriminator
|
||||
|
||||
Parameters:
|
||||
netD (network) -- the discriminator D
|
||||
real (tensor array) -- real images
|
||||
fake (tensor array) -- images generated by a generator
|
||||
|
||||
Return the discriminator loss.
|
||||
We also call loss_D.backward() to calculate the gradients.
|
||||
"""
|
||||
# Real
|
||||
pred_real = self.netD(self.real_B)
|
||||
self.loss_D_real = self.criterionGAN(pred_real, True)
|
||||
# Fake
|
||||
pred_fake = self.netD(self.content_output.detach())
|
||||
self.loss_D_fake = self.criterionGAN(pred_fake, False)
|
||||
# Combined loss and calculate gradients
|
||||
loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5
|
||||
loss_D.backward()
|
||||
return loss_D
|
||||
|
||||
def backward_G(self):
|
||||
"""Calculate GAN and L1 loss for the generator"""
|
||||
|
||||
pred_fake = self.netD(self.content_output)
|
||||
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
|
||||
self.content_output_line = self.HED(self.real_A)
|
||||
self.rec_output_line = self.HED(self.rec_output)
|
||||
self.t1, self.t2, _ = self.styletps(self.content_output, self.real_B, self.real_B)
|
||||
|
||||
decay_lambda = 5 - ((self.epoch_count * 4.5) / self.epoch_count_total)
|
||||
self.loss_G_L1_1 = self.criterionL1_1(self.t1, self.t2) * 10
|
||||
self.loss_G_Rec = self.per_loss_2(self.real_A, self.rec_output) * decay_lambda
|
||||
self.loss_G_line = self.per_loss_3(self.content_output_line, self.rec_output_line) * decay_lambda
|
||||
|
||||
self.loss_G = self.loss_G_GAN + self.loss_G_L1_1 + self.loss_G_Rec + self.loss_G_line
|
||||
self.loss_G.backward()
|
||||
|
||||
def optimize_parameters(self):
|
||||
self.forward() # compute fake images: G(A)
|
||||
# update D
|
||||
self.set_requires_grad(self.netD, True) # enable backprop for D
|
||||
self.optimizer_D.zero_grad() # set D's gradients to zero
|
||||
self.backward_D() # calculate gradients for backward_D_unsuper
|
||||
self.optimizer_D.step() # update D's weights
|
||||
# update G
|
||||
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
|
||||
self.optimizer_GA.zero_grad() # set G's gradients to zero
|
||||
self.optimizer_GB.zero_grad() # set G's gradients to zero
|
||||
self.backward_G() # calculate graidents for G
|
||||
self.optimizer_GA.step() # udpate G's weights
|
||||
self.optimizer_GB.step() # udpate G's weights
|
||||
57
app/service/image2sketch/opt.py
Normal file
57
app/service/image2sketch/opt.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from app.core.config import DEBUG
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# 基本参数
|
||||
self.dataroot = "app/service/image2sketch/datasets/ref_unpair"
|
||||
self.name = 'semi_unpair'
|
||||
self.gpu_ids = [0]
|
||||
# 模型参数
|
||||
self.model = 'unpaired'
|
||||
self.input_nc = 3
|
||||
self.output_nc = 3
|
||||
self.ngf = 64
|
||||
self.ndf = 64
|
||||
self.netD = 'basic'
|
||||
self.netG = 'ref_unpair_cbam_cat'
|
||||
self.netG2 = 'ref_unpair_recon'
|
||||
self.n_layers_D = 3
|
||||
self.norm = 'instance'
|
||||
self.init_type = 'normal'
|
||||
self.init_gain = 0.02
|
||||
self.no_dropout = False # 对应 `--no_dropout`
|
||||
# 数据集参数
|
||||
self.dataset_mode = 'single'
|
||||
self.direction = 'AtoB'
|
||||
self.serial_batches = True # 对应 `--serial_batches`
|
||||
self.num_threads = 4
|
||||
self.batch_size = 4
|
||||
self.load_size = 512
|
||||
self.crop_size = 512
|
||||
self.max_dataset_size = float("inf")
|
||||
self.preprocess = 'resize_and_crop'
|
||||
self.no_flip = False # 对应 `--no_flip`
|
||||
self.display_winsize = 256
|
||||
# 额外参数
|
||||
self.epoch = '100'
|
||||
self.load_iter = 0
|
||||
self.verbose = False # 对应 `--verbose`
|
||||
self.suffix = ''
|
||||
self.isTrain = False
|
||||
self.results_dir = 'service/image2sketch/results'
|
||||
self.aspect_ratio = 1.0
|
||||
self.phase = 'test'
|
||||
self.eval = False
|
||||
self.num_test = 1000
|
||||
self.morm = 'batch'
|
||||
if DEBUG:
|
||||
self.style_image1 = "service/image2sketch/datasets/ref_unpair/testC/style_1.jpg"
|
||||
self.style_image2 = "service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg"
|
||||
self.style_image3 = "service/image2sketch/datasets/ref_unpair/testC/style_3.png"
|
||||
self.checkpoints_dir = 'service/image2sketch/checkpoints/'
|
||||
else:
|
||||
self.checkpoints_dir = 'app/service/image2sketch/checkpoints/'
|
||||
self.style_image1 = "app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg"
|
||||
self.style_image2 = "app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg"
|
||||
self.style_image3 = "app/service/image2sketch/datasets/ref_unpair/testC/style_3.png"
|
||||
88
app/service/image2sketch/server.py
Normal file
88
app/service/image2sketch/server.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
from app.schemas.image2sketch import Image2SketchModel
|
||||
from app.service.image2sketch.infer import tensor2im
|
||||
from app.service.image2sketch.models import create_model
|
||||
from app.service.image2sketch.opt import Config
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def tensor2im(input_image, imtype=np.uint8):
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
||||
image_tensor = input_image.data
|
||||
else:
|
||||
return input_image
|
||||
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
||||
if image_numpy.shape[0] == 1: # grayscale to RGB
|
||||
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
||||
else: # if it is a numpy array, do nothing
|
||||
image_numpy = input_image
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
|
||||
class Image2SketchServer:
|
||||
def __init__(self, request_data):
|
||||
self.image_url = request_data.image_url
|
||||
self.style_image_url = request_data.style_image_url
|
||||
self.sketch_bucket = request_data.sketch_bucket
|
||||
self.sketch_name = request_data.sketch_name
|
||||
self.opt = Config()
|
||||
self.opt.num_threads = 0 # test code only supports num_threads = 0
|
||||
self.opt.batch_size = 1 # test code only supports batch_size = 1
|
||||
self.opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
||||
self.opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
|
||||
self.opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
|
||||
self.data = {}
|
||||
device = torch.device("cuda:0")
|
||||
self.model = create_model(self.opt)
|
||||
self.model.setup(self.opt)
|
||||
transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
|
||||
transform = transforms.Compose(transform_list)
|
||||
if request_data.default_style == "1":
|
||||
style_img = Image.open(self.opt.style_image1).convert('L')
|
||||
elif request_data.default_style == "2":
|
||||
style_img = Image.open(self.opt.style_image2).convert('L')
|
||||
elif request_data.default_style == "3":
|
||||
style_img = Image.open(self.opt.style_image3).convert('L')
|
||||
else:
|
||||
style_img = oss_get_image(bucket=self.style_image_url.split('/')[0], object_name=self.style_image_url[self.style_image_url.find('/') + 1:], data_type="PIL")
|
||||
style_img = style_img.convert('L')
|
||||
style_img = transform(style_img)
|
||||
self.data['B'] = style_img
|
||||
self.data['B'] = self.data['B'].unsqueeze(0).to(device)
|
||||
A, self.width, self.height = self.get_image(self.image_url)
|
||||
self.data['A'] = transform(A)
|
||||
self.data['A'] = self.data['A'].unsqueeze(0).to(device)
|
||||
|
||||
def get_result(self):
|
||||
self.model.set_input(self.data)
|
||||
self.model.test() # run inference
|
||||
visuals = self.model.get_current_visuals() # get image results
|
||||
image_numpy = tensor2im(visuals['content_output'].cpu())
|
||||
image_bytes = cv2.imencode(".jpg", image_numpy)[1].tobytes()
|
||||
req = oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes)
|
||||
return f"{req.bucket_name}/{req.object_name}"
|
||||
|
||||
def get_image(self, image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
image = image.convert('RGB')
|
||||
width = image.size[0]
|
||||
height = image.size[1]
|
||||
return image, width, height
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = Image2SketchModel(image_url="test/real_Dress_790b2c6e370644e134df7abdfe7e54d9.jpg_Img.jpg", sketch_bucket="test", sketch_name="test123.jpg")
|
||||
server = Image2SketchServer(data)
|
||||
sketch_url = server.get_result()
|
||||
print(sketch_url)
|
||||
1
app/service/image2sketch/util/__init__.py
Normal file
1
app/service/image2sketch/util/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""This package includes a miscellaneous collection of useful helper functions."""
|
||||
110
app/service/image2sketch/util/get_data.py
Normal file
110
app/service/image2sketch/util/get_data.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import tarfile
|
||||
import requests
|
||||
from warnings import warn
|
||||
from zipfile import ZipFile
|
||||
from bs4 import BeautifulSoup
|
||||
from os.path import abspath, isdir, join, basename
|
||||
|
||||
|
||||
class GetData(object):
|
||||
"""A Python script for downloading CycleGAN or pix2pix datasets.
|
||||
|
||||
Parameters:
|
||||
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
|
||||
verbose (bool) -- If True, print additional information.
|
||||
|
||||
Examples:
|
||||
>>> from util.get_data import GetData
|
||||
>>> gd = GetData(technique='cyclegan')
|
||||
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
|
||||
|
||||
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
|
||||
and 'scripts/download_cyclegan_model.sh'.
|
||||
"""
|
||||
|
||||
def __init__(self, technique='cyclegan', verbose=True):
|
||||
url_dict = {
|
||||
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
|
||||
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
|
||||
}
|
||||
self.url = url_dict.get(technique.lower())
|
||||
self._verbose = verbose
|
||||
|
||||
def _print(self, text):
|
||||
if self._verbose:
|
||||
print(text)
|
||||
|
||||
@staticmethod
|
||||
def _get_options(r):
|
||||
soup = BeautifulSoup(r.text, 'lxml')
|
||||
options = [h.text for h in soup.find_all('a', href=True)
|
||||
if h.text.endswith(('.zip', 'tar.gz'))]
|
||||
return options
|
||||
|
||||
def _present_options(self):
|
||||
r = requests.get(self.url)
|
||||
options = self._get_options(r)
|
||||
print('Options:\n')
|
||||
for i, o in enumerate(options):
|
||||
print("{0}: {1}".format(i, o))
|
||||
choice = input("\nPlease enter the number of the "
|
||||
"dataset above you wish to download:")
|
||||
return options[int(choice)]
|
||||
|
||||
def _download_data(self, dataset_url, save_path):
|
||||
if not isdir(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
base = basename(dataset_url)
|
||||
temp_save_path = join(save_path, base)
|
||||
|
||||
with open(temp_save_path, "wb") as f:
|
||||
r = requests.get(dataset_url)
|
||||
f.write(r.content)
|
||||
|
||||
if base.endswith('.tar.gz'):
|
||||
obj = tarfile.open(temp_save_path)
|
||||
elif base.endswith('.zip'):
|
||||
obj = ZipFile(temp_save_path, 'r')
|
||||
else:
|
||||
raise ValueError("Unknown File Type: {0}.".format(base))
|
||||
|
||||
self._print("Unpacking Data...")
|
||||
obj.extractall(save_path)
|
||||
obj.close()
|
||||
os.remove(temp_save_path)
|
||||
|
||||
def get(self, save_path, dataset=None):
|
||||
"""
|
||||
|
||||
Download a dataset.
|
||||
|
||||
Parameters:
|
||||
save_path (str) -- A directory to save the data to.
|
||||
dataset (str) -- (optional). A specific dataset to download.
|
||||
Note: this must include the file extension.
|
||||
If None, options will be presented for you
|
||||
to choose from.
|
||||
|
||||
Returns:
|
||||
save_path_full (str) -- the absolute path to the downloaded data.
|
||||
|
||||
"""
|
||||
if dataset is None:
|
||||
selected_dataset = self._present_options()
|
||||
else:
|
||||
selected_dataset = dataset
|
||||
|
||||
save_path_full = join(save_path, selected_dataset.split('.')[0])
|
||||
|
||||
if isdir(save_path_full):
|
||||
warn("\n'{0}' already exists. Voiding Download.".format(
|
||||
save_path_full))
|
||||
else:
|
||||
self._print('Downloading Data...')
|
||||
url = "{0}/{1}".format(self.url, selected_dataset)
|
||||
self._download_data(url, save_path=save_path)
|
||||
|
||||
return abspath(save_path_full)
|
||||
86
app/service/image2sketch/util/html.py
Normal file
86
app/service/image2sketch/util/html.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import dominate
|
||||
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
||||
import os
|
||||
|
||||
|
||||
class HTML:
|
||||
"""This HTML class allows us to save images and write texts into a single HTML file.
|
||||
|
||||
It consists of functions such as <add_header> (add a text header to the HTML file),
|
||||
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
||||
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
||||
"""
|
||||
|
||||
def __init__(self, web_dir, title, refresh=0):
|
||||
"""Initialize the HTML classes
|
||||
|
||||
Parameters:
|
||||
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
||||
title (str) -- the webpage name
|
||||
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
||||
"""
|
||||
self.title = title
|
||||
self.web_dir = web_dir
|
||||
self.img_dir = os.path.join(self.web_dir, 'images')
|
||||
if not os.path.exists(self.web_dir):
|
||||
os.makedirs(self.web_dir)
|
||||
if not os.path.exists(self.img_dir):
|
||||
os.makedirs(self.img_dir)
|
||||
|
||||
self.doc = dominate.document(title=title)
|
||||
if refresh > 0:
|
||||
with self.doc.head:
|
||||
meta(http_equiv="refresh", content=str(refresh))
|
||||
|
||||
def get_image_dir(self):
|
||||
"""Return the directory that stores images"""
|
||||
return self.img_dir
|
||||
|
||||
def add_header(self, text):
|
||||
"""Insert a header to the HTML file
|
||||
|
||||
Parameters:
|
||||
text (str) -- the header text
|
||||
"""
|
||||
with self.doc:
|
||||
h3(text)
|
||||
|
||||
def add_images(self, ims, txts, links, width=400):
|
||||
"""add images to the HTML file
|
||||
|
||||
Parameters:
|
||||
ims (str list) -- a list of image paths
|
||||
txts (str list) -- a list of image names shown on the website
|
||||
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
||||
"""
|
||||
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
||||
self.doc.add(self.t)
|
||||
with self.t:
|
||||
with tr():
|
||||
for im, txt, link in zip(ims, txts, links):
|
||||
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
||||
with p():
|
||||
with a(href=os.path.join('images', link)):
|
||||
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
||||
br()
|
||||
p(txt)
|
||||
|
||||
def save(self):
|
||||
"""save the current content to the HMTL file"""
|
||||
html_file = '%s/index.html' % self.web_dir
|
||||
f = open(html_file, 'wt')
|
||||
f.write(self.doc.render())
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == '__main__': # we show an example usage here.
|
||||
html = HTML('web/', 'test_html')
|
||||
html.add_header('hello world')
|
||||
|
||||
ims, txts, links = [], [], []
|
||||
for n in range(4):
|
||||
ims.append('image_%d.png' % n)
|
||||
txts.append('text_%d' % n)
|
||||
links.append('image_%d.png' % n)
|
||||
html.add_images(ims, txts, links)
|
||||
html.save()
|
||||
54
app/service/image2sketch/util/image_pool.py
Normal file
54
app/service/image2sketch/util/image_pool.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import random
|
||||
import torch
|
||||
|
||||
|
||||
class ImagePool():
|
||||
"""This class implements an image buffer that stores previously generated images.
|
||||
|
||||
This buffer enables us to update discriminators using a history of generated images
|
||||
rather than the ones produced by the latest generators.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_size):
|
||||
"""Initialize the ImagePool class
|
||||
|
||||
Parameters:
|
||||
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
||||
"""
|
||||
self.pool_size = pool_size
|
||||
if self.pool_size > 0: # create an empty pool
|
||||
self.num_imgs = 0
|
||||
self.images = []
|
||||
|
||||
def query(self, images):
|
||||
"""Return an image from the pool.
|
||||
|
||||
Parameters:
|
||||
images: the latest generated images from the generator
|
||||
|
||||
Returns images from the buffer.
|
||||
|
||||
By 50/100, the buffer will return input images.
|
||||
By 50/100, the buffer will return images previously stored in the buffer,
|
||||
and insert the current images to the buffer.
|
||||
"""
|
||||
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
||||
return images
|
||||
return_images = []
|
||||
for image in images:
|
||||
image = torch.unsqueeze(image.data, 0)
|
||||
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
|
||||
self.num_imgs = self.num_imgs + 1
|
||||
self.images.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
p = random.uniform(0, 1)
|
||||
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
|
||||
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
||||
tmp = self.images[random_id].clone()
|
||||
self.images[random_id] = image
|
||||
return_images.append(tmp)
|
||||
else: # by another 50% chance, the buffer will return the current image
|
||||
return_images.append(image)
|
||||
return_images = torch.cat(return_images, 0) # collect all the images and return
|
||||
return return_images
|
||||
103
app/service/image2sketch/util/util.py
Normal file
103
app/service/image2sketch/util/util.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""This module contains simple helper functions """
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
|
||||
def tensor2im(input_image, imtype=np.uint8):
|
||||
""""Converts a Tensor array into a numpy image array.
|
||||
|
||||
Parameters:
|
||||
input_image (tensor) -- the input image tensor array
|
||||
imtype (type) -- the desired type of the converted numpy array
|
||||
"""
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
||||
image_tensor = input_image.data
|
||||
else:
|
||||
return input_image
|
||||
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
||||
if image_numpy.shape[0] == 1: # grayscale to RGB
|
||||
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
||||
else: # if it is a numpy array, do nothing
|
||||
image_numpy = input_image
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
|
||||
def diagnose_network(net, name='network'):
|
||||
"""Calculate and print the mean of average absolute(gradients)
|
||||
|
||||
Parameters:
|
||||
net (torch network) -- Torch network
|
||||
name (str) -- the name of the network
|
||||
"""
|
||||
mean = 0.0
|
||||
count = 0
|
||||
for param in net.parameters():
|
||||
if param.grad is not None:
|
||||
mean += torch.mean(torch.abs(param.grad.data))
|
||||
count += 1
|
||||
if count > 0:
|
||||
mean = mean / count
|
||||
print(name)
|
||||
print(mean)
|
||||
|
||||
|
||||
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
||||
"""Save a numpy image to the disk
|
||||
|
||||
Parameters:
|
||||
image_numpy (numpy array) -- input numpy array
|
||||
image_path (str) -- the path of the image
|
||||
"""
|
||||
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
h, w, _ = image_numpy.shape
|
||||
|
||||
if aspect_ratio > 1.0:
|
||||
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
||||
if aspect_ratio < 1.0:
|
||||
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
||||
image_pil.save(image_path)
|
||||
|
||||
|
||||
def print_numpy(x, val=True, shp=False):
|
||||
"""Print the mean, min, max, median, std, and size of a numpy array
|
||||
|
||||
Parameters:
|
||||
val (bool) -- if print the values of the numpy array
|
||||
shp (bool) -- if print the shape of the numpy array
|
||||
"""
|
||||
x = x.astype(np.float64)
|
||||
if shp:
|
||||
print('shape,', x.shape)
|
||||
if val:
|
||||
x = x.flatten()
|
||||
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
||||
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
"""create empty directories if they don't exist
|
||||
|
||||
Parameters:
|
||||
paths (str list) -- a list of directory paths
|
||||
"""
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
"""create a single empty directory if it didn't exist
|
||||
|
||||
Parameters:
|
||||
path (str) -- a single directory path
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
223
app/service/image2sketch/util/visualizer.py
Normal file
223
app/service/image2sketch/util/visualizer.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import ntpath
|
||||
import time
|
||||
from . import util, html
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
VisdomExceptionBase = Exception
|
||||
else:
|
||||
VisdomExceptionBase = ConnectionError
|
||||
|
||||
|
||||
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
||||
"""Save images to the disk.
|
||||
|
||||
Parameters:
|
||||
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
||||
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
||||
image_path (str) -- the string is used to create image paths
|
||||
aspect_ratio (float) -- the aspect ratio of saved images
|
||||
width (int) -- the images will be resized to width x width
|
||||
|
||||
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
||||
"""
|
||||
image_dir = webpage.get_image_dir()
|
||||
short_path = ntpath.basename(image_path[0])
|
||||
name = os.path.splitext(short_path)[0]
|
||||
|
||||
webpage.add_header(name)
|
||||
ims, txts, links = [], [], []
|
||||
|
||||
for label, im_data in visuals.items():
|
||||
im = util.tensor2im(im_data)
|
||||
image_name = '%s_%s.png' % (name, label)
|
||||
save_path = os.path.join(image_dir, image_name)
|
||||
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
|
||||
ims.append(image_name)
|
||||
txts.append(label)
|
||||
links.append(image_name)
|
||||
webpage.add_images(ims, txts, links, width=width)
|
||||
|
||||
|
||||
class Visualizer():
|
||||
"""This class includes several functions that can display/save images and print/save logging information.
|
||||
|
||||
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the Visualizer class
|
||||
|
||||
Parameters:
|
||||
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
Step 1: Cache the training/test options
|
||||
Step 2: connect to a visdom server
|
||||
Step 3: create an HTML object for saveing HTML filters
|
||||
Step 4: create a logging file to store training losses
|
||||
"""
|
||||
self.opt = opt # cache the option
|
||||
self.display_id = opt.display_id
|
||||
self.use_html = opt.isTrain and not opt.no_html
|
||||
self.win_size = opt.display_winsize
|
||||
self.name = opt.name
|
||||
self.port = opt.display_port
|
||||
self.saved = False
|
||||
'''
|
||||
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
||||
import visdom
|
||||
self.ncols = opt.display_ncols
|
||||
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
||||
if not self.vis.check_connection():
|
||||
self.create_visdom_connections()
|
||||
'''
|
||||
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
||||
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
||||
self.img_dir = os.path.join(self.web_dir, 'images')
|
||||
print('create web directory %s...' % self.web_dir)
|
||||
util.mkdirs([self.web_dir, self.img_dir])
|
||||
# create a logging file to store training losses
|
||||
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
||||
with open(self.log_name, "a") as log_file:
|
||||
now = time.strftime("%c")
|
||||
log_file.write('================ Training Loss (%s) ================\n' % now)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the self.saved status"""
|
||||
self.saved = False
|
||||
'''
|
||||
def create_visdom_connections(self):
|
||||
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
||||
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
||||
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
||||
print('Command: %s' % cmd)
|
||||
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
||||
|
||||
def display_current_results(self, visuals, epoch, save_result):
|
||||
"""Display current results on visdom; save current results to an HTML file.
|
||||
|
||||
Parameters:
|
||||
visuals (OrderedDict) - - dictionary of images to display or save
|
||||
epoch (int) - - the current epoch
|
||||
save_result (bool) - - if save the current results to an HTML file
|
||||
"""
|
||||
if self.display_id > 0: # show images in the browser using visdom
|
||||
ncols = self.ncols
|
||||
if ncols > 0: # show all the images in one visdom panel
|
||||
ncols = min(ncols, len(visuals))
|
||||
h, w = next(iter(visuals.values())).shape[:2]
|
||||
table_css = """<style>
|
||||
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
||||
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
||||
</style>""" % (w, h) # create a table css
|
||||
# create a table of images.
|
||||
title = self.name
|
||||
label_html = ''
|
||||
label_html_row = ''
|
||||
images = []
|
||||
idx = 0
|
||||
for label, image in visuals.items():
|
||||
image_numpy = util.tensor2im(image)
|
||||
label_html_row += '<td>%s</td>' % label
|
||||
images.append(image_numpy.transpose([2, 0, 1]))
|
||||
idx += 1
|
||||
if idx % ncols == 0:
|
||||
label_html += '<tr>%s</tr>' % label_html_row
|
||||
label_html_row = ''
|
||||
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
||||
while idx % ncols != 0:
|
||||
images.append(white_image)
|
||||
label_html_row += '<td></td>'
|
||||
idx += 1
|
||||
if label_html_row != '':
|
||||
label_html += '<tr>%s</tr>' % label_html_row
|
||||
try:
|
||||
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
|
||||
padding=2, opts=dict(title=title + ' images'))
|
||||
label_html = '<table>%s</table>' % label_html
|
||||
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
||||
opts=dict(title=title + ' labels'))
|
||||
except VisdomExceptionBase:
|
||||
self.create_visdom_connections()
|
||||
|
||||
else: # show each image in a separate visdom panel;
|
||||
idx = 1
|
||||
try:
|
||||
for label, image in visuals.items():
|
||||
image_numpy = util.tensor2im(image)
|
||||
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
|
||||
win=self.display_id + idx)
|
||||
idx += 1
|
||||
except VisdomExceptionBase:
|
||||
self.create_visdom_connections()
|
||||
|
||||
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
||||
self.saved = True
|
||||
# save images to the disk
|
||||
for label, image in visuals.items():
|
||||
image_numpy = util.tensor2im(image)
|
||||
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
||||
util.save_image(image_numpy, img_path)
|
||||
|
||||
# update website
|
||||
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
|
||||
for n in range(epoch, 0, -1):
|
||||
webpage.add_header('epoch [%d]' % n)
|
||||
ims, txts, links = [], [], []
|
||||
|
||||
for label, image_numpy in visuals.items():
|
||||
image_numpy = util.tensor2im(image)
|
||||
img_path = 'epoch%.3d_%s.png' % (n, label)
|
||||
ims.append(img_path)
|
||||
txts.append(label)
|
||||
links.append(img_path)
|
||||
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
webpage.save()
|
||||
'''
|
||||
def plot_current_losses(self, epoch, counter_ratio, losses):
|
||||
"""display the current losses on visdom display: dictionary of error labels and values
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch
|
||||
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
||||
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
||||
"""
|
||||
if not hasattr(self, 'plot_data'):
|
||||
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
||||
self.plot_data['X'].append(epoch + counter_ratio)
|
||||
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
|
||||
'''
|
||||
try:
|
||||
self.vis.line(
|
||||
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
||||
Y=np.array(self.plot_data['Y']),
|
||||
opts={
|
||||
'title': self.name + ' loss over time',
|
||||
'legend': self.plot_data['legend'],
|
||||
'xlabel': 'epoch',
|
||||
'ylabel': 'loss'},
|
||||
win=self.display_id)
|
||||
except VisdomExceptionBase:
|
||||
self.create_visdom_connections()
|
||||
'''
|
||||
# losses: same format as |losses| of plot_current_losses
|
||||
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
||||
"""print current losses on console; also save the losses to the disk
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch
|
||||
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
||||
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
||||
t_comp (float) -- computational time per data point (normalized by batch_size)
|
||||
t_data (float) -- data loading time per data point (normalized by batch_size)
|
||||
"""
|
||||
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
||||
for k, v in losses.items():
|
||||
message += '%s: %.3f ' % (k, v)
|
||||
|
||||
print(message) # print the message
|
||||
with open(self.log_name, "a") as log_file:
|
||||
log_file.write('%s\n' % message) # save the message
|
||||
45
app/service/image2sketch_2/download_checkpoints.py
Normal file
45
app/service/image2sketch_2/download_checkpoints.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
MINIO_URL = "www.minio.aida.com.hk:12024"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
# 配置MinIO客户端
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
# 下载函数
|
||||
def download_folder(bucket_name, folder_name, local_dir):
|
||||
try:
|
||||
# 确保本地目录存在
|
||||
if not os.path.exists(local_dir):
|
||||
os.makedirs(local_dir)
|
||||
|
||||
# 遍历MinIO中的文件
|
||||
objects = minio_client.list_objects(bucket_name, prefix=folder_name, recursive=True)
|
||||
for obj in objects:
|
||||
# 构造本地文件路径
|
||||
local_file_path = os.path.join(local_dir, obj.object_name[len(folder_name):])
|
||||
local_file_dir = os.path.dirname(local_file_path)
|
||||
|
||||
# 确保本地目录存在
|
||||
if not os.path.exists(local_file_dir):
|
||||
os.makedirs(local_file_dir)
|
||||
|
||||
# 下载文件
|
||||
minio_client.fget_object(bucket_name, obj.object_name, local_file_path)
|
||||
print(f"Downloaded {obj.object_name} to {local_file_path}")
|
||||
|
||||
except S3Error as e:
|
||||
print(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# 使用示例
|
||||
bucket_name = "test" # 替换成你的bucket名称
|
||||
folder_name = "checkpoints/lineart/" # 权重文件夹的路径
|
||||
local_dir = "app/service/image2sketch_2" # 替换成你希望保存到的本地目录
|
||||
|
||||
download_folder(bucket_name, folder_name, local_dir)
|
||||
142
app/service/image2sketch_2/server.py
Normal file
142
app/service/image2sketch_2/server.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import cv2
|
||||
import numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
norm_layer = nn.InstanceNorm2d
|
||||
|
||||
weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_features):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
conv_block = [nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(in_features, in_features, 3),
|
||||
norm_layer(in_features),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(in_features, in_features, 3),
|
||||
norm_layer(in_features)
|
||||
]
|
||||
|
||||
self.conv_block = nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.conv_block(x)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
# Initial convolution block
|
||||
model0 = [nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(input_nc, 64, 7),
|
||||
norm_layer(64),
|
||||
nn.ReLU(inplace=True)]
|
||||
self.model0 = nn.Sequential(*model0)
|
||||
|
||||
# Downsampling
|
||||
model1 = []
|
||||
in_features = 64
|
||||
out_features = in_features * 2
|
||||
for _ in range(2):
|
||||
model1 += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
||||
norm_layer(out_features),
|
||||
nn.ReLU(inplace=True)]
|
||||
in_features = out_features
|
||||
out_features = in_features * 2
|
||||
self.model1 = nn.Sequential(*model1)
|
||||
|
||||
model2 = []
|
||||
# Residual blocks
|
||||
for _ in range(n_residual_blocks):
|
||||
model2 += [ResidualBlock(in_features)]
|
||||
self.model2 = nn.Sequential(*model2)
|
||||
|
||||
# Upsampling
|
||||
model3 = []
|
||||
out_features = in_features // 2
|
||||
for _ in range(2):
|
||||
model3 += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
||||
norm_layer(out_features),
|
||||
nn.ReLU(inplace=True)]
|
||||
in_features = out_features
|
||||
out_features = in_features // 2
|
||||
self.model3 = nn.Sequential(*model3)
|
||||
|
||||
# Output layer
|
||||
model4 = [nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(64, output_nc, 7)]
|
||||
if sigmoid:
|
||||
model4 += [nn.Sigmoid()]
|
||||
|
||||
self.model4 = nn.Sequential(*model4)
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
out = self.model0(x)
|
||||
out = self.model1(out)
|
||||
out = self.model2(out)
|
||||
out = self.model3(out)
|
||||
out = self.model4(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
model1 = Generator(3, 1, 3)
|
||||
model1.load_state_dict(torch.load('app/service/image2sketch_2/model.pth', map_location=torch.device('cpu')))
|
||||
model1.eval()
|
||||
|
||||
|
||||
def predict(input_img, width):
|
||||
transform = transforms.Compose([transforms.Resize(width, Image.BICUBIC), transforms.ToTensor()])
|
||||
input_img = transform(input_img)
|
||||
input_img = torch.unsqueeze(input_img, 0)
|
||||
|
||||
with torch.no_grad():
|
||||
drawing = model1(input_img)[0].detach()
|
||||
|
||||
drawing = transforms.ToPILImage()(drawing)
|
||||
|
||||
# 转ndarray
|
||||
drawing = numpy.array(drawing)
|
||||
return drawing
|
||||
|
||||
|
||||
def get_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
image = image.convert('RGB')
|
||||
width = image.size[0]
|
||||
height = image.size[1]
|
||||
return image, width, height
|
||||
|
||||
|
||||
def processing_pipeline(image_url, thickness, sketch_bucket, sketch_name):
|
||||
thickness = int(thickness)
|
||||
# 提取sketch
|
||||
image, width, height = get_image(image_url)
|
||||
sketch_image = predict(image, width)
|
||||
|
||||
# 设定线条粗细
|
||||
if thickness != 0:
|
||||
dilated = cv2.erode(sketch_image, kernel, iterations=1)
|
||||
# 将原图与膨胀后的图像进行混合,使用不同的权重
|
||||
sketch_image = cv2.addWeighted(sketch_image, weights[thickness][0], dilated, weights[thickness][1], 0)
|
||||
|
||||
# 上传minio
|
||||
image_bytes = cv2.imencode(".jpg", sketch_image)[1].tobytes()
|
||||
req = oss_upload_image(bucket=sketch_bucket, object_name=sketch_name, image_bytes=image_bytes)
|
||||
return f"{req.bucket_name}/{req.object_name}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
result_url = processing_pipeline("aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png", 1, "test", "test123.jpg")
|
||||
print(result_url)
|
||||
99
app/service/lineart/service.py
Normal file
99
app/service/lineart/service.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import DESIGN_MODEL_URL
|
||||
from app.schemas.image2sketch import Image2SketchModel
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class LineArtService:
|
||||
def __init__(self, request_item):
|
||||
self.line_style = int(request_item.default_style)
|
||||
self.image_url = request_item.image_url
|
||||
self.sketch_bucket = request_item.sketch_bucket
|
||||
self.sketch_name = request_item.sketch_name
|
||||
self.weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)]
|
||||
|
||||
def get_result(self):
|
||||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
input_image = self.get_image()
|
||||
input_img, ori_shape = self.line_art_preprocess(input_image)
|
||||
transformed_img = input_img.astype(np.float32)
|
||||
|
||||
inputs = [httpclient.InferInput(f"input__0", transformed_img.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
outputs = [httpclient.InferRequestedOutput(f"output__0", binary_data=True)]
|
||||
results = client.infer(model_name=f"lineart", inputs=inputs, outputs=outputs)
|
||||
inference_output1 = results.as_numpy("output__0")
|
||||
line_art_result = self.line_art_postprocess(inference_output1, ori_shape)
|
||||
|
||||
line_art_result = (line_art_result[0] * 255.0).round().astype(np.uint8)
|
||||
if self.line_style != 0:
|
||||
logger.info(self.line_style)
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
dilated = cv2.erode(line_art_result, kernel, iterations=1)
|
||||
# 将原图与膨胀后的图像进行混合,使用不同的权重
|
||||
line_art_result = cv2.addWeighted(line_art_result, self.weights[self.line_style][0], dilated, self.weights[self.line_style][1], 0)
|
||||
# cv2.imshow("", line_art_result)
|
||||
# cv2.waitKey(0)
|
||||
return self.put_image(line_art_result)
|
||||
|
||||
def get_image(self):
|
||||
image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
||||
# 将其转换为彩色图像
|
||||
if len(image.shape) == 3 and image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
||||
elif len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
return image
|
||||
|
||||
def put_image(self, image):
|
||||
try:
|
||||
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||
oss_upload_image(bucket=self.sketch_bucket, object_name=f"{self.sketch_name}.jpg", image_bytes=image_bytes)
|
||||
return f"{self.sketch_bucket}/{self.sketch_name}.jpg"
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
|
||||
@staticmethod
|
||||
def line_art_preprocess(image):
|
||||
img = mmcv.imread(image)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale_w, img_scale_h = ori_shape
|
||||
if ori_shape[0] > 1024:
|
||||
img_scale_w = 1024
|
||||
if ori_shape[1] > 1024:
|
||||
img_scale_h = 1024
|
||||
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||||
if ori_shape != (img_scale_w, img_scale_h):
|
||||
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||||
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, ori_shape
|
||||
|
||||
@staticmethod
|
||||
def line_art_postprocess(output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_item = Image2SketchModel(
|
||||
image_url="aida-collection-element/87/Sketchboard/555a443f-fd6b-4cd7-8147-b92d55513af0.png",
|
||||
default_style="4",
|
||||
sketch_bucket="test",
|
||||
sketch_name="test123"
|
||||
)
|
||||
service = LineArtService(request_item)
|
||||
result_url = service.get_result()
|
||||
print(result_url)
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
||||
def RunTime(func):
|
||||
@@ -7,8 +7,22 @@ def RunTime(func):
|
||||
t1 = time.time()
|
||||
res = func(*args, **kwargs)
|
||||
t2 = time.time()
|
||||
if t2 - t1 > 0.05:
|
||||
logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||
# if t2 - t1 > 0.05:
|
||||
# logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||
logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def ClassCallRunTime(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
class_name = args[0].__class__.__name__ # 获取类名
|
||||
print(f"class name: {class_name} , run time is : {execution_time} s")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
94
app/service/utils/new_oss_client.py
Normal file
94
app/service/utils/new_oss_client.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import io
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.decorator import RunTime
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
# 自定义 Retry 类
|
||||
class CustomRetry(urllib3.Retry):
|
||||
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
|
||||
# 调用父类的 increment 方法
|
||||
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
|
||||
# 打印重试信息
|
||||
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
|
||||
return new_retry
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
|
||||
http_client = urllib3.PoolManager(
|
||||
num_pools=10, # 设置连接池大小
|
||||
maxsize=10,
|
||||
timeout=timeout,
|
||||
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||
retries=CustomRetry(
|
||||
total=5,
|
||||
backoff_factor=0.2,
|
||||
status_forcelist=[500, 502, 503, 504],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 获取图片
|
||||
@RunTime
|
||||
def oss_get_image(oss_client, bucket, object_name, data_type):
|
||||
# cv2 默认全通道读取
|
||||
image_object = None
|
||||
try:
|
||||
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||
if data_type == "cv2":
|
||||
image_bytes = image_data.read()
|
||||
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
|
||||
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
|
||||
if image_object.dtype == np.uint16:
|
||||
image_object = (image_object / 256).astype('uint8')
|
||||
else:
|
||||
data_bytes = BytesIO(image_data.read())
|
||||
image_object = Image.open(data_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}")
|
||||
return image_object
|
||||
|
||||
|
||||
@RunTime
|
||||
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
|
||||
req = None
|
||||
try:
|
||||
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||
except Exception as e:
|
||||
logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}")
|
||||
return req
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png"
|
||||
# url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg"
|
||||
# url = "aida-sys-image/images/female/outwear/0628000054.jpg"
|
||||
# url = "aida-users/89/product_image/string-89.png"
|
||||
# url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png"
|
||||
# url = 'aida-users/89/relight_image/123-89.png'
|
||||
# url = 'aida-users/89/relight_image/123-89.png'
|
||||
# url = 'aida-users/89/relight_image/123-89.png'
|
||||
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
||||
# url = "aida-users/89/single_logo/123-89.png"
|
||||
url = "aida-users/31/sketchboard/female/dress/6edcbf92-7da9-4809-a0a8-a4b4f06dec1e0628000041.jpg"
|
||||
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
||||
read_type = "cv2"
|
||||
if read_type == "cv2":
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||
img.show()
|
||||
@@ -1,16 +1,38 @@
|
||||
import io
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import boto3
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
|
||||
# 自定义 Retry 类
|
||||
class CustomRetry(urllib3.Retry):
|
||||
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
|
||||
# 调用父类的 increment 方法
|
||||
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
|
||||
# 打印重试信息
|
||||
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
|
||||
return new_retry
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
|
||||
http_client = urllib3.PoolManager(
|
||||
num_pools=10, # 设置连接池大小
|
||||
maxsize=10,
|
||||
timeout=timeout,
|
||||
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||
retries=CustomRetry(
|
||||
total=5,
|
||||
backoff_factor=0.2,
|
||||
status_forcelist=[500, 502, 503, 504],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 获取图片
|
||||
@@ -18,12 +40,8 @@ def oss_get_image(bucket, object_name, data_type):
|
||||
# cv2 默认全通道读取
|
||||
image_object = None
|
||||
try:
|
||||
if OSS == "minio":
|
||||
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||
else:
|
||||
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||
image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body']
|
||||
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE, http_client=http_client)
|
||||
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||
if data_type == "cv2":
|
||||
image_bytes = image_data.read()
|
||||
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
|
||||
@@ -41,12 +59,8 @@ def oss_get_image(bucket, object_name, data_type):
|
||||
def oss_upload_image(bucket, object_name, image_bytes):
|
||||
req = None
|
||||
try:
|
||||
if OSS == "minio":
|
||||
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||
else:
|
||||
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||
req = oss_client.put_object(Bucket=bucket, Key=object_name, Body=io.BytesIO(image_bytes), ContentType='image/png')
|
||||
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||
except Exception as e:
|
||||
logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}")
|
||||
return req
|
||||
@@ -64,8 +78,8 @@ if __name__ == '__main__':
|
||||
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
||||
# url = "aida-users/89/single_logo/123-89.png"
|
||||
# url = "aida-users/89/product_image/string-89.png"
|
||||
url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
||||
url = "aida-results/result_e2673d92-8d25-11ef-be24-0826ae3ad6b3.png"
|
||||
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
||||
read_type = "cv2"
|
||||
if read_type == "cv2":
|
||||
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Reference in New Issue
Block a user