新增 取消agent配饰推荐,改为默认随机配饰搭配

This commit is contained in:
zhh
2025-11-19 15:10:57 +08:00
parent 36fc937f0c
commit 3b685c34f0
4 changed files with 315 additions and 7 deletions

View File

@@ -19,7 +19,16 @@ logger = logging.getLogger(__name__)
class AsyncStylistAgent:
CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'}
CATEGORY_SET = {
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear'
# 取消推荐配饰
# , 'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls'
}
CATEGORY_SET_ALL = {
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear',
'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry',
'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls'
}
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
# self.outfit_items: List[Dict[str, str]] = []
@@ -99,8 +108,8 @@ class AsyncStylistAgent:
## Your Workflow and Constraints
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**.
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories.
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions coordination**.
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: major garments (tops/outerwear/bottoms/dresses).
3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows:
```json
@@ -118,7 +127,6 @@ class AsyncStylistAgent:
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
* **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern)
* **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look)
* **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.**
4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process:
@@ -156,7 +164,7 @@ class AsyncStylistAgent:
# self._clear_uploaded_files()
# 1. 添加图片内容
if self.outfit_items:
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len, add_text=False)
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len + 1, add_text=False)
image_bytes_io = io.BytesIO()
image_format = 'JPEG'
mime_type = 'image/jpeg'
@@ -252,6 +260,21 @@ class AsyncStylistAgent:
print(f"An error occurred during item retrieval: {e}")
return None
def _get_random_accessories(self):
results = self.local_db.random_get_accessories()
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
best_meta = results['metadatas'][0]
return {
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
"category": best_meta['category'],
"gpt_description": best_meta['description'],
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
}
def _build_user_input(self) -> str:
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
if not self.outfit_items:
@@ -313,7 +336,20 @@ class AsyncStylistAgent:
# 3. 检查终止条件
if gemini_data.get('action') == 'stop':
print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
response_data['path'] = minio_path
response_data['items'].append({"item_id": item_id, "category": item_category})
response_data['status'] = "ok"
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
# 新增配饰
new_item = self._get_random_accessories()
self.outfit_items.append(new_item)
user_input = self._build_user_input()
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
@@ -327,7 +363,7 @@ class AsyncStylistAgent:
description = gemini_data.get('description')
# 4a. 检查类别是否有效 (重要步骤)
if category not in self.CATEGORY_SET:
if category not in self.CATEGORY_SET_ALL:
print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。")
# 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正
# 这里简化为跳过本次循环
@@ -382,6 +418,18 @@ class AsyncStylistAgent:
break
if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": self.outfit_items[-1]['item_id'], "category": self.outfit_items[-1]['category']})
response_data['status'] = "ok"
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
# 新增配饰
new_item = self._get_random_accessories()
self.outfit_items.append(new_item)
user_input = self._build_user_input()
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
logger.info("🚨 达到最大搭配数量限制,强制终止。")
self.stop_reason = "Finish reason: Reached max outfit length."
response_data['status'] = "stop"

View File

@@ -1,3 +1,5 @@
import random
import torch
import chromadb
from PIL import Image
@@ -15,6 +17,8 @@ class VectorDatabase():
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
self.cache_filtered_ids = self.load_filtered_ids()
self.total_count = len(self.cache_filtered_ids)
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
@@ -58,3 +62,52 @@ class VectorDatabase():
include=['documents', 'metadatas', 'distances']
)
return results
def load_filtered_ids(self):
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
FILTER_CRITERIA = {
"$and": [
{"item_group_id": {"$ne": "Clothing"}},
{"modality": "image"},
]
}
MAX_LIMIT = 100000
try:
# 获取所有符合条件的 ID
all_ids_results = self.collection.get(
where=FILTER_CRITERIA,
limit=MAX_LIMIT,
include=[]
)
all_matched_ids = all_ids_results['ids']
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
return all_matched_ids
except Exception as e:
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
return []
def random_get_accessories(self):
random_single_id = random.choice(self.cache_filtered_ids)
# 2. 调用 ChromaDB只查询这一个 ID 的详细信息
try:
final_results = self.collection.get(
ids=[random_single_id],
include=["metadatas"] # 你只需要元数据
)
# 提取结果
if final_results['ids']:
return final_results
else:
return None
except Exception as e:
print(f"❌ 获取最终记录时发生错误: {e}")
return None
if __name__ == '__main__':
db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
print(db.random_get_accessories())