关于保存特征的一些代码
This commit is contained in:
@@ -327,9 +327,12 @@ class OutfitMaterTypeAware(OutfitMatcher):
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput("output__0", binary_data=True),
|
||||
httpclient.InferRequestedOutput("output__1", binary_data=True)
|
||||
]
|
||||
results = client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
scores = torch.from_numpy(results.as_numpy("output__0"))
|
||||
return scores # Shape (N, 1)
|
||||
scores = torch.from_numpy(results.as_numpy("output__0")) # Shape (N, 1)
|
||||
features = torch.from_numpy(results.as_numpy("output__1")) # Shape (N, 64)
|
||||
|
||||
return scores, features
|
||||
|
||||
@@ -14,7 +14,14 @@ if __name__ == '__main__':
|
||||
bad_list = []
|
||||
for item in param["query"]:
|
||||
outfits = fashion_dataset.generate_outfit(item, param["topk"], param["max_outfits"])
|
||||
scores = service.get_result(outfits)
|
||||
scores, features = service.get_result(outfits)
|
||||
# save features
|
||||
|
||||
# 链接milvus
|
||||
|
||||
# 存入数据库
|
||||
# 关闭链接
|
||||
|
||||
# print(scores)
|
||||
# print(len(scores))
|
||||
best_outfits, best_scores = service.visualize(outfits, scores, param["topk"], best=True,
|
||||
|
||||
Reference in New Issue
Block a user