关于保存特征的一些代码

This commit is contained in:
zchen
2024-03-26 18:00:35 +08:00
parent 7dce70027a
commit 88c815dd9d
16 changed files with 378 additions and 10 deletions

View File

@@ -327,9 +327,12 @@ class OutfitMaterTypeAware(OutfitMatcher):
# 输出集
outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True),
httpclient.InferRequestedOutput("output__1", binary_data=True)
]
results = client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
# 推理
# 取结果
scores = torch.from_numpy(results.as_numpy("output__0"))
return scores # Shape (N, 1)
scores = torch.from_numpy(results.as_numpy("output__0")) # Shape (N, 1)
features = torch.from_numpy(results.as_numpy("output__1")) # Shape (N, 64)
return scores, features

View File

@@ -14,7 +14,14 @@ if __name__ == '__main__':
bad_list = []
for item in param["query"]:
outfits = fashion_dataset.generate_outfit(item, param["topk"], param["max_outfits"])
scores = service.get_result(outfits)
scores, features = service.get_result(outfits)
# save features
# 链接milvus
# 存入数据库
# 关闭链接
# print(scores)
# print(len(scores))
best_outfits, best_scores = service.visualize(outfits, scores, param["topk"], best=True,