新增 取消agent配饰推荐,改为默认随机配饰搭配
This commit is contained in:
@@ -19,7 +19,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncStylistAgent:
|
||||
CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'}
|
||||
CATEGORY_SET = {
|
||||
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear'
|
||||
# 取消推荐配饰
|
||||
# , 'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls'
|
||||
}
|
||||
CATEGORY_SET_ALL = {
|
||||
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear',
|
||||
'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry',
|
||||
'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls'
|
||||
}
|
||||
|
||||
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
|
||||
# self.outfit_items: List[Dict[str, str]] = []
|
||||
@@ -99,8 +108,8 @@ class AsyncStylistAgent:
|
||||
|
||||
## Your Workflow and Constraints
|
||||
|
||||
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**.
|
||||
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories.
|
||||
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions coordination**.
|
||||
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: major garments (tops/outerwear/bottoms/dresses).
|
||||
3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows:
|
||||
|
||||
```json
|
||||
@@ -118,7 +127,6 @@ class AsyncStylistAgent:
|
||||
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
|
||||
* **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern)
|
||||
* **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look)
|
||||
* **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.**
|
||||
|
||||
4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process:
|
||||
|
||||
@@ -156,7 +164,7 @@ class AsyncStylistAgent:
|
||||
# self._clear_uploaded_files()
|
||||
# 1. 添加图片内容
|
||||
if self.outfit_items:
|
||||
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len, add_text=False)
|
||||
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len + 1, add_text=False)
|
||||
image_bytes_io = io.BytesIO()
|
||||
image_format = 'JPEG'
|
||||
mime_type = 'image/jpeg'
|
||||
@@ -252,6 +260,21 @@ class AsyncStylistAgent:
|
||||
print(f"An error occurred during item retrieval: {e}")
|
||||
return None
|
||||
|
||||
def _get_random_accessories(self):
|
||||
results = self.local_db.random_get_accessories()
|
||||
|
||||
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
||||
best_meta = results['metadatas'][0]
|
||||
return {
|
||||
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
|
||||
"category": best_meta['category'],
|
||||
"gpt_description": best_meta['description'],
|
||||
'description': best_meta['description'],
|
||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||
# 这里假设 item_id 就是文件名的一部分
|
||||
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
|
||||
}
|
||||
|
||||
def _build_user_input(self) -> str:
|
||||
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
||||
if not self.outfit_items:
|
||||
@@ -313,7 +336,20 @@ class AsyncStylistAgent:
|
||||
|
||||
# 3. 检查终止条件
|
||||
if gemini_data.get('action') == 'stop':
|
||||
print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
||||
response_data['path'] = minio_path
|
||||
response_data['items'].append({"item_id": item_id, "category": item_category})
|
||||
response_data['status'] = "ok"
|
||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||
|
||||
# 新增配饰
|
||||
new_item = self._get_random_accessories()
|
||||
self.outfit_items.append(new_item)
|
||||
user_input = self._build_user_input()
|
||||
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
|
||||
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
|
||||
|
||||
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
||||
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
|
||||
response_data['status'] = "stop"
|
||||
response_data['message'] = self.stop_reason
|
||||
@@ -327,7 +363,7 @@ class AsyncStylistAgent:
|
||||
description = gemini_data.get('description')
|
||||
|
||||
# 4a. 检查类别是否有效 (重要步骤)
|
||||
if category not in self.CATEGORY_SET:
|
||||
if category not in self.CATEGORY_SET_ALL:
|
||||
print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。")
|
||||
# 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正
|
||||
# 这里简化为跳过本次循环
|
||||
@@ -382,6 +418,18 @@ class AsyncStylistAgent:
|
||||
break
|
||||
|
||||
if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环
|
||||
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
|
||||
response_data['items'].append({"item_id": self.outfit_items[-1]['item_id'], "category": self.outfit_items[-1]['category']})
|
||||
response_data['status'] = "ok"
|
||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||
|
||||
# 新增配饰
|
||||
new_item = self._get_random_accessories()
|
||||
self.outfit_items.append(new_item)
|
||||
user_input = self._build_user_input()
|
||||
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
|
||||
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
|
||||
logger.info("🚨 达到最大搭配数量限制,强制终止。")
|
||||
self.stop_reason = "Finish reason: Reached max outfit length."
|
||||
response_data['status'] = "stop"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
import chromadb
|
||||
from PIL import Image
|
||||
@@ -15,6 +17,8 @@ class VectorDatabase():
|
||||
|
||||
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||
self.cache_filtered_ids = self.load_filtered_ids()
|
||||
self.total_count = len(self.cache_filtered_ids)
|
||||
|
||||
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||
@@ -58,3 +62,52 @@ class VectorDatabase():
|
||||
include=['documents', 'metadatas', 'distances']
|
||||
)
|
||||
return results
|
||||
|
||||
def load_filtered_ids(self):
|
||||
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
||||
FILTER_CRITERIA = {
|
||||
"$and": [
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"modality": "image"},
|
||||
]
|
||||
}
|
||||
MAX_LIMIT = 100000
|
||||
|
||||
try:
|
||||
# 获取所有符合条件的 ID
|
||||
all_ids_results = self.collection.get(
|
||||
where=FILTER_CRITERIA,
|
||||
limit=MAX_LIMIT,
|
||||
include=[]
|
||||
)
|
||||
all_matched_ids = all_ids_results['ids']
|
||||
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
||||
return all_matched_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
def random_get_accessories(self):
|
||||
random_single_id = random.choice(self.cache_filtered_ids)
|
||||
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
||||
try:
|
||||
final_results = self.collection.get(
|
||||
ids=[random_single_id],
|
||||
include=["metadatas"] # 你只需要元数据
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
if final_results['ids']:
|
||||
return final_results
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取最终记录时发生错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
||||
print(db.random_get_accessories())
|
||||
|
||||
Reference in New Issue
Block a user