2024-09-19 14:20:56 +08:00
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
|
2024-09-19 15:10:50 +08:00
|
|
|
from app.service.utils.new_oss_client import oss_get_image
|
2024-09-19 14:20:56 +08:00
|
|
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoadBodyImage:
|
|
|
|
|
name = "LoadBodyImage"
|
|
|
|
|
|
|
|
|
|
def __init__(self, minio_client):
|
|
|
|
|
self.minio_client = minio_client
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_name(cls):
|
|
|
|
|
return cls.name
|
|
|
|
|
|
|
|
|
|
def __call__(self, result):
|
|
|
|
|
result["name"] = "mannequin"
|
|
|
|
|
result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL")
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoadImage:
|
|
|
|
|
name = "LoadImage"
|
|
|
|
|
|
|
|
|
|
def __init__(self, minio_client):
|
|
|
|
|
self.minio_client = minio_client
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_name(cls):
|
|
|
|
|
return cls.name
|
|
|
|
|
|
|
|
|
|
def __call__(self, result):
|
|
|
|
|
result['image'], result['pre_mask'] = self.read_image(result['path'])
|
2025-02-07 16:02:15 +08:00
|
|
|
|
|
|
|
|
# 判断是否resize sketch 保留ori image 用于模型输入
|
|
|
|
|
result['ori_image'] = result['image']
|
|
|
|
|
if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0:
|
|
|
|
|
height, width = result['image'].shape[:2]
|
|
|
|
|
# 计算新的宽度和高度
|
|
|
|
|
new_width = int(width * result['resize_scale'][0])
|
|
|
|
|
new_height = int(height * result['resize_scale'][1])
|
|
|
|
|
# 使用cv2.resize()函数进行缩放
|
|
|
|
|
result['image'] = cv2.resize(result['image'], (new_width, new_height))
|
|
|
|
|
if result['pre_mask'] is not None:
|
|
|
|
|
result['pre_mask'] = cv2.resize(result['pre_mask'], (new_width, new_height))
|
2024-09-19 14:20:56 +08:00
|
|
|
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
|
|
|
|
result['keypoint'] = self.get_keypoint(result['name'])
|
|
|
|
|
result['img_shape'] = result['image'].shape
|
|
|
|
|
result['ori_shape'] = result['image'].shape
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def read_image(self, image_path):
|
|
|
|
|
image_mask = None
|
|
|
|
|
image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
|
|
|
|
if len(image.shape) == 2:
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
|
|
|
if image.shape[2] == 4: # 如果是四通道 mask
|
|
|
|
|
image_mask = image[:, :, 3]
|
|
|
|
|
image = image[:, :, :3]
|
|
|
|
|
|
|
|
|
|
if image.shape[:2] <= (50, 50):
|
|
|
|
|
# 计算新尺寸
|
|
|
|
|
new_size = (image.shape[1] * 2, image.shape[0] * 2)
|
|
|
|
|
# 调整大小
|
|
|
|
|
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
return image, image_mask
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_keypoint(name):
|
|
|
|
|
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
|
|
|
|
|
keypoint = 'shoulder'
|
|
|
|
|
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
|
|
|
|
|
keypoint = 'waistband'
|
|
|
|
|
elif name == 'bag':
|
|
|
|
|
keypoint = 'hand_point'
|
|
|
|
|
elif name == 'shoes':
|
|
|
|
|
keypoint = 'toe'
|
|
|
|
|
elif name == 'hairstyle':
|
|
|
|
|
keypoint = 'head_point'
|
|
|
|
|
elif name == 'earring':
|
|
|
|
|
keypoint = 'ear_point'
|
2024-11-26 16:08:10 +08:00
|
|
|
elif name == 'accessories':
|
|
|
|
|
keypoint = "accessories"
|
2024-09-19 14:20:56 +08:00
|
|
|
else:
|
|
|
|
|
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
|
|
|
|
f"bag, shoes, hairstyle, earring.")
|
|
|
|
|
return keypoint
|