1.新增是否推理获取特征判断

2.取消搭配不足异常逻辑
This commit is contained in:
zhouchengrong
2024-04-05 17:45:25 +08:00
parent 010d1536eb
commit 726eee86ab
7 changed files with 332 additions and 110 deletions

View File

@@ -1,57 +1,84 @@
import json
import os
from pprint import pprint
import numpy as np
import time
from copy import deepcopy
from pymilvus import MilvusClient
from app.core.config import *
from app.service.outfit_matcher.dataset import FashionDataset
from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware, Backbone
logger = logging.getLogger()
if __name__ == '__main__':
with open("./test_param/recommendation_test.json", "r") as f:
param = json.load(f)
fashion_dataset = FashionDataset(param["database"])
with open("./test_param/test.json", "r") as f:
request_item = json.load(f)
request_item = dict(request_item)
for i in range(len(request_item['query'])):
request_item['query'][i] = dict(request_item['query'][i])
for i in range(len(request_item['database'])):
request_item['database'][i] = dict(request_item['database'][i])
fashion_dataset = FashionDataset(request_item['database'])
backbone_service = Backbone()
service = OutfitMaterTypeAware()
# read feature from vector database
all_items = param["query"] + param["database"]
unextracted_item = []
all_items = request_item["query"] + request_item["database"]
prepared_feature = {}
# 拿到所有需要提取特征的图片
for item in all_items:
if f'{item["item_name"]}.npy' not in os.listdir("feature"):
unextracted_item.append(item)
if len(unextracted_item) > 0:
# 通过backbone模型提取图片特征
extracted_features = backbone_service.get_result(unextracted_item)
for i, item in enumerate(unextracted_item):
# save features
# 链接milvus
# TODO
np.save(f'feature/{item["item_name"]}.npy', extracted_features[i])
# 存入数据库
# 关闭链接
# 连接milvus
client = MilvusClient(uri=MILVUS_URL, token="root:Milvus", db_name="mixi")
search_data = client.get(collection_name='mixi_outfit', ids=[item['item_name'] for item in all_items])
# TODO 读取本次任务需要的图片特征
for item in all_items:
if item["item_name"] not in prepared_feature.keys():
prepared_feature[item["item_name"]] = np.load(f'feature/{item["item_name"]}.npy')
# 查询数据库,分成两批 需要过模型推理的和不需要的
have_features_data = []
no_have_features_data = []
temp_data = deepcopy(all_items)
for td in temp_data:
for sd in search_data:
if td['item_name'] == sd['item_name']:
td['features'] = sd['features']
if "features" not in td.keys():
no_have_features_data.append(td)
else:
have_features_data.append(td)
if len(no_have_features_data) > 0:
extracted_features = backbone_service.get_result(no_have_features_data)
# 准备数据
data = deepcopy(no_have_features_data) # 做深拷贝 , all_items 是list 可变数组
for i, feature in enumerate(extracted_features):
data[i]['features'] = feature
if 'mapped_cate' in data[i].keys():
del data[i]['mapped_cate']
# 存入数据
res = client.insert(collection_name="mixi_outfit", data=data)
for d in data:
prepared_feature[d['item_name']] = d['features']
for hfd in have_features_data:
prepared_feature[hfd['item_name']] = hfd['features']
# 断开连接
client.close()
result = []
start_time = time.time()
for item in request_item['query']:
# try:
outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"])
# except ValueError as e:
# logger.warning(e)
# return {"code": 500, "message": f"valueError : {e}", "data": e}
# 开始服装搭配任务
for item in param["query"]:
# 根据一定规则生成outfit
outfits = fashion_dataset.generate_outfit(item, param["topk"], param["max_outfits"])
# 根据模型对生成的outfit打分
scores = service.get_result(outfits, prepared_feature)
# 对评分排序拿到最好的topk个outfit输出
sorted_indices = np.argsort(scores)[:param["topk"]] # type-aware
best_outfits = [outfits[i] for i in sorted_indices] # 最好的五个
# 结果可视化
# service.visualize(outfits, scores, param["topk"], best=True,
# output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123",
# f"{item['item_name']}_best_{param['topk']}.png"))
# service.visualize(outfits, scores, param["topk"], best=False,
# output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123",
# f"{item['item_name']}_worst_{param['topk']}.png"))
if request_item['is_best']:
best_outfits, best_scores = service.visualize(outfits, scores, request_item["topk"], best=True,
output_path=rf"E:\workspace\trinity_client_mixi\app\service\outfit_matcher\output_outfit\{item['item_name']}_best_{request_item['topk']}.png"
)
result.append({"outfits": best_outfits, "scores": best_scores})
else:
bad_outfits, bad_scores = service.visualize(outfits, scores, request_item["topk"], best=False,
# output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_worst_{param['topk']}.png")
)
result.append({"outfits": bad_outfits, "scores": bad_scores})
logger.info(f"run time is : {time.time() - start_time}")