新增 取消agent配饰推荐,改为默认随机配饰搭配
This commit is contained in:
106
app/test/chromadb/random_accessories.py
Normal file
106
app/test/chromadb/random_accessories.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import time
|
||||
|
||||
import chromadb
|
||||
import random
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# --- 你的配置 ---
|
||||
DB_PATH = "/workspace/lc_stylist_agent/db"
|
||||
COLLECTION_NAME = 'lc_clothing_embedding'
|
||||
FILTER_CRITERIA = {
|
||||
"$and": [
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"modality": "image"},
|
||||
]
|
||||
}
|
||||
MAX_LIMIT = 1000000 # 用于第一次获取所有ID的限制
|
||||
|
||||
client = chromadb.PersistentClient(path=DB_PATH)
|
||||
try:
|
||||
collection = client.get_collection(name=COLLECTION_NAME)
|
||||
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
|
||||
except ValueError:
|
||||
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
|
||||
exit()
|
||||
|
||||
|
||||
# -----------------------------------------------
|
||||
# 步骤 1: 应用程序启动时/初始化时执行(只执行一次)
|
||||
# -----------------------------------------------
|
||||
def load_filtered_ids(coll: chromadb.api.models.Collection.Collection, filter_criteria: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
加载并缓存所有符合条件的记录ID。
|
||||
"""
|
||||
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
||||
|
||||
try:
|
||||
# 获取所有符合条件的 ID
|
||||
all_ids_results = coll.get(
|
||||
where=filter_criteria,
|
||||
limit=MAX_LIMIT,
|
||||
include=[]
|
||||
)
|
||||
all_matched_ids = all_ids_results['ids']
|
||||
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
||||
return all_matched_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# 存储所有符合条件的 ID 的全局变量 (缓存)
|
||||
start_time = time.time()
|
||||
CACHED_FILTERED_IDS = load_filtered_ids(collection, FILTER_CRITERIA)
|
||||
print(time.time() - start_time)
|
||||
|
||||
|
||||
# -----------------------------------------------
|
||||
# 步骤 2: 每次需要随机记录时调用 (高效重复执行)
|
||||
# -----------------------------------------------
|
||||
def get_random_record_from_cache(coll: chromadb.api.models.Collection.Collection, cached_ids: List[str]) -> Dict[str, Any] | None:
|
||||
"""
|
||||
从缓存的 ID 列表中随机选择一个 ID,然后查询其详细信息。
|
||||
"""
|
||||
total_count = len(cached_ids)
|
||||
|
||||
if total_count == 0:
|
||||
return None
|
||||
|
||||
# 1. 纯 Python 内存操作:从缓存中随机选择一个 ID
|
||||
random_single_id = random.choice(cached_ids)
|
||||
|
||||
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
||||
try:
|
||||
final_results = coll.get(
|
||||
ids=[random_single_id],
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
if final_results['ids']:
|
||||
return {
|
||||
"id": final_results['ids'][0],
|
||||
"metadata": final_results['metadatas'][0]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取最终记录时发生错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# --- 执行并打印结果 (可以多次调用,每次都很快) ---
|
||||
print("\n--- 随机获取 1 ---")
|
||||
start_time = time.time()
|
||||
random_data_1 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS)
|
||||
print(time.time() - start_time)
|
||||
if random_data_1:
|
||||
print(f" ID: {random_data_1}")
|
||||
|
||||
print("\n--- 随机获取 2 ---")
|
||||
start_time = time.time()
|
||||
random_data_2 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS)
|
||||
print(time.time() - start_time)
|
||||
if random_data_2:
|
||||
print(f" ID: {random_data_2}")
|
||||
101
app/test/chromadb/type_list.py
Normal file
101
app/test/chromadb/type_list.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import chromadb
|
||||
from typing import Set, List, Dict, Union, Any
|
||||
|
||||
# --- 你的配置 ---
|
||||
DB_PATH = "/workspace/lc_stylist_agent/db"
|
||||
COLLECTION_NAME = 'lc_clothing_embedding'
|
||||
# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大)
|
||||
MAX_LIMIT = 1000000
|
||||
|
||||
client = chromadb.PersistentClient(path=DB_PATH)
|
||||
try:
|
||||
collection = client.get_collection(name=COLLECTION_NAME)
|
||||
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
|
||||
except ValueError:
|
||||
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
|
||||
# 如果 collection 不存在,我们将跳过后续操作
|
||||
collection = None
|
||||
|
||||
|
||||
def get_category_item_group_map(
|
||||
coll: chromadb.api.models.Collection.Collection
|
||||
) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
获取 Collection 中所有记录的 'category' 和 'item_group_id' 字段,
|
||||
并返回一个 Category 到其所有唯一 Item Group ID 集合的映射。
|
||||
"""
|
||||
print("\n--- 正在获取所有记录的元数据 (category 和 item_group_id)... ---")
|
||||
|
||||
# 1. 获取所有记录的元数据
|
||||
try:
|
||||
# 使用 .get() 方法获取所有 metadatas,不包含 embeddings 和 documents
|
||||
results = coll.get(
|
||||
limit=MAX_LIMIT,
|
||||
include=["metadatas"]
|
||||
)
|
||||
|
||||
all_metadatas: List[Dict[str, Any]] = results.get('metadatas', [])
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取元数据时发生错误: {e}")
|
||||
return {}
|
||||
|
||||
if not all_metadatas:
|
||||
print("❌ 集合中没有元数据记录。")
|
||||
return {}
|
||||
|
||||
# 2. 构建 Category 到 Item Group ID 集合的映射
|
||||
# 结构: { 'CategoryA': {'group_id_1', 'group_id_2'}, 'CategoryB': {'group_id_3'} }
|
||||
category_item_group_map: Dict[str, Set[str]] = {}
|
||||
|
||||
for metadata in all_metadatas:
|
||||
category_value: Union[str, None] = metadata.get('category')
|
||||
item_group_id_value: Union[str, None] = metadata.get('item_group_id')
|
||||
|
||||
# 确保两个元数据字段都存在
|
||||
if category_value is not None and item_group_id_value is not None:
|
||||
category = str(category_value)
|
||||
item_group_id = str(item_group_id_value)
|
||||
|
||||
# 使用 setdefault 来确保 category 键存在,并初始化为一个空的 set
|
||||
# 然后将 item_group_id 添加到对应的 set 中
|
||||
category_item_group_map.setdefault(category, set()).add(item_group_id)
|
||||
|
||||
return category_item_group_map
|
||||
|
||||
|
||||
# --- 执行并打印结果 ---
|
||||
if collection:
|
||||
category_item_group_mapping = get_category_item_group_map(collection)
|
||||
|
||||
if category_item_group_mapping:
|
||||
|
||||
# 统计总共发现多少种唯一 Category
|
||||
print(f"\n🎉 发现 {len(category_item_group_mapping)} 种唯一 Category 及其对应的 Item Group IDs:")
|
||||
|
||||
# 对类别名称进行排序,便于查看
|
||||
sorted_categories = sorted(category_item_group_mapping.keys())
|
||||
|
||||
# 打印详细结果
|
||||
for category in sorted_categories:
|
||||
item_group_ids = category_item_group_mapping[category]
|
||||
|
||||
# 打印类别名称和唯一的 Item Group ID 数量
|
||||
print(f"\n--- 👕 Category: **{category}** ---")
|
||||
print(f"**Item Group 总数:** {len(item_group_ids)} 个唯一 Item Group ID")
|
||||
|
||||
# 将 Item Group ID 转换为列表并排序,以便展示
|
||||
sorted_item_group_ids = sorted(list(item_group_ids))
|
||||
|
||||
# 打印前 10 个 Item Group ID 作为示例
|
||||
print(f"**部分 Item Group IDs (示例):**")
|
||||
|
||||
# 使用列表展示
|
||||
for i, item_group_id in enumerate(sorted_item_group_ids[:10]):
|
||||
print(f"* {item_group_id}")
|
||||
|
||||
if len(sorted_item_group_ids) > 10:
|
||||
print(f"* ... (还有 {len(sorted_item_group_ids) - 10} 个 Item Group ID 未列出)")
|
||||
|
||||
else:
|
||||
print("没有找到任何 Category 或 Item Group ID 数据。请检查元数据字段名称是否正确。")
|
||||
Reference in New Issue
Block a user