Files
sora_python/app/api/api_outfit_matcher.py
2024-10-22 15:11:08 +08:00

97 lines
4.1 KiB
Python

import logging
import os
import time
from copy import deepcopy
from fastapi import APIRouter
from pymilvus import MilvusClient
from app.core.config import MILVUS_URL
from app.schemas.outfit_matcher import OutfitMatcher
from app.service.outfit_matcher.dataset import FashionDataset
from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware, Backbone
from app.service.utils.decorator import RunTime
logger = logging.getLogger()
router = APIRouter()
@router.post("/outfit_matcher")
def outfit_matcher(request_item: OutfitMatcher):
start_time = time.time()
request_item = dict(request_item)
for i in range(len(request_item['query'])):
request_item['query'][i] = dict(request_item['query'][i])
for i in range(len(request_item['database'])):
request_item['database'][i] = dict(request_item['database'][i])
fashion_dataset = FashionDataset(request_item['database'])
backbone_service = Backbone()
service = OutfitMaterTypeAware()
all_items = request_item["query"] + request_item["database"]
prepared_feature = {}
have_features_data = []
no_have_features_data = []
temp_data = deepcopy(all_items)
# 连接milvus
client = MilvusClient(uri=MILVUS_URL, token="root:Milvus", db_name="mixi")
try:
search_data = client.get(collection_name='mixi_outfit', ids=[item['item_name'] for item in all_items])
# 查询数据库,分成两批 需要过模型推理的和不需要的
for td in temp_data:
for sd in search_data:
if td['item_name'] == sd['item_name']:
td['features'] = sd['features']
if "features" not in td.keys():
no_have_features_data.append(td)
else:
have_features_data.append(td)
if len(no_have_features_data) > 0:
extracted_features = backbone_service.get_result(no_have_features_data)
# 准备数据
data = deepcopy(no_have_features_data) # 做深拷贝 , all_items 是list 可变数组
for i, feature in enumerate(extracted_features):
data[i]['features'] = feature
if 'mapped_cate' in data[i].keys():
del data[i]['mapped_cate']
# 存入数据
res = client.insert(collection_name="mixi_outfit", data=data)
# 断开连接
for d in data:
prepared_feature[d['item_name']] = d['features']
finally:
client.close()
for hfd in have_features_data:
prepared_feature[hfd['item_name']] = hfd['features']
result = []
for item in request_item['query']:
# try:
outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"])
# except ValueError as e:
# logger.warning(e)
# return {"code": 500, "message": f"valueError : {e}", "data": e}
scores = service.get_result(outfits, prepared_feature)
if request_item['is_best']:
best_outfits, best_scores = service.visualize(outfits, scores, request_item["topk"], best=True,
# output_path=rf"E:\workspace\trinity_client_mixi\app\service\outfit_matcher\output_outfit\{item['item_name']}_best_{request_item['topk']}.png"
)
result.append({"outfits": best_outfits, "scores": best_scores})
else:
bad_outfits, bad_scores = service.visualize(outfits, scores, request_item["topk"], best=False,
# output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_worst_{param['topk']}.png")
)
result.append({"outfits": bad_outfits, "scores": bad_scores})
logger.info(f"run time is : {time.time() - start_time}")
logger.info({"code": 200, "message": "ok", "data": result})
return {"code": 200, "message": "ok", "data": result}
# except Exception as e:
# logger.warning(e)
# return {"message": f"{e}", "data": e}