新增对话接口

This commit is contained in:
zcr
2026-03-12 13:13:52 +08:00
parent 7042d428fa
commit a6393df0e3
35 changed files with 843 additions and 1163 deletions

View File

@@ -6,7 +6,7 @@ from typing import AsyncGenerator
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
from src.server.agent.graph import app
from src.server.agent.graph import app # 导入已经 compile 好的 graph
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
@@ -169,8 +169,10 @@ async def chat_stream(request: ChatRequest):
await app.aupdate_state(current_config, combined_values)
async def event_generator() -> AsyncGenerator[str, None]:
# 初始事件
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
# 構造輸入(保持不變)
new_messages = initial_messages[:] if not source_thread_id else []
new_messages.append(HumanMessage(content=request.message))
@@ -180,143 +182,124 @@ async def chat_stream(request: ChatRequest):
"use_report": request.use_report,
}
interrupted = False
current_cp_id = None
# try:
# ─── 重點改這裡 ───────────────────────────────────────
async for event in app.astream(
input_data,
config=current_config,
stream_mode=["updates", "messages", "custom"], # 确保包含 "values"
stream_mode=["custom", "updates", "messages"], # 推薦組合
subgraphs=True
# 不再需要,行為已包含
):
if interrupted:
break
logger.info(event)
# 取得 checkpoint_id可選視前端是否真的需要
latest_state = await app.aget_state(current_config)
configurable = latest_state.config.get("configurable", {})
current_cp_id = configurable.get("checkpoint_id", "")
if len(event) == 3:
namespace, channel, payload = event
# 路由更新
if event[1] == "updates":
namespace, _, payload = event
logger.info(f"Received event: {event}")
if isinstance(payload, dict):
for update_node, update_content in payload.items():
if not isinstance(event, tuple) or len(event) != 3:
continue
# 处理 reducerOverwrite / Append
if isinstance(update_content, dict):
for k, v in update_content.items():
if hasattr(v, "value"): # Overwrite(...)
update_content[k] = v.value
run_id, channel, payload = event
if isinstance(update_content, dict) and "messages" in update_content:
msgs = []
for m in update_content["messages"]:
msgs.append({
"type": m.__class__.__name__,
"content": getattr(m, "content", ""),
"name": getattr(m, "name", None),
"tool_calls": getattr(m, "tool_calls", None),
})
update_content["messages"] = msgs
# ─────────────── 检测 interrupt ───────────────
# __interrupt__ 最常出现在 "values" 或 "updates" channel 的 payload 中
if channel in ("values", "updates") and isinstance(payload, dict) and "__interrupt__" in payload:
interrupt_data = payload["__interrupt__"][0].value['__interrupt__']
interrupted = True
yield f"data: {json.dumps({
"type": "interrupt",
"node": interrupt_data.get("node", "Persona"),
"question": interrupt_data.get('question', "异常|||||||||||||"),
"current_persona": interrupt_data.get("current_persona_snapshot", {}),
"status": "waiting_for_input"
}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({
"node": "Supervisor",
"type": "updates",
"content": update_content,
"is_delta": False,
"checkpoint_id": current_cp_id,
}, ensure_ascii=False)}\n\n"
# 立即停止后续发送,等待用户回复后 resume
break
# 自定义事件
elif event[1] == "custom":
if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", "report_error", "report_save_warning", "report_complete"):
delta = payload.get("delta", "").strip()
if delta:
yield f"data: {json.dumps({
'node': 'Researcher',
'type': 'report_delta',
'content': delta,
'is_delta': True,
'checkpoint_id': current_cp_id,
}, ensure_ascii=False)}\n\n"
# 基础消息
elif event[1] == "messages":
if namespace:
node_name = namespace[-1] if isinstance(namespace, tuple) else namespace
if ':' in node_name:
node_name = node_name.split(':')[0]
else:
node_name = "Main"
message, metadata = payload
is_not_research = node_name != 'Researcher'
node_name = metadata.get("langgraph_node", node_name)
# 3. 处理不同类型的 message
payload_out = {
"node": node_name,
"checkpoint_id": current_cp_id, # 你之前已经获取了
"is_delta": False,
"content": "",
"suggestions": [],
"type": "unknown"
}
# ─────────────── 正常消息处理 ───────────────
if channel == "messages":
if run_id:
node_name = run_id[-1] if isinstance(run_id, tuple) else run_id
if ':' in node_name:
node_name = node_name.split(':')[0]
else:
node_name = "Main"
message, metadata = payload
node_name = metadata.get("langgraph_node", node_name)
payload_out = {
"node": node_name,
"checkpoint_id": current_cp_id or "unknown",
"is_delta": False,
"content": "",
"suggestions": [],
"type": "unknown"
}
if isinstance(message, AIMessageChunk):
if node_name != 'Researcher' and message.tool_call_chunks:
payload_out.update({
"type": "delta",
"is_delta": True,
"content": message.content,
"tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(message, ToolMessage):
try:
tools_data = json.loads(message.content)
if isinstance(message, AIMessageChunk):
# 节点不是research 并且 tool_call_chunks不为空的情况下避免research的report工具使用custom发出的消息和message的消息重复了
if is_not_research and node_name != 'Researcher' and message.tool_call_chunks:
payload_out.update({
"type": "delta",
"is_delta": True,
"content": message.content,
# 如果有 tool call chunk也可以在这里处理
"tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(message, ToolMessage):
# 工具执行结果(完整的一次性输出)
payload_out.update({
"type": "tool_result",
"is_delta": False,
"content": tools_data.get("data", ""),
"tool_name": tools_data.get("tool_name", ""),
"content": message.content,
"tool_name": message.name,
"tool_call_id": message.tool_call_id
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
except:
pass
elif isinstance(message, AIMessage):
payload_out.update({
"type": "complete_message",
"is_delta": False,
"content": message.content
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif channel == "updates":
# 处理 updates非 interrupt 的部分)
if isinstance(payload, dict):
for update_node, update_content in payload.items():
# 处理 reducer 包裹的值
if isinstance(update_content, dict):
for k, v in update_content.items():
if hasattr(v, "value"):
update_content[k] = v.value
# 序列化 messages
if isinstance(update_content, dict) and "messages" in update_content:
msgs = []
for m in update_content["messages"]:
msgs.append({
"type": m.__class__.__name__,
"content": getattr(m, "content", ""),
"name": getattr(m, "name", None),
"tool_calls": getattr(m, "tool_calls", None),
})
update_content["messages"] = msgs
yield f"data: {json.dumps({
"node": "Supervisor", # update_node
"type": "updates",
"content": update_content,
elif isinstance(message, AIMessage):
# 完整 AIMessage不常见在 messages 模式下,但以防万一)
payload_out.update({
"type": "complete_message",
"is_delta": False,
"checkpoint_id": current_cp_id,
}, ensure_ascii=False)}\n\n"
"content": message.content
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif channel == "custom":
if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", ...):
delta = payload.get("delta", "").strip()
if delta:
yield f"data: {json.dumps({
'node': 'Researcher',
'type': 'report_delta',
'content': delta,
'is_delta': True,
'checkpoint_id': current_cp_id,
}, ensure_ascii=False)}\n\n"
# except Exception as e:
# print("error")
else:
# 其他未知类型,记录日志
print(f"未知消息类型: {type(message)}", message)
continue
# 结束标记
if interrupted:
yield f"data: {json.dumps({'status': 'interrupted', 'reason': 'waiting_for_user_input'})}\n\n"
else:
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
# 流結束
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")

View File

@@ -1,4 +1,5 @@
import logging
import random
import uuid
import json
from typing import AsyncGenerator
@@ -9,6 +10,7 @@ from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
from src.server.deep_agent.agents.main_agent import build_main_agent
from src.server.deep_agent.tools.extract_suggested_questions import format_messages, generate_suggested_questions
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
logger = logging.getLogger(__name__)
@@ -80,7 +82,6 @@ async def chat_stream(request: ChatRequest):
"checkpoint_id": "快照ID",
"is_delta": "boolean",
"type": "消息类型",
"suggestions": "建议列表(可选)",
"tool_name": "工具名称(可选)",
"tool_call_chunk": "工具调用片段(可选)",
"tool_call_id": "工具调用ID可选"
@@ -136,8 +137,6 @@ async def chat_stream(request: ChatRequest):
# 2. 配置參數
temp = request.config_params.temperature if request.config_params else 0.7
need_suggestion = request.need_suggestion,
current_config = {
"recursion_limit": 120,
"configurable": {
@@ -184,28 +183,55 @@ async def chat_stream(request: ChatRequest):
input_data = {
"messages": new_messages,
}
current_cp_id = None
async for stream in main_agent.astream(
input_data,
config=current_config,
stream_mode=["updates", "messages", "custom"], # 确保包含 "values"
stream_mode=["updates", "messages", "custom"],
subgraphs=True
):
# logger.info(f"Received event: {event}")
_, mode, chunks = stream
if mode == "updates":
# TODO 补充
if mode == "updates": # 只做记录 不做事件返回
print(f"[updates] {chunks}")
update_model_messages = chunks.get("model", None)
update_tools_messages = chunks.get("tools", None)
payload_out = {
"node": "",
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
"is_delta": False,
"content": "",
"type": "updates"
}
if update_model_messages:
model_messages = update_model_messages.get("messages", [])
for model_token in model_messages:
if isinstance(model_token, AIMessage):
model_content_blocks = model_token.content_blocks[0]
model_name = model_token.name
payload_out.update({
"node": model_name if model_name else "main",
"tool_calls": model_token.tool_calls
})
logger.info(f"[updates] {model_name} -- {model_content_blocks} -- {model_token.tool_calls}")
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif update_tools_messages:
tools_messages = update_tools_messages.get("messages", [])
for tools_token in tools_messages:
if isinstance(tools_token, ToolMessage):
tool_content_blocks = tools_token.content_blocks[0]
tool_name = tools_token.name
logger.info(f"[updates] {tool_name} -- {tool_content_blocks}")
else:
logger.info(f"[updates] -- {chunks}")
elif mode == "messages":
token, metadata = chunks
subagent_name = metadata.get('lc_agent_name', None)
subagent_name = metadata.get('lc_agent_name', "main")
payload_out = {
"node": subagent_name,
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
"is_delta": False,
"content": "",
"suggestions": [],
"type": ""
}
@@ -213,29 +239,36 @@ async def chat_stream(request: ChatRequest):
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
text = [b for b in token.content_blocks if b["type"] == "text"]
if reasoning:
payload_out.update({
"type": "reasoning",
"is_delta": True,
"content": text,
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
if len(reasoning) == 1:
payload_out.update({
"type": "reasoning",
"is_delta": True,
"content": reasoning[0].get("reasoning", ""),
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
print(f"[reasoning] {reasoning}*************************************************************************************")
elif text:
payload_out.update({
"type": "text",
"is_delta": True,
"content": text,
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
if len(text) == 1:
payload_out.update({
"type": "text",
"is_delta": True,
"content": text[0].get("text", ""),
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
print(f"[text] {text}*************************************************************************************")
else:
payload_out.update({
"type": "tool_call",
"is_delta": True,
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(token, ToolMessageChunk): # 工具返回
text = [b for b in token.content_blocks if b["type"] == "text"]
payload_out.update({
"type": "tool_text",
"type": "tool_result",
"is_delta": False,
"content": text,
"tool_name": token.name,
@@ -244,7 +277,7 @@ async def chat_stream(request: ChatRequest):
elif isinstance(token, ToolMessage): # 工具返回
text = [b for b in token.content_blocks if b["type"] == "text"]
payload_out.update({
"type": "tool_text",
"type": "tool_result",
"is_delta": False,
"content": text,
"tool_name": token.name,
@@ -254,14 +287,11 @@ async def chat_stream(request: ChatRequest):
continue
elif mode == "custom":
token, metadata = chunks
subagent_name = metadata.get('lc_agent_name', None)
payload_out = {
"node": subagent_name,
"node": "research-agent",
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
"is_delta": False,
"content": "",
"suggestions": [],
"type": ""
}
delta = chunks.get("delta", "").strip()
@@ -273,40 +303,10 @@ async def chat_stream(request: ChatRequest):
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
# elif channel == "updates":
# # 处理 updates非 interrupt 的部分)
# if isinstance(payload, dict):
# for update_node, update_content in payload.items():
# # 处理 reducer 包裹的值
# if isinstance(update_content, dict):
# for k, v in update_content.items():
# if hasattr(v, "value"):
# update_content[k] = v.value
#
# # 序列化 messages
# if isinstance(update_content, dict) and "messages" in update_content:
# msgs = []
# for m in update_content["messages"]:
# msgs.append({
# "type": m.__class__.__name__,
# "content": getattr(m, "content", ""),
# "name": getattr(m, "name", None),
# "tool_calls": getattr(m, "tool_calls", None),
# })
# update_content["messages"] = msgs
#
# yield f"data: {json.dumps({
# "node": "Supervisor", # 或 update_node
# "type": "updates",
# "content": update_content,
# "is_delta": False,
# "checkpoint_id": current_cp_id,
# }, ensure_ascii=False)}\n\n"
#
# elif channel == "custom":
else:
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
suggested_questions = generate_suggested_questions(main_agent, target_thread_id)
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -357,6 +357,7 @@ async def get_chat_history(thread_id: str):
"""
config = {"configurable": {"thread_id": thread_id}, }
history_data = []
main_agent = build_main_agent(False)
async for state in main_agent.aget_state_history(config):
msg_content = "Initial"
if state.values and "messages" in state.values: