新增对话接口
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import AsyncGenerator
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
from src.server.agent.graph import app
|
||||
from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
@@ -169,8 +169,10 @@ async def chat_stream(request: ChatRequest):
|
||||
await app.aupdate_state(current_config, combined_values)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# 初始事件
|
||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 構造輸入(保持不變)
|
||||
new_messages = initial_messages[:] if not source_thread_id else []
|
||||
new_messages.append(HumanMessage(content=request.message))
|
||||
|
||||
@@ -180,143 +182,124 @@ async def chat_stream(request: ChatRequest):
|
||||
"use_report": request.use_report,
|
||||
}
|
||||
|
||||
interrupted = False
|
||||
|
||||
current_cp_id = None
|
||||
# try:
|
||||
# ─── 重點改這裡 ───────────────────────────────────────
|
||||
async for event in app.astream(
|
||||
input_data,
|
||||
config=current_config,
|
||||
stream_mode=["updates", "messages", "custom"], # 确保包含 "values"
|
||||
stream_mode=["custom", "updates", "messages"], # 推薦組合
|
||||
subgraphs=True
|
||||
# 不再需要,行為已包含
|
||||
):
|
||||
if interrupted:
|
||||
break
|
||||
logger.info(event)
|
||||
# 取得 checkpoint_id(可選,視前端是否真的需要)
|
||||
latest_state = await app.aget_state(current_config)
|
||||
configurable = latest_state.config.get("configurable", {})
|
||||
current_cp_id = configurable.get("checkpoint_id", "")
|
||||
if len(event) == 3:
|
||||
namespace, channel, payload = event
|
||||
# 路由更新
|
||||
if event[1] == "updates":
|
||||
namespace, _, payload = event
|
||||
|
||||
logger.info(f"Received event: {event}")
|
||||
if isinstance(payload, dict):
|
||||
for update_node, update_content in payload.items():
|
||||
|
||||
if not isinstance(event, tuple) or len(event) != 3:
|
||||
continue
|
||||
# 处理 reducer(Overwrite / Append)
|
||||
if isinstance(update_content, dict):
|
||||
for k, v in update_content.items():
|
||||
if hasattr(v, "value"): # Overwrite(...)
|
||||
update_content[k] = v.value
|
||||
|
||||
run_id, channel, payload = event
|
||||
if isinstance(update_content, dict) and "messages" in update_content:
|
||||
msgs = []
|
||||
for m in update_content["messages"]:
|
||||
msgs.append({
|
||||
"type": m.__class__.__name__,
|
||||
"content": getattr(m, "content", ""),
|
||||
"name": getattr(m, "name", None),
|
||||
"tool_calls": getattr(m, "tool_calls", None),
|
||||
})
|
||||
update_content["messages"] = msgs
|
||||
|
||||
# ─────────────── 检测 interrupt ───────────────
|
||||
# __interrupt__ 最常出现在 "values" 或 "updates" channel 的 payload 中
|
||||
if channel in ("values", "updates") and isinstance(payload, dict) and "__interrupt__" in payload:
|
||||
interrupt_data = payload["__interrupt__"][0].value['__interrupt__']
|
||||
interrupted = True
|
||||
yield f"data: {json.dumps({
|
||||
"type": "interrupt",
|
||||
"node": interrupt_data.get("node", "Persona"),
|
||||
"question": interrupt_data.get('question', "异常|||||||||||||"),
|
||||
"current_persona": interrupt_data.get("current_persona_snapshot", {}),
|
||||
"status": "waiting_for_input"
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({
|
||||
"node": "Supervisor",
|
||||
"type": "updates",
|
||||
"content": update_content,
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 立即停止后续发送,等待用户回复后 resume
|
||||
break
|
||||
# 自定义事件
|
||||
elif event[1] == "custom":
|
||||
if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", "report_error", "report_save_warning", "report_complete"):
|
||||
delta = payload.get("delta", "").strip()
|
||||
if delta:
|
||||
yield f"data: {json.dumps({
|
||||
'node': 'Researcher',
|
||||
'type': 'report_delta',
|
||||
'content': delta,
|
||||
'is_delta': True,
|
||||
'checkpoint_id': current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
# 基础消息
|
||||
elif event[1] == "messages":
|
||||
if namespace:
|
||||
node_name = namespace[-1] if isinstance(namespace, tuple) else namespace
|
||||
if ':' in node_name:
|
||||
node_name = node_name.split(':')[0]
|
||||
else:
|
||||
node_name = "Main"
|
||||
message, metadata = payload
|
||||
is_not_research = node_name != 'Researcher'
|
||||
node_name = metadata.get("langgraph_node", node_name)
|
||||
# 3. 处理不同类型的 message
|
||||
payload_out = {
|
||||
"node": node_name,
|
||||
"checkpoint_id": current_cp_id, # 你之前已经获取了
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": "unknown"
|
||||
}
|
||||
|
||||
# ─────────────── 正常消息处理 ───────────────
|
||||
if channel == "messages":
|
||||
if run_id:
|
||||
node_name = run_id[-1] if isinstance(run_id, tuple) else run_id
|
||||
if ':' in node_name:
|
||||
node_name = node_name.split(':')[0]
|
||||
else:
|
||||
node_name = "Main"
|
||||
|
||||
message, metadata = payload
|
||||
node_name = metadata.get("langgraph_node", node_name)
|
||||
|
||||
payload_out = {
|
||||
"node": node_name,
|
||||
"checkpoint_id": current_cp_id or "unknown",
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": "unknown"
|
||||
}
|
||||
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if node_name != 'Researcher' and message.tool_call_chunks:
|
||||
payload_out.update({
|
||||
"type": "delta",
|
||||
"is_delta": True,
|
||||
"content": message.content,
|
||||
"tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif isinstance(message, ToolMessage):
|
||||
try:
|
||||
tools_data = json.loads(message.content)
|
||||
if isinstance(message, AIMessageChunk):
|
||||
# 节点不是research 并且 tool_call_chunks不为空的情况下,避免research的report工具使用custom发出的消息和message的消息重复了
|
||||
if is_not_research and node_name != 'Researcher' and message.tool_call_chunks:
|
||||
payload_out.update({
|
||||
"type": "delta",
|
||||
"is_delta": True,
|
||||
"content": message.content,
|
||||
# 如果有 tool call chunk,也可以在这里处理
|
||||
"tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif isinstance(message, ToolMessage):
|
||||
# 工具执行结果(完整的一次性输出)
|
||||
payload_out.update({
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": tools_data.get("data", ""),
|
||||
"tool_name": tools_data.get("tool_name", ""),
|
||||
"content": message.content,
|
||||
"tool_name": message.name,
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
except:
|
||||
pass
|
||||
|
||||
elif isinstance(message, AIMessage):
|
||||
payload_out.update({
|
||||
"type": "complete_message",
|
||||
"is_delta": False,
|
||||
"content": message.content
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif channel == "updates":
|
||||
# 处理 updates(非 interrupt 的部分)
|
||||
if isinstance(payload, dict):
|
||||
for update_node, update_content in payload.items():
|
||||
# 处理 reducer 包裹的值
|
||||
if isinstance(update_content, dict):
|
||||
for k, v in update_content.items():
|
||||
if hasattr(v, "value"):
|
||||
update_content[k] = v.value
|
||||
|
||||
# 序列化 messages
|
||||
if isinstance(update_content, dict) and "messages" in update_content:
|
||||
msgs = []
|
||||
for m in update_content["messages"]:
|
||||
msgs.append({
|
||||
"type": m.__class__.__name__,
|
||||
"content": getattr(m, "content", ""),
|
||||
"name": getattr(m, "name", None),
|
||||
"tool_calls": getattr(m, "tool_calls", None),
|
||||
})
|
||||
update_content["messages"] = msgs
|
||||
|
||||
yield f"data: {json.dumps({
|
||||
"node": "Supervisor", # 或 update_node
|
||||
"type": "updates",
|
||||
"content": update_content,
|
||||
elif isinstance(message, AIMessage):
|
||||
# 完整 AIMessage(不常见在 messages 模式下,但以防万一)
|
||||
payload_out.update({
|
||||
"type": "complete_message",
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
"content": message.content
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif channel == "custom":
|
||||
if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", ...):
|
||||
delta = payload.get("delta", "").strip()
|
||||
if delta:
|
||||
yield f"data: {json.dumps({
|
||||
'node': 'Researcher',
|
||||
'type': 'report_delta',
|
||||
'content': delta,
|
||||
'is_delta': True,
|
||||
'checkpoint_id': current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
# except Exception as e:
|
||||
# print("error")
|
||||
else:
|
||||
# 其他未知类型,记录日志
|
||||
print(f"未知消息类型: {type(message)}", message)
|
||||
continue
|
||||
|
||||
# 结束标记
|
||||
if interrupted:
|
||||
yield f"data: {json.dumps({'status': 'interrupted', 'reason': 'waiting_for_user_input'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
# 流結束
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import random
|
||||
import uuid
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
@@ -9,6 +10,7 @@ from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
||||
|
||||
from src.server.deep_agent.agents.main_agent import build_main_agent
|
||||
from src.server.deep_agent.tools.extract_suggested_questions import format_messages, generate_suggested_questions
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,7 +82,6 @@ async def chat_stream(request: ChatRequest):
|
||||
"checkpoint_id": "快照ID",
|
||||
"is_delta": "boolean",
|
||||
"type": "消息类型",
|
||||
"suggestions": "建议列表(可选)",
|
||||
"tool_name": "工具名称(可选)",
|
||||
"tool_call_chunk": "工具调用片段(可选)",
|
||||
"tool_call_id": "工具调用ID(可选)"
|
||||
@@ -136,8 +137,6 @@ async def chat_stream(request: ChatRequest):
|
||||
# 2. 配置參數
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
|
||||
need_suggestion = request.need_suggestion,
|
||||
|
||||
current_config = {
|
||||
"recursion_limit": 120,
|
||||
"configurable": {
|
||||
@@ -184,28 +183,55 @@ async def chat_stream(request: ChatRequest):
|
||||
input_data = {
|
||||
"messages": new_messages,
|
||||
}
|
||||
current_cp_id = None
|
||||
async for stream in main_agent.astream(
|
||||
input_data,
|
||||
config=current_config,
|
||||
stream_mode=["updates", "messages", "custom"], # 确保包含 "values"
|
||||
stream_mode=["updates", "messages", "custom"],
|
||||
subgraphs=True
|
||||
):
|
||||
# logger.info(f"Received event: {event}")
|
||||
_, mode, chunks = stream
|
||||
if mode == "updates":
|
||||
# TODO 补充
|
||||
if mode == "updates": # 只做记录 不做事件返回
|
||||
print(f"[updates] {chunks}")
|
||||
update_model_messages = chunks.get("model", None)
|
||||
update_tools_messages = chunks.get("tools", None)
|
||||
payload_out = {
|
||||
"node": "",
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"type": "updates"
|
||||
}
|
||||
|
||||
if update_model_messages:
|
||||
model_messages = update_model_messages.get("messages", [])
|
||||
for model_token in model_messages:
|
||||
if isinstance(model_token, AIMessage):
|
||||
model_content_blocks = model_token.content_blocks[0]
|
||||
model_name = model_token.name
|
||||
payload_out.update({
|
||||
"node": model_name if model_name else "main",
|
||||
"tool_calls": model_token.tool_calls
|
||||
})
|
||||
logger.info(f"[updates] {model_name} -- {model_content_blocks} -- {model_token.tool_calls}")
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif update_tools_messages:
|
||||
tools_messages = update_tools_messages.get("messages", [])
|
||||
for tools_token in tools_messages:
|
||||
if isinstance(tools_token, ToolMessage):
|
||||
tool_content_blocks = tools_token.content_blocks[0]
|
||||
tool_name = tools_token.name
|
||||
logger.info(f"[updates] {tool_name} -- {tool_content_blocks}")
|
||||
else:
|
||||
logger.info(f"[updates] -- {chunks}")
|
||||
|
||||
elif mode == "messages":
|
||||
token, metadata = chunks
|
||||
subagent_name = metadata.get('lc_agent_name', None)
|
||||
subagent_name = metadata.get('lc_agent_name', "main")
|
||||
payload_out = {
|
||||
"node": subagent_name,
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": ""
|
||||
}
|
||||
|
||||
@@ -213,29 +239,36 @@ async def chat_stream(request: ChatRequest):
|
||||
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
if reasoning:
|
||||
payload_out.update({
|
||||
"type": "reasoning",
|
||||
"is_delta": True,
|
||||
"content": text,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
if len(reasoning) == 1:
|
||||
payload_out.update({
|
||||
"type": "reasoning",
|
||||
"is_delta": True,
|
||||
"content": reasoning[0].get("reasoning", ""),
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
else:
|
||||
print(f"[reasoning] {reasoning}*************************************************************************************")
|
||||
elif text:
|
||||
payload_out.update({
|
||||
"type": "text",
|
||||
"is_delta": True,
|
||||
"content": text,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
if len(text) == 1:
|
||||
payload_out.update({
|
||||
"type": "text",
|
||||
"is_delta": True,
|
||||
"content": text[0].get("text", ""),
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
else:
|
||||
print(f"[text] {text}*************************************************************************************")
|
||||
else:
|
||||
payload_out.update({
|
||||
"type": "tool_call",
|
||||
"is_delta": True,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif isinstance(token, ToolMessageChunk): # 工具返回
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
payload_out.update({
|
||||
"type": "tool_text",
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": text,
|
||||
"tool_name": token.name,
|
||||
@@ -244,7 +277,7 @@ async def chat_stream(request: ChatRequest):
|
||||
elif isinstance(token, ToolMessage): # 工具返回
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
payload_out.update({
|
||||
"type": "tool_text",
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": text,
|
||||
"tool_name": token.name,
|
||||
@@ -254,14 +287,11 @@ async def chat_stream(request: ChatRequest):
|
||||
continue
|
||||
|
||||
elif mode == "custom":
|
||||
token, metadata = chunks
|
||||
subagent_name = metadata.get('lc_agent_name', None)
|
||||
payload_out = {
|
||||
"node": subagent_name,
|
||||
"node": "research-agent",
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": ""
|
||||
}
|
||||
delta = chunks.get("delta", "").strip()
|
||||
@@ -273,40 +303,10 @@ async def chat_stream(request: ChatRequest):
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
# elif channel == "updates":
|
||||
# # 处理 updates(非 interrupt 的部分)
|
||||
# if isinstance(payload, dict):
|
||||
# for update_node, update_content in payload.items():
|
||||
# # 处理 reducer 包裹的值
|
||||
# if isinstance(update_content, dict):
|
||||
# for k, v in update_content.items():
|
||||
# if hasattr(v, "value"):
|
||||
# update_content[k] = v.value
|
||||
#
|
||||
# # 序列化 messages
|
||||
# if isinstance(update_content, dict) and "messages" in update_content:
|
||||
# msgs = []
|
||||
# for m in update_content["messages"]:
|
||||
# msgs.append({
|
||||
# "type": m.__class__.__name__,
|
||||
# "content": getattr(m, "content", ""),
|
||||
# "name": getattr(m, "name", None),
|
||||
# "tool_calls": getattr(m, "tool_calls", None),
|
||||
# })
|
||||
# update_content["messages"] = msgs
|
||||
#
|
||||
# yield f"data: {json.dumps({
|
||||
# "node": "Supervisor", # 或 update_node
|
||||
# "type": "updates",
|
||||
# "content": update_content,
|
||||
# "is_delta": False,
|
||||
# "checkpoint_id": current_cp_id,
|
||||
# }, ensure_ascii=False)}\n\n"
|
||||
#
|
||||
# elif channel == "custom":
|
||||
else:
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
|
||||
suggested_questions = generate_suggested_questions(main_agent, target_thread_id)
|
||||
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
@@ -357,6 +357,7 @@ async def get_chat_history(thread_id: str):
|
||||
"""
|
||||
config = {"configurable": {"thread_id": thread_id}, }
|
||||
history_data = []
|
||||
main_agent = build_main_agent(False)
|
||||
async for state in main_agent.aget_state_history(config):
|
||||
msg_content = "Initial"
|
||||
if state.values and "messages" in state.values:
|
||||
|
||||
Reference in New Issue
Block a user