diff --git a/app/service/outfit_matcher/dataset.py b/app/service/outfit_matcher/dataset.py index a363c00..6e11983 100644 --- a/app/service/outfit_matcher/dataset.py +++ b/app/service/outfit_matcher/dataset.py @@ -80,12 +80,12 @@ class FashionDataset(object): used_items.add(item["item_name"]) outfit = [query_item, item] outfit_list.append(tuple(outfit)) - - # 20% chance to include an outerwear - if self.cate2num['outerwear'] > 0 and random.random() < 0.2: - outerwear = random.choice(self.cate2item['outerwear']) - outfit.append(outerwear) - outfit_list.append(tuple(outfit)) + if "outerwear" in self.cate2item.keys(): + # 20% chance to include an outerwear + if self.cate2num['outerwear'] > 0 and random.random() < 0.2: + outerwear = random.choice(self.cate2item['outerwear']) + outfit.append(outerwear) + outfit_list.append(tuple(outfit)) if len(outfit_list) < topk: raise ValueError(f"Cannot generate more than {topk} outfits!") diff --git a/app/service/outfit_matcher/test_param/recommendation_test_zcr.json b/app/service/outfit_matcher/test_param/recommendation_test_zcr.json index 616dfc7..8a29976 100644 --- a/app/service/outfit_matcher/test_param/recommendation_test_zcr.json +++ b/app/service/outfit_matcher/test_param/recommendation_test_zcr.json @@ -1,6 +1,6 @@ { - "topk": 5, - "max_outfits": 100, + "topk": 1, + "max_outfits": 10, "is_best": true, "query": [ {