新增 取消agent配饰推荐,改为默认随机配饰搭配

This commit is contained in:
zhh
2025-11-19 15:10:57 +08:00
parent 36fc937f0c
commit 3b685c34f0
4 changed files with 315 additions and 7 deletions

View File

@@ -19,7 +19,16 @@ logger = logging.getLogger(__name__)
class AsyncStylistAgent: class AsyncStylistAgent:
CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'} CATEGORY_SET = {
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear'
# 取消推荐配饰
# , 'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls'
}
CATEGORY_SET_ALL = {
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear',
'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry',
'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls'
}
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str): def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
# self.outfit_items: List[Dict[str, str]] = [] # self.outfit_items: List[Dict[str, str]] = []
@@ -99,8 +108,8 @@ class AsyncStylistAgent:
## Your Workflow and Constraints ## Your Workflow and Constraints
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**. 1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions coordination**.
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories. 2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: major garments (tops/outerwear/bottoms/dresses).
3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows: 3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows:
```json ```json
@@ -118,7 +127,6 @@ class AsyncStylistAgent:
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit) * **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
* **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern) * **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern)
* **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look) * **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look)
* **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.**
4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process: 4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process:
@@ -156,7 +164,7 @@ class AsyncStylistAgent:
# self._clear_uploaded_files() # self._clear_uploaded_files()
# 1. 添加图片内容 # 1. 添加图片内容
if self.outfit_items: if self.outfit_items:
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len, add_text=False) merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len + 1, add_text=False)
image_bytes_io = io.BytesIO() image_bytes_io = io.BytesIO()
image_format = 'JPEG' image_format = 'JPEG'
mime_type = 'image/jpeg' mime_type = 'image/jpeg'
@@ -252,6 +260,21 @@ class AsyncStylistAgent:
print(f"An error occurred during item retrieval: {e}") print(f"An error occurred during item retrieval: {e}")
return None return None
def _get_random_accessories(self):
results = self.local_db.random_get_accessories()
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
best_meta = results['metadatas'][0]
return {
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
"category": best_meta['category'],
"gpt_description": best_meta['description'],
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
}
def _build_user_input(self) -> str: def _build_user_input(self) -> str:
"""构建发送给 Gemini 的用户输入,包含已选单品信息。""" """构建发送给 Gemini 的用户输入,包含已选单品信息。"""
if not self.outfit_items: if not self.outfit_items:
@@ -313,7 +336,20 @@ class AsyncStylistAgent:
# 3. 检查终止条件 # 3. 检查终止条件
if gemini_data.get('action') == 'stop': if gemini_data.get('action') == 'stop':
print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}") response_data['path'] = minio_path
response_data['items'].append({"item_id": item_id, "category": item_category})
response_data['status'] = "ok"
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
# 新增配饰
new_item = self._get_random_accessories()
self.outfit_items.append(new_item)
user_input = self._build_user_input()
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided') self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
response_data['status'] = "stop" response_data['status'] = "stop"
response_data['message'] = self.stop_reason response_data['message'] = self.stop_reason
@@ -327,7 +363,7 @@ class AsyncStylistAgent:
description = gemini_data.get('description') description = gemini_data.get('description')
# 4a. 检查类别是否有效 (重要步骤) # 4a. 检查类别是否有效 (重要步骤)
if category not in self.CATEGORY_SET: if category not in self.CATEGORY_SET_ALL:
print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。") print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。")
# 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正 # 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正
# 这里简化为跳过本次循环 # 这里简化为跳过本次循环
@@ -382,6 +418,18 @@ class AsyncStylistAgent:
break break
if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环 if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": self.outfit_items[-1]['item_id'], "category": self.outfit_items[-1]['category']})
response_data['status'] = "ok"
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
# 新增配饰
new_item = self._get_random_accessories()
self.outfit_items.append(new_item)
user_input = self._build_user_input()
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
logger.info("🚨 达到最大搭配数量限制,强制终止。") logger.info("🚨 达到最大搭配数量限制,强制终止。")
self.stop_reason = "Finish reason: Reached max outfit length." self.stop_reason = "Finish reason: Reached max outfit length."
response_data['status'] = "stop" response_data['status'] = "stop"

View File

@@ -1,3 +1,5 @@
import random
import torch import torch
import chromadb import chromadb
from PIL import Image from PIL import Image
@@ -15,6 +17,8 @@ class VectorDatabase():
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device) self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(embedding_model_name) self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
self.cache_filtered_ids = self.load_filtered_ids()
self.total_count = len(self.cache_filtered_ids)
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]: def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。""" """生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
@@ -58,3 +62,52 @@ class VectorDatabase():
include=['documents', 'metadatas', 'distances'] include=['documents', 'metadatas', 'distances']
) )
return results return results
def load_filtered_ids(self):
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
FILTER_CRITERIA = {
"$and": [
{"item_group_id": {"$ne": "Clothing"}},
{"modality": "image"},
]
}
MAX_LIMIT = 100000
try:
# 获取所有符合条件的 ID
all_ids_results = self.collection.get(
where=FILTER_CRITERIA,
limit=MAX_LIMIT,
include=[]
)
all_matched_ids = all_ids_results['ids']
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
return all_matched_ids
except Exception as e:
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
return []
def random_get_accessories(self):
random_single_id = random.choice(self.cache_filtered_ids)
# 2. 调用 ChromaDB只查询这一个 ID 的详细信息
try:
final_results = self.collection.get(
ids=[random_single_id],
include=["metadatas"] # 你只需要元数据
)
# 提取结果
if final_results['ids']:
return final_results
else:
return None
except Exception as e:
print(f"❌ 获取最终记录时发生错误: {e}")
return None
if __name__ == '__main__':
db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
print(db.random_get_accessories())

View File

@@ -0,0 +1,106 @@
import time
import chromadb
import random
from typing import Dict, Any, List
# --- 你的配置 ---
DB_PATH = "/workspace/lc_stylist_agent/db"
COLLECTION_NAME = 'lc_clothing_embedding'
FILTER_CRITERIA = {
"$and": [
{"item_group_id": {"$ne": "Clothing"}},
{"modality": "image"},
]
}
MAX_LIMIT = 1000000 # 用于第一次获取所有ID的限制
client = chromadb.PersistentClient(path=DB_PATH)
try:
collection = client.get_collection(name=COLLECTION_NAME)
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
except ValueError:
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
exit()
# -----------------------------------------------
# 步骤 1: 应用程序启动时/初始化时执行(只执行一次)
# -----------------------------------------------
def load_filtered_ids(coll: chromadb.api.models.Collection.Collection, filter_criteria: Dict[str, Any]) -> List[str]:
"""
加载并缓存所有符合条件的记录ID。
"""
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
try:
# 获取所有符合条件的 ID
all_ids_results = coll.get(
where=filter_criteria,
limit=MAX_LIMIT,
include=[]
)
all_matched_ids = all_ids_results['ids']
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
return all_matched_ids
except Exception as e:
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
return []
# 存储所有符合条件的 ID 的全局变量 (缓存)
start_time = time.time()
CACHED_FILTERED_IDS = load_filtered_ids(collection, FILTER_CRITERIA)
print(time.time() - start_time)
# -----------------------------------------------
# 步骤 2: 每次需要随机记录时调用 (高效重复执行)
# -----------------------------------------------
def get_random_record_from_cache(coll: chromadb.api.models.Collection.Collection, cached_ids: List[str]) -> Dict[str, Any] | None:
"""
从缓存的 ID 列表中随机选择一个 ID然后查询其详细信息。
"""
total_count = len(cached_ids)
if total_count == 0:
return None
# 1. 纯 Python 内存操作:从缓存中随机选择一个 ID
random_single_id = random.choice(cached_ids)
# 2. 调用 ChromaDB只查询这一个 ID 的详细信息
try:
final_results = coll.get(
ids=[random_single_id],
)
# 提取结果
if final_results['ids']:
return {
"id": final_results['ids'][0],
"metadata": final_results['metadatas'][0]
}
else:
return None
except Exception as e:
print(f"❌ 获取最终记录时发生错误: {e}")
return None
# --- 执行并打印结果 (可以多次调用,每次都很快) ---
print("\n--- 随机获取 1 ---")
start_time = time.time()
random_data_1 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS)
print(time.time() - start_time)
if random_data_1:
print(f" ID: {random_data_1}")
print("\n--- 随机获取 2 ---")
start_time = time.time()
random_data_2 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS)
print(time.time() - start_time)
if random_data_2:
print(f" ID: {random_data_2}")

View File

@@ -0,0 +1,101 @@
import chromadb
from typing import Set, List, Dict, Union, Any
# --- 你的配置 ---
DB_PATH = "/workspace/lc_stylist_agent/db"
COLLECTION_NAME = 'lc_clothing_embedding'
# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大)
MAX_LIMIT = 1000000
client = chromadb.PersistentClient(path=DB_PATH)
try:
collection = client.get_collection(name=COLLECTION_NAME)
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
except ValueError:
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
# 如果 collection 不存在,我们将跳过后续操作
collection = None
def get_category_item_group_map(
coll: chromadb.api.models.Collection.Collection
) -> Dict[str, Set[str]]:
"""
获取 Collection 中所有记录的 'category''item_group_id' 字段,
并返回一个 Category 到其所有唯一 Item Group ID 集合的映射。
"""
print("\n--- 正在获取所有记录的元数据 (category 和 item_group_id)... ---")
# 1. 获取所有记录的元数据
try:
# 使用 .get() 方法获取所有 metadatas不包含 embeddings 和 documents
results = coll.get(
limit=MAX_LIMIT,
include=["metadatas"]
)
all_metadatas: List[Dict[str, Any]] = results.get('metadatas', [])
except Exception as e:
print(f"❌ 获取元数据时发生错误: {e}")
return {}
if not all_metadatas:
print("❌ 集合中没有元数据记录。")
return {}
# 2. 构建 Category 到 Item Group ID 集合的映射
# 结构: { 'CategoryA': {'group_id_1', 'group_id_2'}, 'CategoryB': {'group_id_3'} }
category_item_group_map: Dict[str, Set[str]] = {}
for metadata in all_metadatas:
category_value: Union[str, None] = metadata.get('category')
item_group_id_value: Union[str, None] = metadata.get('item_group_id')
# 确保两个元数据字段都存在
if category_value is not None and item_group_id_value is not None:
category = str(category_value)
item_group_id = str(item_group_id_value)
# 使用 setdefault 来确保 category 键存在,并初始化为一个空的 set
# 然后将 item_group_id 添加到对应的 set 中
category_item_group_map.setdefault(category, set()).add(item_group_id)
return category_item_group_map
# --- 执行并打印结果 ---
if collection:
category_item_group_mapping = get_category_item_group_map(collection)
if category_item_group_mapping:
# 统计总共发现多少种唯一 Category
print(f"\n🎉 发现 {len(category_item_group_mapping)} 种唯一 Category 及其对应的 Item Group IDs:")
# 对类别名称进行排序,便于查看
sorted_categories = sorted(category_item_group_mapping.keys())
# 打印详细结果
for category in sorted_categories:
item_group_ids = category_item_group_mapping[category]
# 打印类别名称和唯一的 Item Group ID 数量
print(f"\n--- 👕 Category: **{category}** ---")
print(f"**Item Group 总数:** {len(item_group_ids)} 个唯一 Item Group ID")
# 将 Item Group ID 转换为列表并排序,以便展示
sorted_item_group_ids = sorted(list(item_group_ids))
# 打印前 10 个 Item Group ID 作为示例
print(f"**部分 Item Group IDs (示例):**")
# 使用列表展示
for i, item_group_id in enumerate(sorted_item_group_ids[:10]):
print(f"* {item_group_id}")
if len(sorted_item_group_ids) > 10:
print(f"* ... (还有 {len(sorted_item_group_ids) - 10} 个 Item Group ID 未列出)")
else:
print("没有找到任何 Category 或 Item Group ID 数据。请检查元数据字段名称是否正确。")