import json import time from copy import deepcopy from pymilvus import MilvusClient from app.core.config import * from app.service.outfit_matcher.dataset import FashionDataset from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware, Backbone logger = logging.getLogger() if __name__ == '__main__': with open("./test_param/test.json", "r") as f: request_item = json.load(f) request_item = dict(request_item) for i in range(len(request_item['query'])): request_item['query'][i] = dict(request_item['query'][i]) for i in range(len(request_item['database'])): request_item['database'][i] = dict(request_item['database'][i]) fashion_dataset = FashionDataset(request_item['database']) backbone_service = Backbone() service = OutfitMaterTypeAware() all_items = request_item["query"] + request_item["database"] prepared_feature = {} # 连接milvus client = MilvusClient(uri=MILVUS_URL, token="root:Milvus", db_name="mixi") search_data = client.get(collection_name='mixi_outfit', ids=[item['item_name'] for item in all_items]) # 查询数据库,分成两批 需要过模型推理的和不需要的 have_features_data = [] no_have_features_data = [] temp_data = deepcopy(all_items) for td in temp_data: for sd in search_data: if td['item_name'] == sd['item_name']: td['features'] = sd['features'] if "features" not in td.keys(): no_have_features_data.append(td) else: have_features_data.append(td) if len(no_have_features_data) > 0: extracted_features = backbone_service.get_result(no_have_features_data) # 准备数据 data = deepcopy(no_have_features_data) # 做深拷贝 , all_items 是list 可变数组 for i, feature in enumerate(extracted_features): data[i]['features'] = feature if 'mapped_cate' in data[i].keys(): del data[i]['mapped_cate'] # 存入数据 res = client.insert(collection_name="mixi_outfit", data=data) for d in data: prepared_feature[d['item_name']] = d['features'] for hfd in have_features_data: prepared_feature[hfd['item_name']] = hfd['features'] # 断开连接 client.close() result = [] start_time = time.time() for item in request_item['query']: # try: outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"]) # except ValueError as e: # logger.warning(e) # return {"code": 500, "message": f"valueError : {e}", "data": e} scores = service.get_result(outfits, prepared_feature) if request_item['is_best']: best_outfits, best_scores = service.visualize(outfits, scores, request_item["topk"], best=True, output_path=rf"E:\workspace\trinity_client_mixi\app\service\outfit_matcher\output_outfit\{item['item_name']}_best_{request_item['topk']}.png" ) result.append({"outfits": best_outfits, "scores": best_scores}) else: bad_outfits, bad_scores = service.visualize(outfits, scores, request_item["topk"], best=False, # output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_worst_{param['topk']}.png") ) result.append({"outfits": bad_outfits, "scores": bad_scores}) logger.info(f"run time is : {time.time() - start_time}")