Files
FiDA_Python/src/server/agent/graph.py

115 lines
4.2 KiB
Python
Raw Normal View History

2026-03-06 16:15:25 +08:00
import random
2026-02-04 17:57:49 +08:00
from typing import Literal
from langchain_core.messages import AIMessage
2026-03-03 17:33:51 +08:00
from langchain_core.runnables import RunnableConfig
from langchain_qwq import ChatQwen
2026-03-04 19:03:12 +08:00
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
2026-02-04 17:57:49 +08:00
from langgraph.graph import StateGraph, END, START
from pydantic import BaseModel
from pymongo import MongoClient
2026-03-03 17:33:51 +08:00
from src.core.config import MONGO_URI, settings
2026-02-04 17:57:49 +08:00
from src.server.agent.state import AgentState
from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node
2026-02-04 17:57:49 +08:00
from langgraph.checkpoint.mongodb import MongoDBSaver
# --- Supervisor (路由逻辑) ---
# 定义路由的输出结构,强制 LLM 选择一个
class RouteResponse(BaseModel):
# 将 FINISH 替换或增加 Suggester
2026-02-06 14:51:25 +08:00
next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"]
2026-02-04 17:57:49 +08:00
2026-03-03 17:33:51 +08:00
llm_supervisor = ChatQwen(
model="qwen3.5-flash",
max_tokens=3_000,
timeout=None,
max_retries=2,
api_key=settings.QWEN_API_KEY)
2026-02-04 17:57:49 +08:00
2026-03-03 17:33:51 +08:00
def supervisor_node(state: AgentState, config: RunnableConfig):
2026-03-06 16:15:25 +08:00
configurable = config["configurable"]
use_report = configurable.get("use_report", False)
suggest_frequency = configurable.get("require_suggestion", 0.6) # 0.0~1.0
2026-02-04 17:57:49 +08:00
messages = state["messages"]
if not messages:
return {"next": "Suggester"}
2026-02-04 17:57:49 +08:00
2026-03-06 16:15:25 +08:00
# ── system prompt 保持不变 ──
system_prompt = f"""你是家具设计主管,负责分配任务。
2026-03-06 16:15:25 +08:00
当前设定
- 是否需要市场研究报告{'' if use_report else ''}
严格遵守以下规则
- 如果 **不需要** 市场研究报告use_report = False**绝对不能** 选择 Researcher
- 只有在 **明确需要** 市场报告竞争分析材质趋势价格区间等外部资讯时才选择 Researcher **必须** use_report = True
- 常见分配
- 纯设计风格尺寸材质建议 Designer
- 需要生成图片渲染 Visualizer
- 需要产生建议按钮 Suggester
- 需要市场报告 Researcher但只有 use_report=True 时才允许
- 对话已完整无需继续 FINISH
用户最后说了什么请根据实际需求决定下一步
"""
2026-02-04 17:57:49 +08:00
chain = llm_supervisor.with_structured_output(RouteResponse)
decision = chain.invoke([{"role": "system", "content": system_prompt}] + messages)
2026-03-06 16:15:25 +08:00
next_node = decision.next # 防空默认 FINISH
2026-03-06 16:15:25 +08:00
# 安全阀:禁止非法选择 Researcher
if next_node == "Researcher" and not use_report:
2026-03-06 16:15:25 +08:00
print("警告LLM 违规选择了 Researcher已强制改为 Suggester 或 FINISH")
next_node = "Suggester" if state.get("require_suggestion", False) else "FINISH"
2026-03-06 16:15:25 +08:00
# 核心改动:只有 LLM 决定 FINISH 时,才掷骰子看是否插入 Suggester
if next_node == "FINISH":
# 满足概率条件 → 插入 Suggester
if suggest_frequency > 0 and random.random() < suggest_frequency:
next_node = "Suggester"
return {"next": next_node}
2026-02-04 17:57:49 +08:00
# --- 构建 Graph ---
workflow = StateGraph(AgentState)
workflow.add_node("Supervisor", supervisor_node)
workflow.add_node("Designer", designer_node)
workflow.add_node("Researcher", researcher_node)
workflow.add_node("Visualizer", visualizer_node)
2026-03-06 16:15:25 +08:00
workflow.add_node("Suggester", suggester_node)
2026-02-04 17:57:49 +08:00
workflow.add_edge(START, "Supervisor")
# 修改条件边映射
2026-02-04 17:57:49 +08:00
workflow.add_conditional_edges(
"Supervisor",
lambda state: state["next"],
{
"Designer": "Designer",
"Researcher": "Researcher",
"Visualizer": "Visualizer",
2026-02-06 14:51:25 +08:00
"Suggester": "Suggester", # 原本的 FINISH 现在指向 Suggester
"FINISH": END # 直接结束,不给建议
2026-02-04 17:57:49 +08:00
}
)
# 专家执行完依然回到 Supervisor
2026-02-04 17:57:49 +08:00
workflow.add_edge("Designer", "Supervisor")
workflow.add_edge("Researcher", "Supervisor")
workflow.add_edge("Visualizer", "Supervisor")
2026-02-06 14:51:25 +08:00
# 重点Suggester 可以是整个流程的终点
workflow.add_edge("Suggester", END)
2026-02-04 17:57:49 +08:00
client = MongoClient(MONGO_URI)
checkpointer = MongoDBSaver(
client=client["furniture_agent_db"],
db_name="langgraph",
2026-03-04 19:03:12 +08:00
collection_name="checkpoints",
serde=JsonPlusSerializer(pickle_fallback=True), # ← 關鍵這一行
2026-02-04 17:57:49 +08:00
)
app = workflow.compile(checkpointer=checkpointer)