From 3b685c34f0e252d218d3d2f802c1edbacb499838 Mon Sep 17 00:00:00 2001 From: zhh Date: Wed, 19 Nov 2025 15:10:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=20=E5=8F=96=E6=B6=88agent?= =?UTF-8?q?=E9=85=8D=E9=A5=B0=E6=8E=A8=E8=8D=90=EF=BC=8C=E6=94=B9=E4=B8=BA?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E9=9A=8F=E6=9C=BA=E9=85=8D=E9=A5=B0=E6=90=AD?= =?UTF-8?q?=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ChatbotAgent/core/stylist_agent_server.py | 62 ++++++++-- .../ChatbotAgent/core/vector_database.py | 53 +++++++++ app/test/chromadb/random_accessories.py | 106 ++++++++++++++++++ app/test/chromadb/type_list.py | 101 +++++++++++++++++ 4 files changed, 315 insertions(+), 7 deletions(-) create mode 100644 app/test/chromadb/random_accessories.py create mode 100644 app/test/chromadb/type_list.py diff --git a/app/server/ChatbotAgent/core/stylist_agent_server.py b/app/server/ChatbotAgent/core/stylist_agent_server.py index f20d7ee..aacf70d 100644 --- a/app/server/ChatbotAgent/core/stylist_agent_server.py +++ b/app/server/ChatbotAgent/core/stylist_agent_server.py @@ -19,7 +19,16 @@ logger = logging.getLogger(__name__) class AsyncStylistAgent: - CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'} + CATEGORY_SET = { + 'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear' + # 取消推荐配饰 + # , 'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls' + } + CATEGORY_SET_ALL = { + 'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear', + 'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry', + 'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls' + } def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str): # self.outfit_items: List[Dict[str, str]] = [] @@ -99,8 +108,8 @@ class AsyncStylistAgent: ## Your Workflow and Constraints - 1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**. - 2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories. + 1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions coordination**. + 2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: major garments (tops/outerwear/bottoms/dresses). 3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows: ```json @@ -118,7 +127,6 @@ class AsyncStylistAgent: * **Fit/Silhouette** (e.g., Oversize, loose, slim-fit) * **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern) * **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look) - * **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.** 4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process: @@ -156,7 +164,7 @@ class AsyncStylistAgent: # self._clear_uploaded_files() # 1. 添加图片内容 if self.outfit_items: - merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len, add_text=False) + merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len + 1, add_text=False) image_bytes_io = io.BytesIO() image_format = 'JPEG' mime_type = 'image/jpeg' @@ -252,6 +260,21 @@ class AsyncStylistAgent: print(f"An error occurred during item retrieval: {e}") return None + def _get_random_accessories(self): + results = self.local_db.random_get_accessories() + + # 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核) + best_meta = results['metadatas'][0] + return { + "item_id": best_meta['item_id'], # 从 metadata 字典中安全获取 + "category": best_meta['category'], + "gpt_description": best_meta['description'], + 'description': best_meta['description'], + # 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导 + # 这里假设 item_id 就是文件名的一部分 + "image_path": os.path.join(f"{best_meta['item_id']}.jpg") + } + def _build_user_input(self) -> str: """构建发送给 Gemini 的用户输入,包含已选单品信息。""" if not self.outfit_items: @@ -313,7 +336,20 @@ class AsyncStylistAgent: # 3. 检查终止条件 if gemini_data.get('action') == 'stop': - print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}") + response_data['path'] = minio_path + response_data['items'].append({"item_id": item_id, "category": item_category}) + response_data['status'] = "ok" + response = post_request(url=url, data=json.dumps(response_data), headers=headers) + logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}") + + # 新增配饰 + new_item = self._get_random_accessories() + self.outfit_items.append(new_item) + user_input = self._build_user_input() + gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id) + response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')}) + + logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}") self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided') response_data['status'] = "stop" response_data['message'] = self.stop_reason @@ -327,7 +363,7 @@ class AsyncStylistAgent: description = gemini_data.get('description') # 4a. 检查类别是否有效 (重要步骤) - if category not in self.CATEGORY_SET: + if category not in self.CATEGORY_SET_ALL: print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。") # 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正 # 这里简化为跳过本次循环 @@ -382,6 +418,18 @@ class AsyncStylistAgent: break if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环 + gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id) + response_data['items'].append({"item_id": self.outfit_items[-1]['item_id'], "category": self.outfit_items[-1]['category']}) + response_data['status'] = "ok" + response = post_request(url=url, data=json.dumps(response_data), headers=headers) + logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}") + + # 新增配饰 + new_item = self._get_random_accessories() + self.outfit_items.append(new_item) + user_input = self._build_user_input() + gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id) + response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')}) logger.info("🚨 达到最大搭配数量限制,强制终止。") self.stop_reason = "Finish reason: Reached max outfit length." response_data['status'] = "stop" diff --git a/app/server/ChatbotAgent/core/vector_database.py b/app/server/ChatbotAgent/core/vector_database.py index 10c4f4a..1baee3f 100644 --- a/app/server/ChatbotAgent/core/vector_database.py +++ b/app/server/ChatbotAgent/core/vector_database.py @@ -1,3 +1,5 @@ +import random + import torch import chromadb from PIL import Image @@ -15,6 +17,8 @@ class VectorDatabase(): self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device) self.processor = CLIPProcessor.from_pretrained(embedding_model_name) + self.cache_filtered_ids = self.load_filtered_ids() + self.total_count = len(self.cache_filtered_ids) def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]: """生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。""" @@ -58,3 +62,52 @@ class VectorDatabase(): include=['documents', 'metadatas', 'distances'] ) return results + + def load_filtered_ids(self): + print("\n--- 初始化阶段:加载所有符合条件的 ID ---") + FILTER_CRITERIA = { + "$and": [ + {"item_group_id": {"$ne": "Clothing"}}, + {"modality": "image"}, + ] + } + MAX_LIMIT = 100000 + + try: + # 获取所有符合条件的 ID + all_ids_results = self.collection.get( + where=FILTER_CRITERIA, + limit=MAX_LIMIT, + include=[] + ) + all_matched_ids = all_ids_results['ids'] + print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。") + return all_matched_ids + + except Exception as e: + print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}") + return [] + + def random_get_accessories(self): + random_single_id = random.choice(self.cache_filtered_ids) + # 2. 调用 ChromaDB:只查询这一个 ID 的详细信息 + try: + final_results = self.collection.get( + ids=[random_single_id], + include=["metadatas"] # 你只需要元数据 + ) + + # 提取结果 + if final_results['ids']: + return final_results + else: + return None + + except Exception as e: + print(f"❌ 获取最终记录时发生错误: {e}") + return None + + +if __name__ == '__main__': + db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32") + print(db.random_get_accessories()) diff --git a/app/test/chromadb/random_accessories.py b/app/test/chromadb/random_accessories.py new file mode 100644 index 0000000..4c494d8 --- /dev/null +++ b/app/test/chromadb/random_accessories.py @@ -0,0 +1,106 @@ +import time + +import chromadb +import random +from typing import Dict, Any, List + +# --- 你的配置 --- +DB_PATH = "/workspace/lc_stylist_agent/db" +COLLECTION_NAME = 'lc_clothing_embedding' +FILTER_CRITERIA = { + "$and": [ + {"item_group_id": {"$ne": "Clothing"}}, + {"modality": "image"}, + ] +} +MAX_LIMIT = 1000000 # 用于第一次获取所有ID的限制 + +client = chromadb.PersistentClient(path=DB_PATH) +try: + collection = client.get_collection(name=COLLECTION_NAME) + print(f"✅ 连接到 Collection: {COLLECTION_NAME}") +except ValueError: + print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。") + exit() + + +# ----------------------------------------------- +# 步骤 1: 应用程序启动时/初始化时执行(只执行一次) +# ----------------------------------------------- +def load_filtered_ids(coll: chromadb.api.models.Collection.Collection, filter_criteria: Dict[str, Any]) -> List[str]: + """ + 加载并缓存所有符合条件的记录ID。 + """ + print("\n--- 初始化阶段:加载所有符合条件的 ID ---") + + try: + # 获取所有符合条件的 ID + all_ids_results = coll.get( + where=filter_criteria, + limit=MAX_LIMIT, + include=[] + ) + all_matched_ids = all_ids_results['ids'] + print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。") + return all_matched_ids + + except Exception as e: + print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}") + return [] + + +# 存储所有符合条件的 ID 的全局变量 (缓存) +start_time = time.time() +CACHED_FILTERED_IDS = load_filtered_ids(collection, FILTER_CRITERIA) +print(time.time() - start_time) + + +# ----------------------------------------------- +# 步骤 2: 每次需要随机记录时调用 (高效重复执行) +# ----------------------------------------------- +def get_random_record_from_cache(coll: chromadb.api.models.Collection.Collection, cached_ids: List[str]) -> Dict[str, Any] | None: + """ + 从缓存的 ID 列表中随机选择一个 ID,然后查询其详细信息。 + """ + total_count = len(cached_ids) + + if total_count == 0: + return None + + # 1. 纯 Python 内存操作:从缓存中随机选择一个 ID + random_single_id = random.choice(cached_ids) + + # 2. 调用 ChromaDB:只查询这一个 ID 的详细信息 + try: + final_results = coll.get( + ids=[random_single_id], + ) + + # 提取结果 + if final_results['ids']: + return { + "id": final_results['ids'][0], + "metadata": final_results['metadatas'][0] + } + else: + return None + + except Exception as e: + print(f"❌ 获取最终记录时发生错误: {e}") + return None + + +# --- 执行并打印结果 (可以多次调用,每次都很快) --- +print("\n--- 随机获取 1 ---") +start_time = time.time() +random_data_1 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS) +print(time.time() - start_time) +if random_data_1: + print(f" ID: {random_data_1}") + +print("\n--- 随机获取 2 ---") +start_time = time.time() +random_data_2 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS) +print(time.time() - start_time) +if random_data_2: + print(f" ID: {random_data_2}") diff --git a/app/test/chromadb/type_list.py b/app/test/chromadb/type_list.py new file mode 100644 index 0000000..cd6e944 --- /dev/null +++ b/app/test/chromadb/type_list.py @@ -0,0 +1,101 @@ +import chromadb +from typing import Set, List, Dict, Union, Any + +# --- 你的配置 --- +DB_PATH = "/workspace/lc_stylist_agent/db" +COLLECTION_NAME = 'lc_clothing_embedding' +# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大) +MAX_LIMIT = 1000000 + +client = chromadb.PersistentClient(path=DB_PATH) +try: + collection = client.get_collection(name=COLLECTION_NAME) + print(f"✅ 连接到 Collection: {COLLECTION_NAME}") +except ValueError: + print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。") + # 如果 collection 不存在,我们将跳过后续操作 + collection = None + + +def get_category_item_group_map( + coll: chromadb.api.models.Collection.Collection +) -> Dict[str, Set[str]]: + """ + 获取 Collection 中所有记录的 'category' 和 'item_group_id' 字段, + 并返回一个 Category 到其所有唯一 Item Group ID 集合的映射。 + """ + print("\n--- 正在获取所有记录的元数据 (category 和 item_group_id)... ---") + + # 1. 获取所有记录的元数据 + try: + # 使用 .get() 方法获取所有 metadatas,不包含 embeddings 和 documents + results = coll.get( + limit=MAX_LIMIT, + include=["metadatas"] + ) + + all_metadatas: List[Dict[str, Any]] = results.get('metadatas', []) + + except Exception as e: + print(f"❌ 获取元数据时发生错误: {e}") + return {} + + if not all_metadatas: + print("❌ 集合中没有元数据记录。") + return {} + + # 2. 构建 Category 到 Item Group ID 集合的映射 + # 结构: { 'CategoryA': {'group_id_1', 'group_id_2'}, 'CategoryB': {'group_id_3'} } + category_item_group_map: Dict[str, Set[str]] = {} + + for metadata in all_metadatas: + category_value: Union[str, None] = metadata.get('category') + item_group_id_value: Union[str, None] = metadata.get('item_group_id') + + # 确保两个元数据字段都存在 + if category_value is not None and item_group_id_value is not None: + category = str(category_value) + item_group_id = str(item_group_id_value) + + # 使用 setdefault 来确保 category 键存在,并初始化为一个空的 set + # 然后将 item_group_id 添加到对应的 set 中 + category_item_group_map.setdefault(category, set()).add(item_group_id) + + return category_item_group_map + + +# --- 执行并打印结果 --- +if collection: + category_item_group_mapping = get_category_item_group_map(collection) + + if category_item_group_mapping: + + # 统计总共发现多少种唯一 Category + print(f"\n🎉 发现 {len(category_item_group_mapping)} 种唯一 Category 及其对应的 Item Group IDs:") + + # 对类别名称进行排序,便于查看 + sorted_categories = sorted(category_item_group_mapping.keys()) + + # 打印详细结果 + for category in sorted_categories: + item_group_ids = category_item_group_mapping[category] + + # 打印类别名称和唯一的 Item Group ID 数量 + print(f"\n--- 👕 Category: **{category}** ---") + print(f"**Item Group 总数:** {len(item_group_ids)} 个唯一 Item Group ID") + + # 将 Item Group ID 转换为列表并排序,以便展示 + sorted_item_group_ids = sorted(list(item_group_ids)) + + # 打印前 10 个 Item Group ID 作为示例 + print(f"**部分 Item Group IDs (示例):**") + + # 使用列表展示 + for i, item_group_id in enumerate(sorted_item_group_ids[:10]): + print(f"* {item_group_id}") + + if len(sorted_item_group_ids) > 10: + print(f"* ... (还有 {len(sorted_item_group_ids) - 10} 个 Item Group ID 未列出)") + + else: + print("没有找到任何 Category 或 Item Group ID 数据。请检查元数据字段名称是否正确。") \ No newline at end of file