diff --git a/pyproject.toml b/pyproject.toml index a35aa3a..71a2029 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,4 +37,5 @@ dependencies = [ "dashscope>=1.25.13", "prompt>=0.4.1", "langchain-qwq>=0.3.4", + "asyncio>=4.0.0", ] diff --git a/src/routers/chat.py b/src/routers/chat.py index 7b1b5b0..0f79fee 100644 --- a/src/routers/chat.py +++ b/src/routers/chat.py @@ -7,7 +7,7 @@ from fastapi import APIRouter from fastapi.responses import StreamingResponse from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem from src.server.agent.graph import app # 导入已经 compile 好的 graph -from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"]) logger = logging.getLogger(__name__) @@ -40,6 +40,8 @@ async def chat_stream(request: ChatRequest): - **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}` - **Session End**: `{"status": "end"}` + - **is_delta**: False/True,表示这个消息不是完整内容,只是 AI 正在生成的一小段内容(一个字、一个词、一句话),需要前端把这些片段拼接起来才能得到完整的回答。 + #### 4. 请求示例 ``` { @@ -57,16 +59,77 @@ async def chat_stream(request: ChatRequest): "checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a" } ``` + + ### 5. 响应流说明 + 所有响应均以 data: 开头,JSON 字符串格式,末尾以 \n\n 结束 + 响应流包含三种类型的事件:会话开始、节点消息、会话结束 + 会话开始: + ``` + { + "thread_id": "str", + "is_branch": "boolean", + "status": "start" + } + ``` + 节点消息: + ``` + { + "node": "节点名称(如Designer/Researcher/Main)", + "content": "消息内容", + "checkpoint_id": "快照ID", + "is_delta": "boolean", + "type": "消息类型", + "suggestions": "建议列表(可选)", + "tool_name": "工具名称(可选)", + "tool_call_chunk": "工具调用片段(可选)", + "tool_call_id": "工具调用ID(可选)" + } + + ``` + 报告增量消息: + ``` + { + "node": "Researcher", + "type": "report_delta", + "content": "报告内容增量", + "is_delta": true, + "checkpoint_id": "xxx" + } + ``` + AI 消息片段: + ``` + { + "node": "Designer", + "content": "设计建议内容", + "checkpoint_id": "xxx", + "is_delta": true, + "type": "delta", + "tool_call_chunk": {...} + } + ``` + 工具执行结果: + ``` + { + "node": "ToolExecutor", + "content": "工具执行结果", + "checkpoint_id": "xxx", + "is_delta": false, + "type": "tool_result", + "tool_name": "ImageGenerator", + "tool_call_id": "yyy" + } + ``` + """ logger.debug(f"chat request data: {request}") source_thread_id = request.thread_id checkpoint_id = request.checkpoint_id - # 1. 确定目标 thread_id + # 1. 確定目標 thread_id is_branching = source_thread_id and checkpoint_id target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8]) - # 2. 配置参数 + # 2. 配置參數 temp = request.config_params.temperature if request.config_params else 0.7 current_config = { "recursion_limit": 100, @@ -77,7 +140,7 @@ async def chat_stream(request: ChatRequest): } } - # 3. 初始化消息 + 系统提示 + # 3. 初始化消息 + 系統提示 initial_messages = [] if not source_thread_id or is_branching: if request.config_params: @@ -91,7 +154,7 @@ async def chat_stream(request: ChatRequest): ) initial_messages.append(SystemMessage(content=system_prompt)) - # 4. 处理分支(从历史 checkpoint 复制状态) + # 4. 處理分支(從歷史 checkpoint 複製狀態) if is_branching: source_config = { "configurable": { @@ -109,7 +172,7 @@ async def chat_stream(request: ChatRequest): # 初始事件 yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n" - # 构造输入 + # 構造輸入(保持不變) new_messages = initial_messages[:] if not source_thread_id else [] new_messages.append(HumanMessage(content=request.message)) @@ -119,130 +182,90 @@ async def chat_stream(request: ChatRequest): "use_report": request.use_report, } - # 使用 astream_events v2 + stream_subgraphs=True 来捕获 DeepAgents 内部流式事件 - async for event in app.astream_events( + # ─── 重點改這裡 ─────────────────────────────────────── + async for event in app.astream( input_data, - version="v2", config=current_config, - stream_subgraphs=True, + stream_mode=["custom", "updates", "messages"], # 推薦組合 + subgraphs=True + # 不再需要,行為已包含 ): - event_kind = event["event"] - - # 获取当前 checkpoint_id(安全方式,避免 KeyError) + # 取得 checkpoint_id(可選,視前端是否真的需要) latest_state = await app.aget_state(current_config) configurable = latest_state.config.get("configurable", {}) - current_cp_id = configurable.get("checkpoint_id", "") # 如果没有,返回空字符串 + current_cp_id = configurable.get("checkpoint_id", "") - # ──────────────────────────────────────────────── - # 1. LLM token 流式输出(主图或子图的逐 token) - # ──────────────────────────────────────────────── - if event_kind == "on_chat_model_stream": - chunk = event["data"].get("chunk") - if chunk and chunk.content: - node_name = event.get("name", "Unknown") - # 判断是否来自 Researcher 子图 - namespace = event.get("parent_ids", []) or event.get("namespace", []) - if any("Researcher" in str(ns) for ns in namespace): - node_name = "Researcher" - - payload = { - "node": node_name, - "content": chunk.content, - "is_delta": True, - "checkpoint_id": current_cp_id, - "image_url": None, - "suggestions": [] - } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - - # ──────────────────────────────────────────────── - # 2. 自定义事件(report_delta 等) - # ──────────────────────────────────────────────── - elif event_kind == "on_custom_event": - custom_data = event["data"] - if isinstance(custom_data, dict): - if custom_data.get("type") == "report_delta": - payload = { - "node": "Researcher", - "content": custom_data.get("delta", ""), - "is_delta": True, - "checkpoint_id": current_cp_id, - "image_url": None, - "suggestions": [] - } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - - # 可选:报告开始/完成/错误等状态提示 - elif custom_data.get("type") in ("report_start", "report_complete", "report_error"): - status_msg = { - "report_start": "Start generating reports...", - "report_complete": "Report generation completed", - "report_error": f"Report generation failed: {custom_data.get('message', '')}" - }.get(custom_data["type"], "") - payload = { - "node": "Researcher", - "content": status_msg, - "is_delta": False, - "checkpoint_id": current_cp_id, - "image_url": None, - "suggestions": [] - } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - - # ──────────────────────────────────────────────── - # 3. 节点启动 / 工具启动(进度提示) - # ──────────────────────────────────────────────── - elif event_kind in {"on_tool_start", "on_tool_end"}: - tool_name = event.get("name", "unknown_tool") - tool_data = event.get("data", {}) - tool_input = tool_data.get("input", "") - tool_output = tool_data.get("output", "") - - if event_kind == "on_tool_start": - payload = { - "node": tool_name, - "content": tool_input, - "is_delta": False, - "checkpoint_id": current_cp_id, - "image_url": None, - "suggestions": [] - } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - else: - if tool_name == "generate_furniture" and isinstance(tool_output, str): - payload = { - "node": tool_name, - "content": "Design sketch has been generated for you.", # 给用户友好的文字提示 - "image_url": tool_output, # 直接传 URL 给前端显示 - "is_delta": False, # 这是一个完整事件,不是增量 - "checkpoint_id": current_cp_id, - "suggestions": [] - } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - elif tool_name == "topic_research": - payload = { - "node": tool_name, - "content": "Visiting...", # 给用户友好的文字提示 - "image_url": None, # 直接传 URL 给前端显示 - "search_list": tool_output.content, - "is_delta": False, # 这是一个完整事件,不是增量 - "checkpoint_id": current_cp_id, - "suggestions": [] - } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + if len(event) == 3: + namespace, channel, payload = event + if event[1] == "custom": + if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", "report_error", "report_save_warning", "report_complete"): + delta = payload.get("delta", "").strip() + if delta: + yield f"data: {json.dumps({ + 'node': 'Researcher', + 'type': 'report_delta', + 'content': delta, + 'is_delta': True, + 'checkpoint_id': current_cp_id, + }, ensure_ascii=False)}\n\n" + if event[1] == "messages": + if namespace: + node_name = namespace[-1] if isinstance(namespace, tuple) else namespace + if ':' in node_name: + node_name = node_name.split(':')[0] else: - # 可选:其他工具的通用处理(debug 或显示结果) - if tool_output: - payload = { - "node": tool_name, - "content": f"tool {tool_name} Execution completed:{str(tool_output)[:200]}...", # 截断避免过长 - "is_delta": False, - "checkpoint_id": current_cp_id, - "image_url": None, - "suggestions": [] - } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - # 流结束 + node_name = "Main" + message, metadata = payload + node_name = metadata.get("langgraph_node", node_name) + # 3. 处理不同类型的 message + payload_out = { + "node": node_name, + "checkpoint_id": current_cp_id, # 你之前已经获取了 + "is_delta": False, + "content": "", + "suggestions": [], + "type": "unknown" + } + + if isinstance(message, AIMessageChunk): + if message.tool_call_chunks: + payload_out.update({ + "type": "delta", + "is_delta": True, + "content": message.content, + # 如果有 tool call chunk,也可以在这里处理 + "tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None + }) + yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n" + elif isinstance(message, ToolMessage): + # 工具执行结果(完整的一次性输出) + payload_out.update({ + "type": "tool_result", + "is_delta": False, + "content": message.content, + "tool_name": message.name, + "tool_call_id": message.tool_call_id + }) + # 特殊处理:如果内容看起来是用户画像或特定格式 + if "实时用户画像" in message.content: + payload_out["type"] = "user_persona" + yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n" + + elif isinstance(message, AIMessage): + # 完整 AIMessage(不常见在 messages 模式下,但以防万一) + payload_out.update({ + "type": "complete_message", + "is_delta": False, + "content": message.content + }) + yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n" + + else: + # 其他未知类型,记录日志 + print(f"未知消息类型: {type(message)}", message) + continue + + # 流結束 yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") diff --git a/src/server/agent/agents.py b/src/server/agent/agents.py index b49ab5f..dfd85e5 100644 --- a/src/server/agent/agents.py +++ b/src/server/agent/agents.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import AsyncGenerator, Dict, Any from deepagents import create_deep_agent from deepagents.backends import FilesystemBackend -from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage +from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage, AIMessageChunk from langchain_core.runnables import RunnableConfig from langchain_qwq import ChatQwen @@ -26,6 +26,7 @@ MAIN_DIR = Path(__file__).resolve().parent PROJECT_ROOT = MAIN_DIR model = ChatQwen( + enable_thinking=False, model="qwen3.5-flash", max_tokens=3_000, timeout=None, @@ -48,6 +49,7 @@ research_agent = create_deep_agent( def get_model(config: RunnableConfig): temp = config["configurable"].get("llm_temperature", 0.5) return ChatQwen( + enable_thinking=False, model="qwen3.5-flash", max_tokens=3_000, timeout=None, @@ -114,110 +116,69 @@ async def researcher_node( )] } - async for event in research_agent.astream_events( + async for chunk in research_agent.astream( {"messages": messages[-12:]}, - version="v2", - config=config, - stream_subgraphs=True + config=config ): - event_type = event["event"] - name = event.get("name", "未知") - - if event["event"] == "on_custom_event": - custom_data = event["data"] - # 你的 writer 发的是 dict,所以这里 custom_data 就是你写的 {"type": "report_delta", "delta": "..."} - if isinstance(custom_data, dict) and custom_data.get("type") == "report_delta": - delta = custom_data.get("delta", "") - print(delta, end="", flush=True) # 实时打印,不换行 - - # ────────────── 工具结束事件:重点处理并 yield 输出 ────────────── - if event["event"] in {"on_tool_start", "on_tool_end"}: - tool_name = event.get("name", "未知") - is_start = event["event"] == "on_tool_start" - - if is_start: - tool_input = event["data"].get("input", {}) - current_step = f"正在執行工具:{tool_name}" - print(f"| {current_step} | {tool_input}") - yield { - "messages": [AIMessage( - content=full_content, - name="Researcher", - additional_kwargs={ - "current_step": current_step, - "tool_name": tool_name, - "tool_input": tool_input, - "tool_status": "start", - "streaming": True - } - )] - } - else: # on_tool_end - tool_output = event["data"].get("output", "") - current_step = f"工具 {tool_name} 已完成" - print(f"| {current_step} | {tool_output}") - yield { - "messages": [AIMessage( - content=full_content, - name="Researcher", - additional_kwargs={ - "current_step": current_step, - "tool_name": tool_name, - "tool_output": tool_output, - "tool_status": "end", - "streaming": True - } - )] - } - - - # ────────────── LLM 内容生成(保持原有逻辑) ────────────── - elif event_type == "on_chat_model_stream": - chunk = event["data"]["chunk"].content or "" - if chunk: - print(chunk, end="", flush=True) - full_content += chunk - if "\n" in chunk or len(full_content) % 4 == 0: - yield { - "messages": [AIMessage( - content=full_content, - name="Researcher", - additional_kwargs={ - "current_step": current_step, - "streaming": True - } - )] - } - - # ────────────── 其他链路事件(可选补充) ────────────── - elif event_type in ("on_chain_start", "on_chain_end"): - status = "开始" if event_type == "on_chain_start" else "完成" - current_step = f"[{status}] {name.upper()}" + if "messages" in chunk and isinstance(chunk["messages"], AIMessageChunk): yield { - "messages": [AIMessage( - content=full_content, - name="Researcher", - additional_kwargs={ - "current_step": current_step, - "streaming": True - } - )] + "messages": chunk["messages"], # 逐 token 追加 + # 可以額外 yield 一些 metadata,例如 + # "node": "Researcher", + # "status": "thinking" } - - # 最终输出 - yield { - "messages": [AIMessage( - content=full_content.strip() or "报告生成完成", - name="Researcher", - additional_kwargs={ - "current_step": "报告已完成", - "streaming": False - } - )], - "next": "Suggester" - } + else: + # 其他類型的 chunk + yield chunk +# +# async def researcher_node( +# state: AgentState, +# config: RunnableConfig +# ) -> Dict[str, Any]: +# """ +# 薄節點:只判斷是否要跑深度報告,並準備初始訊息 +# 真正的 report 生成與 streaming 交給外層或子圖處理 +# """ +# use_report = config["configurable"].get("use_report", False) +# +# if not use_report: +# return { +# "messages": [AIMessage( +# content="深度報告功能未啟用,請通過前端按鈕觸發。", +# name="Researcher" +# )], +# "next": "Supervisor" +# } +# +# # 發送初始訊息,讓前端馬上看到「正在啟動」 +# # initial_msg = AIMessage( +# # content="正在啟動深度報告生成...", +# # name="Researcher", +# # additional_kwargs={ +# # "current_step": "正在啟動深度報告生成...", +# # "streaming": True +# # } +# # ) +# +# # 方式一:最簡單,直接把 research_agent 當作下一個要執行的東西 +# # (假設 research_agent 已 compile 好,且支援 astream) +# # return { +# # "messages": state["messages"] + [initial_msg], +# # # 可以選擇加一個自訂 key 標記 +# # "report_in_progress": True, +# # # next 留空或回 Supervisor,由 conditional edges 決定 +# # } +# +# # 方式二:如果你想更明確(推薦用 Send,未來好擴充) +# return Send( +# "research_sub_agent", # 你要在 graph.add_node("research_sub_agent", research_agent) +# { +# "messages": state["messages"][-12:], +# "configurable": config["configurable"] +# } +# ) # --- 3. Visualizer Agent (视觉专家) --- async def visualizer_node(state: AgentState, config: RunnableConfig): """负责将自然语言转化为绘图 Prompt 并调用绘图工具""" diff --git a/src/server/agent/graph.py b/src/server/agent/graph.py index da86583..e18bcb2 100644 --- a/src/server/agent/graph.py +++ b/src/server/agent/graph.py @@ -2,6 +2,7 @@ from typing import Literal from langchain_core.messages import AIMessage from langchain_core.runnables import RunnableConfig from langchain_qwq import ChatQwen +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from langgraph.graph import StateGraph, END, START from pydantic import BaseModel from pymongo import MongoClient @@ -91,6 +92,7 @@ client = MongoClient(MONGO_URI) checkpointer = MongoDBSaver( client=client["furniture_agent_db"], db_name="langgraph", - collection_name="checkpoints" + collection_name="checkpoints", + serde=JsonPlusSerializer(pickle_fallback=True), # ← 關鍵這一行 ) app = workflow.compile(checkpointer=checkpointer) diff --git a/src/server/agent/tools/report_generator_tool.py b/src/server/agent/tools/report_generator_tool.py index 11effa0..e9fda29 100644 --- a/src/server/agent/tools/report_generator_tool.py +++ b/src/server/agent/tools/report_generator_tool.py @@ -3,6 +3,7 @@ import json import re from typing import Optional, List, Dict from langchain_qwq import ChatQwen +from langgraph.config import get_stream_writer from pydantic import BaseModel, Field from langchain_core.tools import tool from langchain_core.messages import SystemMessage, HumanMessage @@ -15,6 +16,7 @@ from src.core.config import settings llm = ChatQwen( + enable_thinking=False, model="qwen3.5-flash", temperature=0.2, max_tokens=3_000, @@ -47,7 +49,7 @@ class ReportInput(BaseModel): # ========================= @tool("report_generator", args_schema=ReportInput) -async def report_generator( +def report_generator( report_topic: str, structured_data: List[Dict], language: str = "English" @@ -57,11 +59,11 @@ async def report_generator( directly from structured retrieval results. """ + writer = get_stream_writer() if not structured_data: - return { - "status": "error", - "message": "No structured data provided." - } + error_msg = "Error: No structured data provided." + writer({"type": "report_error", "message": error_msg}) + return error_msg collected_data_str = json.dumps( structured_data, @@ -103,55 +105,40 @@ Input Data: # ========================= # 调用 LLM # ========================= + writer({"type": "report_start", "topic": report_topic, "language": language}) + full_report = "" try: - response = await llm.ainvoke([ + for chunk in llm.stream([ SystemMessage(content=system_prompt), HumanMessage(content=user_prompt) - ]) - - report_content = response.content.strip() - - # 清理 markdown block 包裹 - report_content = ( - report_content - .replace("```markdown", "") - .replace("```", "") - .strip() - ) - + ]): + if chunk.content: # Gemini 返回的 chunk.content + delta = chunk.content + full_report += delta + writer({"type": "report_delta", "delta": delta}) # ← 实时推送给前端 except Exception as e: - return { - "status": "error", - "message": f"LLM generation failed: {str(e)}" - } + error_msg = f"LLM generation failed: {str(e)}" + writer({"type": "report_error", "message": error_msg}) + return error_msg + + report_content = full_report.strip() # ========================= # 保存报告 # ========================= - output_dir = "workspace/reports" os.makedirs(output_dir, exist_ok=True) - safe_topic = re.sub( - r'[\\/*?:"<>|]', - "", - report_topic.replace(" ", "_") - ) - + safe_topic = re.sub(r'[\\/*?:"<>|]', "", report_topic.replace(" ", "_")) filename = f"{output_dir}/{safe_topic}.md" try: with open(filename, "w", encoding="utf-8") as f: f.write(report_content) + writer({"type": "report_complete", "file_path": filename}) except Exception as e: - return { - "status": "error", - "message": f"Failed to save report: {str(e)}" - } + writer({"type": "report_save_warning", "message": str(e)}) - return { - "status": "success", - "file_path": filename, - "message": "Report generated successfully." - } + # 返回完整内容(作为 tool result),同时正文已通过 delta 流式输出 + return report_content + f"\n\n✅ Report saved to: {filename}" diff --git a/uv.lock b/uv.lock index 4ed0f8f..fc63b27 100644 --- a/uv.lock +++ b/uv.lock @@ -253,6 +253,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/0a/a72d10ed65068e115044937873362e6e32fab1b7dce0046aeb224682c989/asgiref-3.11.1-py3-none-any.whl", hash = "sha256:e8667a091e69529631969fd45dc268fa79b99c92c5fcdda727757e52146ec133", size = 24345, upload-time = "2026-02-03T13:30:13.039Z" }, ] +[[package]] +name = "asyncio" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/71/ea/26c489a11f7ca862d5705db67683a7361ce11c23a7b98fc6c2deaeccede2/asyncio-4.0.0.tar.gz", hash = "sha256:570cd9e50db83bc1629152d4d0b7558d6451bb1bfd5dfc2e935d96fc2f40329b", size = 5371, upload-time = "2025-08-05T02:51:46.605Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/64/eff2564783bd650ca25e15938d1c5b459cda997574a510f7de69688cb0b4/asyncio-4.0.0-py3-none-any.whl", hash = "sha256:c1eddb0659231837046809e68103969b2bef8b0400d59cfa6363f6b5ed8cc88b", size = 5555, upload-time = "2025-08-05T02:51:45.767Z" }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -935,6 +944,7 @@ name = "fida" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "asyncio" }, { name = "crawl4ai" }, { name = "dashscope" }, { name = "deepagents" }, @@ -972,6 +982,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "asyncio", specifier = ">=4.0.0" }, { name = "crawl4ai", specifier = ">=0.8.0" }, { name = "dashscope", specifier = ">=1.25.13" }, { name = "deepagents", specifier = ">=0.4.3" },