import logging import cv2 from app.service.utils.new_oss_client import oss_get_image logger = logging.getLogger() class LoadBodyImage: name = "LoadBodyImage" def __init__(self, minio_client): self.minio_client = minio_client @classmethod def get_name(cls): return cls.name def __call__(self, result): result["name"] = "mannequin" result['body_image'] = oss_get_image(oss_client=self.minio_client, bucket=result['body_path'].split("/", 1)[0], object_name=result['body_path'].split("/", 1)[1], data_type="PIL") return result class LoadImage: name = "LoadImage" def __init__(self, minio_client): self.minio_client = minio_client @classmethod def get_name(cls): return cls.name def __call__(self, result): result['image'], result['pre_mask'] = self.read_image(result['path']) # 判断是否resize sketch 保留ori image 用于模型输入 result['ori_image'] = result['image'] if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: height, width = result['image'].shape[:2] # 计算新的宽度和高度 new_width = int(width * result['resize_scale'][0]) new_height = int(height * result['resize_scale'][1]) # 使用cv2.resize()函数进行缩放 result['image'] = cv2.resize(result['image'], (new_width, new_height)) if result['pre_mask'] is not None: result['pre_mask'] = cv2.resize(result['pre_mask'], (new_width, new_height)) result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) result['keypoint'] = self.get_keypoint(result['name']) result['img_shape'] = result['image'].shape result['ori_shape'] = result['image'].shape return result def read_image(self, image_path): image_mask = None image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2") if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) if image.shape[2] == 4: # 如果是四通道 mask image_mask = image[:, :, 3] image = image[:, :, :3] if image.shape[:2] <= (50, 50): # 计算新尺寸 new_size = (image.shape[1] * 2, image.shape[0] * 2) # 调整大小 image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR) return image, image_mask @staticmethod def get_keypoint(name): if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops': keypoint = 'shoulder' elif name == 'trousers' or name == 'skirt' or name == 'bottoms': keypoint = 'waistband' elif name == 'bag': keypoint = 'hand_point' elif name == 'shoes': keypoint = 'toe' elif name == 'hairstyle': keypoint = 'head_point' elif name == 'earring': keypoint = 'ear_point' elif name == 'accessories': keypoint = "accessories" else: raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " f"bag, shoes, hairstyle, earring.") return keypoint