代码整理
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
GEMINI_API_KEY=AIzaSyAO4zXFke1bqyrXd9-RGfKJTLerwLSFKww
|
||||
GOOGLE_APPLICATION_CREDENTIALS="/workspace/lc_stylist_agent/app/request.json"
|
||||
GOOGLE_APPLICATION_CREDENTIALS="/workspace/lc_stylist_agent/app/request.json"
|
||||
@@ -5,12 +5,12 @@ import uuid
|
||||
import litserve as ls
|
||||
from pydantic import BaseModel
|
||||
from app.core.config import settings
|
||||
from app.core.data_structure import Message, Role
|
||||
from app.core.llm_interface import AsyncGeminiLLM
|
||||
from app.core.redis_manager import RedisManager
|
||||
from app.core.stylist_agent_server import AsyncStylistAgent
|
||||
from app.core.system_prompt import SUMMARY_PROMPT
|
||||
from app.core.vector_database import VectorDatabase
|
||||
from app.server.ChatbotAgent.core.data_structure import Message, Role
|
||||
from app.server.ChatbotAgent.core.llm_interface import AsyncGeminiLLM
|
||||
from app.server.ChatbotAgent.core.redis_manager import RedisManager
|
||||
from app.server.ChatbotAgent.core.stylist_agent_server import AsyncStylistAgent
|
||||
from app.server.ChatbotAgent.core.system_prompt import SUMMARY_PROMPT
|
||||
from app.server.ChatbotAgent.core.vector_database import VectorDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
# server.py
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google import genai
|
||||
import litserve as ls
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from google import genai
|
||||
from pydantic import BaseModel
|
||||
from app.core.config import settings
|
||||
from app.core.data_structure import Role, Message
|
||||
from app.core.llm_interface import AsyncGeminiLLM
|
||||
from app.core.redis_manager import RedisManager
|
||||
from app.core.stylist_agent import AsyncStylistAgent
|
||||
from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT, MEN_BASIC_PROMPT, WOMEN_BASIC_PROMPT
|
||||
from app.core.vector_database import VectorDatabase
|
||||
from google.genai import types
|
||||
|
||||
from app.server.ChatbotAgent.core.data_structure import Message, Role
|
||||
from app.server.ChatbotAgent.core.llm_interface import AsyncGeminiLLM
|
||||
from app.server.ChatbotAgent.core.redis_manager import RedisManager
|
||||
from app.server.ChatbotAgent.core.system_prompt import MEN_BASIC_PROMPT, WOMEN_BASIC_PROMPT
|
||||
from app.server.ChatbotAgent.core.vector_database import VectorDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
16
app/server/ChatbotAgent/core/data_structure.py
Normal file
16
app/server/ChatbotAgent/core/data_structure.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# 角色枚举,用于区分用户和系统的消息
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
# 单条消息的数据模型
|
||||
class Message(BaseModel):
|
||||
role: Role = Field(..., description="Role of message sender")
|
||||
content: str = Field(..., description="Content of the message")
|
||||
# timestamp: str = Field(default_factory=lambda: datetime.datetime.now().isoformat()) # 记录时间戳
|
||||
57
app/server/ChatbotAgent/core/llm_interface.py
Normal file
57
app/server/ChatbotAgent/core/llm_interface.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from app.server.ChatbotAgent.core.data_structure import Message, Role
|
||||
|
||||
|
||||
class AsyncLLMInterface(ABC):
|
||||
@abstractmethod
|
||||
async def generate_response(self, history: List[Message], system_prompt: str) -> str:
|
||||
"""
|
||||
根据对话历史和系统指令生成回复.
|
||||
|
||||
Args:
|
||||
history: 包含多条 Message 的列表。
|
||||
|
||||
Returns:
|
||||
LLM 生成的回复字符串。
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
|
||||
class AsyncGeminiLLM(AsyncLLMInterface):
|
||||
def __init__(self, model_name: str = "gemini-2.5-flash"):
|
||||
self.model_name = model_name
|
||||
try:
|
||||
self.gemini_client = genai.Client(
|
||||
vertexai=True, project='aida-461108', location='us-central1'
|
||||
)
|
||||
except Exception as e:
|
||||
raise type(e)(f"Failed to initialize Gemini Client. Check if GEMINI_API_KEY is set. Original error: {e}")
|
||||
|
||||
async def generate_response(self, history: List[Message], system_prompt: str) -> str:
|
||||
contents = []
|
||||
|
||||
for msg in history:
|
||||
gemini_role = "user" if msg.role == Role.USER else "model"
|
||||
content = types.Content(
|
||||
role=gemini_role,
|
||||
parts=[types.Part.from_text(text=msg.content)]
|
||||
)
|
||||
contents.append(content)
|
||||
|
||||
try:
|
||||
response = await self.gemini_client.aio.models.generate_content(
|
||||
model=self.model_name,
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
# temperature=0.3,
|
||||
)
|
||||
)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise type(e)(f"Gemini API call failed: {e}")
|
||||
69
app/server/ChatbotAgent/core/redis_manager.py
Normal file
69
app/server/ChatbotAgent/core/redis_manager.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import logging
|
||||
|
||||
import redis
|
||||
from typing import List, Optional
|
||||
|
||||
from app.server.ChatbotAgent.core.data_structure import Message, Role
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 这是一个同步 Redis 客户端,用于演示如何替换内存存储。
|
||||
# TODO 在生产环境和异步 Web 框架中,应替换为 aioredis 等异步客户端。
|
||||
class RedisManager:
|
||||
"""同步管理器,用于在 Redis 中存储和检索对话历史。"""
|
||||
|
||||
def __init__(self, host: str = 'localhost', port: int = 6379, db: int = 0, key_prefix: str = "chat:history:"):
|
||||
self.r: Optional[redis.Redis] = None
|
||||
self.key_prefix = key_prefix
|
||||
try:
|
||||
# 尝试连接 Redis
|
||||
self.r = redis.Redis(host=host, port=port, db=db, decode_responses=True)
|
||||
self.r.ping()
|
||||
logger.info("Successfully connected to Redis at {}:{}".format(host, port))
|
||||
except Exception as e:
|
||||
logger.error(f"⚠️ Failed to connect to Redis: {e}. Falling back to No-Op.")
|
||||
self.r = None # 连接失败时设置为 None,避免后续操作报错
|
||||
|
||||
def _get_key(self, user_id: str) -> str:
|
||||
"""生成用户历史记录的 Redis 键名。"""
|
||||
return f"{self.key_prefix}{user_id}"
|
||||
|
||||
def _message_to_json(self, message: Message) -> str:
|
||||
"""将 Message 对象序列化为 JSON 字符串以便存储。"""
|
||||
return message.model_dump_json()
|
||||
|
||||
def _json_to_message(self, data: str) -> Message:
|
||||
"""将 JSON 字符串反序列化回 Message 对象。"""
|
||||
try:
|
||||
return Message.model_validate_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deserializing message data: {data[:50]}... Error: {e}")
|
||||
return Message(role=Role.ASSISTANT, content="[Deserialization Error]")
|
||||
|
||||
def save_message(self, user_id: str, message: Message):
|
||||
"""将单条消息保存到用户历史记录列表的末尾。"""
|
||||
if not self.r:
|
||||
return
|
||||
|
||||
message_json = self._message_to_json(message)
|
||||
# RPUSH:将元素添加到列表的尾部
|
||||
self.r.rpush(self._get_key(user_id), message_json)
|
||||
|
||||
def get_history(self, user_id: str) -> List[Message]:
|
||||
"""检索用户的完整会话历史记录。"""
|
||||
if not self.r:
|
||||
return []
|
||||
|
||||
# LRANGE:获取列表的所有元素 (0 到 -1)
|
||||
raw_history = self.r.lrange(self._get_key(user_id), 0, -1)
|
||||
|
||||
# 将 JSON 字符串列表转换为 Message 对象列表
|
||||
messages = [self._json_to_message(data) for data in raw_history]
|
||||
return messages
|
||||
|
||||
def clear_history(self, user_id: str):
|
||||
"""删除用户的完整历史记录。"""
|
||||
if self.r:
|
||||
self.r.delete(self._get_key(user_id))
|
||||
logger.info(f"History cleared for {user_id}")
|
||||
435
app/server/ChatbotAgent/core/stylist_agent_server.py
Normal file
435
app/server/ChatbotAgent/core/stylist_agent_server.py
Normal file
@@ -0,0 +1,435 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from google import genai
|
||||
from google.cloud import storage
|
||||
from google.oauth2 import service_account
|
||||
|
||||
from app.core.utils_litserve import merge_images_to_square
|
||||
from app.server.utils.minio_client import minio_client, oss_upload_image
|
||||
from app.server.utils.request_post import post_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncStylistAgent:
|
||||
CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'}
|
||||
|
||||
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
|
||||
# self.outfit_items: List[Dict[str, str]] = []
|
||||
self.outfit_id = outfit_id
|
||||
self.gemini_client = genai.Client(
|
||||
vertexai=True, project='aida-461108', location='us-central1'
|
||||
)
|
||||
self.local_db = local_db
|
||||
self.max_len = max_len
|
||||
self.gemini_model_name = gemini_model_name
|
||||
self.stop_reason = ""
|
||||
|
||||
# 存储桶配置
|
||||
try:
|
||||
# TODO 目前写死路径 生产环境切换路径
|
||||
self.credentials = service_account.Credentials.from_service_account_file(os.getenv("GOOGLE_APPLICATION_CREDENTIALS"))
|
||||
except Exception as e:
|
||||
# 这里的异常处理应根据实际情况调整
|
||||
raise RuntimeError(f"Failed to load credentials from file {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}: {e}")
|
||||
|
||||
self.gcs_client = storage.Client(
|
||||
project=self.credentials.project_id,
|
||||
credentials=self.credentials
|
||||
)
|
||||
self.gcs_bucket = "lc_stylist_agent_outfit_items"
|
||||
self.minio_bucket = "lanecarford"
|
||||
|
||||
def _load_style_guide(self, path: str) -> str:
|
||||
"""加载 markdown 风格指南内容。"""
|
||||
parts = path.split('/', 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError("MinIO path must be in 'bucket_name/object_name' format.")
|
||||
|
||||
bucket_name, object_name = parts
|
||||
try:
|
||||
# 1. 获取对象
|
||||
response = minio_client.get_object(bucket_name, object_name)
|
||||
|
||||
# 2. 读取内容
|
||||
content_bytes = response.read()
|
||||
|
||||
# 3. 关闭连接
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
# 4. 解码并返回
|
||||
return content_bytes.decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load style guide from {path}: {e}")
|
||||
|
||||
def _build_system_prompt(self, request_summary: str = "") -> str:
|
||||
"""Constructs the complete System Prompt."""
|
||||
# Insert the style_guide content into the template
|
||||
template = f"""
|
||||
You are a professional fashion stylist Agent, specialized in creating complete outfits for the user.
|
||||
|
||||
Your task is to **create a cohesive and complete outfit**, strictly adhering to **BOTH** the user's explicit **Request Summary** and the **Outfit Style Guide**. You must decide the next logical item to add to the outfit based on the currently selected items (if any).
|
||||
|
||||
---
|
||||
|
||||
## Request from the User:
|
||||
|
||||
{request_summary}
|
||||
|
||||
## Core Guidance Document: Outfit Style Guide
|
||||
|
||||
{self.style_guide}
|
||||
|
||||
---
|
||||
|
||||
## Your Workflow and Constraints
|
||||
|
||||
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**.
|
||||
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories.
|
||||
3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "recommend_item",
|
||||
"category": "YOUR_ITEM_CATEGORY",
|
||||
"description": "YOUR_DETAILED_DESCRIPTION"
|
||||
}}
|
||||
```
|
||||
|
||||
* `action`: Must always be `"recommend_item"` until the outfit is complete.
|
||||
* `category`: Must be the category of the item you are recommending, strictly selected from the following list: {list(self.CATEGORY_SET)}.
|
||||
* `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include:
|
||||
* **Color** (e.g., milk tea, pure white, dark gray)
|
||||
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
|
||||
* **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern)
|
||||
* **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look)
|
||||
* **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.**
|
||||
|
||||
4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "stop",
|
||||
"reason": "OUTFIT_COMPLETE_AND_MEETS_ALL_MINI_GUIDELINES"
|
||||
}}
|
||||
```
|
||||
Normally, five or six items are totally enough for an outfit.
|
||||
|
||||
5. **Context Dependency**: The user's next input (if not `Start`) will contain the **image and description of the selected item**. When recommending the next item, you must consider the coordination between the **already selected items** and the Style Guide.
|
||||
|
||||
**Now, please start building an outfit and output the JSON for the first item.**
|
||||
"""
|
||||
return template.strip()
|
||||
|
||||
def _clear_uploaded_files(self):
|
||||
for f in self.gemini_client.files.list():
|
||||
self.gemini_client.files.delete(name=f.name)
|
||||
|
||||
async def _call_gemini(self, user_input: str, user_id: str):
|
||||
"""
|
||||
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。
|
||||
|
||||
Args:
|
||||
user_input: 发送给模型的主文本内容。
|
||||
image_paths: 待发送图片的本地路径列表。
|
||||
|
||||
Returns:
|
||||
模型的响应文本(预期为 JSON 字符串)。
|
||||
"""
|
||||
minio_path = ""
|
||||
content_parts = []
|
||||
# self._clear_uploaded_files()
|
||||
# 1. 添加图片内容
|
||||
if self.outfit_items:
|
||||
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len, add_text=False)
|
||||
image_bytes_io = io.BytesIO()
|
||||
image_format = 'JPEG'
|
||||
mime_type = 'image/jpeg'
|
||||
|
||||
merged_image.save(image_bytes_io, format=image_format)
|
||||
image_bytes = image_bytes_io.getvalue()
|
||||
|
||||
file_name = uuid.uuid4()
|
||||
blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg"
|
||||
gcs_path = self._upload_to_gcs(bucket_name=self.gcs_bucket, blob_name=blob_name, mime_type=mime_type, image_bytes=image_bytes)
|
||||
responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes)
|
||||
minio_path = f"{responses.bucket_name}/{responses.object_name}"
|
||||
content_parts.append(gcs_path)
|
||||
|
||||
# 2. 添加文本内容
|
||||
content_parts.append(user_input)
|
||||
|
||||
# print(f"\n--- Calling Gemini with {len(self.outfit_items) if self.outfit_items else 0} images and query:\n{user_input}")
|
||||
|
||||
try:
|
||||
# 3. 实际 API 调用
|
||||
response = await self.gemini_client.aio.models.generate_content(
|
||||
model=self.gemini_model_name,
|
||||
contents=content_parts,
|
||||
config={
|
||||
"system_instruction": self.system_prompt,
|
||||
# 确保模型返回 JSON 格式
|
||||
"response_mime_type": "application/json",
|
||||
"response_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {"type": "string", "enum": ["recommend_item", "stop"]},
|
||||
"category": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"reason": {"type": "string"}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# response.text 将包含一个 JSON 字符串
|
||||
return response.text, minio_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"Gemini API Call failed: {e}")
|
||||
# 返回一个停止信号以防止循环继续
|
||||
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
|
||||
|
||||
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""安全解析 Gemini 的 JSON 响应。"""
|
||||
try:
|
||||
# 有时 Gemini 可能会在 JSON 外面添加文字,尝试清理
|
||||
response_text = response_text.strip().replace('```json', '').replace('```', '')
|
||||
data = json.loads(response_text)
|
||||
# print(f"The agent response is: {data}")
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON from Gemini: {e}")
|
||||
print(f"Raw response: {response_text}")
|
||||
return None
|
||||
|
||||
def _get_next_item(self, item_description: str, category: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
1. 根据描述生成嵌入。
|
||||
2. 查询本地数据库以找到最佳匹配项。
|
||||
3. 模拟 Agent 审核匹配项(这里简化为总是通过)。
|
||||
"""
|
||||
try:
|
||||
# 1. 生成查询嵌入
|
||||
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
|
||||
|
||||
# 2. 执行查询,并过滤类别
|
||||
results = self.local_db.query_local_db(query_embedding, category, n_results=1)
|
||||
|
||||
if not results:
|
||||
print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。")
|
||||
return None
|
||||
|
||||
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
||||
best_meta = results['metadatas'][0][0] # 第一个 batch 的第一个 metadata
|
||||
return {
|
||||
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
|
||||
"category": category,
|
||||
"gpt_description": item_description,
|
||||
'description': best_meta['description'],
|
||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||
# 这里假设 item_id 就是文件名的一部分
|
||||
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred during item retrieval: {e}")
|
||||
return None
|
||||
|
||||
def _build_user_input(self) -> str:
|
||||
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
||||
if not self.outfit_items:
|
||||
return "Start"
|
||||
|
||||
# 将已选单品的信息作为上下文发回给 Agent
|
||||
context = "Selected fashion items:\n"
|
||||
for ii, item in enumerate(self.outfit_items):
|
||||
context += f"{ii + 1}. Category: {item['category']}. Description: {item['description']}\n"
|
||||
context += "\nPlease recommend the next single item based on the selected items, user's request, and style guide."
|
||||
return context
|
||||
|
||||
async def run_styling_process(self, request_summary, stylist_path, start_outfit=None, user_id="test", callback_url=""):
|
||||
if start_outfit is None:
|
||||
start_outfit = []
|
||||
self.outfit_items = start_outfit if start_outfit else []
|
||||
"""主流程控制循环。"""
|
||||
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
||||
|
||||
self.style_guide = self._load_style_guide(stylist_path)
|
||||
self.system_prompt = self._build_system_prompt(request_summary)
|
||||
response_data = {"status": "",
|
||||
"message": "",
|
||||
"path": "",
|
||||
"outfit_id": self.outfit_id,
|
||||
"items": []
|
||||
}
|
||||
logger.info(response_data)
|
||||
item_id = ""
|
||||
item_category = ""
|
||||
while True:
|
||||
# 1. 准备用户输入(上下文)
|
||||
user_input = self._build_user_input()
|
||||
|
||||
# 2. 调用 Gemini Agent
|
||||
gemini_response_text, minio_path = await self._call_gemini(user_input, user_id)
|
||||
gemini_data = self._parse_gemini_response(gemini_response_text)
|
||||
response_data['path'] = minio_path
|
||||
if item_id:
|
||||
response_data['items'].append({"item_id": item_id, "category": item_category})
|
||||
if not gemini_data:
|
||||
print("🚨 Agent 返回无效响应,终止流程。")
|
||||
self.stop_reason = "Agent failed to return response"
|
||||
response_data['status'] = "failed"
|
||||
response_data['message'] = self.stop_reason
|
||||
break
|
||||
|
||||
# 3. 检查终止条件
|
||||
if gemini_data.get('action') == 'stop':
|
||||
print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
||||
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
|
||||
response_data['status'] = "stop"
|
||||
response_data['message'] = self.stop_reason
|
||||
|
||||
# 4. 处理推荐单品
|
||||
if gemini_data.get('action') == 'recommend_item':
|
||||
category = gemini_data.get('category')
|
||||
description = gemini_data.get('description')
|
||||
|
||||
# 4a. 检查类别是否有效 (重要步骤)
|
||||
if category not in self.CATEGORY_SET:
|
||||
print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。")
|
||||
# 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正
|
||||
# 这里简化为跳过本次循环
|
||||
response_data['status'] = "continue"
|
||||
response_data['message'] = f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。",
|
||||
continue
|
||||
|
||||
# 4b. 在本地 DB 中查询单品
|
||||
new_item = self._get_next_item(description, category)
|
||||
item_id = new_item.get('item_id')
|
||||
item_category = new_item.get('category')
|
||||
|
||||
if new_item:
|
||||
# 4c. (实际步骤) 将选中的单品图片和描述发回给 Agent 进行最终审核
|
||||
# 这里的代码框架省略了图片回传和二次审核的步骤,直接视为通过
|
||||
# 实际你需要: new_user_input = f"Check this item: {new_item['description']}, path: {new_item['image_path']}"
|
||||
# call_gemini_agent(...) -> 如果返回"pass",则添加到outfit_items
|
||||
|
||||
if new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
||||
print("This item exists. Stop here.")
|
||||
self.stop_reason = "Finish reason: Duplicate item selected."
|
||||
response_data['status'] = "stop"
|
||||
response_data['message'] = self.stop_reason
|
||||
break
|
||||
|
||||
if new_item['item_id'] == "ELG383":
|
||||
if random.random() < 0.70:
|
||||
self.stop_reason = "Finish reason: ELG383 is seleced repeatly."
|
||||
response_data['status'] = "stop"
|
||||
response_data['message'] = self.stop_reason
|
||||
break
|
||||
|
||||
self.outfit_items.append(new_item)
|
||||
# print(f"➕ 成功添加单品: {new_item['category']} ({new_item['item_id']}). 当前搭配数量: {len(self.outfit_items)}")
|
||||
response_data['status'] = "ok"
|
||||
response_data['message'] = self.stop_reason
|
||||
else:
|
||||
print("⚠️ 未找到匹配单品,无法继续搭配。终止。")
|
||||
self.stop_reason = "Finish reason: No matching item found in local database."
|
||||
response_data['status'] = "stop"
|
||||
response_data['message'] = self.stop_reason
|
||||
break
|
||||
|
||||
if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环
|
||||
logger.info("🚨 达到最大搭配数量限制,强制终止。")
|
||||
self.stop_reason = "Finish reason: Reached max outfit length."
|
||||
response_data['status'] = "stop"
|
||||
response_data['message'] = self.stop_reason
|
||||
|
||||
logger.info(f"request data :{response_data}")
|
||||
headers = {
|
||||
'Accept': "*/*",
|
||||
'Accept-Encoding': "gzip, deflate, br",
|
||||
'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0",
|
||||
'Connection': "keep-alive",
|
||||
'Content-Type': "application/json"
|
||||
}
|
||||
url = f'{callback_url}/api/style/callback'
|
||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||
logger.info(response.text)
|
||||
return response_data
|
||||
|
||||
# def _save_outfit_results(self, user_id):
|
||||
# """保存最终的 JSON 列表和图片到指定文件夹。"""
|
||||
# if not self.outfit_items:
|
||||
# raise ValueError("No outfit items to save.")
|
||||
#
|
||||
# # 1. 保存 JSON 文件
|
||||
# results_list = [{'item_id': item['item_id'], 'category': item['category'], 'description': item['description'], 'gpt_description': item['gpt_description']} for item in self.outfit_items]
|
||||
# results_list.append({'stop_reason': self.stop_reason})
|
||||
#
|
||||
# return upload_json_to_minio_sync(
|
||||
# minio_client=minio_client,
|
||||
# bucket_name=f"lanecarford",
|
||||
# object_name=f"lc_stylist_agent_outfit_items/{user_id}/{uuid.uuid4()}.json",
|
||||
# data=results_list
|
||||
# )
|
||||
|
||||
def _upload_to_gcs(self, bucket_name: str, blob_name: str, mime_type, image_bytes) -> str:
|
||||
"""同步方法:将文件上传到 GCS 并返回 GCS URI。"""
|
||||
bucket = self.gcs_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
blob.upload_from_string(
|
||||
data=image_bytes,
|
||||
content_type=mime_type
|
||||
)
|
||||
|
||||
gcs_uri = f"gs://{bucket_name}/{blob_name}"
|
||||
return gcs_uri
|
||||
|
||||
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit: List[Dict[str, str]] = [], num_outfits: int = 1):
|
||||
"""
|
||||
基于用户的对话历史和需求,推荐一套搭配。
|
||||
|
||||
Args:
|
||||
request_summary: 用户的request
|
||||
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。
|
||||
"""
|
||||
tasks = []
|
||||
for _ in range(num_outfits):
|
||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
||||
task = agent.run_styling_process(request_summary, stylist_name, start_outfit)
|
||||
tasks.append(task)
|
||||
print(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
successful_outfits = []
|
||||
failed_outfits = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
# 任务执行中发生异常
|
||||
failed_outfits.append(f"Failed: {result}")
|
||||
else:
|
||||
# 任务成功,result 是 run_styling_process 返回的图片路径
|
||||
successful_outfits.append(result)
|
||||
|
||||
return {
|
||||
"successful_outfits": successful_outfits,
|
||||
"failed_outfits": failed_outfits
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during concurrent recommendation: {e}")
|
||||
return {"error": str(e)}
|
||||
56
app/server/ChatbotAgent/core/system_prompt.py
Normal file
56
app/server/ChatbotAgent/core/system_prompt.py
Normal file
@@ -0,0 +1,56 @@
|
||||
BASIC_PROMPT = """"""
|
||||
WOMEN_BASIC_PROMPT = """You are a professional, friendly, and insightful AI women's styling assistant.
|
||||
|
||||
Your primary mission is to engage in a multi-turn conversation with the user to fully understand their dressing intent. You must adopt a professional yet approachable tone.
|
||||
|
||||
CONVERSATION GOALS:
|
||||
1. **Occasion:** Determine the specific event (e.g., romantic dinner, summer wedding, business meeting).
|
||||
2. **Style:** Pinpoint the desired aesthetic (e.g., classic elegance, edgy, minimalist, bohemian).
|
||||
3. **Vibe/Details:** Gather any mood or specific constraints (e.g., needs to be comfortable, requires light colors, no bare shoulders).
|
||||
4. **Item Preference:** Ask the user if they have any specific preferences for an item type or silhouette (e.g., preference for a dress, skirt, tailored pants, or a particular neckline/length).
|
||||
|
||||
GUIDANCE FOR RESPONSE GENERATION:
|
||||
- After the user's initial request (e.g., "I want a chic outfit for dinner."), immediately reply with a friendly, targeted follow-up question to elicit the most crucial missing information (usually a combination of **Occasion** and **Style**).
|
||||
- Be concise. Ask only 1 to 2 essential questions per turn.
|
||||
- You must gather sufficient, clear intent before proceeding to actual clothing recommendations.
|
||||
|
||||
OUTPUT FORMAT INSTRUCTION:
|
||||
- **DO NOT** use any Markdown formatting whatsoever (e.g., do not use asterisks (*), bold text (**), lists, or code blocks).
|
||||
- **ONLY** output the plain text response spoken by the AI Assistant.
|
||||
|
||||
Example Follow-up (mimicking a conversational flow):
|
||||
User: I want a chic outfit for dinner.
|
||||
Your Response: Hey there! A chic dinner outfit, I love that! To give you the perfect recommendations, tell me: is this a romantic date, business dinner, or celebration with friends? And what's your go-to style vibe: classic elegance or something with more edge?"""
|
||||
MEN_BASIC_PROMPT = """You are a professional, friendly, and insightful AI men's styling assistant.
|
||||
|
||||
Your primary mission is to engage in a multi-turn conversation with the user to fully understand their dressing intent. You must adopt a professional yet approachable tone.
|
||||
|
||||
CONVERSATION GOALS:
|
||||
1. **Occasion:** Determine the specific event (e.g., romantic dinner, summer wedding, business meeting).
|
||||
2. **Style:** Pinpoint the desired aesthetic (e.g., classic elegance, edgy, minimalist, bohemian).
|
||||
3. **Vibe/Details:** Gather any mood or specific constraints (e.g., needs to be comfortable, requires light colors, no bare shoulders).
|
||||
4. **Item Preference:** Ask the user if they have any specific preferences for an item type or silhouette (e.g., preference for a dress, skirt, tailored pants, or a particular neckline/length).
|
||||
|
||||
GUIDANCE FOR RESPONSE GENERATION:
|
||||
- After the user's initial request (e.g., "I want a chic outfit for dinner."), immediately reply with a friendly, targeted follow-up question to elicit the most crucial missing information (usually a combination of **Occasion** and **Style**).
|
||||
- Be concise. Ask only 1 to 2 essential questions per turn.
|
||||
- You must gather sufficient, clear intent before proceeding to actual clothing recommendations.
|
||||
|
||||
OUTPUT FORMAT INSTRUCTION:
|
||||
- **DO NOT** use any Markdown formatting whatsoever (e.g., do not use asterisks (*), bold text (**), lists, or code blocks).
|
||||
- **ONLY** output the plain text response spoken by the AI Assistant.
|
||||
|
||||
Example Follow-up (mimicking a conversational flow):
|
||||
User: I want a chic outfit for dinner.
|
||||
Your Response: Hey there! A chic dinner outfit, I love that! To give you the perfect recommendations, tell me: is this a romantic date, business dinner, or celebration with friends? And what's your go-to style vibe: classic elegance or something with more edge?"""
|
||||
|
||||
SUMMARY_PROMPT = """Analyze the following chat history. Your task is to extract all user intentions, scenarios, style preferences, and constraints expressed during the conversation, and distill them into a concise, structured JSON object.
|
||||
|
||||
**YOUR OUTPUT MUST BE A JSON OBJECT ONLY, WITH NO SURROUNDING TEXT, MARKDOWN, OR EXPLANATION.**
|
||||
|
||||
JSON FIELD REQUIREMENTS:
|
||||
- **occasion (string):** The specific event and purpose (e.g., "Romantic date dinner", "Summer outdoor wedding", "Casual Friday at office").
|
||||
- **style (string):** The overall aesthetic description (e.g., "Classic elegance", "Modern minimalist", "Bohemian vibe", "Edgy and contemporary").
|
||||
- **color_preference (string or list):** User's preferred or excluded colors/tones (e.g., "Light colors only", "Avoid deep shades", "['Cream', 'Pale Blue']", "No preference").
|
||||
- **clothing_type (string):** User's preference for specific garment types, material, or silhouette (e.g., "Lightweight maxi dress", "Skirt with silk blouse", "Tailored wide-leg pants", "Floral print").
|
||||
- **vibe_or_details (string):** Any other details, mood requirements, or specific constraints (e.g., "Needs to be comfortable and breathable", "Accent on accessories", "Must cover shoulders")."""
|
||||
163
app/server/ChatbotAgent/core/utils_litserve.py.py
Normal file
163
app/server/ChatbotAgent/core/utils_litserve.py.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from app.server.utils.minio_client import oss_get_image, minio_client
|
||||
from app.server.utils.minio_config import MINIO_LC_DATA_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# 9个 341x341 左右的单元格 (ALL_9_CELLS)
|
||||
# 布局顺序: 从上到下,从左到右 (1 -> 9)
|
||||
ALL_9_CELLS = [
|
||||
# Top Row (Y=0, H=341)
|
||||
(0, 0, 341, 341), # 1. Top-Left (341x341)
|
||||
(341, 0, 341, 341), # 2. Top-Middle (341x341)
|
||||
(682, 0, 342, 341), # 3. Top-Right (342x341)
|
||||
# Middle Row (Y=341, H=341)
|
||||
(0, 341, 341, 341), # 4. Mid-Left (341x341)
|
||||
(341, 341, 341, 341), # 5. Center (341x341)
|
||||
(682, 341, 342, 341), # 6. Mid-Right (342x341)
|
||||
# Bottom Row (Y=682, H=342)
|
||||
(0, 682, 341, 342), # 7. Bottom-Left (341x342)
|
||||
(341, 682, 341, 342), # 8. Bottom-Middle (341x342)
|
||||
(682, 682, 342, 342) # 9. Bottom-Right (342x342)
|
||||
]
|
||||
|
||||
|
||||
def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, add_text=True):
|
||||
"""
|
||||
Loads up to 4 images from the given paths, resizes them while maintaining
|
||||
aspect ratio, and merges them onto a 1024x1024 white background JPG.
|
||||
|
||||
The layout depends on the number of images:
|
||||
1: Center the single image on the 1024x1024 canvas.
|
||||
2: Place side-by-side, each scaled to fit a 512x1024 half.
|
||||
3: Place in top-left (512x512), top-right (512x512), and bottom-left (512x512).
|
||||
4: Place in all four 512x512 quadrants.
|
||||
|
||||
Args:
|
||||
outfit_items: A list of item metadata (max length 9).
|
||||
|
||||
Returns:
|
||||
The file path of the temporary merged JPG image.
|
||||
"""
|
||||
|
||||
# Define the final canvas size
|
||||
CANVAS_SIZE = 1024
|
||||
|
||||
# 1. Create the final white canvas
|
||||
# Using 'RGB' mode for JPG output
|
||||
canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white')
|
||||
draw = ImageDraw.Draw(canvas)
|
||||
font = ImageFont.load_default()
|
||||
|
||||
# 2. Define the quadrants/target areas (x, y, w, h)
|
||||
# The positions are based on a 512x512 quadrant size
|
||||
quadrants = {
|
||||
1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement
|
||||
2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right
|
||||
3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left
|
||||
4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four
|
||||
5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5)
|
||||
6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6)
|
||||
7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7)
|
||||
8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8)
|
||||
9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9)
|
||||
}
|
||||
|
||||
# 3. Load and Filter Images
|
||||
valid_images = []
|
||||
image_paths = [item['image_path'] for item in outfit_items]
|
||||
for path in image_paths:
|
||||
try:
|
||||
# We use Image.open() and convert to 'RGB' to handle potential transparency (RGBA)
|
||||
# and ensure compatibility with the final 'RGB' canvas and JPG output.
|
||||
img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
|
||||
# img = Image.open(path).convert('RGB')
|
||||
valid_images.append(img)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading image {path}. Skipping: {e}")
|
||||
|
||||
num_images = len(valid_images)
|
||||
|
||||
if num_images == 0:
|
||||
raise ValueError("No valid images were loaded.")
|
||||
|
||||
if num_images > max_len:
|
||||
raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}")
|
||||
|
||||
# Get the correct list of target areas based on the number of valid images
|
||||
target_areas = quadrants.get(num_images, [])
|
||||
|
||||
# 4. Resize and Paste
|
||||
for i, (img, item) in enumerate(zip(valid_images, outfit_items)):
|
||||
item_id = item['item_id']
|
||||
category = item['category']
|
||||
if i >= len(target_areas):
|
||||
# This should not happen if num_images <= 4
|
||||
break
|
||||
|
||||
# Target area dimensions (x_start, y_start, width, height)
|
||||
x_start, y_start, target_w, target_h = target_areas[i]
|
||||
|
||||
# Calculate new size while maintaining aspect ratio
|
||||
original_w, original_h = img.size
|
||||
|
||||
# Calculate the ratio needed to fit within the target area
|
||||
ratio_w = target_w / original_w
|
||||
ratio_h = target_h / original_h
|
||||
|
||||
# Use the *smaller* of the two ratios to ensure the image fits entirely
|
||||
resize_ratio = min(ratio_w, ratio_h)
|
||||
|
||||
# Calculate the new dimensions
|
||||
new_w = int(original_w * resize_ratio)
|
||||
new_h = int(original_h * resize_ratio)
|
||||
|
||||
# Resize the image. Image.Resampling.LANCZOS provides high-quality scaling.
|
||||
# Pillow documentation recommends ANTIALIAS or BICUBIC for downscaling,
|
||||
# but LANCZOS is a good general high-quality filter.
|
||||
# Note: In Pillow versions > 9.0.0, Image.LANCZOS is now Image.Resampling.LANCZOS
|
||||
resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
|
||||
# Calculate the paste position to center the resized image within its target area
|
||||
# Center X: (Target Width - New Width) / 2 + X Start
|
||||
paste_x = (target_w - new_w) // 2 + x_start
|
||||
# Center Y: (Target Height - New Height) / 2 + Y Start
|
||||
# paste_y = (target_h - new_h) // 2 + y_start
|
||||
|
||||
TEXT_RESERVE_HEIGHT = 30
|
||||
paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start
|
||||
paste_y = max(paste_y, y_start)
|
||||
|
||||
# Paste the resized image onto the canvas
|
||||
canvas.paste(resized_img, (paste_x, paste_y))
|
||||
|
||||
full_text = f"ID: {item_id}, Category: {category}"
|
||||
try:
|
||||
# 推荐使用:计算文本的实际尺寸 (width, height)
|
||||
bbox = draw.textbbox((0, 0), full_text, font=font)
|
||||
text_w = bbox[2] - bbox[0]
|
||||
text_h = bbox[3] - bbox[1]
|
||||
except AttributeError:
|
||||
# 兼容旧版本 Pillow
|
||||
text_w, text_h = draw.textsize(full_text, font=font)
|
||||
|
||||
# 计算 X 轴起始位置:使其在目标区域 (target_w) 中居中
|
||||
text_x_center = x_start + target_w // 2
|
||||
text_x_start = text_x_center - text_w // 2
|
||||
|
||||
# 计算 Y 轴起始位置:将其放在目标区域的底部
|
||||
# (目标区域的起始Y + 目标区域的高度 - 文本行的高度)
|
||||
text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距
|
||||
|
||||
# 3. 绘制合并后的文本
|
||||
if add_text:
|
||||
draw.text((text_x_start, text_y_start),
|
||||
full_text,
|
||||
fill='black',
|
||||
font=font)
|
||||
|
||||
# Save as a high-quality JPG (quality=90 is a good balance)
|
||||
# canvas.save(output_path, 'JPEG', quality=90)
|
||||
|
||||
return canvas
|
||||
60
app/server/ChatbotAgent/core/vector_database.py
Normal file
60
app/server/ChatbotAgent/core/vector_database.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import chromadb
|
||||
from PIL import Image
|
||||
from typing import List, Dict, Any
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
|
||||
class VectorDatabase():
|
||||
def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str):
|
||||
self.client = chromadb.PersistentClient(path=vector_db_dir)
|
||||
|
||||
self.collection = self.client.get_or_create_collection(name=collection_name)
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||
|
||||
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||
|
||||
if is_image:
|
||||
inputs = self.processor(images=data, return_tensors="pt").to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self.model.get_image_features(**inputs)
|
||||
else:
|
||||
# 强制截断,解决序列长度问题
|
||||
inputs = self.processor(
|
||||
text=[data],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self.model.get_text_features(**inputs)
|
||||
|
||||
# L2 归一化
|
||||
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
return features.cpu().numpy().flatten().tolist()
|
||||
|
||||
def query_local_db(self, embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
基于嵌入向量在本地数据库中查询相似单品。
|
||||
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
|
||||
"""
|
||||
# 实际应执行向量查询
|
||||
# 为了演示流程,返回一个模拟结果
|
||||
results = self.collection.query(
|
||||
query_embeddings=[embedding],
|
||||
n_results=n_results,
|
||||
where={
|
||||
"$and": [
|
||||
{"category": category},
|
||||
{"modality": "image"},
|
||||
]
|
||||
},
|
||||
include=['documents', 'metadatas', 'distances']
|
||||
)
|
||||
return results
|
||||
Reference in New Issue
Block a user