first commit
This commit is contained in:
30
main.py
Normal file
30
main.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from src.routers import chat
|
||||||
|
|
||||||
|
app_server = FastAPI(
|
||||||
|
title="Gemini Furniture Designer API",
|
||||||
|
description="基于 LangGraph + Gemini 2.0 Flash 的家具设计 Agent 接口",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置跨域,方便前端调用
|
||||||
|
app_server.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 包含路由
|
||||||
|
app_server.include_router(chat.router)
|
||||||
|
|
||||||
|
|
||||||
|
@app_server.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "Furniture Design Agent API is running."}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run("main:app_server", host="0.0.0.0", port=7777, reload=True)
|
||||||
18
pyproject.toml
Normal file
18
pyproject.toml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
[project]
|
||||||
|
name = "FiDA"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.128.0",
|
||||||
|
"langchain-core>=1.2.8",
|
||||||
|
"langchain-google-genai>=4.2.0",
|
||||||
|
"langgraph>=1.0.7",
|
||||||
|
"langgraph-checkpoint-mongodb>=0.3.1",
|
||||||
|
"motor>=3.7.1",
|
||||||
|
"pydantic>=2.12.5",
|
||||||
|
"pydantic-settings>=2.12.0",
|
||||||
|
"pymongo[srv]>=4.15.5",
|
||||||
|
"python-dotenv>=1.2.1",
|
||||||
|
"uvicorn>=0.40.0",
|
||||||
|
]
|
||||||
2
src/__init__.py
Normal file
2
src/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
"""furniture_design_agent 源码包。"""
|
||||||
|
__version__ = "0.1.0"
|
||||||
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
26
src/core/config.py
Normal file
26
src/core/config.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""
|
||||||
|
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
|
||||||
|
"""
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file='.env',
|
||||||
|
env_file_encoding='utf-8',
|
||||||
|
extra='ignore' # 忽略环境变量中多余的键
|
||||||
|
)
|
||||||
|
# --- google api 配置信息 ---
|
||||||
|
GOOGLE_GENAI_USE_VERTEXAI: str = Field(default="", description="")
|
||||||
|
GOOGLE_API_KEY: str = Field(default="", description="")
|
||||||
|
|
||||||
|
# --- mongodb配置信息 ---
|
||||||
|
MONGODB_USERNAME: str = Field(default="", description="")
|
||||||
|
MONGODB_PASSWORD: str = Field(default="", description="")
|
||||||
|
MONGODB_HOST: str = Field(default="localhost", description="")
|
||||||
|
MONGODB_PORT: int = Field(default=27017, description="")
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
MONGO_URI = f"mongodb://{settings.MONGODB_USERNAME}:{settings.MONGODB_PASSWORD}@{settings.MONGODB_HOST}:{settings.MONGODB_PORT}"
|
||||||
0
src/routers/__init__.py
Normal file
0
src/routers/__init__.py
Normal file
163
src/routers/chat.py
Normal file
163
src/routers/chat.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||||
|
from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/stream")
|
||||||
|
async def chat_stream(request: ChatRequest):
|
||||||
|
"""
|
||||||
|
### 家具设计流式对话接口 (SSE)
|
||||||
|
|
||||||
|
通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。
|
||||||
|
|
||||||
|
#### 1. 核心功能
|
||||||
|
* **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。
|
||||||
|
* **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。
|
||||||
|
* **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。
|
||||||
|
|
||||||
|
#### 2. 请求参数
|
||||||
|
* `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。
|
||||||
|
* `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。
|
||||||
|
* `checkpoint_id`: (可选) 历史快照 ID。
|
||||||
|
|
||||||
|
#### 3. 响应流说明 (Data Format)
|
||||||
|
响应以 `data: ` 开头的 JSON 字符串流形式发送:
|
||||||
|
- **Session Start**: `{"thread_id": "...", "status": "start"}`
|
||||||
|
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
|
||||||
|
- **Session End**: `{"status": "end"}`
|
||||||
|
|
||||||
|
#### 4. 请求示例
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"message": "设计一款北欧风格的躺椅."
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
"message": "就以上信息直接生成sketch.",
|
||||||
|
"thread_id": "187e58af"
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
"message": "不要躺椅,要桌子",
|
||||||
|
"thread_id": "187e58af",
|
||||||
|
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
source_thread_id = request.thread_id
|
||||||
|
checkpoint_id = request.checkpoint_id
|
||||||
|
|
||||||
|
# 1. 确定目标 thread_id
|
||||||
|
# 如果是回溯操作,我们生成一个新的 ID,或者由前端传入一个新的 target_thread_id
|
||||||
|
is_branching = source_thread_id and checkpoint_id
|
||||||
|
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||||
|
|
||||||
|
# 2. 如果是分叉请求,我们需要先“搬家”状态
|
||||||
|
if is_branching:
|
||||||
|
# 获取旧状态
|
||||||
|
source_config = {"configurable": {"thread_id": source_thread_id, "checkpoint_id": checkpoint_id}}
|
||||||
|
older_state = await app.aget_state(source_config)
|
||||||
|
|
||||||
|
# 将旧状态的消息,作为新 thread 的初始值注入
|
||||||
|
# 注意:这里我们手动把旧消息塞给新 thread
|
||||||
|
new_config = {"configurable": {"thread_id": target_thread_id}}
|
||||||
|
await app.aupdate_state(new_config, older_state.values)
|
||||||
|
|
||||||
|
# 现在的 config 指向新 Thread
|
||||||
|
current_config = new_config
|
||||||
|
else:
|
||||||
|
current_config = {"configurable": {"thread_id": target_thread_id}}
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
# 告诉前端:现在是在哪个 Thread 上工作(如果是分叉,前端需要更新本地存储的 ID)
|
||||||
|
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
async for event in app.astream(
|
||||||
|
{"messages": [HumanMessage(content=request.message)]},
|
||||||
|
current_config,
|
||||||
|
stream_mode="updates"
|
||||||
|
):
|
||||||
|
# ... 发送流式内容的逻辑保持不变 ...
|
||||||
|
for node_name, output in event.items():
|
||||||
|
if "messages" in output:
|
||||||
|
msg = output["messages"][-1]
|
||||||
|
state = await app.aget_state(current_config)
|
||||||
|
yield f"data: {json.dumps({'node': node_name, 'content': msg.content, 'checkpoint_id': state.config['configurable']['checkpoint_id']}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history/{thread_id}", response_model=HistoryResponse)
|
||||||
|
async def get_chat_history(thread_id: str):
|
||||||
|
"""
|
||||||
|
### 获取项目设计历史记录
|
||||||
|
|
||||||
|
此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。
|
||||||
|
|
||||||
|
#### 1. 功能说明
|
||||||
|
* **快照列表**: 返回该项目从启动至今的所有关键节点(Checkpoints)。
|
||||||
|
* **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。
|
||||||
|
* **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。
|
||||||
|
|
||||||
|
#### 2. 路径参数
|
||||||
|
* `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。
|
||||||
|
|
||||||
|
#### 3. 返回字段定义
|
||||||
|
* `thread_id`: 当前查询的项目ID。
|
||||||
|
* `history`: 历史记录数组,包含:
|
||||||
|
- `checkpoint_id`: 必填,回溯时使用的关键凭证。
|
||||||
|
- `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。
|
||||||
|
- `node`: 产生该快照的节点名称(如 Designer, Visualizer)。
|
||||||
|
- `timestamp`: 逻辑步骤序号。
|
||||||
|
|
||||||
|
#### 4. 响应示例
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"thread_id": "proj_001",
|
||||||
|
"history": [
|
||||||
|
{
|
||||||
|
"checkpoint_id": "d82f3a12",
|
||||||
|
"last_message": "我想设计一款北欧风书架",
|
||||||
|
"node": "Supervisor",
|
||||||
|
"timestamp": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"checkpoint_id": "f4k92m1a",
|
||||||
|
"last_message": "建议使用浅色橡木材质,增加简约感...",
|
||||||
|
"node": "Designer",
|
||||||
|
"timestamp": 2
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
config = {"configurable": {"thread_id": thread_id}}
|
||||||
|
history_data = []
|
||||||
|
async for state in app.aget_state_history(config):
|
||||||
|
msg_content = "Initial"
|
||||||
|
if state.values and "messages" in state.values:
|
||||||
|
msgs = state.values["messages"]
|
||||||
|
if msgs and len(msgs) > 0:
|
||||||
|
last_msg = msgs[-1]
|
||||||
|
# 获取内容并做摘要截断
|
||||||
|
content = getattr(last_msg, "content", str(last_msg))
|
||||||
|
msg_content = content[:50] + ("..." if len(content) > 50 else "")
|
||||||
|
|
||||||
|
history_data.append(HistoryItem(
|
||||||
|
checkpoint_id=state.config["configurable"]["checkpoint_id"],
|
||||||
|
last_message=msg_content[:50],
|
||||||
|
node=state.metadata.get("source"),
|
||||||
|
timestamp=state.metadata.get("step")
|
||||||
|
))
|
||||||
|
|
||||||
|
return HistoryResponse(thread_id=thread_id, history=history_data)
|
||||||
|
# try:
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# raise HTTPException(status_code=404, detail=f"History not found: {str(e)}")
|
||||||
0
src/schemas/__init__.py
Normal file
0
src/schemas/__init__.py
Normal file
26
src/schemas/chat.py
Normal file
26
src/schemas/chat.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
message: str = Field(..., description="用户的输入指令")
|
||||||
|
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
||||||
|
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryItem(BaseModel):
|
||||||
|
checkpoint_id: str
|
||||||
|
last_message: str
|
||||||
|
node: Optional[str]
|
||||||
|
timestamp: Any
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryResponse(BaseModel):
|
||||||
|
thread_id: str
|
||||||
|
history: List[HistoryItem]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamChunk(BaseModel):
|
||||||
|
node: str
|
||||||
|
content: str
|
||||||
|
checkpoint_id: str
|
||||||
0
src/server/__init__.py
Normal file
0
src/server/__init__.py
Normal file
0
src/server/agent/__init__.py
Normal file
0
src/server/agent/__init__.py
Normal file
86
src/server/agent/agents.py
Normal file
86
src/server/agent/agents.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from src.server.agent.state import AgentState
|
||||||
|
from src.server.agent.tools import generate_2025_report_tool, generate_furniture_sketch
|
||||||
|
from src.server.agent.config_loader import get_agent_prompt
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
creds = service_account.Credentials.from_service_account_file(
|
||||||
|
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
# 初始化 Gemini 模型 (使用 Flash 以保证速度)
|
||||||
|
llm = ChatGoogleGenerativeAI(
|
||||||
|
model="gemini-2.0-flash", temperature=0.5, credentials=creds,
|
||||||
|
project="aida-461108", location='us-central1', vertexai=True, api_key=settings.GOOGLE_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 1. Designer Agent (设计顾问) ---
|
||||||
|
def designer_node(state: AgentState):
|
||||||
|
"""负责细化设计需求,提供专业参数"""
|
||||||
|
messages = state["messages"]
|
||||||
|
system_text = get_agent_prompt("designer") or """
|
||||||
|
你是一位资深的家具设计师。你的职责是:
|
||||||
|
1. 从用户的模糊描述中提取或补充具体的设计参数(尺寸、材质、人体工学数据)。
|
||||||
|
2. 如果用户想画图,不要直接画,而是先描述清楚细节,然后让 Visualizer 去画。
|
||||||
|
请以专业的口吻回复。
|
||||||
|
"""
|
||||||
|
system_prompt = SystemMessage(content=system_text)
|
||||||
|
response = llm.invoke([system_prompt] + messages)
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
|
||||||
|
# --- 2. Researcher Agent (情报专家) ---
|
||||||
|
def researcher_node(state: AgentState):
|
||||||
|
"""负责调用报告生成工具"""
|
||||||
|
# 绑定工具给 LLM
|
||||||
|
tools = [generate_2025_report_tool]
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
|
||||||
|
messages = state["messages"]
|
||||||
|
system_text = get_agent_prompt("researcher") or "你是情报专家,负责检索与整理参考资料并生成报告。"
|
||||||
|
system_prompt = SystemMessage(content=system_text)
|
||||||
|
response = llm_with_tools.invoke([system_prompt] + messages)
|
||||||
|
|
||||||
|
# 如果模型决定调用工具
|
||||||
|
if response.tool_calls:
|
||||||
|
# 这里为了简化,直接在节点内执行工具(LangGraph也可以用 ToolNode)
|
||||||
|
tool_call = response.tool_calls[0]
|
||||||
|
if tool_call["name"] == "generate_2025_report_tool":
|
||||||
|
result = generate_2025_report_tool.invoke(tool_call["args"])
|
||||||
|
return {"messages": [response, HumanMessage(content=str(result))]}
|
||||||
|
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
|
||||||
|
# --- 3. Visualizer Agent (视觉专家) ---
|
||||||
|
def visualizer_node(state: AgentState):
|
||||||
|
"""负责将自然语言转化为绘图 Prompt 并调用绘图工具"""
|
||||||
|
tools = [generate_furniture_sketch]
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
|
||||||
|
messages = state["messages"]
|
||||||
|
system_text = get_agent_prompt("visualizer") or """
|
||||||
|
你是视觉专家。你的目标是生成高质量的家具草图。
|
||||||
|
步骤:
|
||||||
|
1. 根据上下文,编写一个详细的 Stable Diffusion 风格的英文 Prompt。
|
||||||
|
2. 必须调用 generate_furniture_sketch 工具来生成图片。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 强制它尝试调用工具
|
||||||
|
system_prompt = SystemMessage(content=system_text)
|
||||||
|
response = llm_with_tools.invoke([system_prompt] + messages)
|
||||||
|
|
||||||
|
if response.tool_calls:
|
||||||
|
tool_call = response.tool_calls[0]
|
||||||
|
if tool_call["name"] == "generate_furniture_sketch":
|
||||||
|
result = generate_furniture_sketch.invoke(tool_call["args"])
|
||||||
|
# 返回工具结果给 LLM,让它生成最终回复
|
||||||
|
final_msg = f"已为您生成草图,链接如下:{result}"
|
||||||
|
return {"messages": [response, HumanMessage(content=final_msg)]}
|
||||||
|
|
||||||
|
return {"messages": [response]}
|
||||||
32
src/server/agent/config_loader.py
Normal file
32
src/server/agent/config_loader.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""加载项目根目录下的 config.yaml 并提供 agent prompt 访问接口。"""
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def _project_root() -> str:
|
||||||
|
return os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", ".."))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def load_config() -> Dict[str, Any]:
|
||||||
|
path = os.path.join(_project_root(), "config.yaml")
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return {}
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_prompt(agent_name: str) -> Optional[str]:
|
||||||
|
cfg = load_config()
|
||||||
|
agents = cfg.get("agents", {})
|
||||||
|
entry = agents.get(agent_name, {})
|
||||||
|
prompt = entry.get("prompt_template") or entry.get("prompt")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config() -> Dict[str, Any]:
|
||||||
|
cfg = load_config()
|
||||||
|
return cfg.get("model", {})
|
||||||
98
src/server/agent/graph.py
Normal file
98
src/server/agent/graph.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langgraph.graph import StateGraph, END, START
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pymongo import MongoClient
|
||||||
|
|
||||||
|
from src.core.config import settings, MONGO_URI
|
||||||
|
from src.server.agent.state import AgentState
|
||||||
|
from src.server.agent.agents import designer_node, researcher_node, visualizer_node
|
||||||
|
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||||
|
|
||||||
|
|
||||||
|
# --- Supervisor (路由逻辑) ---
|
||||||
|
# 定义路由的输出结构,强制 LLM 选择一个
|
||||||
|
class RouteResponse(BaseModel):
|
||||||
|
next: Literal["Designer", "Researcher", "Visualizer", "FINISH"]
|
||||||
|
|
||||||
|
|
||||||
|
creds = service_account.Credentials.from_service_account_file(
|
||||||
|
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_supervisor = ChatGoogleGenerativeAI(
|
||||||
|
model="gemini-2.0-flash", temperature=0, credentials=creds,
|
||||||
|
project="aida-461108", location='us-central1', vertexai=True, api_key=settings.GOOGLE_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def supervisor_node(state: AgentState):
|
||||||
|
messages = state["messages"]
|
||||||
|
if not messages:
|
||||||
|
return {"next": "FINISH"}
|
||||||
|
|
||||||
|
last_message = messages[-1]
|
||||||
|
|
||||||
|
# --- 改进的拦截逻辑 ---
|
||||||
|
# 如果最后一条消息是 AI 产生的(且没有调用工具),说明专家已经回复完了用户
|
||||||
|
# 此时我们才拦截并结束,否则会导致专家没机会说话
|
||||||
|
if isinstance(last_message, AIMessage) and not last_message.tool_calls:
|
||||||
|
return {"next": "FINISH"}
|
||||||
|
|
||||||
|
# 如果最后一条是 HumanMessage,说明用户刚说完,Supervisor 必须派发任务
|
||||||
|
system_prompt = """
|
||||||
|
你是家具设计团队的主管(Supervisor)。
|
||||||
|
请根据用户的意图,选择最合适的专家:
|
||||||
|
- Designer: 设计建议、参数细化、闲聊、问候。
|
||||||
|
- Visualizer: 绘图、看草图。
|
||||||
|
- Researcher: 市场报告、趋势。
|
||||||
|
|
||||||
|
只需输出专家名称。
|
||||||
|
"""
|
||||||
|
|
||||||
|
chain = llm_supervisor.with_structured_output(RouteResponse)
|
||||||
|
decision = chain.invoke([{"role": "system", "content": system_prompt}] + messages)
|
||||||
|
|
||||||
|
return {"next": decision.next}
|
||||||
|
|
||||||
|
|
||||||
|
# --- 构建 Graph ---
|
||||||
|
workflow = StateGraph(AgentState)
|
||||||
|
|
||||||
|
workflow.add_node("Supervisor", supervisor_node)
|
||||||
|
workflow.add_node("Designer", designer_node)
|
||||||
|
workflow.add_node("Researcher", researcher_node)
|
||||||
|
workflow.add_node("Visualizer", visualizer_node)
|
||||||
|
|
||||||
|
workflow.add_edge(START, "Supervisor")
|
||||||
|
|
||||||
|
# 这里的逻辑是关键:Supervisor 决定去向
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"Supervisor",
|
||||||
|
lambda state: state["next"],
|
||||||
|
{
|
||||||
|
"Designer": "Designer",
|
||||||
|
"Researcher": "Researcher",
|
||||||
|
"Visualizer": "Visualizer",
|
||||||
|
"FINISH": END
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重点修改:专家执行完后,必须回到 Supervisor 进行状态检查
|
||||||
|
# 如果 Supervisor 发现专家刚说完话,它会触发上面的逻辑返回 FINISH
|
||||||
|
workflow.add_edge("Designer", "Supervisor")
|
||||||
|
workflow.add_edge("Researcher", "Supervisor")
|
||||||
|
workflow.add_edge("Visualizer", "Supervisor")
|
||||||
|
|
||||||
|
client = MongoClient(MONGO_URI)
|
||||||
|
checkpointer = MongoDBSaver(
|
||||||
|
client=client["furniture_agent_db"],
|
||||||
|
db_name="langgraph",
|
||||||
|
collection_name="checkpoints"
|
||||||
|
)
|
||||||
|
app = workflow.compile(checkpointer=checkpointer)
|
||||||
49
src/server/agent/run_test.py
Normal file
49
src/server/agent/run_test.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
from src.server.agent.graph import app
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 模拟 thread_id 区分不同用户或项目
|
||||||
|
config = {"configurable": {"thread_id": "project_alpha"}}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("\n👤 设计师 (输入 'history' 定位轮次): ")
|
||||||
|
|
||||||
|
# --- 官方推荐的异步回溯逻辑 ---
|
||||||
|
if user_input.lower() == "history":
|
||||||
|
print("\n--- 历史记录 ---")
|
||||||
|
for state in app.get_state_history(config):
|
||||||
|
# 每一个 state 都是一个 CheckpointTuple
|
||||||
|
cp_id = state.config["configurable"]["checkpoint_id"]
|
||||||
|
msg = state.values["messages"][-1].content[:30] if state.values.get("messages") else "Initial"
|
||||||
|
print(f"ID: {cp_id} | 内容: {msg}...")
|
||||||
|
|
||||||
|
target_id = input("\n请输入想要回溯的 Checkpoint ID (直接回车取消): ")
|
||||||
|
if target_id:
|
||||||
|
# 重新配置 config,指向特定的 checkpoint_id 实现分支
|
||||||
|
config = {"configurable": {"thread_id": "project_alpha", "checkpoint_id": target_id}}
|
||||||
|
print(f"✅ 已定位到节点 {target_id},后续对话将从此分叉。")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 官方推荐的 astream 异步流式调用 ---
|
||||||
|
print("🤖 Agent 思考中...")
|
||||||
|
for event in app.stream(
|
||||||
|
{"messages": [HumanMessage(content=user_input)]},
|
||||||
|
config,
|
||||||
|
stream_mode="values" # 这里设为 values 可以直接获取当前状态的消息列表
|
||||||
|
):
|
||||||
|
# 获取当前节点处理后的最新消息
|
||||||
|
if "messages" in event:
|
||||||
|
last_msg = event["messages"][-1]
|
||||||
|
if isinstance(last_msg, AIMessage):
|
||||||
|
# 为了极致流式体验,可以在此处对 content 进行打印
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 运行结束后,最新的状态已经自动持久化到 MongoDB
|
||||||
|
# 我们可以通过 app.get_state(config) 验证
|
||||||
|
final_state = app.get_state(config)
|
||||||
|
print(f"\n✅ 最终回复: {final_state.values['messages'][-1].content}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
9
src/server/agent/state.py
Normal file
9
src/server/agent/state.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import operator
|
||||||
|
from typing import Annotated, Sequence, TypedDict, Union
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
class AgentState(TypedDict):
|
||||||
|
# messages 存储完整的对话历史,operator.add 表示新消息是追加而不是覆盖
|
||||||
|
messages: Annotated[Sequence[BaseMessage], operator.add]
|
||||||
|
# next 存储 Supervisor 决定的下一步是谁
|
||||||
|
next: str
|
||||||
25
src/server/agent/tools.py
Normal file
25
src/server/agent/tools.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
# --- 模拟你已经开发好的报告生成功能 ---
|
||||||
|
@tool
|
||||||
|
def generate_2025_report_tool(topic: str) -> str:
|
||||||
|
"""
|
||||||
|
专门用于收集信息并生成报告。
|
||||||
|
当用户询问关于趋势、市场分析、年度报告(如2025家具报告)时调用此工具。
|
||||||
|
"""
|
||||||
|
print(f"\n[系统日志] 正在调用外部模块生成关于 '{topic}' 的报告...")
|
||||||
|
# 这里对接你实际的代码,比如:return my_existing_module.run(topic)
|
||||||
|
return f"【报告生成成功】已生成关于 {topic} 的 PDF 报告。核心洞察:2025年趋势倾向于生物嗜好设计(Biophilic Design)和可持续软木材质。"
|
||||||
|
|
||||||
|
|
||||||
|
# --- 绘图工具 ---
|
||||||
|
@tool
|
||||||
|
def generate_furniture_sketch(prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
用于生成家具草图。输入必须是详细的英文绘画提示词(Prompt)。
|
||||||
|
"""
|
||||||
|
print(f"\n[系统日志] 正在调用 Gemini/Imagen 绘图 API,Prompt: {prompt}...")
|
||||||
|
# 在真实场景中,这里调用 Google Imagen API 或 Midjourney API
|
||||||
|
# 示例返回一个模拟的图片链接
|
||||||
|
return "https://furniture-design-db.com/generated_sketch_v1.jpg"
|
||||||
Reference in New Issue
Block a user