Files

96 lines
3.5 KiB
Python

import logging
from skimage.morphology import skeletonize
import cv2
import numpy as np
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
class LoadBodyImage:
name = "LoadBodyImage"
def __init__(self, minio_client):
self.minio_client = minio_client
@classmethod
def get_name(cls):
return cls.name
def __call__(self, result):
result["name"] = "mannequin"
result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL")
return result
class LoadImage:
name = "LoadImage"
def __init__(self, minio_client):
self.minio_client = minio_client
@classmethod
def get_name(cls):
return cls.name
def __call__(self, result):
if result.get("merge_image_path"):
result['merge_image'], _ = self.read_image(result['merge_image_path'])
result['image'], result['pre_mask'] = self.read_image(result['path'])
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY))
result['keypoint'] = self.get_keypoint(result['name'])
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
return result
@staticmethod
def get_lines(img):
binary = cv2.adaptiveThreshold(img, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
25, 10)
binary_bool = binary > 0
skeleton = skeletonize(binary_bool, method='zhang')
mask = skeleton
result = np.ones_like(img) * 255
result[mask] = img[mask]
return result
def read_image(self, image_path):
image_mask = None
image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
if image.shape[:2] <= (50, 50):
# 计算新尺寸
new_size = (image.shape[1] * 2, image.shape[0] * 2)
# 调整大小
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
return image, image_mask
@staticmethod
def get_keypoint(name):
if name in ['blouse', 'outwear', 'dress', 'tops', 'blouse_merge', 'outwear_merge', 'dress_merge', 'tops_merge']:
keypoint = 'shoulder'
elif name in ['trousers', 'skirt', 'bottoms', 'trousers_merge', 'skirt_merge', 'bottoms_merge']:
keypoint = 'waistband'
elif name in ['bag', 'bag_merge']:
keypoint = 'hand_point'
elif name in ['shoes', 'shoes_merge']:
keypoint = 'toe'
elif name in ['hairstyle', 'hairstyle_merge']:
keypoint = 'head_point'
elif name in ['earring', 'earring_merge']:
keypoint = 'ear_point'
elif name in ['others', 'others_merge']:
keypoint = "others"
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")
return keypoint