update config

This commit is contained in:
zhouchengrong
2024-03-11 14:49:01 +08:00
parent a13ae87e57
commit 077066607d
3 changed files with 923 additions and 851 deletions

View File

@@ -6,6 +6,7 @@ import cv2
import numpy as np import numpy as np
import tritonclient.http as httpclient import tritonclient.http as httpclient
import torch import torch
from matplotlib import pyplot as plt, image as mpimg
from foco import extract_main_colors from foco import extract_main_colors
@@ -178,6 +179,51 @@ def evaluate_outfits(outfits):
return scores # Shape (N, 1) return scores # Shape (N, 1)
def visualize(outfits, scores, topk=5, best=True, output_path=None):
# 将outfits和scores按照scores的值进行排序
sorted_indices = np.argsort(-scores.flatten() if best else scores.flatten())[:topk] # 使用负号进行降序排序
outfits = [outfits[i] for i in sorted_indices]
scores = scores[sorted_indices]
# 设置子图的行列数
num_rows = len(outfits)
num_cols = max([len(x) for x in outfits]) + 1 # 一个是图片,一个是分数
# 创建一个新的图像,并指定子图的行列数
fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 15))
title = f"Best {topk} Outfits" if best else f"Worst {topk} Outfits"
fig.suptitle(title, fontsize=16)
# 遍历每套outfit并将其显示在对应的子图中
for i, (outfit, score) in enumerate(zip(outfits, scores)):
# 显示分数
axes[i, 0].text(0.1, 0.5, f"Score: {score[0]:.4f}", fontsize=12)
axes[i, 0].axis("off")
# 显示图片
for j, item in enumerate(outfit):
img = mpimg.imread(item['image_path']) # 读取图片
axes[i, j + 1].imshow(img) # 在对应的子图中显示图片
axes[i, j + 1].axis('off') # 关闭坐标轴
axes[i, j + 1].set_title(item["semantic_category"], fontsize=10)
for j in range(len(outfit), num_cols):
axes[i, j].axis("off")
# 在每一行的底部添加一条横线
axes[i, 0].axhline(y=0, color='black', linewidth=1)
# 隐藏最后一行的横线
axes[-1, 0].axhline(y=0, color='white', linewidth=1)
# 调整布局
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.tight_layout()
if output_path:
plt.savefig(output_path)
else:
plt.show()
if __name__ == '__main__': if __name__ == '__main__':
with open("test_input.json", "r") as f: with open("test_input.json", "r") as f:
outfits = json.load(f) outfits = json.load(f)

View File

@@ -1,3 +1,6 @@
import json
import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tritonclient.http as httpclient import tritonclient.http as httpclient
@@ -5,14 +8,17 @@ import requests
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm
from app.service.outfit_matcher.dataset import FashionDataset
from app.service.outfit_matcher.foco import extract_main_colors from app.service.outfit_matcher.foco import extract_main_colors
from app.service.outfit_matcher.outfit_evaluator import evaluate_outfits, visualize
class OutfitMatcherHon: class OutfitMatcherHon:
def __init__(self, outfits): def __init__(self, outfits):
self.outfits = outfits self.outfits = outfits
self.tritonclient = httpclient.InferenceServerClient(url="localhost:8000") self.tritonclient = httpclient.InferenceServerClient(url="10.1.1.240:10010")
@staticmethod @staticmethod
def imnormalize(img, mean, std, to_rgb=True): def imnormalize(img, mean, std, to_rgb=True):
@@ -91,7 +97,7 @@ class OutfitMatcherHon:
for outfit in self.outfits: for outfit in self.outfits:
images = [] images = []
colors = [] colors = []
for item in outfit["items"]: for item in outfit:
image = self.load_image(item["image_path"]) image = self.load_image(item["image_path"])
image = self.resize_image(image) image = self.resize_image(image)
normalized_image = self.imnormalize(image, normalized_image = self.imnormalize(image,
@@ -108,7 +114,7 @@ class OutfitMatcherHon:
outfit_colors, _ = self.pad_array(outfit_colors) outfit_colors, _ = self.pad_array(outfit_colors)
return outfit_images, outfit_colors, mask return outfit_images, outfit_colors, mask
def get_result(self, outfits): def get_result(self):
# start = time.time() # start = time.time()
image, color, mask = self.preprocess() image, color, mask = self.preprocess()
# print(start - time.time()) # print(start - time.time())
@@ -126,8 +132,29 @@ class OutfitMatcherHon:
outputs = [ outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True), httpclient.InferRequestedOutput("output__0", binary_data=True),
] ]
results = self.tritonclient.infer(model_name="outfit_matcher", inputs=inputs, outputs=outputs) results = self.tritonclient.infer(model_name="outfit_matcher_hon", inputs=inputs, outputs=outputs)
# 推理 # 推理
# 取结果 # 取结果
inference_output1 = torch.from_numpy(results.as_numpy("output__0")) inference_output1 = torch.from_numpy(results.as_numpy("output__0"))
return inference_output1 # Shape (N, 1) return inference_output1 # Shape (N, 1)
if __name__ == '__main__':
with open("./test_param/recommendation_test.json", "r") as f:
param = json.load(f)
fashion_dataset = FashionDataset(param["database"])
for item in tqdm(param["query"]):
outfits = fashion_dataset.generate_outfit(item, param["topk"], param["max_outfits"])
service = OutfitMatcherHon(outfits=outfits)
scores = service.get_result()
visualize(outfits, scores, param["topk"], best=True,
output_path=os.path.join(
r"E:\workspace\outfit_matcher\2024 SS Outfit",
f"{item['item_name']}_best_{param['topk']}.png"
))
visualize(outfits, scores, param["topk"], best=False,
output_path=os.path.join(
r"E:\workspace\outfit_matcher\2024 SS Outfit",
f"{item['item_name']}_worst_{param['topk']}.png"
))
a = 1

File diff suppressed because it is too large Load Diff