新增 取消agent配饰(保留鞋子)推荐,改为默认随机配饰搭配 使用json文件补充stylist删除掉的必要配饰

This commit is contained in:
zhh
2025-11-21 16:05:57 +08:00
parent affef482e6
commit 266b23c97d
2 changed files with 55 additions and 28 deletions

View File

@@ -296,21 +296,36 @@ class AsyncStylistAgent:
async def _get_random_accessories(self, stylist, item_count):
stylist_item = []
stylist_item_ids = []
filter_items = [
{"item_group_id": {"$ne": "Clothing"}},
{"item_group_id": {"$ne": "Shoes"}},
{"modality": "image"}
]
random_items = []
for i in stylist:
# 1. 根据stylist要求抽取item
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
filter_items.append({"item_group_id": {"$ne": i['category']}})
accessories_count = 9 - item_count - len(stylist_item)
if accessories_count > 0:
accessories_count = random.randint(1, accessories_count)
if accessories_count > 4:
accessories_count = 4
for i in range(accessories_count):
random_poll = self.local_db.load_filtered_ids(filter_items)
# 2. 在配饰池中过滤掉已经选中的item 然后抽两个item
random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas']
item = self.local_db.random_get_accessories(random.choice(random_poll))
filter_items.append({"item_group_id": {"$ne": item['metadatas'][0]['category']}})
random_items.append(item['metadatas'][0])
all_items = stylist_item + random_items
else:
all_items = stylist_item

View File

@@ -1,4 +1,5 @@
import random
import time
import torch
import chromadb
@@ -17,8 +18,12 @@ class VectorDatabase():
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
self.cache_filtered_ids = self.load_filtered_ids()
self.total_count = len(self.cache_filtered_ids)
# self.cache_filtered_ids = self.load_filtered_ids([
# {"item_group_id": {"$ne": "Clothing"}},
# {"item_group_id": {"$ne": "Shoes"}},
# {"modality": "image"}
# ])
# self.total_count = len(self.cache_filtered_ids)
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
@@ -63,14 +68,11 @@ class VectorDatabase():
)
return results
def load_filtered_ids(self):
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
def load_filtered_ids(self, filter_item):
# print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
start_time = time.time()
FILTER_CRITERIA = {
"$and": [
{"item_group_id": {"$ne": "Clothing"}},
{"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes
{"modality": "image"}
]
"$and": filter_item
}
MAX_LIMIT = 100000
@@ -82,7 +84,8 @@ class VectorDatabase():
include=[]
)
all_matched_ids = all_ids_results['ids']
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
# print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
print(time.time() - start_time)
return all_matched_ids
except Exception as e:
@@ -114,20 +117,29 @@ if __name__ == '__main__':
'count': 2,
'category': "Jewelry"
}
max_len = 5
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
A = local_db.load_filtered_ids([
{"item_group_id": {"$ne": "Clothing"}},
{"item_group_id": {"$ne": "Shoes"}},
{"modality": "image"}
])
# print(db.random_get_accessories())
query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
# 2. 从结果集中抽 stylist['count'] 个item
stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
# 3. 从随机库中抽取配饰总数达到9件 需过滤掉已经抽中的item
accessories_count = 9 - max_len - stylist['count']
random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
all_items = stylist_item + random_items
start_time = time.time()
X = local_db.random_get_accessories(['ELI699_img'])
print(X)
print(time.time() - start_time)
# query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
#
# results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
# # 2. 从结果集中抽 stylist['count'] 个item
# stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
# stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
#
# # 3. 从随机库中抽取配饰总数达到9件 需过滤掉已经抽中的item
# accessories_count = 9 - max_len - stylist['count']
#
# random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
# random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
# all_items = stylist_item + random_items