Files
sora_python/app/service/outfit_matcher/outfit_evaluator.py

406 lines
16 KiB
Python
Raw Normal View History

2024-03-28 10:30:18 +08:00
import io
from PIL import Image
import cv2
import numpy as np
import tritonclient.http as httpclient
import torch
from matplotlib import pyplot as plt, image as mpimg
from minio import Minio
from torchvision import transforms
from app.core.config import *
from app.service.outfit_matcher.foco import extract_main_colors
from app.service.utils.decorator import RunTime
2024-03-28 14:14:35 +08:00
class Backbone(object):
def __init__(self):
self.tritonclient = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
@RunTime
def load_image(self, img_path):
try:
# 从 MinIO 中获取对象(图像文件)
image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1])
# 读取图像数据并转换为 PIL 图像对象
pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB")
# 将 PIL 图像转换为 NumPy 数组
# image_array = np.array(pil_image)
return pil_image
except Exception as e:
print(f"An error occurred: {e}")
return None
@staticmethod
def resize_image(img):
"""
Args:
img: ndarray (height, width, channel)
"""
image_transforms = transforms.Compose([
transforms.Resize(112),
transforms.CenterCrop(112),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
resized_img = image_transforms(img).numpy()
return resized_img
def preprocess(self, items):
images = []
for item in items:
image = self.load_image(item["image_path"])
image = self.resize_image(image)
images.append(image)
images = np.stack(images, axis=0)
return images
@RunTime
def get_result(self, items):
"""Input items and output features for similiarity.
Args:
items: images of fashion items
Example:
[
{
"item_name": "MSE_57987",
"semantic_category": "BOTTOM/PANTS",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MSE_57987.jpg",
"mapped_cate": "bottoms"
},
{
"item_name": "MPO_SP7712",
"semantic_category": "TOP/TANK",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MPO_SP7712.jpg",
"mapped_cate": "tops"
},
{
"item_name": "MWSS27195",
"semantic_category": "OUTERWEAR/GILET",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MWSS27195.jpg",
"mapped_cate": "outerwear"
}
]
Returns:
scores: List of image features
"""
image = self.preprocess(items)
client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
# 输入集
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
]
results = client.infer(model_name="outfit_matcher_backbone", inputs=inputs, outputs=outputs)
# 推理
# 取结果
features = results.as_numpy("output__0") # Shape (N, 64)
return features
2024-03-28 10:30:18 +08:00
class OutfitMatcher(object):
def __init__(self):
self.tritonclient = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
2024-03-28 14:14:35 +08:00
def load_image(self, img_path):
try:
# 从 MinIO 中获取对象(图像文件)
image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1])
# 读取图像数据并转换为 PIL 图像对象
pil_image = Image.open(io.BytesIO(image_data.data)).convert("RGB")
# 将 PIL 图像转换为 NumPy 数组
# image_array = np.array(pil_image)
return pil_image
except Exception as e:
print(f"An error occurred: {e}")
return None
2024-03-28 10:30:18 +08:00
@staticmethod
def pad_array(input_value, value=0):
"""pad List of Array into same batch size
Args:
input_value: List of numpy arrary need to be padded
Returns:
Tensor: [batch_dim, max_dim, original_tensor_size]
"""
max_dim = max([len(x) for x in input_value])
mask = np.zeros((len(input_value), max_dim), dtype=np.float32)
# Pad each array
padded_arrays = []
for i, array in enumerate(input_value):
# Compute padding amount along the pad dimension
pad_dim = max_dim - array.shape[0]
consistent_shape = array.shape[1:]
pad_widths = [(0, pad_dim)] + [(0, 0)] * len(consistent_shape)
padded_array = np.pad(array, pad_widths, mode='constant', constant_values=value)
padded_arrays.append(padded_array)
mask[i, array.shape[0]:] = float("-inf")
# Stack the padded arrays and change the dimension
batched_arrays = np.stack(padded_arrays, axis=0)
return batched_arrays, mask
@staticmethod
def imnormalize(img, mean, std, to_rgb=True):
"""Normalize an image with mean and std.
Args:
img (ndarray): Image to be normalized.
mean (ndarray): The mean to be used for normalize.
std (ndarray): The std to be used for normalize.
to_rgb (bool): Whether to convert to rgb.
Returns:
ndarray: The normalized image.
"""
img = img.copy().astype(np.float32)
assert img.dtype != np.uint8
mean = np.float64(mean.reshape(1, -1))
stdinv = 1 / np.float64(std.reshape(1, -1))
if to_rgb:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
cv2.subtract(img, mean, img) # inplace
cv2.multiply(img, stdinv, img) # inplace
return img
@RunTime
def visualize(self, outfits, scores, topk=5, best=True, output_path=None):
# 将outfits和scores按照scores的值进行排序
2024-03-28 14:14:35 +08:00
# sorted_indices = np.argsort(-scores if best else scores)[:topk] # for HON
sorted_indices = np.argsort(scores if best else -scores)[:topk] # type-aware
2024-03-28 10:30:18 +08:00
outfits = [outfits[i] for i in sorted_indices] # 最好或最差的五个
scores = scores[sorted_indices] # 这五个的分数
if SHOW_OR_SAVE_result_image:
# 设置子图的行列数
num_rows = len(outfits)
num_cols = max([len(x) for x in outfits]) + 1 # 一个是图片,一个是分数
# 创建一个新的图像,并指定子图的行列数
fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 15))
title = f"Best {topk} Outfits" if best else f"Worst {topk} Outfits"
fig.suptitle(title, fontsize=16)
# 遍历每套outfit并将其显示在对应的子图中
for i, (outfit, score) in enumerate(zip(outfits, scores)):
# 显示分数
axes[i, 0].text(0.1, 0.5, f"Score: {score:.4f}", fontsize=12)
axes[i, 0].axis("off")
# 显示图片
for j, item in enumerate(outfit):
img = self.load_image(item['image_path']) # 读取图片
axes[i, j + 1].imshow(img) # 在对应的子图中显示图片
axes[i, j + 1].axis('off') # 关闭坐标轴
axes[i, j + 1].set_title(item["semantic_category"], fontsize=10)
for j in range(len(outfit), num_cols):
axes[i, j].axis("off")
# 在每一行的底部添加一条横线
axes[i, 0].axhline(y=0, color='black', linewidth=1)
# 隐藏最后一行的横线
axes[-1, 0].axhline(y=0, color='white', linewidth=1)
# 调整布局
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.tight_layout()
if output_path:
plt.savefig(output_path)
else:
plt.show()
2024-03-28 17:22:51 +08:00
return outfits, scores.tolist()
2024-03-28 10:30:18 +08:00
class OutfitMatcherHon(OutfitMatcher):
def __init__(self):
super().__init__()
def load_image(self, img_path):
try:
# 从 MinIO 中获取对象(图像文件)
image_data = self.minio_client.get_object(img_path.split("/", 1)[0], img_path.split("/", 1)[1])
# 读取图像数据并转换为 PIL 图像对象
pil_image = Image.open(io.BytesIO(image_data.read()))
# 将 PIL 图像转换为 NumPy 数组
image_array = np.array(pil_image)
return image_array
except Exception as e:
print(f"An error occurred: {e}")
return None
# if 'http' in img_path:
# file = requests.get(img_path)
# image = cv2.imdecode(np.fromstring(file.content, np.uint8), 1)
# image = Image.fromarray(image.astype('uint8'), 'RGB')
# else:
# image = Image.open(img_path).convert('RGB')
# return np.array(image)
@staticmethod
def resize_image(img):
"""
Args:
img: ndarray (height, width, channel)
"""
resized_img = cv2.resize(img, (224, 224), dst=None, interpolation=1)
return resized_img
def preprocess(self, outfits):
outfit_images = []
outfit_colors = []
for outfit in outfits:
images = []
colors = []
for item in outfit:
image = self.load_image(item["image_path"])
image = self.resize_image(image)
normalized_image = self.imnormalize(image,
mean=np.array([208.32996145, 201.28227452, 198.47047691],
dtype=np.float32),
std=np.array([75.48939648, 80.47423057, 82.21144189],
dtype=np.float32))
images.append(normalized_image.transpose(2, 0, 1))
color = extract_main_colors(image)
colors.append(color)
images = np.stack(images, axis=0)
outfit_images.append(images) # List[(items, 3, 224, 224)]
colors = np.stack(colors, axis=0)
outfit_colors.append(colors)
outfit_images, mask = self.pad_array(outfit_images)
outfit_colors, _ = self.pad_array(outfit_colors)
return outfit_images, outfit_colors, mask
def get_result(self, outfits):
# start = time.time()
image, color, mask = self.preprocess(outfits)
# print(start - time.time())
# transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
httpclient.InferInput("input__1", color.shape, datatype="FP32"),
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(color.astype(np.float32), binary_data=True)
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
]
results = self.tritonclient.infer(model_name="outfit_matcher_hon", inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = torch.from_numpy(results.as_numpy("output__0"))
return inference_output1 # Shape (N, 1)
class OutfitMaterTypeAware(OutfitMatcher):
base_fashion_categories = [
'accessories', 'all-body', 'bags', 'bottoms', 'hats', 'jewellery',
'outerwear', 'scarves', 'shoes', 'sunglasses', 'tops'
]
def __init__(self):
super().__init__()
2024-03-28 14:14:35 +08:00
def preprocess(self, outfits, features):
2024-03-28 10:30:18 +08:00
outfit_images = []
outfit_categories = []
for outfit in outfits:
images = []
categories = []
for item in outfit:
2024-03-28 14:14:35 +08:00
image = features[item["item_name"]]
2024-03-28 10:30:18 +08:00
images.append(image)
category = self.base_fashion_categories.index(item["mapped_cate"])
categories.append(category)
images = np.stack(images, axis=0)
outfit_images.append(images) # List[(items, 3, 224, 224)]
categories = np.array(categories)
outfit_categories.append(categories) # List[(items)]
2024-03-28 14:14:35 +08:00
return outfit_images, outfit_categories
2024-03-28 10:30:18 +08:00
@RunTime
2024-03-28 14:14:35 +08:00
def get_result(self, outfits, features):
2024-03-28 10:30:18 +08:00
"""Input outfits structure and output scores.
Args:
outfits: outfits to be evaluated.
Example:
[
[
{
"item_name": "MSE_57987",
"semantic_category": "BOTTOM/PANTS",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MSE_57987.jpg",
"mapped_cate": "bottoms"
},
{
"item_name": "MPO_SP7712",
"semantic_category": "TOP/TANK",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MPO_SP7712.jpg",
"mapped_cate": "tops"
},
{
"item_name": "MWSS27195",
"semantic_category": "OUTERWEAR/GILET",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MWSS27195.jpg",
"mapped_cate": "outerwear"
}
],
...
]
2024-03-28 14:14:35 +08:00
features: dict(item_name = np.Array) image features of those items
2024-03-28 10:30:18 +08:00
Returns:
scores: List of float
"""
2024-03-28 14:14:35 +08:00
outfit_images, outfit_categories = self.preprocess(outfits, features)
scores = []
for images, categories in zip(outfit_images, outfit_categories):
client = httpclient.InferenceServerClient(url=f"{OM_TRITON_IP}:{OM_TRITON_PORT}")
# 输入集
inputs = [
httpclient.InferInput("input__0", images.shape, datatype="FP32"),
httpclient.InferInput("input__1", categories.shape, datatype="INT16")
]
inputs[0].set_data_from_numpy(images.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(categories.astype(np.int16), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
]
results = client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
scores.append(results.as_numpy("output__0")) # Shape (N, 1)
scores = np.stack(scores, axis=0)
return scores.flatten()