import logging import time from copy import deepcopy from fastapi import APIRouter from pymilvus import MilvusClient from app.core.config import MILVUS_URL from app.schemas.outfit_matcher import OutfitMatcher from app.service.outfit_matcher.dataset import FashionDataset from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware, Backbone from app.service.utils.decorator import RunTime logger = logging.getLogger() router = APIRouter() @RunTime @router.post("outfit_matcher") def outfit_matcher(request_item: OutfitMatcher): request_item = dict(request_item) for i in range(len(request_item['query'])): request_item['query'][i] = dict(request_item['query'][i]) for i in range(len(request_item['database'])): request_item['database'][i] = dict(request_item['database'][i]) fashion_dataset = FashionDataset(request_item['database']) backbone_service = Backbone() service = OutfitMaterTypeAware() all_items = request_item["query"] + request_item["database"] prepared_feature = {} extracted_features = backbone_service.get_result(all_items) data = deepcopy(all_items) # 做深拷贝 , all_items 是list 可变数组 for i, feature in enumerate(extracted_features): data[i]['features'] = feature if 'mapped_cate' in data[i].keys(): del data[i]['mapped_cate'] client = MilvusClient(uri=MILVUS_URL, token="root:Milvus", db_name="mixi") res = client.insert(collection_name="mixi_outfit", data=data) client.close() for d in data: prepared_feature[d['item_name']] = d['features'] result = [] start_time = time.time() for item in request_item['query']: try: outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"]) except ValueError as e: logger.warning(e) return {"code": 500, "message": f"valueError : {e}", "data": e} scores = service.get_result(outfits, prepared_feature) if request_item['is_best']: best_outfits, best_scores = service.visualize(outfits, scores, request_item["topk"], best=True, # output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_best_{param['topk']}.png") ) result.append({"outfits": best_outfits, "scores": best_scores}) else: bad_outfits, bad_scores = service.visualize(outfits, scores, request_item["topk"], best=False, # output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_worst_{param['topk']}.png") ) result.append({"outfits": bad_outfits, "scores": bad_scores}) logger.info(f"run time is : {time.time() - start_time}") return {"code": 200, "message": "ok", "data": result} # except Exception as e: # logger.warning(e) # return {"message": f"{e}", "data": e}