first commit
This commit is contained in:
98
src/server/agent/graph.py
Normal file
98
src/server/agent/graph.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from google.oauth2 import service_account
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langgraph.graph import StateGraph, END, START
|
||||
from pydantic import BaseModel
|
||||
from pymongo import MongoClient
|
||||
|
||||
from src.core.config import settings, MONGO_URI
|
||||
from src.server.agent.state import AgentState
|
||||
from src.server.agent.agents import designer_node, researcher_node, visualizer_node
|
||||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||
|
||||
|
||||
# --- Supervisor (路由逻辑) ---
|
||||
# 定义路由的输出结构,强制 LLM 选择一个
|
||||
class RouteResponse(BaseModel):
|
||||
next: Literal["Designer", "Researcher", "Visualizer", "FINISH"]
|
||||
|
||||
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
llm_supervisor = ChatGoogleGenerativeAI(
|
||||
model="gemini-2.0-flash", temperature=0, credentials=creds,
|
||||
project="aida-461108", location='us-central1', vertexai=True, api_key=settings.GOOGLE_API_KEY
|
||||
)
|
||||
|
||||
|
||||
def supervisor_node(state: AgentState):
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return {"next": "FINISH"}
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
# --- 改进的拦截逻辑 ---
|
||||
# 如果最后一条消息是 AI 产生的(且没有调用工具),说明专家已经回复完了用户
|
||||
# 此时我们才拦截并结束,否则会导致专家没机会说话
|
||||
if isinstance(last_message, AIMessage) and not last_message.tool_calls:
|
||||
return {"next": "FINISH"}
|
||||
|
||||
# 如果最后一条是 HumanMessage,说明用户刚说完,Supervisor 必须派发任务
|
||||
system_prompt = """
|
||||
你是家具设计团队的主管(Supervisor)。
|
||||
请根据用户的意图,选择最合适的专家:
|
||||
- Designer: 设计建议、参数细化、闲聊、问候。
|
||||
- Visualizer: 绘图、看草图。
|
||||
- Researcher: 市场报告、趋势。
|
||||
|
||||
只需输出专家名称。
|
||||
"""
|
||||
|
||||
chain = llm_supervisor.with_structured_output(RouteResponse)
|
||||
decision = chain.invoke([{"role": "system", "content": system_prompt}] + messages)
|
||||
|
||||
return {"next": decision.next}
|
||||
|
||||
|
||||
# --- 构建 Graph ---
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
workflow.add_node("Supervisor", supervisor_node)
|
||||
workflow.add_node("Designer", designer_node)
|
||||
workflow.add_node("Researcher", researcher_node)
|
||||
workflow.add_node("Visualizer", visualizer_node)
|
||||
|
||||
workflow.add_edge(START, "Supervisor")
|
||||
|
||||
# 这里的逻辑是关键:Supervisor 决定去向
|
||||
workflow.add_conditional_edges(
|
||||
"Supervisor",
|
||||
lambda state: state["next"],
|
||||
{
|
||||
"Designer": "Designer",
|
||||
"Researcher": "Researcher",
|
||||
"Visualizer": "Visualizer",
|
||||
"FINISH": END
|
||||
}
|
||||
)
|
||||
|
||||
# 重点修改:专家执行完后,必须回到 Supervisor 进行状态检查
|
||||
# 如果 Supervisor 发现专家刚说完话,它会触发上面的逻辑返回 FINISH
|
||||
workflow.add_edge("Designer", "Supervisor")
|
||||
workflow.add_edge("Researcher", "Supervisor")
|
||||
workflow.add_edge("Visualizer", "Supervisor")
|
||||
|
||||
client = MongoClient(MONGO_URI)
|
||||
checkpointer = MongoDBSaver(
|
||||
client=client["furniture_agent_db"],
|
||||
db_name="langgraph",
|
||||
collection_name="checkpoints"
|
||||
)
|
||||
app = workflow.compile(checkpointer=checkpointer)
|
||||
Reference in New Issue
Block a user