From 7dce70027a43149f30740dc431bebdd9ca2cce46 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 20 Mar 2024 17:21:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20dataset=20outerwear?= =?UTF-8?q?=E5=87=BA=E7=8E=B0=E7=9A=84key=20error=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/outfit_matcher/dataset.py | 12 ++++++------ .../test_param/recommendation_test_zcr.json | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/app/service/outfit_matcher/dataset.py b/app/service/outfit_matcher/dataset.py index a363c00..6e11983 100644 --- a/app/service/outfit_matcher/dataset.py +++ b/app/service/outfit_matcher/dataset.py @@ -80,12 +80,12 @@ class FashionDataset(object): used_items.add(item["item_name"]) outfit = [query_item, item] outfit_list.append(tuple(outfit)) - - # 20% chance to include an outerwear - if self.cate2num['outerwear'] > 0 and random.random() < 0.2: - outerwear = random.choice(self.cate2item['outerwear']) - outfit.append(outerwear) - outfit_list.append(tuple(outfit)) + if "outerwear" in self.cate2item.keys(): + # 20% chance to include an outerwear + if self.cate2num['outerwear'] > 0 and random.random() < 0.2: + outerwear = random.choice(self.cate2item['outerwear']) + outfit.append(outerwear) + outfit_list.append(tuple(outfit)) if len(outfit_list) < topk: raise ValueError(f"Cannot generate more than {topk} outfits!") diff --git a/app/service/outfit_matcher/test_param/recommendation_test_zcr.json b/app/service/outfit_matcher/test_param/recommendation_test_zcr.json index 616dfc7..8a29976 100644 --- a/app/service/outfit_matcher/test_param/recommendation_test_zcr.json +++ b/app/service/outfit_matcher/test_param/recommendation_test_zcr.json @@ -1,6 +1,6 @@ { - "topk": 5, - "max_outfits": 100, + "topk": 1, + "max_outfits": 10, "is_best": true, "query": [ {