Files
sora_python/app/service/outfit_matcher/outfit_evaluator.py
2024-03-12 12:11:26 +08:00

293 lines
11 KiB
Python

import requests
from PIL import Image
import cv2
import numpy as np
import tritonclient.http as httpclient
import torch
from matplotlib import pyplot as plt, image as mpimg
from torchvision import transforms
from foco import extract_main_colors
class OutfitMatcher(object):
def __init__(self):
self.tritonclient = httpclient.InferenceServerClient(url="10.1.1.240:10010")
@staticmethod
def pad_array(input_value, value=0):
"""pad List of Array into same batch size
Args:
input_value: List of numpy arrary need to be padded
Returns:
Tensor: [batch_dim, max_dim, original_tensor_size]
"""
max_dim = max([len(x) for x in input_value])
mask = np.zeros((len(input_value), max_dim), dtype=np.float32)
# Pad each array
padded_arrays = []
for i, array in enumerate(input_value):
# Compute padding amount along the pad dimension
pad_dim = max_dim - array.shape[0]
consistent_shape = array.shape[1:]
pad_widths = [(0, pad_dim)] + [(0, 0)] * len(consistent_shape)
padded_array = np.pad(array, pad_widths, mode='constant', constant_values=value)
padded_arrays.append(padded_array)
mask[i, array.shape[0]:] = float("-inf")
# Stack the padded arrays and change the dimension
batched_arrays = np.stack(padded_arrays, axis=0)
return batched_arrays, mask
@staticmethod
def imnormalize(img, mean, std, to_rgb=True):
"""Normalize an image with mean and std.
Args:
img (ndarray): Image to be normalized.
mean (ndarray): The mean to be used for normalize.
std (ndarray): The std to be used for normalize.
to_rgb (bool): Whether to convert to rgb.
Returns:
ndarray: The normalized image.
"""
img = img.copy().astype(np.float32)
assert img.dtype != np.uint8
mean = np.float64(mean.reshape(1, -1))
stdinv = 1 / np.float64(std.reshape(1, -1))
if to_rgb:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
cv2.subtract(img, mean, img) # inplace
cv2.multiply(img, stdinv, img) # inplace
return img
def visualize(self, outfits, scores, topk=5, best=True, output_path=None):
# 将outfits和scores按照scores的值进行排序
sorted_indices = np.argsort(-scores.flatten() if best else scores.flatten())[:topk] # 使用负号进行降序排序
outfits = [outfits[i] for i in sorted_indices]
scores = scores[sorted_indices]
# 设置子图的行列数
num_rows = len(outfits)
num_cols = max([len(x) for x in outfits]) + 1 # 一个是图片,一个是分数
# 创建一个新的图像,并指定子图的行列数
fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 15))
title = f"Best {topk} Outfits" if best else f"Worst {topk} Outfits"
fig.suptitle(title, fontsize=16)
# 遍历每套outfit并将其显示在对应的子图中
for i, (outfit, score) in enumerate(zip(outfits, scores)):
# 显示分数
axes[i, 0].text(0.1, 0.5, f"Score: {score[0]:.4f}", fontsize=12)
axes[i, 0].axis("off")
# 显示图片
for j, item in enumerate(outfit):
img = mpimg.imread(item['image_path']) # 读取图片
axes[i, j + 1].imshow(img) # 在对应的子图中显示图片
axes[i, j + 1].axis('off') # 关闭坐标轴
axes[i, j + 1].set_title(item["semantic_category"], fontsize=10)
for j in range(len(outfit), num_cols):
axes[i, j].axis("off")
# 在每一行的底部添加一条横线
axes[i, 0].axhline(y=0, color='black', linewidth=1)
# 隐藏最后一行的横线
axes[-1, 0].axhline(y=0, color='white', linewidth=1)
# 调整布局
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.tight_layout()
if output_path:
plt.savefig(output_path)
else:
plt.show()
class OutfitMatcherHon(OutfitMatcher):
def __init__(self):
super().__init__()
@staticmethod
def load_image(img_path):
if 'http' in img_path:
file = requests.get(img_path)
image = cv2.imdecode(np.fromstring(file.content, np.uint8), 1)
image = Image.fromarray(image.astype('uint8'), 'RGB')
else:
image = Image.open(img_path).convert('RGB')
return np.array(image)
@staticmethod
def resize_image(img):
"""
Args:
img: ndarray (height, width, channel)
"""
resized_img = cv2.resize(img, (224, 224), dst=None, interpolation=1)
return resized_img
def preprocess(self, outfits):
outfit_images = []
outfit_colors = []
for outfit in outfits:
images = []
colors = []
for item in outfit:
image = self.load_image(item["image_path"])
image = self.resize_image(image)
normalized_image = self.imnormalize(image,
mean=np.array([208.32996145, 201.28227452, 198.47047691],
dtype=np.float32),
std=np.array([75.48939648, 80.47423057, 82.21144189],
dtype=np.float32))
images.append(normalized_image.transpose(2, 0, 1))
color = extract_main_colors(image)
colors.append(color)
images = np.stack(images, axis=0)
outfit_images.append(images) # List[(items, 3, 224, 224)]
colors = np.stack(colors, axis=0)
outfit_colors.append(colors)
outfit_images, mask = self.pad_array(outfit_images)
outfit_colors, _ = self.pad_array(outfit_colors)
return outfit_images, outfit_colors, mask
def get_result(self, outfits):
# start = time.time()
image, color, mask = self.preprocess(outfits)
# print(start - time.time())
# transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
httpclient.InferInput("input__1", color.shape, datatype="FP32"),
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(color.astype(np.float32), binary_data=True)
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
]
results = self.tritonclient.infer(model_name="outfit_matcher_hon", inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = torch.from_numpy(results.as_numpy("output__0"))
return inference_output1 # Shape (N, 1)
class OutfitMaterTypeAware(OutfitMatcher):
base_fashion_categories = [
'accessories', 'all-body', 'bags', 'bottoms', 'hats', 'jewellery',
'outerwear', 'scarves', 'shoes', 'sunglasses', 'tops'
]
def __init__(self):
super().__init__()
@staticmethod
def load_image(img_path):
if 'http' in img_path:
file = requests.get(img_path)
image = cv2.imdecode(np.fromstring(file.content, np.uint8), 1)
image = Image.fromarray(image.astype('uint8'), 'RGB')
else:
image = Image.open(img_path).convert('RGB')
return image
@staticmethod
def resize_image(img):
"""
Args:
img: ndarray (height, width, channel)
"""
image_transforms = transforms.Compose([
transforms.Resize(112),
transforms.CenterCrop(112),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
resized_img = image_transforms(img).numpy()
return resized_img
def preprocess(self, outfits):
outfit_images = []
outfit_categories = []
for outfit in outfits:
images = []
categories = []
for item in outfit:
image = self.load_image(item["image_path"])
image = self.resize_image(image)
images.append(image)
category = self.base_fashion_categories.index(item["mapped_cate"])
categories.append(category)
images = np.stack(images, axis=0)
outfit_images.append(images) # List[(items, 3, 224, 224)]
categories = np.array(categories)
outfit_categories.append(categories) # List[(items)]
outfit_images, mask = self.pad_array(outfit_images, value=0)
outfit_categories, _ = self.pad_array(outfit_categories, value=len(self.base_fashion_categories))
return outfit_images, outfit_categories, mask
def get_result(self, outfits):
"""Input outfits structure and output scores.
Args:
outfits: outfits to be evaluated.
Example:
[
[
{
"item_name": "MSE_57987",
"semantic_category": "BOTTOM/PANTS",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MSE_57987.jpg",
"mapped_cate": "bottoms"
},
{
"item_name": "MPO_SP7712",
"semantic_category": "TOP/TANK",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MPO_SP7712.jpg",
"mapped_cate": "tops"
},
{
"item_name": "MWSS27195",
"semantic_category": "OUTERWEAR/GILET",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MWSS27195.jpg",
"mapped_cate": "outerwear"
}
],
...
]
Returns:
scores: List of float
"""
image, category, mask = self.preprocess(outfits)
client = httpclient.InferenceServerClient(url="localhost:8000")
# 输入集
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
httpclient.InferInput("input__1", category.shape, datatype="INT16"),
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True)
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
]
results = client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
# 推理
# 取结果
scores = torch.from_numpy(results.as_numpy("output__0"))
return scores # Shape (N, 1)