From a2abe69b60eadace22befa9df68b966772813b6d Mon Sep 17 00:00:00 2001 From: zcr Date: Fri, 6 Feb 2026 14:51:25 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=E5=A2=9E=E5=8A=A0=E6=8E=A8=E8=8D=90?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E5=8F=AF=E6=8E=A7=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/routers/chat.py | 7 ++++++- src/schemas/chat.py | 2 +- src/server/agent/agents.py | 3 ++- src/server/agent/graph.py | 16 +++++++++++----- src/server/agent/state.py | 3 ++- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/routers/chat.py b/src/routers/chat.py index f664ea1..9be16b0 100644 --- a/src/routers/chat.py +++ b/src/routers/chat.py @@ -25,6 +25,8 @@ async def chat_stream(request: ChatRequest): * `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。 * `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。 * `checkpoint_id`: (可选) 历史快照 ID。 + * `config_params`: (可选) 对话配置参数 + * `require_suggestion`: (可选) 是否需要建议按钮 #### 3. 响应流说明 (Data Format) 响应以 `data: ` 开头的 JSON 字符串流形式发送: @@ -122,7 +124,10 @@ async def chat_stream(request: ChatRequest): # ) # new_messages.append(force_instruction) - input_data = {"messages": new_messages} + input_data = { + "messages": new_messages, + "require_suggestion": request.need_suggestion # 初始由前端决定 + } async for event in app.astream( input_data, diff --git a/src/schemas/chat.py b/src/schemas/chat.py index a92e3d0..c4c01c8 100644 --- a/src/schemas/chat.py +++ b/src/schemas/chat.py @@ -14,7 +14,7 @@ class ChatRequest(BaseModel): thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话") checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话") config_params: Optional[AgentConfig] = None - # force_sketch: bool = False # 新增:是否强制绘图 + need_suggestion: bool = False class HistoryItem(BaseModel): diff --git a/src/server/agent/agents.py b/src/server/agent/agents.py index 3c7ca98..1b6bbec 100644 --- a/src/server/agent/agents.py +++ b/src/server/agent/agents.py @@ -42,9 +42,10 @@ async def designer_node(state: AgentState, config: RunnableConfig): system_text = get_agent_prompt("designer") system_prompt = SystemMessage(content=system_text) + should_suggest = len(state["messages"]) % 5 == 0 # 改为异步调用 ainvoke response = await model.ainvoke([system_prompt] + messages) - return {"messages": [response]} + return {"messages": [response], "require_suggestion": should_suggest} # --- 2. Researcher Agent (情报专家) --- diff --git a/src/server/agent/graph.py b/src/server/agent/graph.py index feb3de7..5ede929 100644 --- a/src/server/agent/graph.py +++ b/src/server/agent/graph.py @@ -18,7 +18,7 @@ from langgraph.checkpoint.mongodb import MongoDBSaver # 定义路由的输出结构,强制 LLM 选择一个 class RouteResponse(BaseModel): # 将 FINISH 替换或增加 Suggester - next: Literal["Designer", "Researcher", "Visualizer", "Suggester"] + next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"] creds = service_account.Credentials.from_service_account_file( @@ -42,7 +42,13 @@ def supervisor_node(state: AgentState): # --- 拦截逻辑修改 --- # 如果专家已经回复完了(AIMessage 且无工具调用),则交给 Suggester 生成按钮 if isinstance(last_message, AIMessage) and not last_message.tool_calls: - return {"next": "Suggester"} + should_go_to_suggester = state.get("require_suggestion", False) + + # 如果符合建议条件 + if should_go_to_suggester: + return {"next": "Suggester"} + else: + return {"next": "FINISH"} system_prompt = """你是家具设计主管。分配任务给专家: - Designer: 设计建议、参数细化。 @@ -74,7 +80,8 @@ workflow.add_conditional_edges( "Designer": "Designer", "Researcher": "Researcher", "Visualizer": "Visualizer", - "Suggester": "Suggester" # 原本的 FINISH 现在指向 Suggester + "Suggester": "Suggester", # 原本的 FINISH 现在指向 Suggester + "FINISH": END # 直接结束,不给建议 } ) @@ -82,8 +89,7 @@ workflow.add_conditional_edges( workflow.add_edge("Designer", "Supervisor") workflow.add_edge("Researcher", "Supervisor") workflow.add_edge("Visualizer", "Supervisor") - -# 重点:Suggester 是整个流程的终点 +# 重点:Suggester 可以是整个流程的终点 workflow.add_edge("Suggester", END) client = MongoClient(MONGO_URI) diff --git a/src/server/agent/state.py b/src/server/agent/state.py index eb0a43b..b22e199 100644 --- a/src/server/agent/state.py +++ b/src/server/agent/state.py @@ -6,4 +6,5 @@ class AgentState(TypedDict): # messages 存储完整的对话历史,operator.add 表示新消息是追加而不是覆盖 messages: Annotated[Sequence[BaseMessage], operator.add] # next 存储 Supervisor 决定的下一步是谁 - next: str \ No newline at end of file + next: str + require_suggestion: bool # 是否需要建议按钮 \ No newline at end of file