搭配服务修改

This commit is contained in:
zhouchengrong
2024-03-28 17:22:51 +08:00
parent 39dae92ea0
commit eb9351dc87
5 changed files with 117 additions and 79 deletions

View File

@@ -15,5 +15,4 @@ router = APIRouter()
def attribute(request_data: AttributeModel):
service = AttributeRecognition()
response = service.attribute(const, request_data)
logger.info("test")
return {"code": 200, "message": "ok", "data": response}

View File

@@ -1,10 +1,13 @@
import logging
import time
from copy import deepcopy
from fastapi import APIRouter
from pymilvus import MilvusClient
from app.schemas.outfit_matcher import OutfitMatcher
from app.service.outfit_matcher.dataset import FashionDataset
from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware
from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware, Backbone
from app.service.utils.decorator import RunTime
logger = logging.getLogger()
@@ -19,16 +22,28 @@ def outfit_matcher(request_item: OutfitMatcher):
request_item['query'][i] = dict(request_item['query'][i])
for i in range(len(request_item['database'])):
request_item['database'][i] = dict(request_item['database'][i])
# try:
fashion_dataset = FashionDataset(request_item['database'])
backbone_service = Backbone()
service = OutfitMaterTypeAware()
all_items = request_item["query"] + request_item["database"]
prepared_feature = {}
extracted_features = backbone_service.get_result(all_items)
data = deepcopy(all_items) # 做深拷贝 , all_items 是list 可变数组
for i, feature in enumerate(extracted_features):
data[i]['features'] = feature
if 'mapped_cate' in data[i].keys():
del data[i]['mapped_cate']
client = MilvusClient(uri="http://10.1.1.240:19530", token="root:Milvus", db_name="mixi")
res = client.insert(collection_name="mixi_outfit", data=data)
client.close()
for d in data:
prepared_feature[d['item_name']] = d['features']
result = []
start_time = time.time()
for item in request_item['query']:
outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"])
scores, features = service.get_result(outfits)
# save features in databases
scores = service.get_result(outfits, prepared_feature)
if request_item['is_best']:
best_outfits, best_scores = service.visualize(outfits, scores, request_item["topk"], best=True,
@@ -44,4 +59,4 @@ def outfit_matcher(request_item: OutfitMatcher):
return {"message": "ok", "data": result}
# except Exception as e:
# logger.warning(e)
# return {"message": f"{e}", "data": e}
# return {"message": f"{e}", "data": e}