fix 代码整理

This commit is contained in:
zhouchengrong
2024-04-10 14:33:06 +08:00
parent 49f5e0a4b5
commit f06f22f092
5 changed files with 23 additions and 17 deletions

View File

@@ -16,9 +16,9 @@ logger = logging.getLogger()
router = APIRouter() router = APIRouter()
@RunTime
@router.post("outfit_matcher") @router.post("outfit_matcher")
def outfit_matcher(request_item: OutfitMatcher): def outfit_matcher(request_item: OutfitMatcher):
start_time = time.time()
request_item = dict(request_item) request_item = dict(request_item)
for i in range(len(request_item['query'])): for i in range(len(request_item['query'])):
request_item['query'][i] = dict(request_item['query'][i]) request_item['query'][i] = dict(request_item['query'][i])
@@ -69,7 +69,6 @@ def outfit_matcher(request_item: OutfitMatcher):
prepared_feature[hfd['item_name']] = hfd['features'] prepared_feature[hfd['item_name']] = hfd['features']
result = [] result = []
start_time = time.time()
for item in request_item['query']: for item in request_item['query']:
# try: # try:
outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"]) outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"])

View File

@@ -1,8 +1,12 @@
import io
import logging import logging
import time import time
from PIL import Image
from fastapi import APIRouter from fastapi import APIRouter
from matplotlib import pyplot as plt
from app.core.config import SIMILAR_MATCH_DRAW
from app.schemas.similar_match import SimilarMatchMItem from app.schemas.similar_match import SimilarMatchMItem
from app.service.similar_match.service import SimilarMatch from app.service.similar_match.service import SimilarMatch
from app.service.utils.decorator import RunTime from app.service.utils.decorator import RunTime
@@ -22,6 +26,13 @@ def similar_match(request_item: SimilarMatchMItem):
response_data = [] response_data = []
for response in search_response[0]: for response in search_response[0]:
response_data.append(response['entity']) response_data.append(response['entity'])
if SIMILAR_MATCH_DRAW:
resource_image = service.load_image(request_item.image_path)
similar_diagram_list = [service.load_image(image_url['image_path']) for image_url in response_data]
resource_image.save("similar_match/3/resource.png")
for i, image in enumerate(similar_diagram_list):
image.save(f"similar_match/3/{i}.png")
return {"message": "ok", "data": response_data} return {"message": "ok", "data": response_data}
except KeyError as e: except KeyError as e:
logger.warning(str(e)) logger.warning(str(e))

View File

@@ -39,7 +39,7 @@ DEBUG = 1
SHOW_OR_SAVE_result_image = False SHOW_OR_SAVE_result_image = False
# service env : 1 # service env : 1
# pycharm debug : 2 # pycharm debug : 2
SIMILAR_MATCH_DRAW = False
if DEBUG == 1: if DEBUG == 1:
LOGS_PATH = "app/logs/errors.log" LOGS_PATH = "app/logs/errors.log"
FASHION_CATEGORIES = "app/service/outfit_matcher/config/fashion_categories.json" FASHION_CATEGORIES = "app/service/outfit_matcher/config/fashion_categories.json"

View File

@@ -23,7 +23,6 @@ class Backbone(object):
secure=MINIO_SECURE) secure=MINIO_SECURE)
@RunTime @RunTime
# TODO 用多线程读图片
def load_image(self, img_path): def load_image(self, img_path):
try: try:
# 从 MinIO 中获取对象(图像文件) # 从 MinIO 中获取对象(图像文件)

View File

@@ -55,33 +55,30 @@ class SimilarMatch:
def preprocess(self, img_path): def preprocess(self, img_path):
image = self.load_image(img_path) image = self.load_image(img_path)
image = self.resize_image(image) image = self.resize_image(image)
image = np.stack([[image]], axis=0) image = np.stack([image], axis=0)
category = np.stack([[1, 6]], axis=0) # category = np.stack([[1, 6]], axis=0)
mask = np.zeros((1, 1), dtype=np.float32) # mask = np.zeros((1, 1), dtype=np.float32)
return image, category, mask return image
# , category, mask)
def get_features(self): def get_features(self):
image, category, mask = self.preprocess(self.image_path) image = self.preprocess(self.image_path)
# image, category, mask = self.preprocess(self.image_path)
# 输入集 # 输入集
inputs = [ inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32"), httpclient.InferInput("input__0", image.shape, datatype="FP32"),
httpclient.InferInput("input__1", category.shape, datatype="INT16"),
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
] ]
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True) inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True)
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
# 输出集 # 输出集
outputs = [ outputs = [
httpclient.InferRequestedOutput("output__0", binary_data=True), httpclient.InferRequestedOutput("output__0", binary_data=True),
httpclient.InferRequestedOutput("output__1", binary_data=True)
] ]
results = self.triton_client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs) results = self.triton_client.infer(model_name="outfit_matcher_backbone", inputs=inputs, outputs=outputs)
# 推理 # 推理
# 取结果 # 取结果
features = results.as_numpy("output__1") # Shape (N, 64) features = results.as_numpy("output__0") # Shape (N, 64)
return features return features
@RunTime @RunTime
@@ -94,7 +91,7 @@ class SimilarMatch:
# Replace with your query vector # Replace with your query vector
data=[self.features[0]], data=[self.features[0]],
limit=self.result_number, # Max. number of search results to return limit=self.result_number, # Max. number of search results to return
output_fields=["id", "image_path"], # Search parameters output_fields=["item_name", "image_path"], # Search parameters
) )
return search_response return search_response
finally: finally: