新增 随机池种类递减(把新增的类型从随机池中剔除)

This commit is contained in:
zhh
2025-11-21 17:55:47 +08:00
parent 266b23c97d
commit 78670b4210

View File

@@ -20,8 +20,9 @@ logger = logging.getLogger(__name__)
class AsyncStylistAgent:
CATEGORY_SET = {
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear', 'Shoes',
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Shoes',
# 取消推荐配饰
# 'Swimwear', 'Underwear',
# , 'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Scarves & Shawls'
}
CATEGORY_SET_ALL = {
@@ -310,7 +311,7 @@ class AsyncStylistAgent:
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
filter_items.append({"item_group_id": {"$ne": i['category']}})
filter_items.append({"category": {"$ne": i["category"]}})
accessories_count = 9 - item_count - len(stylist_item)
@@ -318,10 +319,19 @@ class AsyncStylistAgent:
if accessories_count > 4:
accessories_count = 4
for i in range(accessories_count):
random_poll = self.local_db.load_filtered_ids(filter_items)
# 2. 在配饰池中过滤掉已经选中的item 然后抽两个item
random_poll = self.local_db.load_filtered_ids(filter_items)
logger.info(f"random_poll 数量: {len(random_poll)}")
item = self.local_db.random_get_accessories(random.choice(random_poll))
filter_items.append({"item_group_id": {"$ne": item['metadatas'][0]['category']}})
if item['metadatas'][0]['category'] in ['Shopping Totes', 'Handbags', 'Backpacks', 'Briefcases']:
filter_items.append({"category": {"$ne": "Shopping Totes"}})
filter_items.append({"category": {"$ne": "Handbags"}})
filter_items.append({"category": {"$ne": "Backpacks"}})
filter_items.append({"category": {"$ne": "Briefcases"}})
else:
filter_items.append({"category": {"$ne": item['metadatas'][0]['category']}})
random_items.append(item['metadatas'][0])
all_items = stylist_item + random_items