add file
This commit is contained in:
140
app/service/outfit_matcher/dataset.py
Normal file
140
app/service/outfit_matcher/dataset.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import json
|
||||
import random
|
||||
import itertools
|
||||
import difflib
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
# Helper function to calculate the edit distance similarity between two file names
|
||||
def sim_score(a, b):
|
||||
return difflib.SequenceMatcher(None, a, b).ratio()
|
||||
|
||||
|
||||
class FashionDataset(object):
|
||||
fashion_categories = json.load(open(r"config/fashion_categories.json", "r"))
|
||||
fashion_categories_mapping = json.load(open(r"config/fashion_category_mapping.json", "r"))
|
||||
|
||||
def __init__(self, item_metadata):
|
||||
self.item_metadat = item_metadata
|
||||
self.item2cate = self.get_item2cate(item_metadata)
|
||||
self.cate2item = self.get_cate2item(item_metadata)
|
||||
self.cate2num = {k: len(v) for k, v in self.cate2item.items()}
|
||||
|
||||
# def generate_outfit(self, query_item, topk, max_outfits):
|
||||
# query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]]
|
||||
# possible_outfits = self.count_possible(query_item["mapped_cate"])
|
||||
#
|
||||
# if possible_outfits < topk:
|
||||
# raise ValueError(f"Cannot generate more than {topk} outfits!")
|
||||
#
|
||||
# outfit_list = self.get_possible_outfit(query_item)
|
||||
# if possible_outfits > max_outfits:
|
||||
# outfit_list = random.sample(outfit_list, max_outfits)
|
||||
# return outfit_list
|
||||
|
||||
def count_possible(self, given_cate):
|
||||
possible = 0
|
||||
if given_cate == 'tops':
|
||||
if self.cate2num['outerwear'] == 0:
|
||||
possible = self.cate2num['bottoms']
|
||||
else:
|
||||
possible = self.cate2num['bottoms'] * self.cate2num['outerwear']
|
||||
elif given_cate == 'bottoms':
|
||||
if self.cate2num['outerwear'] == 0:
|
||||
possible = self.cate2num['tops']
|
||||
else:
|
||||
possible = self.cate2num['tops'] * self.cate2num['outerwear']
|
||||
elif given_cate == 'outerwear':
|
||||
possible = self.cate2num['tops'] * self.cate2num['bottoms'] + self.cate2num['all-body']
|
||||
elif given_cate == 'all-body':
|
||||
possible = self.cate2num['outerwear']
|
||||
return possible
|
||||
|
||||
def generate_outfit(self, query_item, topk, max_outfits):
|
||||
query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]]
|
||||
given_cate = query_item["mapped_cate"]
|
||||
|
||||
if given_cate == 'tops' or given_cate == "bottoms":
|
||||
complementary_cate = "bottoms" if given_cate == "tops" else "tops"
|
||||
# check bottom num
|
||||
if not self.cate2num[complementary_cate]:
|
||||
raise ValueError(f"Not enough {complementary_cate} available to generate outfits.")
|
||||
|
||||
complementary_items = deepcopy(self.cate2item[complementary_cate])
|
||||
sim_scores = [sim_score(item["item_name"], query_item["item_name"]) for item in complementary_items]
|
||||
outfit_list = []
|
||||
used_items = set()
|
||||
|
||||
while len(outfit_list) < max_outfits:
|
||||
if not complementary_items:
|
||||
break
|
||||
|
||||
# 根据权重从bottoms中选择一个元素
|
||||
item_index = random.choices(range(len(complementary_items)), weights=sim_scores)[0]
|
||||
item = complementary_items.pop(item_index)
|
||||
sim_scores.pop(item_index)
|
||||
|
||||
if item["item_name"] not in used_items:
|
||||
used_items.add(item["item_name"])
|
||||
outfit = [query_item, item]
|
||||
outfit_list.append(tuple(outfit))
|
||||
|
||||
# 20% chance to include an outerwear
|
||||
if self.cate2num['outerwear'] > 0 and random.random() < 0.2:
|
||||
outerwear = random.choice(self.cate2item['outerwear'])
|
||||
outfit.append(outerwear)
|
||||
outfit_list.append(tuple(outfit))
|
||||
if len(outfit_list) < topk:
|
||||
raise ValueError(f"Cannot generate more than {topk} outfits!")
|
||||
|
||||
return outfit_list
|
||||
elif given_cate == 'outerwear':
|
||||
top_bottom_combination = [(x[0], x[1]) for x in
|
||||
itertools.product(self.cate2item['tops'], self.cate2item['bottoms'])]
|
||||
sim_scores = [sim_score(x[0]["item_name"], x[1]["item_name"]) for x in top_bottom_combination]
|
||||
outfit_list = []
|
||||
used_items = set()
|
||||
|
||||
while len(outfit_list) < max_outfits:
|
||||
if not top_bottom_combination:
|
||||
break
|
||||
|
||||
# 根据权重从top bottom的组合中选择一个,根据top和bottom之间的文件名相似度选择
|
||||
top_bottom_index = random.choices(range(len(top_bottom_combination)), weights=sim_scores)[0]
|
||||
top_bottom = top_bottom_combination.pop(top_bottom_index)
|
||||
sim_scores.pop(top_bottom_index)
|
||||
|
||||
top_name, bottom_name = top_bottom[0]["item_name"], top_bottom[1]["item_name"]
|
||||
if top_name + bottom_name not in used_items:
|
||||
used_items.add(top_name + bottom_name)
|
||||
outfit = [query_item] + list(top_bottom)
|
||||
outfit_list.append(tuple(outfit))
|
||||
|
||||
if len(outfit_list) < topk:
|
||||
raise ValueError(f"Cannot generate more than {topk} outfits!")
|
||||
|
||||
return outfit_list
|
||||
elif given_cate == 'all-body':
|
||||
outfit_list = [(query_item, x) for x in self.cate2item['outerwear']]
|
||||
return outfit_list
|
||||
|
||||
|
||||
def get_item2cate(self, item_metadata):
|
||||
item2cate = {}
|
||||
for metadata in item_metadata:
|
||||
cate = metadata["semantic_category"]
|
||||
mapped_cate = self.fashion_categories_mapping[cate]
|
||||
item2cate[metadata["item_name"]] = mapped_cate
|
||||
return item2cate
|
||||
|
||||
def get_cate2item(self, item_metadata):
|
||||
cate2item = {}
|
||||
for metadata in item_metadata:
|
||||
cate = metadata["semantic_category"]
|
||||
mapped_cate = self.fashion_categories_mapping[cate]
|
||||
metadata["mapped_cate"] = mapped_cate
|
||||
if mapped_cate not in cate2item.keys():
|
||||
cate2item[mapped_cate] = [metadata]
|
||||
else:
|
||||
cate2item[mapped_cate].append(metadata)
|
||||
return cate2item
|
||||
Reference in New Issue
Block a user