Files
sora_python/app/service/outfit_matcher/service.py
2024-03-28 17:22:51 +08:00

58 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import os
from pprint import pprint
import numpy as np
from app.service.outfit_matcher.dataset import FashionDataset
from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware, Backbone
if __name__ == '__main__':
with open("./test_param/recommendation_test.json", "r") as f:
param = json.load(f)
fashion_dataset = FashionDataset(param["database"])
backbone_service = Backbone()
service = OutfitMaterTypeAware()
# read feature from vector database
all_items = param["query"] + param["database"]
unextracted_item = []
prepared_feature = {}
# 拿到所有需要提取特征的图片
for item in all_items:
if f'{item["item_name"]}.npy' not in os.listdir("feature"):
unextracted_item.append(item)
if len(unextracted_item) > 0:
# 通过backbone模型提取图片特征
extracted_features = backbone_service.get_result(unextracted_item)
for i, item in enumerate(unextracted_item):
# save features
# 链接milvus
# TODO
np.save(f'feature/{item["item_name"]}.npy', extracted_features[i])
# 存入数据库
# 关闭链接
# TODO 读取本次任务需要的图片特征
for item in all_items:
if item["item_name"] not in prepared_feature.keys():
prepared_feature[item["item_name"]] = np.load(f'feature/{item["item_name"]}.npy')
# 开始服装搭配任务
for item in param["query"]:
# 根据一定规则生成outfit
outfits = fashion_dataset.generate_outfit(item, param["topk"], param["max_outfits"])
# 根据模型对生成的outfit打分
scores = service.get_result(outfits, prepared_feature)
# 对评分排序拿到最好的topk个outfit输出
sorted_indices = np.argsort(scores)[:param["topk"]] # type-aware
best_outfits = [outfits[i] for i in sorted_indices] # 最好的五个
# 结果可视化
# service.visualize(outfits, scores, param["topk"], best=True,
# output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123",
# f"{item['item_name']}_best_{param['topk']}.png"))
# service.visualize(outfits, scores, param["topk"], best=False,
# output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123",
# f"{item['item_name']}_worst_{param['topk']}.png"))