Files
AiDA_Python/app/service/design/items/pipelines/loading.py
zhouchengrong 8dd6fc924c feat
fix  design load image 判断图片size 如果小于50 则resize 一倍 ,否则不能推理
2024-07-25 10:33:25 +08:00

135 lines
5.5 KiB
Python

import cv2
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile(object):
def __init__(self, path, color=None, print_dict=None):
self.path = path
self.color = color
self.print_dict = print_dict
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# @ClassCallRunTime
def __call__(self, result):
result['image'], result['pre_mask'] = self.read_image(self.path)
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['keypoint'] = self.get_keypoint(result['name'])
result['path'] = self.path
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
result['color'] = self.color if self.color is not None else None
result['print_dict'] = self.print_dict
return result
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
keypoint = 'waistband'
elif name == 'bag':
keypoint = 'hand_point'
elif name == 'shoes':
keypoint = 'toe'
elif name == 'hairstyle':
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")
return keypoint
@staticmethod
def read_image(image_path):
image_mask = None
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
if image.shape[:2] <= (50, 50):
# 计算新尺寸
new_size = (image.shape[1] * 2, image.shape[0] * 2)
# 调整大小
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
return image, image_mask
@PIPELINES.register_module()
class LoadBodyImageFromFile(object):
def __init__(self, body_path):
self.body_path = body_path
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
# @ RunTime
def __call__(self, result):
result["image_url"] = result['body_path'] = self.body_path
result["name"] = "mannequin"
# if not result['image_url'].lower().endswith(".png"):
# bucket = self.body_path.split("/", 1)[0]
# object_name = self.body_path.split("/", 1)[1]
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
# image = self.minioClient.get_object(bucket, object_name)
# image = Image.open(io.BytesIO(image.data))
# image = image.convert("RGBA")
# data = image.getdata()
# #
# new_data = []
# for item in data:
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
# new_data.append((255, 255, 255, 0))
# else:
# new_data.append(item)
# image.putdata(new_data)
# image_data = io.BytesIO()
# image.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# self.body_path = image_path
# result["image_url"] = result['body_path'] = self.body_path
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
# put_image_time = time.time()
# result['body_image'] = Image.open(io.BytesIO(response.read()))
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
return result
@PIPELINES.register_module()
class ImageShow(object):
def __init__(self, key):
self.key = key
# @ RunTime
def __call__(self, result):
import matplotlib.pyplot as plt
if isinstance(self.key, list):
for key in self.key:
plt.imshow(result[key])
plt.title(key)
plt.show()
elif isinstance(self.key, str):
img = self._resize_img(result[self.key])
cv2.imshow(self.key, img)
cv2.waitKey(0)
else:
raise TypeError(f'key should be string but got type {type(self.key)}.')
return result
@staticmethod
def _resize_img(img):
shape = img.shape
if shape[0] > 400 or shape[1] > 400:
ratio = min(400 / shape[0], 400 / shape[1])
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
return img