Files
FiDA_Python/src/server/agent/graph.py

115 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import random
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langchain_qwq import ChatQwen
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.graph import StateGraph, END, START
from pydantic import BaseModel
from pymongo import MongoClient
from src.core.config import MONGO_URI, settings
from src.server.agent.state import AgentState
from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node
from langgraph.checkpoint.mongodb import MongoDBSaver
# --- Supervisor (路由逻辑) ---
# 定义路由的输出结构,强制 LLM 选择一个
class RouteResponse(BaseModel):
# 将 FINISH 替换或增加 Suggester
next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"]
llm_supervisor = ChatQwen(
model="qwen3.5-flash",
max_tokens=3_000,
timeout=None,
max_retries=2,
api_key=settings.QWEN_API_KEY)
def supervisor_node(state: AgentState, config: RunnableConfig):
configurable = config["configurable"]
use_report = configurable.get("use_report", False)
suggest_frequency = configurable.get("require_suggestion", 0.6) # 0.0~1.0
messages = state["messages"]
if not messages:
return {"next": "Suggester"}
# ── system prompt 保持不变 ──
system_prompt = f"""你是家具设计主管,负责分配任务。
当前设定:
- 是否需要市场研究报告:{'' if use_report else ''}
严格遵守以下规则:
- 如果 **不需要** 市场研究报告use_report = False**绝对不能** 选择 Researcher
- 只有在 **明确需要** 市场报告、竞争分析、材质趋势、价格区间等外部资讯时,才选择 Researcher且 **必须** use_report = True
- 常见分配:
- 纯设计、风格、尺寸、材质建议 → Designer
- 需要生成图片、渲染 → Visualizer
- 需要产生建议按钮 → Suggester
- 需要市场报告 → Researcher但只有 use_report=True 时才允许)
- 对话已完整、无需继续 → FINISH
用户最后说了什么?请根据实际需求决定下一步。
"""
chain = llm_supervisor.with_structured_output(RouteResponse)
decision = chain.invoke([{"role": "system", "content": system_prompt}] + messages)
next_node = decision.next # 防空默认 FINISH
# 安全阀:禁止非法选择 Researcher
if next_node == "Researcher" and not use_report:
print("警告LLM 违规选择了 Researcher已强制改为 Suggester 或 FINISH")
next_node = "Suggester" if state.get("require_suggestion", False) else "FINISH"
# 核心改动:只有 LLM 决定 FINISH 时,才掷骰子看是否插入 Suggester
if next_node == "FINISH":
# 满足概率条件 → 插入 Suggester
if suggest_frequency > 0 and random.random() < suggest_frequency:
next_node = "Suggester"
return {"next": next_node}
# --- 构建 Graph ---
workflow = StateGraph(AgentState)
workflow.add_node("Supervisor", supervisor_node)
workflow.add_node("Designer", designer_node)
workflow.add_node("Researcher", researcher_node)
workflow.add_node("Visualizer", visualizer_node)
workflow.add_node("Suggester", suggester_node)
workflow.add_edge(START, "Supervisor")
# 修改条件边映射
workflow.add_conditional_edges(
"Supervisor",
lambda state: state["next"],
{
"Designer": "Designer",
"Researcher": "Researcher",
"Visualizer": "Visualizer",
"Suggester": "Suggester", # 原本的 FINISH 现在指向 Suggester
"FINISH": END # 直接结束,不给建议
}
)
# 专家执行完依然回到 Supervisor
workflow.add_edge("Designer", "Supervisor")
workflow.add_edge("Researcher", "Supervisor")
workflow.add_edge("Visualizer", "Supervisor")
# 重点Suggester 可以是整个流程的终点
workflow.add_edge("Suggester", END)
client = MongoClient(MONGO_URI)
checkpointer = MongoDBSaver(
client=client["furniture_agent_db"],
db_name="langgraph",
collection_name="checkpoints",
serde=JsonPlusSerializer(pickle_fallback=True), # ← 關鍵這一行
)
app = workflow.compile(checkpointer=checkpointer)