fix 代码整理
This commit is contained in:
@@ -16,9 +16,9 @@ logger = logging.getLogger()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@RunTime
|
|
||||||
@router.post("outfit_matcher")
|
@router.post("outfit_matcher")
|
||||||
def outfit_matcher(request_item: OutfitMatcher):
|
def outfit_matcher(request_item: OutfitMatcher):
|
||||||
|
start_time = time.time()
|
||||||
request_item = dict(request_item)
|
request_item = dict(request_item)
|
||||||
for i in range(len(request_item['query'])):
|
for i in range(len(request_item['query'])):
|
||||||
request_item['query'][i] = dict(request_item['query'][i])
|
request_item['query'][i] = dict(request_item['query'][i])
|
||||||
@@ -69,7 +69,6 @@ def outfit_matcher(request_item: OutfitMatcher):
|
|||||||
prepared_feature[hfd['item_name']] = hfd['features']
|
prepared_feature[hfd['item_name']] = hfd['features']
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
start_time = time.time()
|
|
||||||
for item in request_item['query']:
|
for item in request_item['query']:
|
||||||
# try:
|
# try:
|
||||||
outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"])
|
outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"])
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
from app.core.config import SIMILAR_MATCH_DRAW
|
||||||
from app.schemas.similar_match import SimilarMatchMItem
|
from app.schemas.similar_match import SimilarMatchMItem
|
||||||
from app.service.similar_match.service import SimilarMatch
|
from app.service.similar_match.service import SimilarMatch
|
||||||
from app.service.utils.decorator import RunTime
|
from app.service.utils.decorator import RunTime
|
||||||
@@ -22,6 +26,13 @@ def similar_match(request_item: SimilarMatchMItem):
|
|||||||
response_data = []
|
response_data = []
|
||||||
for response in search_response[0]:
|
for response in search_response[0]:
|
||||||
response_data.append(response['entity'])
|
response_data.append(response['entity'])
|
||||||
|
|
||||||
|
if SIMILAR_MATCH_DRAW:
|
||||||
|
resource_image = service.load_image(request_item.image_path)
|
||||||
|
similar_diagram_list = [service.load_image(image_url['image_path']) for image_url in response_data]
|
||||||
|
resource_image.save("similar_match/3/resource.png")
|
||||||
|
for i, image in enumerate(similar_diagram_list):
|
||||||
|
image.save(f"similar_match/3/{i}.png")
|
||||||
return {"message": "ok", "data": response_data}
|
return {"message": "ok", "data": response_data}
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ DEBUG = 1
|
|||||||
SHOW_OR_SAVE_result_image = False
|
SHOW_OR_SAVE_result_image = False
|
||||||
# service env : 1
|
# service env : 1
|
||||||
# pycharm debug : 2
|
# pycharm debug : 2
|
||||||
|
SIMILAR_MATCH_DRAW = False
|
||||||
if DEBUG == 1:
|
if DEBUG == 1:
|
||||||
LOGS_PATH = "app/logs/errors.log"
|
LOGS_PATH = "app/logs/errors.log"
|
||||||
FASHION_CATEGORIES = "app/service/outfit_matcher/config/fashion_categories.json"
|
FASHION_CATEGORIES = "app/service/outfit_matcher/config/fashion_categories.json"
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ class Backbone(object):
|
|||||||
secure=MINIO_SECURE)
|
secure=MINIO_SECURE)
|
||||||
|
|
||||||
@RunTime
|
@RunTime
|
||||||
# TODO 用多线程读图片
|
|
||||||
def load_image(self, img_path):
|
def load_image(self, img_path):
|
||||||
try:
|
try:
|
||||||
# 从 MinIO 中获取对象(图像文件)
|
# 从 MinIO 中获取对象(图像文件)
|
||||||
|
|||||||
@@ -55,33 +55,30 @@ class SimilarMatch:
|
|||||||
def preprocess(self, img_path):
|
def preprocess(self, img_path):
|
||||||
image = self.load_image(img_path)
|
image = self.load_image(img_path)
|
||||||
image = self.resize_image(image)
|
image = self.resize_image(image)
|
||||||
image = np.stack([[image]], axis=0)
|
image = np.stack([image], axis=0)
|
||||||
|
|
||||||
category = np.stack([[1, 6]], axis=0)
|
# category = np.stack([[1, 6]], axis=0)
|
||||||
|
|
||||||
mask = np.zeros((1, 1), dtype=np.float32)
|
# mask = np.zeros((1, 1), dtype=np.float32)
|
||||||
return image, category, mask
|
return image
|
||||||
|
# , category, mask)
|
||||||
|
|
||||||
def get_features(self):
|
def get_features(self):
|
||||||
image, category, mask = self.preprocess(self.image_path)
|
image = self.preprocess(self.image_path)
|
||||||
|
# image, category, mask = self.preprocess(self.image_path)
|
||||||
# 输入集
|
# 输入集
|
||||||
inputs = [
|
inputs = [
|
||||||
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
|
httpclient.InferInput("input__0", image.shape, datatype="FP32"),
|
||||||
httpclient.InferInput("input__1", category.shape, datatype="INT16"),
|
|
||||||
httpclient.InferInput("input__2", mask.shape, datatype="FP32"),
|
|
||||||
]
|
]
|
||||||
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
|
inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True)
|
||||||
inputs[1].set_data_from_numpy(category.astype(np.int16), binary_data=True)
|
|
||||||
inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True)
|
|
||||||
# 输出集
|
# 输出集
|
||||||
outputs = [
|
outputs = [
|
||||||
httpclient.InferRequestedOutput("output__0", binary_data=True),
|
httpclient.InferRequestedOutput("output__0", binary_data=True),
|
||||||
httpclient.InferRequestedOutput("output__1", binary_data=True)
|
|
||||||
]
|
]
|
||||||
results = self.triton_client.infer(model_name="outfit_matcher_type_aware", inputs=inputs, outputs=outputs)
|
results = self.triton_client.infer(model_name="outfit_matcher_backbone", inputs=inputs, outputs=outputs)
|
||||||
# 推理
|
# 推理
|
||||||
# 取结果
|
# 取结果
|
||||||
features = results.as_numpy("output__1") # Shape (N, 64)
|
features = results.as_numpy("output__0") # Shape (N, 64)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@RunTime
|
@RunTime
|
||||||
@@ -94,7 +91,7 @@ class SimilarMatch:
|
|||||||
# Replace with your query vector
|
# Replace with your query vector
|
||||||
data=[self.features[0]],
|
data=[self.features[0]],
|
||||||
limit=self.result_number, # Max. number of search results to return
|
limit=self.result_number, # Max. number of search results to return
|
||||||
output_fields=["id", "image_path"], # Search parameters
|
output_fields=["item_name", "image_path"], # Search parameters
|
||||||
)
|
)
|
||||||
return search_response
|
return search_response
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
Reference in New Issue
Block a user