修复 dataset outerwear出现的key error问题
This commit is contained in:
@@ -80,12 +80,12 @@ class FashionDataset(object):
|
|||||||
used_items.add(item["item_name"])
|
used_items.add(item["item_name"])
|
||||||
outfit = [query_item, item]
|
outfit = [query_item, item]
|
||||||
outfit_list.append(tuple(outfit))
|
outfit_list.append(tuple(outfit))
|
||||||
|
if "outerwear" in self.cate2item.keys():
|
||||||
# 20% chance to include an outerwear
|
# 20% chance to include an outerwear
|
||||||
if self.cate2num['outerwear'] > 0 and random.random() < 0.2:
|
if self.cate2num['outerwear'] > 0 and random.random() < 0.2:
|
||||||
outerwear = random.choice(self.cate2item['outerwear'])
|
outerwear = random.choice(self.cate2item['outerwear'])
|
||||||
outfit.append(outerwear)
|
outfit.append(outerwear)
|
||||||
outfit_list.append(tuple(outfit))
|
outfit_list.append(tuple(outfit))
|
||||||
if len(outfit_list) < topk:
|
if len(outfit_list) < topk:
|
||||||
raise ValueError(f"Cannot generate more than {topk} outfits!")
|
raise ValueError(f"Cannot generate more than {topk} outfits!")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"topk": 5,
|
"topk": 1,
|
||||||
"max_outfits": 100,
|
"max_outfits": 10,
|
||||||
"is_best": true,
|
"is_best": true,
|
||||||
"query": [
|
"query": [
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user