diff --git a/app/api/api_attribute.py b/app/api/api_attribute.py index 972ccd3..213a1a9 100644 --- a/app/api/api_attribute.py +++ b/app/api/api_attribute.py @@ -15,5 +15,4 @@ router = APIRouter() def attribute(request_data: AttributeModel): service = AttributeRecognition() response = service.attribute(const, request_data) - logger.info("test") return {"code": 200, "message": "ok", "data": response} diff --git a/app/api/api_outfit_matcher.py b/app/api/api_outfit_matcher.py index 24a6337..4532b0c 100644 --- a/app/api/api_outfit_matcher.py +++ b/app/api/api_outfit_matcher.py @@ -1,10 +1,13 @@ import logging import time +from copy import deepcopy from fastapi import APIRouter +from pymilvus import MilvusClient + from app.schemas.outfit_matcher import OutfitMatcher from app.service.outfit_matcher.dataset import FashionDataset -from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware +from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware, Backbone from app.service.utils.decorator import RunTime logger = logging.getLogger() @@ -19,16 +22,28 @@ def outfit_matcher(request_item: OutfitMatcher): request_item['query'][i] = dict(request_item['query'][i]) for i in range(len(request_item['database'])): request_item['database'][i] = dict(request_item['database'][i]) - - # try: fashion_dataset = FashionDataset(request_item['database']) + backbone_service = Backbone() service = OutfitMaterTypeAware() + all_items = request_item["query"] + request_item["database"] + prepared_feature = {} + extracted_features = backbone_service.get_result(all_items) + data = deepcopy(all_items) # 做深拷贝 , all_items 是list 可变数组 + for i, feature in enumerate(extracted_features): + data[i]['features'] = feature + if 'mapped_cate' in data[i].keys(): + del data[i]['mapped_cate'] + + client = MilvusClient(uri="http://10.1.1.240:19530", token="root:Milvus", db_name="mixi") + res = client.insert(collection_name="mixi_outfit", data=data) + client.close() + for d in data: + prepared_feature[d['item_name']] = d['features'] result = [] start_time = time.time() for item in request_item['query']: outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"]) - scores, features = service.get_result(outfits) - # save features in databases + scores = service.get_result(outfits, prepared_feature) if request_item['is_best']: best_outfits, best_scores = service.visualize(outfits, scores, request_item["topk"], best=True, @@ -44,4 +59,4 @@ def outfit_matcher(request_item: OutfitMatcher): return {"message": "ok", "data": result} # except Exception as e: # logger.warning(e) - # return {"message": f"{e}", "data": e} \ No newline at end of file + # return {"message": f"{e}", "data": e} diff --git a/app/core/config.py b/app/core/config.py index f4b7583..8338d65 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -33,17 +33,19 @@ OM_TRITON_PORT = "10010" ATT_TRITON_IP = "10.1.1.240" ATT_TRITON_PORT = "10020" -# service env -LOGS_PATH = "app/logs/errors.log" -FASHION_CATEGORIES = "app/service/outfit_matcher/config/fashion_categories.json" -FASHION_CATEGORIES_MAPPING = "app/service/outfit_matcher/config/fashion_category_mapping.json" +DEBUG = 1 +# service env : 1 +# pycharm debug : 2 -# pycharm debug -# LOGS_PATH = "logs/errors.log" -# FASHION_CATEGORIES = "service/outfit_matcher/config/fashion_categories.json" -# FASHION_CATEGORIES_MAPPING = "service/outfit_matcher/config/fashion_category_mapping.json" - - -# LOGS_PATH = "app/logs/errors.log" -# FASHION_CATEGORIES = "./config/fashion_categories.json" -# FASHION_CATEGORIES_MAPPING = "./config/fashion_category_mapping.json" \ No newline at end of file +if DEBUG == 1: + LOGS_PATH = "app/logs/errors.log" + FASHION_CATEGORIES = "app/service/outfit_matcher/config/fashion_categories.json" + FASHION_CATEGORIES_MAPPING = "app/service/outfit_matcher/config/fashion_category_mapping.json" +elif DEBUG == 2: + LOGS_PATH = "logs/errors.log" + FASHION_CATEGORIES = "service/outfit_matcher/config/fashion_categories.json" + FASHION_CATEGORIES_MAPPING = "service/outfit_matcher/config/fashion_category_mapping.json" +elif DEBUG == 3: + LOGS_PATH = "app/logs/errors.log" + FASHION_CATEGORIES = "./config/fashion_categories.json" + FASHION_CATEGORIES_MAPPING = "./config/fashion_category_mapping.json" diff --git a/app/service/outfit_matcher/outfit_evaluator.py b/app/service/outfit_matcher/outfit_evaluator.py index 564c350..a35155e 100644 --- a/app/service/outfit_matcher/outfit_evaluator.py +++ b/app/service/outfit_matcher/outfit_evaluator.py @@ -197,43 +197,45 @@ class OutfitMatcher(object): outfits = [outfits[i] for i in sorted_indices] # 最好或最差的五个 scores = scores[sorted_indices] # 这五个的分数 - # 设置子图的行列数 - num_rows = len(outfits) - num_cols = max([len(x) for x in outfits]) + 1 # 一个是图片,一个是分数 + return outfits, scores.tolist() - # 创建一个新的图像,并指定子图的行列数 - fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 15)) - - title = f"Best {topk} Outfits" if best else f"Worst {topk} Outfits" - fig.suptitle(title, fontsize=16) - - # 遍历每套outfit并将其显示在对应的子图中 - for i, (outfit, score) in enumerate(zip(outfits, scores)): - # 显示分数 - axes[i, 0].text(0.1, 0.5, f"Score: {score:.4f}", fontsize=12) - axes[i, 0].axis("off") - # 显示图片 - for j, item in enumerate(outfit): - img = self.load_image(item['image_path']) # 读取图片 - axes[i, j + 1].imshow(img) # 在对应的子图中显示图片 - axes[i, j + 1].axis('off') # 关闭坐标轴 - axes[i, j + 1].set_title(item["semantic_category"], fontsize=10) - for j in range(len(outfit), num_cols): - axes[i, j].axis("off") - - # 在每一行的底部添加一条横线 - axes[i, 0].axhline(y=0, color='black', linewidth=1) - # 隐藏最后一行的横线 - axes[-1, 0].axhline(y=0, color='white', linewidth=1) - - # 调整布局 - plt.subplots_adjust(wspace=0.1, hspace=0.1) - plt.tight_layout() - - if output_path: - plt.savefig(output_path) - else: - plt.show() + # # 设置子图的行列数 + # num_rows = len(outfits) + # num_cols = max([len(x) for x in outfits]) + 1 # 一个是图片,一个是分数 + # + # # 创建一个新的图像,并指定子图的行列数 + # fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 15)) + # + # title = f"Best {topk} Outfits" if best else f"Worst {topk} Outfits" + # fig.suptitle(title, fontsize=16) + # + # # 遍历每套outfit并将其显示在对应的子图中 + # for i, (outfit, score) in enumerate(zip(outfits, scores)): + # # 显示分数 + # axes[i, 0].text(0.1, 0.5, f"Score: {score:.4f}", fontsize=12) + # axes[i, 0].axis("off") + # # 显示图片 + # for j, item in enumerate(outfit): + # img = self.load_image(item['image_path']) # 读取图片 + # axes[i, j + 1].imshow(img) # 在对应的子图中显示图片 + # axes[i, j + 1].axis('off') # 关闭坐标轴 + # axes[i, j + 1].set_title(item["semantic_category"], fontsize=10) + # for j in range(len(outfit), num_cols): + # axes[i, j].axis("off") + # + # # 在每一行的底部添加一条横线 + # axes[i, 0].axhline(y=0, color='black', linewidth=1) + # # 隐藏最后一行的横线 + # axes[-1, 0].axhline(y=0, color='white', linewidth=1) + # + # # 调整布局 + # plt.subplots_adjust(wspace=0.1, hspace=0.1) + # plt.tight_layout() + # + # if output_path: + # plt.savefig(output_path) + # else: + # plt.show() class OutfitMatcherHon(OutfitMatcher): diff --git a/app/service/outfit_matcher/service.py b/app/service/outfit_matcher/service.py index 59961cb..b99af95 100644 --- a/app/service/outfit_matcher/service.py +++ b/app/service/outfit_matcher/service.py @@ -1,37 +1,57 @@ import json import os from pprint import pprint +import numpy as np from app.service.outfit_matcher.dataset import FashionDataset -from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware +from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware, Backbone if __name__ == '__main__': with open("./test_param/recommendation_test.json", "r") as f: param = json.load(f) fashion_dataset = FashionDataset(param["database"]) + backbone_service = Backbone() service = OutfitMaterTypeAware() - best_list = [] - bad_list = [] + + # read feature from vector database + all_items = param["query"] + param["database"] + unextracted_item = [] + prepared_feature = {} + + # 拿到所有需要提取特征的图片 + for item in all_items: + if f'{item["item_name"]}.npy' not in os.listdir("feature"): + unextracted_item.append(item) + if len(unextracted_item) > 0: + # 通过backbone模型提取图片特征 + extracted_features = backbone_service.get_result(unextracted_item) + for i, item in enumerate(unextracted_item): + # save features + # 链接milvus + # TODO + np.save(f'feature/{item["item_name"]}.npy', extracted_features[i]) + # 存入数据库 + # 关闭链接 + + # TODO 读取本次任务需要的图片特征 + for item in all_items: + if item["item_name"] not in prepared_feature.keys(): + prepared_feature[item["item_name"]] = np.load(f'feature/{item["item_name"]}.npy') + + # 开始服装搭配任务 for item in param["query"]: + # 根据一定规则生成outfit outfits = fashion_dataset.generate_outfit(item, param["topk"], param["max_outfits"]) - scores, features = service.get_result(outfits) - # save features + # 根据模型对生成的outfit打分 + scores = service.get_result(outfits, prepared_feature) + # 对评分排序,拿到最好的topk个outfit输出 + sorted_indices = np.argsort(scores)[:param["topk"]] # type-aware + best_outfits = [outfits[i] for i in sorted_indices] # 最好的五个 - # 链接milvus - - # 存入数据库 - # 关闭链接 - - # print(scores) - # print(len(scores)) - best_outfits, best_scores = service.visualize(outfits, scores, param["topk"], best=True, - # output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_best_{param['topk']}.png") - ) - bad_outfits, bad_scores = service.visualize(outfits, scores, param["topk"], best=False, - # output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_worst_{param['topk']}.png") - ) - best_list.append({"best_outfits": best_outfits, "best_scores": best_scores}) - bad_list.append({"bad_outfits": bad_outfits, "bad_scores": bad_scores}) - - pprint(best_list) - pprint(bad_list) + # 结果可视化 + # service.visualize(outfits, scores, param["topk"], best=True, + # output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123", + # f"{item['item_name']}_best_{param['topk']}.png")) + # service.visualize(outfits, scores, param["topk"], best=False, + # output_path=os.path.join(r"D:\PhD_Study\MIXI\mitu\image\123", + # f"{item['item_name']}_worst_{param['topk']}.png"))