feat 接入report
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
@@ -7,6 +10,7 @@ from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
@@ -52,29 +56,28 @@ async def chat_stream(request: ChatRequest):
|
||||
}
|
||||
```
|
||||
"""
|
||||
logger.debug(f"chat request data: {request}")
|
||||
source_thread_id = request.thread_id
|
||||
checkpoint_id = request.checkpoint_id
|
||||
|
||||
# 1. 确定目标 thread_id
|
||||
# 如果是回溯操作,我们生成一个新的 ID,或者由前端传入一个新的 target_thread_id
|
||||
is_branching = source_thread_id and checkpoint_id
|
||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||
# 2. 获取配置参数
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
|
||||
# 构建基础 Config
|
||||
# 2. 配置参数
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
current_config = {
|
||||
"recursion_limit": 100,
|
||||
"configurable": {
|
||||
"thread_id": target_thread_id,
|
||||
"llm_temperature": temp
|
||||
"llm_temperature": temp,
|
||||
"use_report": request.use_report,
|
||||
}
|
||||
}
|
||||
# 3. 处理状态初始化与分支
|
||||
initial_messages = []
|
||||
|
||||
# 如果是全新的对话(没有 source_thread_id),或者明确要求分叉
|
||||
# 3. 初始化消息 + 系统提示
|
||||
initial_messages = []
|
||||
if not source_thread_id or is_branching:
|
||||
# 如果用户传了标签,构造 SystemMessage 注入上下文
|
||||
if request.config_params:
|
||||
cp = request.config_params
|
||||
system_prompt = (
|
||||
@@ -86,7 +89,7 @@ async def chat_stream(request: ChatRequest):
|
||||
)
|
||||
initial_messages.append(SystemMessage(content=system_prompt))
|
||||
|
||||
# 4. 执行分叉逻辑(搬运旧数据)
|
||||
# 4. 处理分支(从历史 checkpoint 复制状态)
|
||||
if is_branching:
|
||||
source_config = {
|
||||
"configurable": {
|
||||
@@ -95,80 +98,149 @@ async def chat_stream(request: ChatRequest):
|
||||
}
|
||||
}
|
||||
older_state = await app.aget_state(source_config)
|
||||
|
||||
# 将旧消息和我们新定义的 SystemMessage 合并
|
||||
# update_state 会将这些消息推送到新 thread 的存储中
|
||||
combined_values = older_state.values.copy()
|
||||
if initial_messages:
|
||||
combined_values["messages"] = list(combined_values["messages"]) + initial_messages
|
||||
|
||||
combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
|
||||
await app.aupdate_state(current_config, combined_values)
|
||||
|
||||
async def event_generator():
|
||||
# 初始推送状态信息
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# 初始事件
|
||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构造本次请求的输入
|
||||
# 如果是第一次开始,且有 initial_messages,则连同 user message 一起发送
|
||||
# --- 核心逻辑:构造本次请求的消息列表 ---
|
||||
new_messages = []
|
||||
if not source_thread_id and initial_messages:
|
||||
new_messages.extend(initial_messages)
|
||||
# 添加用户消息
|
||||
# 构造输入
|
||||
new_messages = initial_messages[:] if not source_thread_id else []
|
||||
new_messages.append(HumanMessage(content=request.message))
|
||||
|
||||
# --- 新增:强制绘图指令注入 ---
|
||||
# if request.force_sketch:
|
||||
# force_instruction = HumanMessage(
|
||||
# content="[SYSTEM_DIRECTIVE]: 用户点击了强制生成按钮。请立即根据当前上下文调用 generate_furniture_sketch 工具生成草图,无需确认。"
|
||||
# )
|
||||
# new_messages.append(force_instruction)
|
||||
|
||||
input_data = {
|
||||
"messages": new_messages,
|
||||
"require_suggestion": request.need_suggestion # 初始由前端决定
|
||||
"require_suggestion": request.need_suggestion,
|
||||
"use_report": request.use_report,
|
||||
}
|
||||
|
||||
async for event in app.astream(
|
||||
# 使用 astream_events v2 + stream_subgraphs=True 来捕获 DeepAgents 内部流式事件
|
||||
async for event in app.astream_events(
|
||||
input_data,
|
||||
current_config,
|
||||
stream_mode="updates"
|
||||
version="v2",
|
||||
config=current_config,
|
||||
stream_subgraphs=True,
|
||||
):
|
||||
for node_name, output in event.items():
|
||||
if "messages" in output:
|
||||
# 获取最新 state 以获取 checkpoint_id
|
||||
state = await app.aget_state(current_config)
|
||||
current_cp_id = state.config["configurable"].get("checkpoint_id")
|
||||
event_kind = event["event"]
|
||||
|
||||
# 遍历本次 update 产生的所有消息
|
||||
for msg in output["messages"]:
|
||||
# 获取当前 checkpoint_id(安全方式,避免 KeyError)
|
||||
latest_state = await app.aget_state(current_config)
|
||||
configurable = latest_state.config.get("configurable", {})
|
||||
current_cp_id = configurable.get("checkpoint_id", "") # 如果没有,返回空字符串
|
||||
|
||||
# ────────────────────────────────────────────────
|
||||
# 1. LLM token 流式输出(主图或子图的逐 token)
|
||||
# ────────────────────────────────────────────────
|
||||
if event_kind == "on_chat_model_stream":
|
||||
chunk = event["data"].get("chunk")
|
||||
if chunk and chunk.content:
|
||||
node_name = event.get("name", "Unknown")
|
||||
# 判断是否来自 Researcher 子图
|
||||
namespace = event.get("parent_ids", []) or event.get("namespace", [])
|
||||
if any("Researcher" in str(ns) for ns in namespace):
|
||||
node_name = "Researcher"
|
||||
|
||||
payload = {
|
||||
"node": node_name,
|
||||
"content": chunk.content,
|
||||
"is_delta": True,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ────────────────────────────────────────────────
|
||||
# 2. 自定义事件(report_delta 等)
|
||||
# ────────────────────────────────────────────────
|
||||
elif event_kind == "on_custom_event":
|
||||
custom_data = event["data"]
|
||||
if isinstance(custom_data, dict):
|
||||
if custom_data.get("type") == "report_delta":
|
||||
payload = {
|
||||
"node": node_name,
|
||||
"content": "",
|
||||
"node": "Researcher",
|
||||
"content": custom_data.get("delta", ""),
|
||||
"is_delta": True,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 可选:报告开始/完成/错误等状态提示
|
||||
elif custom_data.get("type") in ("report_start", "report_complete", "report_error"):
|
||||
status_msg = {
|
||||
"report_start": "Start generating reports...",
|
||||
"report_complete": "Report generation completed",
|
||||
"report_error": f"Report generation failed: {custom_data.get('message', '')}"
|
||||
}.get(custom_data["type"], "")
|
||||
payload = {
|
||||
"node": "Researcher",
|
||||
"content": status_msg,
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ────────────────────────────────────────────────
|
||||
# 3. 节点启动 / 工具启动(进度提示)
|
||||
# ────────────────────────────────────────────────
|
||||
elif event_kind in {"on_tool_start", "on_tool_end"}:
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
tool_data = event.get("data", {})
|
||||
tool_input = tool_data.get("input", "")
|
||||
tool_output = tool_data.get("output", "")
|
||||
|
||||
if event_kind == "on_tool_start":
|
||||
payload = {
|
||||
"node": tool_name,
|
||||
"content": tool_input,
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
if tool_name == "generate_furniture" and isinstance(tool_output, str):
|
||||
payload = {
|
||||
"node": tool_name,
|
||||
"content": "Design sketch has been generated for you.", # 给用户友好的文字提示
|
||||
"image_url": tool_output, # 直接传 URL 给前端显示
|
||||
"is_delta": False, # 这是一个完整事件,不是增量
|
||||
"checkpoint_id": current_cp_id,
|
||||
"suggestions": []
|
||||
}
|
||||
|
||||
# --- 核心改动:提取建议按钮 ---
|
||||
# 无论是不是 Suggester 节点,只要消息里带了建议就提取
|
||||
if hasattr(msg, "additional_kwargs") and "suggestions" in msg.additional_kwargs:
|
||||
payload["suggestions"] = msg.additional_kwargs["suggestions"]
|
||||
|
||||
content = msg.content
|
||||
# 逻辑判断:MinIO 图片处理
|
||||
if node_name == "Visualizer" and str(content).endswith("png") and "furniture/sketches" in str(content):
|
||||
payload["image_url"] = content
|
||||
payload["content"] = "已为您生成设计草图"
|
||||
else:
|
||||
payload["content"] = content
|
||||
|
||||
# 如果消息既没有文本、也没有图片、也没有建议(比如中间的 ToolCall 消息),则跳过
|
||||
if not payload["content"] and not payload["image_url"] and not payload["suggestions"]:
|
||||
continue
|
||||
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif tool_name == "topic_research":
|
||||
payload = {
|
||||
"node": tool_name,
|
||||
"content": "Visiting...", # 给用户友好的文字提示
|
||||
"image_url": None, # 直接传 URL 给前端显示
|
||||
"search_list": tool_output.content,
|
||||
"is_delta": False, # 这是一个完整事件,不是增量
|
||||
"checkpoint_id": current_cp_id,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
# 可选:其他工具的通用处理(debug 或显示结果)
|
||||
if tool_output:
|
||||
payload = {
|
||||
"node": tool_name,
|
||||
"content": f"tool {tool_name} Execution completed:{str(tool_output)[:200]}...", # 截断避免过长
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"image_url": None,
|
||||
"suggestions": []
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
# 流结束
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
@@ -218,7 +290,7 @@ async def get_chat_history(thread_id: str):
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
config = {"configurable": {"thread_id": thread_id}, }
|
||||
history_data = []
|
||||
async for state in app.aget_state_history(config):
|
||||
msg_content = "Initial"
|
||||
|
||||
Reference in New Issue
Block a user