diff --git a/app/service/outfit_matcher/outfit_evaluator.py b/app/service/outfit_matcher/outfit_evaluator.py index ff8f9af..311d597 100644 --- a/app/service/outfit_matcher/outfit_evaluator.py +++ b/app/service/outfit_matcher/outfit_evaluator.py @@ -327,9 +327,11 @@ class OutfitMaterTypeAware(OutfitMatcher): # 输出集 outputs = [ httpclient.InferRequestedOutput("output__0", binary_data=True), + httpclient.InferRequestedOutput("output__1", binary_data=True), ] results = client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs) # 推理 # 取结果 - scores = torch.from_numpy(results.as_numpy("output__0")) - return scores # Shape (N, 1) + scores = torch.from_numpy(results.as_numpy("output__0")) # Shape (N, 1) + features = torch.from_numpy(results.as_numpy("output__1")) # Shape (N, -1, 64) + return scores, features