新增 取消agent配饰推荐,改为默认随机配饰搭配

This commit is contained in:
zhh
2025-11-19 15:10:57 +08:00
parent 36fc937f0c
commit 3b685c34f0
4 changed files with 315 additions and 7 deletions

View File

@@ -0,0 +1,106 @@
import time
import chromadb
import random
from typing import Dict, Any, List
# --- 你的配置 ---
DB_PATH = "/workspace/lc_stylist_agent/db"
COLLECTION_NAME = 'lc_clothing_embedding'
FILTER_CRITERIA = {
"$and": [
{"item_group_id": {"$ne": "Clothing"}},
{"modality": "image"},
]
}
MAX_LIMIT = 1000000 # 用于第一次获取所有ID的限制
client = chromadb.PersistentClient(path=DB_PATH)
try:
collection = client.get_collection(name=COLLECTION_NAME)
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
except ValueError:
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
exit()
# -----------------------------------------------
# 步骤 1: 应用程序启动时/初始化时执行(只执行一次)
# -----------------------------------------------
def load_filtered_ids(coll: chromadb.api.models.Collection.Collection, filter_criteria: Dict[str, Any]) -> List[str]:
"""
加载并缓存所有符合条件的记录ID。
"""
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
try:
# 获取所有符合条件的 ID
all_ids_results = coll.get(
where=filter_criteria,
limit=MAX_LIMIT,
include=[]
)
all_matched_ids = all_ids_results['ids']
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
return all_matched_ids
except Exception as e:
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
return []
# 存储所有符合条件的 ID 的全局变量 (缓存)
start_time = time.time()
CACHED_FILTERED_IDS = load_filtered_ids(collection, FILTER_CRITERIA)
print(time.time() - start_time)
# -----------------------------------------------
# 步骤 2: 每次需要随机记录时调用 (高效重复执行)
# -----------------------------------------------
def get_random_record_from_cache(coll: chromadb.api.models.Collection.Collection, cached_ids: List[str]) -> Dict[str, Any] | None:
"""
从缓存的 ID 列表中随机选择一个 ID然后查询其详细信息。
"""
total_count = len(cached_ids)
if total_count == 0:
return None
# 1. 纯 Python 内存操作:从缓存中随机选择一个 ID
random_single_id = random.choice(cached_ids)
# 2. 调用 ChromaDB只查询这一个 ID 的详细信息
try:
final_results = coll.get(
ids=[random_single_id],
)
# 提取结果
if final_results['ids']:
return {
"id": final_results['ids'][0],
"metadata": final_results['metadatas'][0]
}
else:
return None
except Exception as e:
print(f"❌ 获取最终记录时发生错误: {e}")
return None
# --- 执行并打印结果 (可以多次调用,每次都很快) ---
print("\n--- 随机获取 1 ---")
start_time = time.time()
random_data_1 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS)
print(time.time() - start_time)
if random_data_1:
print(f" ID: {random_data_1}")
print("\n--- 随机获取 2 ---")
start_time = time.time()
random_data_2 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS)
print(time.time() - start_time)
if random_data_2:
print(f" ID: {random_data_2}")

View File

@@ -0,0 +1,101 @@
import chromadb
from typing import Set, List, Dict, Union, Any
# --- 你的配置 ---
DB_PATH = "/workspace/lc_stylist_agent/db"
COLLECTION_NAME = 'lc_clothing_embedding'
# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大)
MAX_LIMIT = 1000000
client = chromadb.PersistentClient(path=DB_PATH)
try:
collection = client.get_collection(name=COLLECTION_NAME)
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
except ValueError:
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
# 如果 collection 不存在,我们将跳过后续操作
collection = None
def get_category_item_group_map(
coll: chromadb.api.models.Collection.Collection
) -> Dict[str, Set[str]]:
"""
获取 Collection 中所有记录的 'category''item_group_id' 字段,
并返回一个 Category 到其所有唯一 Item Group ID 集合的映射。
"""
print("\n--- 正在获取所有记录的元数据 (category 和 item_group_id)... ---")
# 1. 获取所有记录的元数据
try:
# 使用 .get() 方法获取所有 metadatas不包含 embeddings 和 documents
results = coll.get(
limit=MAX_LIMIT,
include=["metadatas"]
)
all_metadatas: List[Dict[str, Any]] = results.get('metadatas', [])
except Exception as e:
print(f"❌ 获取元数据时发生错误: {e}")
return {}
if not all_metadatas:
print("❌ 集合中没有元数据记录。")
return {}
# 2. 构建 Category 到 Item Group ID 集合的映射
# 结构: { 'CategoryA': {'group_id_1', 'group_id_2'}, 'CategoryB': {'group_id_3'} }
category_item_group_map: Dict[str, Set[str]] = {}
for metadata in all_metadatas:
category_value: Union[str, None] = metadata.get('category')
item_group_id_value: Union[str, None] = metadata.get('item_group_id')
# 确保两个元数据字段都存在
if category_value is not None and item_group_id_value is not None:
category = str(category_value)
item_group_id = str(item_group_id_value)
# 使用 setdefault 来确保 category 键存在,并初始化为一个空的 set
# 然后将 item_group_id 添加到对应的 set 中
category_item_group_map.setdefault(category, set()).add(item_group_id)
return category_item_group_map
# --- 执行并打印结果 ---
if collection:
category_item_group_mapping = get_category_item_group_map(collection)
if category_item_group_mapping:
# 统计总共发现多少种唯一 Category
print(f"\n🎉 发现 {len(category_item_group_mapping)} 种唯一 Category 及其对应的 Item Group IDs:")
# 对类别名称进行排序,便于查看
sorted_categories = sorted(category_item_group_mapping.keys())
# 打印详细结果
for category in sorted_categories:
item_group_ids = category_item_group_mapping[category]
# 打印类别名称和唯一的 Item Group ID 数量
print(f"\n--- 👕 Category: **{category}** ---")
print(f"**Item Group 总数:** {len(item_group_ids)} 个唯一 Item Group ID")
# 将 Item Group ID 转换为列表并排序,以便展示
sorted_item_group_ids = sorted(list(item_group_ids))
# 打印前 10 个 Item Group ID 作为示例
print(f"**部分 Item Group IDs (示例):**")
# 使用列表展示
for i, item_group_id in enumerate(sorted_item_group_ids[:10]):
print(f"* {item_group_id}")
if len(sorted_item_group_ids) > 10:
print(f"* ... (还有 {len(sorted_item_group_ids) - 10} 个 Item Group ID 未列出)")
else:
print("没有找到任何 Category 或 Item Group ID 数据。请检查元数据字段名称是否正确。")