Files
sora_python/app/service/outfit_matcher/outfit_evaluator.py
zhouchengrong 117e569730 add file
2024-03-11 10:58:34 +08:00

185 lines
6.1 KiB
Python

import os
import requests
import json
from PIL import Image
import cv2
import numpy as np
import tritonclient.http as httpclient
import torch
from foco import extract_main_colors
TRITON_IP_DEFAULT = "0.0.0.0"
def imnormalize(img, mean, std, to_rgb=True):
"""Normalize an image with mean and std.
Args:
img (ndarray): Image to be normalized.
mean (ndarray): The mean to be used for normalize.
std (ndarray): The std to be used for normalize.
to_rgb (bool): Whether to convert to rgb.
Returns:
ndarray: The normalized image.
"""
img = img.copy().astype(np.float32)
assert img.dtype != np.uint8
mean = np.float64(mean.reshape(1, -1))
stdinv = 1 / np.float64(std.reshape(1, -1))
if to_rgb:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
cv2.subtract(img, mean, img) # inplace
cv2.multiply(img, stdinv, img) # inplace
return img
def load_image(img_path):
if 'http' in img_path:
file = requests.get(img_path)
image = cv2.imdecode(np.fromstring(file.content, np.uint8), 1)
image = Image.fromarray(image.astype('uint8'), 'RGB')
else:
image = Image.open(img_path).convert('RGB')
return np.array(image)
def resize_image(img):
"""
Args:
img: ndarray (height, width, channel)
"""
resized_img = cv2.resize(img, (224, 224), dst=None, interpolation=1)
return resized_img
def pad_array(input_value):
"""pad List of Array into same batch size
Args:
input_value: List of numpy arrary need to be padded
Returns:
Tensor: [batch_dim, max_dim, original_tensor_size]
"""
max_dim = max([len(x) for x in input_value])
mask = np.zeros((len(input_value), max_dim), dtype=np.float32)
# Pad each array
padded_arrays = []
for i, array in enumerate(input_value):
# Compute padding amount along the pad dimension
pad_dim = max_dim - array.shape[0]
consistent_shape = array.shape[1:]
pad_widths = [(0, pad_dim)] + [(0, 0)] * len(consistent_shape)
padded_array = np.pad(array, pad_widths, mode='constant', constant_values=0)
padded_arrays.append(padded_array)
mask[i, array.shape[0]:] = float("-inf")
# Stack the padded arrays and change the dimension
batched_arrays = np.stack(padded_arrays, axis=0)
return batched_arrays, mask
def extract_color(image, img_path):
# TODO: replace to vector database
relative_path, filename = os.path.split(img_path)
basename = filename.split(".")[0]
color_file = os.path.join(r"D:\PhD_Study\trinity_client\application\outfit_matcher\color",
basename + ".npy")
if os.path.exists(color_file):
color = np.load(color_file)
else:
color = extract_main_colors(image)
np.save(color_file, color)
return color
def preprocess(outfits):
outfit_images = []
outfit_colors = []
for outfit in outfits:
images = []
colors = []
for item in outfit:
image = load_image(item["image_path"])
image = resize_image(image)
normalized_image = imnormalize(image,
mean=np.array([208.32996145, 201.28227452, 198.47047691], dtype=np.float32),
std=np.array([75.48939648, 80.47423057, 82.21144189], dtype=np.float32))
images.append(normalized_image.transpose(2, 0, 1))
color = extract_color(image, item["image_path"])
colors.append(color)
images = np.stack(images, axis=0)
outfit_images.append(images) # List[(items, 3, 224, 224)]
colors = np.stack(colors, axis=0)
outfit_colors.append(colors)
outfit_images, mask = pad_array(outfit_images)
outfit_colors, _ = pad_array(outfit_colors)
return outfit_images, outfit_colors, mask
def evaluate_outfits(outfits):
"""Input outfits structure and output scores.
Args:
outfits: outfits to be evaluated.
Example:
[
[
{
"item_name": "MSE_57987",
"semantic_category": "BOTTOM/PANTS",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MSE_57987.jpg",
"mapped_cate": "bottoms"
},
{
"item_name": "MPO_SP7712",
"semantic_category": "TOP/TANK",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MPO_SP7712.jpg",
"mapped_cate": "tops"
},
{
"item_name": "MWSS27195",
"semantic_category": "OUTERWEAR/GILET",
"image_path": "D:\\PhD_Study\\MIXI\\mitu\\image\\2024 SS\\MWSS27195.jpg",
"mapped_cate": "outerwear"
}
],
...
]
Returns:
scores: List of float
"""
# start = time.time()
image, color, mask = preprocess(outfits)
# print(start - time.time())
client = httpclient.InferenceServerClient(url="localhost:8000")
# transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
httpclient.InferInput("input__1", color.shape, datatype="FP32"),
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(color.astype(np.float32), binary_data=True)
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
]
results = client.infer(model_name="outfit_matcher_hon", inputs=inputs, outputs=outputs)
# 推理
# 取结果
scores = torch.from_numpy(results.as_numpy("output__0"))
return scores # Shape (N, 1)
if __name__ == '__main__':
with open("test_input.json", "r") as f:
outfits = json.load(f)
scores = evaluate_outfits(outfits)
print(scores)