144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
|
|
import io
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
from PIL import Image
|
||
|
|
from minio import Minio
|
||
|
|
|
||
|
|
from app.core.config import *
|
||
|
|
from ..builder import PIPELINES
|
||
|
|
|
||
|
|
|
||
|
|
@PIPELINES.register_module()
|
||
|
|
class LoadImageFromFile(object):
|
||
|
|
def __init__(self, path, color=None, print_dict=None):
|
||
|
|
self.path = path
|
||
|
|
self.color = color
|
||
|
|
self.print_dict = print_dict
|
||
|
|
self.minio_client = Minio(
|
||
|
|
f"{MINIO_URL}",
|
||
|
|
access_key=MINIO_ACCESS,
|
||
|
|
secret_key=MINIO_SECRET,
|
||
|
|
secure=MINIO_SECURE)
|
||
|
|
|
||
|
|
def __call__(self, result):
|
||
|
|
result['image'], result['pre_mask'] = self.read_image(self.path)
|
||
|
|
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||
|
|
result['keypoint'] = self.get_keypoint(result['name'])
|
||
|
|
result['path'] = self.path
|
||
|
|
result['img_shape'] = result['image'].shape
|
||
|
|
result['ori_shape'] = result['image'].shape
|
||
|
|
result['color'] = self.color if self.color is not None else None
|
||
|
|
result['print_dict'] = self.print_dict
|
||
|
|
return result
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_keypoint(name):
|
||
|
|
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
|
||
|
|
keypoint = 'shoulder'
|
||
|
|
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
|
||
|
|
keypoint = 'waistband'
|
||
|
|
elif name == 'bag':
|
||
|
|
keypoint = 'hand_point'
|
||
|
|
elif name == 'shoes':
|
||
|
|
keypoint = 'toe'
|
||
|
|
elif name == 'hairstyle':
|
||
|
|
keypoint = 'head_point'
|
||
|
|
elif name == 'earring':
|
||
|
|
keypoint = 'ear_point'
|
||
|
|
else:
|
||
|
|
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||
|
|
f"bag, shoes, hairstyle, earring.")
|
||
|
|
return keypoint
|
||
|
|
|
||
|
|
def read_image(self, image_path):
|
||
|
|
image_mask = None
|
||
|
|
file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data
|
||
|
|
image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
|
||
|
|
|
||
|
|
if len(image.shape) == 2:
|
||
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||
|
|
if image.shape[2] == 4: # 如果是四通道 mask
|
||
|
|
image_mask = image[:, :, 3]
|
||
|
|
image = image[:, :, :3]
|
||
|
|
return image, image_mask
|
||
|
|
|
||
|
|
|
||
|
|
@PIPELINES.register_module()
|
||
|
|
class LoadBodyImageFromFile(object):
|
||
|
|
def __init__(self, body_path):
|
||
|
|
self.body_path = body_path
|
||
|
|
self.minioClient = Minio(
|
||
|
|
f"{MINIO_URL}",
|
||
|
|
access_key=MINIO_ACCESS,
|
||
|
|
secret_key=MINIO_SECRET,
|
||
|
|
secure=MINIO_SECURE)
|
||
|
|
|
||
|
|
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
|
||
|
|
|
||
|
|
# @ RunTime
|
||
|
|
def __call__(self, result):
|
||
|
|
result["image_url"] = result['body_path'] = self.body_path
|
||
|
|
result["name"] = "mannequin"
|
||
|
|
if not result['image_url'].lower().endswith(".png"):
|
||
|
|
logging.info(1)
|
||
|
|
bucket = self.body_path.split("/", 1)[0]
|
||
|
|
object_name = self.body_path.split("/", 1)[1]
|
||
|
|
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
||
|
|
image = self.minioClient.get_object(bucket, object_name)
|
||
|
|
image = Image.open(io.BytesIO(image.data))
|
||
|
|
image = image.convert("RGBA")
|
||
|
|
data = image.getdata()
|
||
|
|
#
|
||
|
|
new_data = []
|
||
|
|
for item in data:
|
||
|
|
if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
|
||
|
|
new_data.append((255, 255, 255, 0))
|
||
|
|
else:
|
||
|
|
new_data.append(item)
|
||
|
|
image.putdata(new_data)
|
||
|
|
image_data = io.BytesIO()
|
||
|
|
image.save(image_data, format='PNG')
|
||
|
|
image_data.seek(0)
|
||
|
|
image_bytes = image_data.read()
|
||
|
|
image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||
|
|
self.body_path = image_path
|
||
|
|
result["image_url"] = result['body_path'] = self.body_path
|
||
|
|
response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
|
||
|
|
# put_image_time = time.time()
|
||
|
|
result['body_image'] = Image.open(io.BytesIO(response.read()))
|
||
|
|
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
@PIPELINES.register_module()
|
||
|
|
class ImageShow(object):
|
||
|
|
def __init__(self, key):
|
||
|
|
self.key = key
|
||
|
|
|
||
|
|
# @ RunTime
|
||
|
|
def __call__(self, result):
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
if isinstance(self.key, list):
|
||
|
|
for key in self.key:
|
||
|
|
plt.imshow(result[key])
|
||
|
|
plt.title(key)
|
||
|
|
plt.show()
|
||
|
|
elif isinstance(self.key, str):
|
||
|
|
img = self._resize_img(result[self.key])
|
||
|
|
cv2.imshow(self.key, img)
|
||
|
|
cv2.waitKey(0)
|
||
|
|
else:
|
||
|
|
raise TypeError(f'key should be string but got type {type(self.key)}.')
|
||
|
|
return result
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _resize_img(img):
|
||
|
|
shape = img.shape
|
||
|
|
if shape[0] > 400 or shape[1] > 400:
|
||
|
|
ratio = min(400 / shape[0], 400 / shape[1])
|
||
|
|
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
|
||
|
|
return img
|