This commit is contained in:
pangkaicheng
2024-03-28 10:56:21 +08:00
parent 4e832bbadb
commit 7d5e2b28e9
3 changed files with 744 additions and 43 deletions

View File

@@ -1,7 +1,6 @@
import json
import os
from pprint import pprint
from tqdm import tqdm
import numpy as np
from app.service.outfit_matcher.dataset import FashionDataset
@@ -40,21 +39,21 @@ if __name__ == '__main__':
prepared_feature[item["item_name"]] = np.load(f'feature/{item["item_name"]}.npy')
# 开始服装搭配任务
for item in tqdm(param["query"] * 10):
for item in param["query"]:
# 根据一定规则生成outfit
outfits = fashion_dataset.generate_outfit(item, param["topk"], param["max_outfits"])
# 根据模型对生成的outfit打分
scores = service.get_result(outfits, prepared_feature)
# 对评分排序拿到最好的topk个outfit输出
sorted_indices = np.argsort(scores)[:param["topk"]] # type-aware
outfits = [outfits[i] for i in sorted_indices] # 最好的五个
best_outfits = [outfits[i] for i in sorted_indices] # 最好的五个
# 结果可视化
# service.visualize(outfits, scores, param["topk"], best=True,
# output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123",
# f"{item['item_name']}_best_{param['topk']}.png"))
# service.visualize(outfits, scores, param["topk"], best=False,
# output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123",
# f"{item['item_name']}_worst_{param['topk']}.png"))
service.visualize(outfits, scores, param["topk"], best=True,
output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123",
f"{item['item_name']}_best_{param['topk']}.png"))
service.visualize(outfits, scores, param["topk"], best=False,
output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123",
f"{item['item_name']}_worst_{param['topk']}.png"))