141 lines
6.1 KiB
Python
141 lines
6.1 KiB
Python
import json
|
||
import random
|
||
import itertools
|
||
import difflib
|
||
from copy import deepcopy
|
||
|
||
|
||
# Helper function to calculate the edit distance similarity between two file names
|
||
def sim_score(a, b):
|
||
return difflib.SequenceMatcher(None, a, b).ratio()
|
||
|
||
|
||
class FashionDataset(object):
|
||
fashion_categories = json.load(open(r"config/fashion_categories.json", "r"))
|
||
fashion_categories_mapping = json.load(open(r"config/fashion_category_mapping.json", "r"))
|
||
|
||
def __init__(self, item_metadata):
|
||
self.item_metadat = item_metadata
|
||
self.item2cate = self.get_item2cate(item_metadata)
|
||
self.cate2item = self.get_cate2item(item_metadata)
|
||
self.cate2num = {k: len(v) for k, v in self.cate2item.items()}
|
||
|
||
# def generate_outfit(self, query_item, topk, max_outfits):
|
||
# query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]]
|
||
# possible_outfits = self.count_possible(query_item["mapped_cate"])
|
||
#
|
||
# if possible_outfits < topk:
|
||
# raise ValueError(f"Cannot generate more than {topk} outfits!")
|
||
#
|
||
# outfit_list = self.get_possible_outfit(query_item)
|
||
# if possible_outfits > max_outfits:
|
||
# outfit_list = random.sample(outfit_list, max_outfits)
|
||
# return outfit_list
|
||
|
||
def count_possible(self, given_cate):
|
||
possible = 0
|
||
if given_cate == 'tops':
|
||
if self.cate2num['outerwear'] == 0:
|
||
possible = self.cate2num['bottoms']
|
||
else:
|
||
possible = self.cate2num['bottoms'] * self.cate2num['outerwear']
|
||
elif given_cate == 'bottoms':
|
||
if self.cate2num['outerwear'] == 0:
|
||
possible = self.cate2num['tops']
|
||
else:
|
||
possible = self.cate2num['tops'] * self.cate2num['outerwear']
|
||
elif given_cate == 'outerwear':
|
||
possible = self.cate2num['tops'] * self.cate2num['bottoms'] + self.cate2num['all-body']
|
||
elif given_cate == 'all-body':
|
||
possible = self.cate2num['outerwear']
|
||
return possible
|
||
|
||
def generate_outfit(self, query_item, topk, max_outfits):
|
||
query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]]
|
||
given_cate = query_item["mapped_cate"]
|
||
|
||
if given_cate == 'tops' or given_cate == "bottoms":
|
||
complementary_cate = "bottoms" if given_cate == "tops" else "tops"
|
||
# check bottom num
|
||
if not self.cate2num[complementary_cate]:
|
||
raise ValueError(f"Not enough {complementary_cate} available to generate outfits.")
|
||
|
||
complementary_items = deepcopy(self.cate2item[complementary_cate])
|
||
sim_scores = [sim_score(item["item_name"], query_item["item_name"]) for item in complementary_items]
|
||
outfit_list = []
|
||
used_items = set()
|
||
|
||
while len(outfit_list) < max_outfits:
|
||
if not complementary_items:
|
||
break
|
||
|
||
# 根据权重从bottoms中选择一个元素
|
||
item_index = random.choices(range(len(complementary_items)), weights=sim_scores)[0]
|
||
item = complementary_items.pop(item_index)
|
||
sim_scores.pop(item_index)
|
||
|
||
if item["item_name"] not in used_items:
|
||
used_items.add(item["item_name"])
|
||
outfit = [query_item, item]
|
||
outfit_list.append(tuple(outfit))
|
||
|
||
# 20% chance to include an outerwear
|
||
if self.cate2num['outerwear'] > 0 and random.random() < 0.2:
|
||
outerwear = random.choice(self.cate2item['outerwear'])
|
||
outfit.append(outerwear)
|
||
outfit_list.append(tuple(outfit))
|
||
if len(outfit_list) < topk:
|
||
raise ValueError(f"Cannot generate more than {topk} outfits!")
|
||
|
||
return outfit_list
|
||
elif given_cate == 'outerwear':
|
||
top_bottom_combination = [(x[0], x[1]) for x in
|
||
itertools.product(self.cate2item['tops'], self.cate2item['bottoms'])]
|
||
sim_scores = [sim_score(x[0]["item_name"], x[1]["item_name"]) for x in top_bottom_combination]
|
||
outfit_list = []
|
||
used_items = set()
|
||
|
||
while len(outfit_list) < max_outfits:
|
||
if not top_bottom_combination:
|
||
break
|
||
|
||
# 根据权重从top bottom的组合中选择一个,根据top和bottom之间的文件名相似度选择
|
||
top_bottom_index = random.choices(range(len(top_bottom_combination)), weights=sim_scores)[0]
|
||
top_bottom = top_bottom_combination.pop(top_bottom_index)
|
||
sim_scores.pop(top_bottom_index)
|
||
|
||
top_name, bottom_name = top_bottom[0]["item_name"], top_bottom[1]["item_name"]
|
||
if top_name + bottom_name not in used_items:
|
||
used_items.add(top_name + bottom_name)
|
||
outfit = [query_item] + list(top_bottom)
|
||
outfit_list.append(tuple(outfit))
|
||
|
||
if len(outfit_list) < topk:
|
||
raise ValueError(f"Cannot generate more than {topk} outfits!")
|
||
|
||
return outfit_list
|
||
elif given_cate == 'all-body':
|
||
outfit_list = [(query_item, x) for x in self.cate2item['outerwear']]
|
||
return outfit_list
|
||
|
||
|
||
def get_item2cate(self, item_metadata):
|
||
item2cate = {}
|
||
for metadata in item_metadata:
|
||
cate = metadata["semantic_category"]
|
||
mapped_cate = self.fashion_categories_mapping[cate]
|
||
item2cate[metadata["item_name"]] = mapped_cate
|
||
return item2cate
|
||
|
||
def get_cate2item(self, item_metadata):
|
||
cate2item = {}
|
||
for metadata in item_metadata:
|
||
cate = metadata["semantic_category"]
|
||
mapped_cate = self.fashion_categories_mapping[cate]
|
||
metadata["mapped_cate"] = mapped_cate
|
||
if mapped_cate not in cate2item.keys():
|
||
cate2item[mapped_cate] = [metadata]
|
||
else:
|
||
cate2item[mapped_cate].append(metadata)
|
||
return cate2item
|