Files
sora_python/app/service/outfit_matcher/dataset.py

142 lines
6.2 KiB
Python
Raw Normal View History

2024-03-11 10:58:34 +08:00
import json
import random
import itertools
import difflib
from copy import deepcopy
from app.core.config import FASHION_CATEGORIES, FASHION_CATEGORIES_MAPPING
2024-03-11 10:58:34 +08:00
# Helper function to calculate the edit distance similarity between two file names
def sim_score(a, b):
return difflib.SequenceMatcher(None, a, b).ratio()
class FashionDataset(object):
fashion_categories = json.load(open(FASHION_CATEGORIES, "r"))
fashion_categories_mapping = json.load(open(FASHION_CATEGORIES_MAPPING, "r"))
2024-03-11 10:58:34 +08:00
def __init__(self, item_metadata):
self.item_metadat = item_metadata
self.item2cate = self.get_item2cate(item_metadata)
self.cate2item = self.get_cate2item(item_metadata)
self.cate2num = {k: len(v) for k, v in self.cate2item.items()}
# def generate_outfit(self, query_item, topk, max_outfits):
# query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]]
# possible_outfits = self.count_possible(query_item["mapped_cate"])
#
# if possible_outfits < topk:
# raise ValueError(f"Cannot generate more than {topk} outfits!")
#
# outfit_list = self.get_possible_outfit(query_item)
# if possible_outfits > max_outfits:
# outfit_list = random.sample(outfit_list, max_outfits)
# return outfit_list
def count_possible(self, given_cate):
possible = 0
if given_cate == 'tops':
if self.cate2num['outerwear'] == 0:
possible = self.cate2num['bottoms']
else:
possible = self.cate2num['bottoms'] * self.cate2num['outerwear']
elif given_cate == 'bottoms':
if self.cate2num['outerwear'] == 0:
possible = self.cate2num['tops']
else:
possible = self.cate2num['tops'] * self.cate2num['outerwear']
elif given_cate == 'outerwear':
possible = self.cate2num['tops'] * self.cate2num['bottoms'] + self.cate2num['all-body']
elif given_cate == 'all-body':
possible = self.cate2num['outerwear']
return possible
def generate_outfit(self, query_item, topk, max_outfits):
query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]]
given_cate = query_item["mapped_cate"]
if given_cate == 'tops' or given_cate == "bottoms":
complementary_cate = "bottoms" if given_cate == "tops" else "tops"
# check bottom num
2024-03-27 14:57:57 +08:00
if complementary_cate not in self.cate2num.keys() or not self.cate2num[complementary_cate]:
2024-03-11 10:58:34 +08:00
raise ValueError(f"Not enough {complementary_cate} available to generate outfits.")
complementary_items = deepcopy(self.cate2item[complementary_cate])
sim_scores = [sim_score(item["item_name"], query_item["item_name"]) for item in complementary_items]
outfit_list = []
used_items = set()
while len(outfit_list) < max_outfits:
if not complementary_items:
break
# 根据权重从bottoms中选择一个元素
item_index = random.choices(range(len(complementary_items)), weights=sim_scores)[0]
item = complementary_items.pop(item_index)
sim_scores.pop(item_index)
if item["item_name"] not in used_items:
used_items.add(item["item_name"])
outfit = [query_item, item]
outfit_list.append(tuple(outfit))
if "outerwear" in self.cate2item.keys():
# 20% chance to include an outerwear
if self.cate2num['outerwear'] > 0 and random.random() < 0.2:
outerwear = random.choice(self.cate2item['outerwear'])
outfit.append(outerwear)
outfit_list.append(tuple(outfit))
2024-03-11 10:58:34 +08:00
if len(outfit_list) < topk:
raise ValueError(f"Cannot generate more than {topk} outfits!")
return outfit_list
elif given_cate == 'outerwear':
top_bottom_combination = [(x[0], x[1]) for x in
itertools.product(self.cate2item['tops'], self.cate2item['bottoms'])]
sim_scores = [sim_score(x[0]["item_name"], x[1]["item_name"]) for x in top_bottom_combination]
outfit_list = []
used_items = set()
while len(outfit_list) < max_outfits:
if not top_bottom_combination:
break
# 根据权重从top bottom的组合中选择一个根据top和bottom之间的文件名相似度选择
top_bottom_index = random.choices(range(len(top_bottom_combination)), weights=sim_scores)[0]
top_bottom = top_bottom_combination.pop(top_bottom_index)
sim_scores.pop(top_bottom_index)
top_name, bottom_name = top_bottom[0]["item_name"], top_bottom[1]["item_name"]
if top_name + bottom_name not in used_items:
used_items.add(top_name + bottom_name)
outfit = [query_item] + list(top_bottom)
outfit_list.append(tuple(outfit))
if len(outfit_list) < topk:
raise ValueError(f"Cannot generate more than {topk} outfits!")
return outfit_list
elif given_cate == 'all-body':
outfit_list = [(query_item, x) for x in self.cate2item['outerwear']]
return outfit_list
def get_item2cate(self, item_metadata):
item2cate = {}
for metadata in item_metadata:
cate = metadata["semantic_category"]
mapped_cate = self.fashion_categories_mapping[cate]
item2cate[metadata["item_name"]] = mapped_cate
return item2cate
def get_cate2item(self, item_metadata):
cate2item = {}
for metadata in item_metadata:
cate = metadata["semantic_category"]
mapped_cate = self.fashion_categories_mapping[cate]
metadata["mapped_cate"] = mapped_cate
if mapped_cate not in cate2item.keys():
cate2item[mapped_cate] = [metadata]
else:
cate2item[mapped_cate].append(metadata)
return cate2item