import json import random import itertools import difflib from copy import deepcopy from app.core.config import FASHION_CATEGORIES, FASHION_CATEGORIES_MAPPING # Helper function to calculate the edit distance similarity between two file names def sim_score(a, b): return difflib.SequenceMatcher(None, a, b).ratio() class FashionDataset(object): fashion_categories = json.load(open(FASHION_CATEGORIES, "r")) fashion_categories_mapping = json.load(open(FASHION_CATEGORIES_MAPPING, "r")) def __init__(self, item_metadata): self.item_metadat = item_metadata self.item2cate = self.get_item2cate(item_metadata) self.cate2item = self.get_cate2item(item_metadata) self.cate2num = {k: len(v) for k, v in self.cate2item.items()} # def generate_outfit(self, query_item, topk, max_outfits): # query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]] # possible_outfits = self.count_possible(query_item["mapped_cate"]) # # if possible_outfits < topk: # raise ValueError(f"Cannot generate more than {topk} outfits!") # # outfit_list = self.get_possible_outfit(query_item) # if possible_outfits > max_outfits: # outfit_list = random.sample(outfit_list, max_outfits) # return outfit_list def count_possible(self, given_cate): possible = 0 if given_cate == 'tops': if self.cate2num['outerwear'] == 0: possible = self.cate2num['bottoms'] else: possible = self.cate2num['bottoms'] * self.cate2num['outerwear'] elif given_cate == 'bottoms': if self.cate2num['outerwear'] == 0: possible = self.cate2num['tops'] else: possible = self.cate2num['tops'] * self.cate2num['outerwear'] elif given_cate == 'outerwear': possible = self.cate2num['tops'] * self.cate2num['bottoms'] + self.cate2num['all-body'] elif given_cate == 'all-body': possible = self.cate2num['outerwear'] return possible def generate_outfit(self, query_item, topk, max_outfits): query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]] given_cate = query_item["mapped_cate"] if given_cate == 'tops' or given_cate == "bottoms": complementary_cate = "bottoms" if given_cate == "tops" else "tops" # check bottom num if complementary_cate not in self.cate2num.keys() or not self.cate2num[complementary_cate]: raise ValueError(f"Not enough {complementary_cate} available to generate outfits.") complementary_items = deepcopy(self.cate2item[complementary_cate]) sim_scores = [sim_score(item["item_name"], query_item["item_name"]) for item in complementary_items] outfit_list = [] used_items = set() while len(outfit_list) < max_outfits: if not complementary_items: break # 根据权重从bottoms中选择一个元素 item_index = random.choices(range(len(complementary_items)), weights=sim_scores)[0] item = complementary_items.pop(item_index) sim_scores.pop(item_index) if item["item_name"] not in used_items: used_items.add(item["item_name"]) outfit = [query_item, item] outfit_list.append(tuple(outfit)) if "outerwear" in self.cate2item.keys(): # 20% chance to include an outerwear if self.cate2num['outerwear'] > 0 and random.random() < 0.2: outerwear = random.choice(self.cate2item['outerwear']) outfit.append(outerwear) outfit_list.append(tuple(outfit)) # if len(outfit_list) < topk: # raise ValueError(f"Cannot generate more than {topk} outfits!") return outfit_list elif given_cate == 'outerwear': top_bottom_combination = [(x[0], x[1]) for x in itertools.product(self.cate2item['tops'], self.cate2item['bottoms'])] sim_scores = [sim_score(x[0]["item_name"], x[1]["item_name"]) for x in top_bottom_combination] outfit_list = [] used_items = set() while len(outfit_list) < max_outfits: if not top_bottom_combination: break # 根据权重从top bottom的组合中选择一个,根据top和bottom之间的文件名相似度选择 top_bottom_index = random.choices(range(len(top_bottom_combination)), weights=sim_scores)[0] top_bottom = top_bottom_combination.pop(top_bottom_index) sim_scores.pop(top_bottom_index) top_name, bottom_name = top_bottom[0]["item_name"], top_bottom[1]["item_name"] if top_name + bottom_name not in used_items: used_items.add(top_name + bottom_name) outfit = [query_item] + list(top_bottom) outfit_list.append(tuple(outfit)) # if len(outfit_list) < topk: # raise ValueError(f"Cannot generate more than {topk} outfits!") return outfit_list elif given_cate == 'all-body': outfit_list = [(query_item, x) for x in self.cate2item['outerwear']] return outfit_list def get_item2cate(self, item_metadata): item2cate = {} for metadata in item_metadata: cate = metadata["semantic_category"] mapped_cate = self.fashion_categories_mapping[cate] item2cate[metadata["item_name"]] = mapped_cate return item2cate def get_cate2item(self, item_metadata): cate2item = {} for metadata in item_metadata: cate = metadata["semantic_category"] mapped_cate = self.fashion_categories_mapping[cate] metadata["mapped_cate"] = mapped_cate if mapped_cate not in cate2item.keys(): cate2item[mapped_cate] = [metadata] else: cate2item[mapped_cate].append(metadata) return cate2item