From 266b23c97d026efa6d39d868f6dac86a46d7750f Mon Sep 17 00:00:00 2001 From: zhh Date: Fri, 21 Nov 2025 16:05:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=20=E5=8F=96=E6=B6=88agent?= =?UTF-8?q?=E9=85=8D=E9=A5=B0=EF=BC=88=E4=BF=9D=E7=95=99=E9=9E=8B=E5=AD=90?= =?UTF-8?q?=EF=BC=89=E6=8E=A8=E8=8D=90=EF=BC=8C=E6=94=B9=E4=B8=BA=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E9=9A=8F=E6=9C=BA=E9=85=8D=E9=A5=B0=E6=90=AD=E9=85=8D?= =?UTF-8?q?=20=E4=BD=BF=E7=94=A8json=E6=96=87=E4=BB=B6=E8=A1=A5=E5=85=85st?= =?UTF-8?q?ylist=E5=88=A0=E9=99=A4=E6=8E=89=E7=9A=84=E5=BF=85=E8=A6=81?= =?UTF-8?q?=E9=85=8D=E9=A5=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ChatbotAgent/core/stylist_agent_server.py | 23 +++++-- .../ChatbotAgent/core/vector_database.py | 60 +++++++++++-------- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/app/server/ChatbotAgent/core/stylist_agent_server.py b/app/server/ChatbotAgent/core/stylist_agent_server.py index 38a6f4e..de298ac 100644 --- a/app/server/ChatbotAgent/core/stylist_agent_server.py +++ b/app/server/ChatbotAgent/core/stylist_agent_server.py @@ -296,21 +296,36 @@ class AsyncStylistAgent: async def _get_random_accessories(self, stylist, item_count): stylist_item = [] stylist_item_ids = [] + + filter_items = [ + {"item_group_id": {"$ne": "Clothing"}}, + {"item_group_id": {"$ne": "Shoes"}}, + {"modality": "image"} + ] + random_items = [] + for i in stylist: # 1. 根据stylist要求抽取item query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False) stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10) stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count']) stylist_item_ids += [item_id['item_id'] for item_id in stylist_item] + filter_items.append({"item_group_id": {"$ne": i['category']}}) accessories_count = 9 - item_count - len(stylist_item) if accessories_count > 0: - accessories_count = random.randint(1, accessories_count) - # 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item - random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) - random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas'] + if accessories_count > 4: + accessories_count = 4 + for i in range(accessories_count): + random_poll = self.local_db.load_filtered_ids(filter_items) + # 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item + item = self.local_db.random_get_accessories(random.choice(random_poll)) + filter_items.append({"item_group_id": {"$ne": item['metadatas'][0]['category']}}) + random_items.append(item['metadatas'][0]) + all_items = stylist_item + random_items + else: all_items = stylist_item diff --git a/app/server/ChatbotAgent/core/vector_database.py b/app/server/ChatbotAgent/core/vector_database.py index bf74a23..a7b141b 100644 --- a/app/server/ChatbotAgent/core/vector_database.py +++ b/app/server/ChatbotAgent/core/vector_database.py @@ -1,4 +1,5 @@ import random +import time import torch import chromadb @@ -17,8 +18,12 @@ class VectorDatabase(): self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device) self.processor = CLIPProcessor.from_pretrained(embedding_model_name) - self.cache_filtered_ids = self.load_filtered_ids() - self.total_count = len(self.cache_filtered_ids) + # self.cache_filtered_ids = self.load_filtered_ids([ + # {"item_group_id": {"$ne": "Clothing"}}, + # {"item_group_id": {"$ne": "Shoes"}}, + # {"modality": "image"} + # ]) + # self.total_count = len(self.cache_filtered_ids) def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]: """生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。""" @@ -63,14 +68,11 @@ class VectorDatabase(): ) return results - def load_filtered_ids(self): - print("\n--- 初始化阶段:加载所有符合条件的 ID ---") + def load_filtered_ids(self, filter_item): + # print("\n--- 初始化阶段:加载所有符合条件的 ID ---") + start_time = time.time() FILTER_CRITERIA = { - "$and": [ - {"item_group_id": {"$ne": "Clothing"}}, - {"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes - {"modality": "image"} - ] + "$and": filter_item } MAX_LIMIT = 100000 @@ -82,7 +84,8 @@ class VectorDatabase(): include=[] ) all_matched_ids = all_ids_results['ids'] - print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。") + # print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。") + print(time.time() - start_time) return all_matched_ids except Exception as e: @@ -114,20 +117,29 @@ if __name__ == '__main__': 'count': 2, 'category': "Jewelry" } + max_len = 5 local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32") + A = local_db.load_filtered_ids([ + {"item_group_id": {"$ne": "Clothing"}}, + {"item_group_id": {"$ne": "Shoes"}}, + {"modality": "image"} + ]) # print(db.random_get_accessories()) - - query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False) - - results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10) - # 2. 从结果集中抽 stylist['count'] 个item - stylist_item = random.choices(results['metadatas'][0], k=stylist['count']) - stylist_item_ids = [item_id['item_id'] for item_id in stylist_item] - - # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item - accessories_count = 9 - max_len - stylist['count'] - - random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) - random_items = local_db.random_get_accessories(random_single_ids)['metadatas'] - all_items = stylist_item + random_items + start_time = time.time() + X = local_db.random_get_accessories(['ELI699_img']) + print(X) + print(time.time() - start_time) + # query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False) + # + # results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10) + # # 2. 从结果集中抽 stylist['count'] 个item + # stylist_item = random.choices(results['metadatas'][0], k=stylist['count']) + # stylist_item_ids = [item_id['item_id'] for item_id in stylist_item] + # + # # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item + # accessories_count = 9 - max_len - stylist['count'] + # + # random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) + # random_items = local_db.random_get_accessories(random_single_ids)['metadatas'] + # all_items = stylist_item + random_items