From 5951205ac9a7d9b010f4f33a5a65d24818edc47d Mon Sep 17 00:00:00 2001 From: zcr Date: Fri, 6 Mar 2026 11:08:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=93=8D=E5=BA=94=E6=B6=88=E6=81=AF=E4=B8=AD?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E2=80=9C=E4=BA=8B=E4=BB=B6=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E2=80=9D=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 4 ++-- src/routers/chat.py | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 0b87742..d7104e4 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -import logging +import logging.config import uvicorn from fastapi import FastAPI @@ -33,4 +33,4 @@ async def root(): if __name__ == "__main__": - uvicorn.run("main:app_server", host="0.0.0.0", port=7777, reload=True) + uvicorn.run("main:app_server", host="0.0.0.0", port=7777, reload=False) diff --git a/src/routers/chat.py b/src/routers/chat.py index b3a6854..c870c27 100644 --- a/src/routers/chat.py +++ b/src/routers/chat.py @@ -195,10 +195,37 @@ async def chat_stream(request: ChatRequest): latest_state = await app.aget_state(current_config) configurable = latest_state.config.get("configurable", {}) current_cp_id = configurable.get("checkpoint_id", "") - if len(event) == 3: namespace, channel, payload = event - if event[1] == "custom": + # 路由更新 + if event[1] == "updates": + if isinstance(payload, dict): + for node_name, update_content in payload.items(): + + # 将 LangChain Message 转为可 JSON 序列化 + if isinstance(update_content, dict) and "messages" in update_content: + msgs = [] + for m in update_content["messages"]: + msgs.append({ + "type": m.__class__.__name__, + "content": getattr(m, "content", ""), + "tool_calls": getattr(m, "tool_calls", []), + }) + update_content = { + **update_content, + "messages": msgs + } + + yield f"data: {json.dumps({ + "node": node_name, + "type": "updates", + "content": update_content, + "is_delta": False, + "checkpoint_id": current_cp_id, + }, ensure_ascii=False)}\n\n" + + # 自定义事件 + elif event[1] == "custom": if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", "report_error", "report_save_warning", "report_complete"): delta = payload.get("delta", "").strip() if delta: @@ -209,7 +236,8 @@ async def chat_stream(request: ChatRequest): 'is_delta': True, 'checkpoint_id': current_cp_id, }, ensure_ascii=False)}\n\n" - if event[1] == "messages": + # 基础消息 + elif event[1] == "messages": if namespace: node_name = namespace[-1] if isinstance(namespace, tuple) else namespace if ':' in node_name: