design design batch

This commit is contained in:
zhouchengrong
2024-12-11 11:06:20 +08:00
parent 2c0a7729b8
commit e6f0ee7f3a
17 changed files with 281 additions and 26 deletions

View File

@@ -1,3 +1,4 @@
from .back_perspective import BackPerspective
from .color import Color
from .contour_detection import ContourDetection
from .keypoint import KeyPoint
@@ -13,6 +14,7 @@ __all__ = [
'KeyPoint',
'ContourDetection',
'Segmentation',
'BackPerspective',
'Color',
'PrintPainting',
'Scaling',

View File

@@ -0,0 +1,79 @@
import cv2
import numpy as np
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.new_oss_client import oss_upload_image
class BackPerspective:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
# 如果sketch为系统图 查看是否有对应的 背后视角图
if result['path'].split('/')[0] == 'aida-sys-image':
file_path = result['path'].replace("images", 'images_back', 1)
if self.is_file_exists(bucket_name='aida-sys-image', file_name=file_path[file_path.find('/') + 1:]):
result['back_perspective_url'] = file_path
return result
else:
seg_result = get_seg_result("1", result['image'])[0]
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
seg_result = result['seg_result']
else:
seg_result = get_seg_result("1", result['image'])[0]
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
back_sketch = result['image'].copy()
back_sketch[m > 100] = 255
# 上传背后视角图
_, img_encoded = cv2.imencode(".jpg", back_sketch)
resp = oss_upload_image(self.minio_client, bucket='test', object_name=result['path'], image_bytes=img_encoded.tobytes())
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
return result
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
mask = mask.astype(np.uint8) * 255
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 创建一个彩色副本用于绘制轮廓
mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
def thicken_contour_inward(contour, thick):
# 创建一个空白的黑色图像与原始掩码大小相同
blank = np.zeros_like(mask)
# 在空白图像上绘制白色的轮廓
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
# 找到轮廓的中心(可以用重心等方法近似)
M = cv2.moments(contour)
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
# 进行距离变换,离中心越近的值越小
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留
result = np.zeros_like(mask)
for i in range(dist_transform.shape[0]):
for j in range(dist_transform.shape[1]):
if dist_transform[i, j] < thick:
result[i, j] = 255
return result
for contour in contours:
thickened_contour = thicken_contour_inward(contour, thickness)
mask_color[thickened_contour > 0] = color
_, binary_result = cv2.threshold(mask_color, 127, 255, cv2.THRESH_BINARY)
# 转换为掩码形式
mask_result = cv2.cvtColor(binary_result, cv2.COLOR_BGR2GRAY)
return mask_result
def is_file_exists(self, bucket_name, file_name):
try:
self.minio_client.stat_object(bucket_name, file_name)
return True
except Exception:
return False

View File

@@ -14,11 +14,18 @@ class Color:
def __call__(self, result):
dim_image_h, dim_image_w = result['image'].shape[0:2]
# 渐变色
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
# 无色
elif "color" not in result.keys() or result['color'] == "":
result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
result['alpha'] = 100 / 255.0
return result
# 正常颜色
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)

View File

@@ -4,7 +4,8 @@ import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.design_batch.utils.design_ensemble import get_keypoint_result
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
from app.service.utils.decorator import ClassCallRunTime, RunTime
logger = logging.getLogger(__name__)
@@ -16,14 +17,15 @@ class KeyPoint:
def get_name(cls):
return cls.name
@ClassCallRunTime
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# keypoint_cache = self.keypoint_cache(result, site)
keypoint_cache = False
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
@@ -87,7 +89,7 @@ class KeyPoint:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
@RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)

View File

@@ -1,6 +1,9 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image
@@ -71,6 +74,8 @@ class LoadImage:
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
elif name == 'accessories':
keypoint = "accessories"
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")

View File

@@ -46,4 +46,16 @@ class Scaling:
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
elif result['keypoint'] == 'accessories':
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
distance_clo = result['img_shape'][1]
distance_bdy = 320 / 2
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
else:
result['scale'] = 1
return result

View File

@@ -5,7 +5,8 @@ import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.design_batch.utils.design_ensemble import get_seg_result
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
@@ -15,6 +16,7 @@ class Segmentation:
def __init__(self, minio_client):
self.minio_client = minio_client
@ClassCallRunTime
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
@@ -31,13 +33,26 @@ class Segmentation:
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
result['seg_result'] = seg_result
if not _:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
# 处理前片后片
temp_front = seg_result == 1.0
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
@@ -48,7 +63,7 @@ class Segmentation:
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"seg_cache/{image_id}.npy"
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}")
@@ -57,7 +72,7 @@ class Segmentation:
@staticmethod
def load_seg_result(image_id):
file_path = f"seg_cache/{image_id}.npy"
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)

View File

@@ -7,10 +7,11 @@ from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.design_batch.utils.conversion_image import rgb_to_rgba
from app.service.design_batch.utils.upload_image import upload_png_mask
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
from app.service.design_fast.utils.upload_image import upload_png_mask
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
class Split(object):
@@ -20,7 +21,7 @@ class Split(object):
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms','accessories'):
front_mask = result['front_mask']
back_mask = result['back_mask']
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
@@ -30,6 +31,24 @@ class Split(object):
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
# 预处理用户自选区mask
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
blue_mask = b > r
# 创建红色和绿色掩码
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
else:
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape