feat 接入report
This commit is contained in:
@@ -7,7 +7,7 @@ from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,6 +40,8 @@ async def chat_stream(request: ChatRequest):
|
||||
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
|
||||
- **Session End**: `{"status": "end"}`
|
||||
|
||||
- **is_delta**: False/True,表示这个消息不是完整内容,只是 AI 正在生成的一小段内容(一个字、一个词、一句话),需要前端把这些片段拼接起来才能得到完整的回答。
|
||||
|
||||
#### 4. 请求示例
|
||||
```
|
||||
{
|
||||
@@ -57,16 +59,77 @@ async def chat_stream(request: ChatRequest):
|
||||
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
|
||||
}
|
||||
```
|
||||
|
||||
### 5. 响应流说明
|
||||
所有响应均以 data: 开头,JSON 字符串格式,末尾以 \n\n 结束
|
||||
响应流包含三种类型的事件:会话开始、节点消息、会话结束
|
||||
会话开始:
|
||||
```
|
||||
{
|
||||
"thread_id": "str",
|
||||
"is_branch": "boolean",
|
||||
"status": "start"
|
||||
}
|
||||
```
|
||||
节点消息:
|
||||
```
|
||||
{
|
||||
"node": "节点名称(如Designer/Researcher/Main)",
|
||||
"content": "消息内容",
|
||||
"checkpoint_id": "快照ID",
|
||||
"is_delta": "boolean",
|
||||
"type": "消息类型",
|
||||
"suggestions": "建议列表(可选)",
|
||||
"tool_name": "工具名称(可选)",
|
||||
"tool_call_chunk": "工具调用片段(可选)",
|
||||
"tool_call_id": "工具调用ID(可选)"
|
||||
}
|
||||
|
||||
```
|
||||
报告增量消息:
|
||||
```
|
||||
{
|
||||
"node": "Researcher",
|
||||
"type": "report_delta",
|
||||
"content": "报告内容增量",
|
||||
"is_delta": true,
|
||||
"checkpoint_id": "xxx"
|
||||
}
|
||||
```
|
||||
AI 消息片段:
|
||||
```
|
||||
{
|
||||
"node": "Designer",
|
||||
"content": "设计建议内容",
|
||||
"checkpoint_id": "xxx",
|
||||
"is_delta": true,
|
||||
"type": "delta",
|
||||
"tool_call_chunk": {...}
|
||||
}
|
||||
```
|
||||
工具执行结果:
|
||||
```
|
||||
{
|
||||
"node": "ToolExecutor",
|
||||
"content": "工具执行结果",
|
||||
"checkpoint_id": "xxx",
|
||||
"is_delta": false,
|
||||
"type": "tool_result",
|
||||
"tool_name": "ImageGenerator",
|
||||
"tool_call_id": "yyy"
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
logger.debug(f"chat request data: {request}")
|
||||
source_thread_id = request.thread_id
|
||||
checkpoint_id = request.checkpoint_id
|
||||
|
||||
# 1. 确定目标 thread_id
|
||||
# 1. 確定目標 thread_id
|
||||
is_branching = source_thread_id and checkpoint_id
|
||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||
|
||||
# 2. 配置参数
|
||||
# 2. 配置參數
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
current_config = {
|
||||
"recursion_limit": 100,
|
||||
@@ -77,7 +140,7 @@ async def chat_stream(request: ChatRequest):
|
||||
}
|
||||
}
|
||||
|
||||
# 3. 初始化消息 + 系统提示
|
||||
# 3. 初始化消息 + 系統提示
|
||||
initial_messages = []
|
||||
if not source_thread_id or is_branching:
|
||||
if request.config_params:
|
||||
@@ -91,7 +154,7 @@ async def chat_stream(request: ChatRequest):
|
||||
)
|
||||
initial_messages.append(SystemMessage(content=system_prompt))
|
||||
|
||||
# 4. 处理分支(从历史 checkpoint 复制状态)
|
||||
# 4. 處理分支(從歷史 checkpoint 複製狀態)
|
||||
if is_branching:
|
||||
source_config = {
|
||||
"configurable": {
|
||||
@@ -109,7 +172,7 @@ async def chat_stream(request: ChatRequest):
|
||||
# 初始事件
|
||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构造输入
|
||||
# 構造輸入(保持不變)
|
||||
new_messages = initial_messages[:] if not source_thread_id else []
|
||||
new_messages.append(HumanMessage(content=request.message))
|
||||
|
||||
@@ -119,130 +182,90 @@ async def chat_stream(request: ChatRequest):
|
||||
"use_report": request.use_report,
|
||||
}
|
||||
|
||||
# 使用 astream_events v2 + stream_subgraphs=True 来捕获 DeepAgents 内部流式事件
|
||||
async for event in app.astream_events(
|
||||
# ─── 重點改這裡 ───────────────────────────────────────
|
||||
async for event in app.astream(
|
||||
input_data,
|
||||
version="v2",
|
||||
config=current_config,
|
||||
stream_subgraphs=True,
|
||||
stream_mode=["custom", "updates", "messages"], # 推薦組合
|
||||
subgraphs=True
|
||||
# 不再需要,行為已包含
|
||||
):
|
||||
event_kind = event["event"]
|
||||
|
||||
# 获取当前 checkpoint_id(安全方式,避免 KeyError)
|
||||
# 取得 checkpoint_id(可選,視前端是否真的需要)
|
||||
latest_state = await app.aget_state(current_config)
|
||||
configurable = latest_state.config.get("configurable", {})
|
||||
current_cp_id = configurable.get("checkpoint_id", "") # 如果没有,返回空字符串
|
||||
current_cp_id = configurable.get("checkpoint_id", "")
|
||||
|
||||
# ────────────────────────────────────────────────
|
||||
# 1. LLM token 流式输出(主图或子图的逐 token)
|
||||
# ────────────────────────────────────────────────
|
||||
if event_kind == "on_chat_model_stream":
|
||||
chunk = event["data"].get("chunk")
|
||||
if chunk and chunk.content:
|
||||
node_name = event.get("name", "Unknown")
|
||||
# 判断是否来自 Researcher 子图
|
||||
namespace = event.get("parent_ids", []) or event.get("namespace", [])
|
||||
if any("Researcher" in str(ns) for ns in namespace):
|
||||
node_name = "Researcher"
|
||||
|
||||
payload = {
|
||||
"node": node_name,
|
||||
"content": chunk.content,
|
||||
"is_delta": True,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ────────────────────────────────────────────────
|
||||
# 2. 自定义事件(report_delta 等)
|
||||
# ────────────────────────────────────────────────
|
||||
elif event_kind == "on_custom_event":
|
||||
custom_data = event["data"]
|
||||
if isinstance(custom_data, dict):
|
||||
if custom_data.get("type") == "report_delta":
|
||||
payload = {
|
||||
"node": "Researcher",
|
||||
"content": custom_data.get("delta", ""),
|
||||
"is_delta": True,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 可选:报告开始/完成/错误等状态提示
|
||||
elif custom_data.get("type") in ("report_start", "report_complete", "report_error"):
|
||||
status_msg = {
|
||||
"report_start": "Start generating reports...",
|
||||
"report_complete": "Report generation completed",
|
||||
"report_error": f"Report generation failed: {custom_data.get('message', '')}"
|
||||
}.get(custom_data["type"], "")
|
||||
payload = {
|
||||
"node": "Researcher",
|
||||
"content": status_msg,
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ────────────────────────────────────────────────
|
||||
# 3. 节点启动 / 工具启动(进度提示)
|
||||
# ────────────────────────────────────────────────
|
||||
elif event_kind in {"on_tool_start", "on_tool_end"}:
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
tool_data = event.get("data", {})
|
||||
tool_input = tool_data.get("input", "")
|
||||
tool_output = tool_data.get("output", "")
|
||||
|
||||
if event_kind == "on_tool_start":
|
||||
payload = {
|
||||
"node": tool_name,
|
||||
"content": tool_input,
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
if tool_name == "generate_furniture" and isinstance(tool_output, str):
|
||||
payload = {
|
||||
"node": tool_name,
|
||||
"content": "Design sketch has been generated for you.", # 给用户友好的文字提示
|
||||
"image_url": tool_output, # 直接传 URL 给前端显示
|
||||
"is_delta": False, # 这是一个完整事件,不是增量
|
||||
"checkpoint_id": current_cp_id,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
elif tool_name == "topic_research":
|
||||
payload = {
|
||||
"node": tool_name,
|
||||
"content": "Visiting...", # 给用户友好的文字提示
|
||||
"image_url": None, # 直接传 URL 给前端显示
|
||||
"search_list": tool_output.content,
|
||||
"is_delta": False, # 这是一个完整事件,不是增量
|
||||
"checkpoint_id": current_cp_id,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
if len(event) == 3:
|
||||
namespace, channel, payload = event
|
||||
if event[1] == "custom":
|
||||
if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", "report_error", "report_save_warning", "report_complete"):
|
||||
delta = payload.get("delta", "").strip()
|
||||
if delta:
|
||||
yield f"data: {json.dumps({
|
||||
'node': 'Researcher',
|
||||
'type': 'report_delta',
|
||||
'content': delta,
|
||||
'is_delta': True,
|
||||
'checkpoint_id': current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
if event[1] == "messages":
|
||||
if namespace:
|
||||
node_name = namespace[-1] if isinstance(namespace, tuple) else namespace
|
||||
if ':' in node_name:
|
||||
node_name = node_name.split(':')[0]
|
||||
else:
|
||||
# 可选:其他工具的通用处理(debug 或显示结果)
|
||||
if tool_output:
|
||||
payload = {
|
||||
"node": tool_name,
|
||||
"content": f"tool {tool_name} Execution completed:{str(tool_output)[:200]}...", # 截断避免过长
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
# 流结束
|
||||
node_name = "Main"
|
||||
message, metadata = payload
|
||||
node_name = metadata.get("langgraph_node", node_name)
|
||||
# 3. 处理不同类型的 message
|
||||
payload_out = {
|
||||
"node": node_name,
|
||||
"checkpoint_id": current_cp_id, # 你之前已经获取了
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": "unknown"
|
||||
}
|
||||
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if message.tool_call_chunks:
|
||||
payload_out.update({
|
||||
"type": "delta",
|
||||
"is_delta": True,
|
||||
"content": message.content,
|
||||
# 如果有 tool call chunk,也可以在这里处理
|
||||
"tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif isinstance(message, ToolMessage):
|
||||
# 工具执行结果(完整的一次性输出)
|
||||
payload_out.update({
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": message.content,
|
||||
"tool_name": message.name,
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
# 特殊处理:如果内容看起来是用户画像或特定格式
|
||||
if "实时用户画像" in message.content:
|
||||
payload_out["type"] = "user_persona"
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif isinstance(message, AIMessage):
|
||||
# 完整 AIMessage(不常见在 messages 模式下,但以防万一)
|
||||
payload_out.update({
|
||||
"type": "complete_message",
|
||||
"is_delta": False,
|
||||
"content": message.content
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
else:
|
||||
# 其他未知类型,记录日志
|
||||
print(f"未知消息类型: {type(message)}", message)
|
||||
continue
|
||||
|
||||
# 流結束
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
Reference in New Issue
Block a user