Files
sora_python/app/service/outfit_matcher/dataset.py
2024-03-18 14:17:43 +08:00

144 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import random
import itertools
import difflib
from copy import deepcopy
# Helper function to calculate the edit distance similarity between two file names
def sim_score(a, b):
return difflib.SequenceMatcher(None, a, b).ratio()
class FashionDataset(object):
# fashion_categories = json.load(open(r"config/fashion_categories.json", "r"))
# fashion_categories_mapping = json.load(open(r"config/fashion_category_mapping.json", "r"))
# service debug
fashion_categories = json.load(open(r"service/outfit_matcher/config/fashion_categories.json", "r"))
fashion_categories_mapping = json.load(open(r"service/outfit_matcher/config/fashion_category_mapping.json", "r"))
def __init__(self, item_metadata):
self.item_metadat = item_metadata
self.item2cate = self.get_item2cate(item_metadata)
self.cate2item = self.get_cate2item(item_metadata)
self.cate2num = {k: len(v) for k, v in self.cate2item.items()}
# def generate_outfit(self, query_item, topk, max_outfits):
# query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]]
# possible_outfits = self.count_possible(query_item["mapped_cate"])
#
# if possible_outfits < topk:
# raise ValueError(f"Cannot generate more than {topk} outfits!")
#
# outfit_list = self.get_possible_outfit(query_item)
# if possible_outfits > max_outfits:
# outfit_list = random.sample(outfit_list, max_outfits)
# return outfit_list
def count_possible(self, given_cate):
possible = 0
if given_cate == 'tops':
if self.cate2num['outerwear'] == 0:
possible = self.cate2num['bottoms']
else:
possible = self.cate2num['bottoms'] * self.cate2num['outerwear']
elif given_cate == 'bottoms':
if self.cate2num['outerwear'] == 0:
possible = self.cate2num['tops']
else:
possible = self.cate2num['tops'] * self.cate2num['outerwear']
elif given_cate == 'outerwear':
possible = self.cate2num['tops'] * self.cate2num['bottoms'] + self.cate2num['all-body']
elif given_cate == 'all-body':
possible = self.cate2num['outerwear']
return possible
def generate_outfit(self, query_item, topk, max_outfits):
query_item["mapped_cate"] = self.fashion_categories_mapping[query_item["semantic_category"]]
given_cate = query_item["mapped_cate"]
if given_cate == 'tops' or given_cate == "bottoms":
complementary_cate = "bottoms" if given_cate == "tops" else "tops"
# check bottom num
if not self.cate2num[complementary_cate]:
raise ValueError(f"Not enough {complementary_cate} available to generate outfits.")
complementary_items = deepcopy(self.cate2item[complementary_cate])
sim_scores = [sim_score(item["item_name"], query_item["item_name"]) for item in complementary_items]
outfit_list = []
used_items = set()
while len(outfit_list) < max_outfits:
if not complementary_items:
break
# 根据权重从bottoms中选择一个元素
item_index = random.choices(range(len(complementary_items)), weights=sim_scores)[0]
item = complementary_items.pop(item_index)
sim_scores.pop(item_index)
if item["item_name"] not in used_items:
used_items.add(item["item_name"])
outfit = [query_item, item]
outfit_list.append(tuple(outfit))
# 20% chance to include an outerwear
if self.cate2num['outerwear'] > 0 and random.random() < 0.2:
outerwear = random.choice(self.cate2item['outerwear'])
outfit.append(outerwear)
outfit_list.append(tuple(outfit))
if len(outfit_list) < topk:
raise ValueError(f"Cannot generate more than {topk} outfits!")
return outfit_list
elif given_cate == 'outerwear':
top_bottom_combination = [(x[0], x[1]) for x in
itertools.product(self.cate2item['tops'], self.cate2item['bottoms'])]
sim_scores = [sim_score(x[0]["item_name"], x[1]["item_name"]) for x in top_bottom_combination]
outfit_list = []
used_items = set()
while len(outfit_list) < max_outfits:
if not top_bottom_combination:
break
# 根据权重从top bottom的组合中选择一个根据top和bottom之间的文件名相似度选择
top_bottom_index = random.choices(range(len(top_bottom_combination)), weights=sim_scores)[0]
top_bottom = top_bottom_combination.pop(top_bottom_index)
sim_scores.pop(top_bottom_index)
top_name, bottom_name = top_bottom[0]["item_name"], top_bottom[1]["item_name"]
if top_name + bottom_name not in used_items:
used_items.add(top_name + bottom_name)
outfit = [query_item] + list(top_bottom)
outfit_list.append(tuple(outfit))
if len(outfit_list) < topk:
raise ValueError(f"Cannot generate more than {topk} outfits!")
return outfit_list
elif given_cate == 'all-body':
outfit_list = [(query_item, x) for x in self.cate2item['outerwear']]
return outfit_list
def get_item2cate(self, item_metadata):
item2cate = {}
for metadata in item_metadata:
cate = metadata["semantic_category"]
mapped_cate = self.fashion_categories_mapping[cate]
item2cate[metadata["item_name"]] = mapped_cate
return item2cate
def get_cate2item(self, item_metadata):
cate2item = {}
for metadata in item_metadata:
cate = metadata["semantic_category"]
mapped_cate = self.fashion_categories_mapping[cate]
metadata["mapped_cate"] = mapped_cate
if mapped_cate not in cate2item.keys():
cate2item[mapped_cate] = [metadata]
else:
cate2item[mapped_cate].append(metadata)
return cate2item