From f06f22f0926c7a4f2fe1f840e60e4a81c95a62c6 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 10 Apr 2024 14:33:06 +0800 Subject: [PATCH] =?UTF-8?q?fix=20=E4=BB=A3=E7=A0=81=E6=95=B4=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_outfit_matcher.py | 3 +-- app/api/api_similar_match.py | 11 +++++++++ app/core/config.py | 2 +- .../outfit_matcher/outfit_evaluator.py | 1 - app/service/similar_match/service.py | 23 ++++++++----------- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/app/api/api_outfit_matcher.py b/app/api/api_outfit_matcher.py index 9bef932..c9cfaf9 100644 --- a/app/api/api_outfit_matcher.py +++ b/app/api/api_outfit_matcher.py @@ -16,9 +16,9 @@ logger = logging.getLogger() router = APIRouter() -@RunTime @router.post("outfit_matcher") def outfit_matcher(request_item: OutfitMatcher): + start_time = time.time() request_item = dict(request_item) for i in range(len(request_item['query'])): request_item['query'][i] = dict(request_item['query'][i]) @@ -69,7 +69,6 @@ def outfit_matcher(request_item: OutfitMatcher): prepared_feature[hfd['item_name']] = hfd['features'] result = [] - start_time = time.time() for item in request_item['query']: # try: outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"]) diff --git a/app/api/api_similar_match.py b/app/api/api_similar_match.py index 28c0928..4f3ffe2 100644 --- a/app/api/api_similar_match.py +++ b/app/api/api_similar_match.py @@ -1,8 +1,12 @@ +import io import logging import time +from PIL import Image from fastapi import APIRouter +from matplotlib import pyplot as plt +from app.core.config import SIMILAR_MATCH_DRAW from app.schemas.similar_match import SimilarMatchMItem from app.service.similar_match.service import SimilarMatch from app.service.utils.decorator import RunTime @@ -22,6 +26,13 @@ def similar_match(request_item: SimilarMatchMItem): response_data = [] for response in search_response[0]: response_data.append(response['entity']) + + if SIMILAR_MATCH_DRAW: + resource_image = service.load_image(request_item.image_path) + similar_diagram_list = [service.load_image(image_url['image_path']) for image_url in response_data] + resource_image.save("similar_match/3/resource.png") + for i, image in enumerate(similar_diagram_list): + image.save(f"similar_match/3/{i}.png") return {"message": "ok", "data": response_data} except KeyError as e: logger.warning(str(e)) diff --git a/app/core/config.py b/app/core/config.py index 39edcce..bafbbd0 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -39,7 +39,7 @@ DEBUG = 1 SHOW_OR_SAVE_result_image = False # service env : 1 # pycharm debug : 2 - +SIMILAR_MATCH_DRAW = False if DEBUG == 1: LOGS_PATH = "app/logs/errors.log" FASHION_CATEGORIES = "app/service/outfit_matcher/config/fashion_categories.json" diff --git a/app/service/outfit_matcher/outfit_evaluator.py b/app/service/outfit_matcher/outfit_evaluator.py index 780cb69..00589c6 100644 --- a/app/service/outfit_matcher/outfit_evaluator.py +++ b/app/service/outfit_matcher/outfit_evaluator.py @@ -23,7 +23,6 @@ class Backbone(object): secure=MINIO_SECURE) @RunTime - # TODO 用多线程读图片 def load_image(self, img_path): try: # 从 MinIO 中获取对象(图像文件) diff --git a/app/service/similar_match/service.py b/app/service/similar_match/service.py index c5d3af7..367ff85 100644 --- a/app/service/similar_match/service.py +++ b/app/service/similar_match/service.py @@ -55,33 +55,30 @@ class SimilarMatch: def preprocess(self, img_path): image = self.load_image(img_path) image = self.resize_image(image) - image = np.stack([[image]], axis=0) + image = np.stack([image], axis=0) - category = np.stack([[1, 6]], axis=0) + # category = np.stack([[1, 6]], axis=0) - mask = np.zeros((1, 1), dtype=np.float32) - return image, category, mask + # mask = np.zeros((1, 1), dtype=np.float32) + return image + # , category, mask) def get_features(self): - image, category, mask = self.preprocess(self.image_path) + image = self.preprocess(self.image_path) + # image, category, mask = self.preprocess(self.image_path) # 输入集 inputs = [ httpclient.InferInput("input__0", image.shape, datatype="FP32"), - httpclient.InferInput("input__1", category.shape, datatype="INT16"), - httpclient.InferInput("input__2", mask.shape, datatype="FP32"), ] inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True) - inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True) - inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True) # 输出集 outputs = [ httpclient.InferRequestedOutput("output__0", binary_data=True), - httpclient.InferRequestedOutput("output__1", binary_data=True) ] - results = self.triton_client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs) + results = self.triton_client.infer(model_name="outfit_matcher_backbone", inputs=inputs, outputs=outputs) # 推理 # 取结果 - features = results.as_numpy("output__1") # Shape (N, 64) + features = results.as_numpy("output__0") # Shape (N, 64) return features @RunTime @@ -94,7 +91,7 @@ class SimilarMatch: # Replace with your query vector data=[self.features[0]], limit=self.result_number, # Max. number of search results to return - output_fields=["id", "image_path"], # Search parameters + output_fields=["item_name", "image_path"], # Search parameters ) return search_response finally: