diff --git a/app/server/ChatbotAgent/core/stylist_agent_server.py b/app/server/ChatbotAgent/core/stylist_agent_server.py index 38a6f4e..de298ac 100644 --- a/app/server/ChatbotAgent/core/stylist_agent_server.py +++ b/app/server/ChatbotAgent/core/stylist_agent_server.py @@ -296,21 +296,36 @@ class AsyncStylistAgent: async def _get_random_accessories(self, stylist, item_count): stylist_item = [] stylist_item_ids = [] + + filter_items = [ + {"item_group_id": {"$ne": "Clothing"}}, + {"item_group_id": {"$ne": "Shoes"}}, + {"modality": "image"} + ] + random_items = [] + for i in stylist: # 1. 根据stylist要求抽取item query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False) stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10) stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count']) stylist_item_ids += [item_id['item_id'] for item_id in stylist_item] + filter_items.append({"item_group_id": {"$ne": i['category']}}) accessories_count = 9 - item_count - len(stylist_item) if accessories_count > 0: - accessories_count = random.randint(1, accessories_count) - # 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item - random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) - random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas'] + if accessories_count > 4: + accessories_count = 4 + for i in range(accessories_count): + random_poll = self.local_db.load_filtered_ids(filter_items) + # 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item + item = self.local_db.random_get_accessories(random.choice(random_poll)) + filter_items.append({"item_group_id": {"$ne": item['metadatas'][0]['category']}}) + random_items.append(item['metadatas'][0]) + all_items = stylist_item + random_items + else: all_items = stylist_item diff --git a/app/server/ChatbotAgent/core/vector_database.py b/app/server/ChatbotAgent/core/vector_database.py index bf74a23..a7b141b 100644 --- a/app/server/ChatbotAgent/core/vector_database.py +++ b/app/server/ChatbotAgent/core/vector_database.py @@ -1,4 +1,5 @@ import random +import time import torch import chromadb @@ -17,8 +18,12 @@ class VectorDatabase(): self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device) self.processor = CLIPProcessor.from_pretrained(embedding_model_name) - self.cache_filtered_ids = self.load_filtered_ids() - self.total_count = len(self.cache_filtered_ids) + # self.cache_filtered_ids = self.load_filtered_ids([ + # {"item_group_id": {"$ne": "Clothing"}}, + # {"item_group_id": {"$ne": "Shoes"}}, + # {"modality": "image"} + # ]) + # self.total_count = len(self.cache_filtered_ids) def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]: """生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。""" @@ -63,14 +68,11 @@ class VectorDatabase(): ) return results - def load_filtered_ids(self): - print("\n--- 初始化阶段:加载所有符合条件的 ID ---") + def load_filtered_ids(self, filter_item): + # print("\n--- 初始化阶段:加载所有符合条件的 ID ---") + start_time = time.time() FILTER_CRITERIA = { - "$and": [ - {"item_group_id": {"$ne": "Clothing"}}, - {"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes - {"modality": "image"} - ] + "$and": filter_item } MAX_LIMIT = 100000 @@ -82,7 +84,8 @@ class VectorDatabase(): include=[] ) all_matched_ids = all_ids_results['ids'] - print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。") + # print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。") + print(time.time() - start_time) return all_matched_ids except Exception as e: @@ -114,20 +117,29 @@ if __name__ == '__main__': 'count': 2, 'category': "Jewelry" } + max_len = 5 local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32") + A = local_db.load_filtered_ids([ + {"item_group_id": {"$ne": "Clothing"}}, + {"item_group_id": {"$ne": "Shoes"}}, + {"modality": "image"} + ]) # print(db.random_get_accessories()) - - query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False) - - results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10) - # 2. 从结果集中抽 stylist['count'] 个item - stylist_item = random.choices(results['metadatas'][0], k=stylist['count']) - stylist_item_ids = [item_id['item_id'] for item_id in stylist_item] - - # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item - accessories_count = 9 - max_len - stylist['count'] - - random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) - random_items = local_db.random_get_accessories(random_single_ids)['metadatas'] - all_items = stylist_item + random_items + start_time = time.time() + X = local_db.random_get_accessories(['ELI699_img']) + print(X) + print(time.time() - start_time) + # query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False) + # + # results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10) + # # 2. 从结果集中抽 stylist['count'] 个item + # stylist_item = random.choices(results['metadatas'][0], k=stylist['count']) + # stylist_item_ids = [item_id['item_id'] for item_id in stylist_item] + # + # # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item + # accessories_count = 9 - max_len - stylist['count'] + # + # random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) + # random_items = local_db.random_get_accessories(random_single_ids)['metadatas'] + # all_items = stylist_item + random_items