新增对话接口

This commit is contained in:
zcr
2026-03-12 13:13:52 +08:00
parent 7042d428fa
commit a6393df0e3
35 changed files with 843 additions and 1163 deletions

View File

@@ -1,4 +1,5 @@
import logging
import random
import uuid
import json
from typing import AsyncGenerator
@@ -9,6 +10,7 @@ from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
from src.server.deep_agent.agents.main_agent import build_main_agent
from src.server.deep_agent.tools.extract_suggested_questions import format_messages, generate_suggested_questions
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
logger = logging.getLogger(__name__)
@@ -80,7 +82,6 @@ async def chat_stream(request: ChatRequest):
"checkpoint_id": "快照ID",
"is_delta": "boolean",
"type": "消息类型",
"suggestions": "建议列表(可选)",
"tool_name": "工具名称(可选)",
"tool_call_chunk": "工具调用片段(可选)",
"tool_call_id": "工具调用ID可选"
@@ -136,8 +137,6 @@ async def chat_stream(request: ChatRequest):
# 2. 配置參數
temp = request.config_params.temperature if request.config_params else 0.7
need_suggestion = request.need_suggestion,
current_config = {
"recursion_limit": 120,
"configurable": {
@@ -184,28 +183,55 @@ async def chat_stream(request: ChatRequest):
input_data = {
"messages": new_messages,
}
current_cp_id = None
async for stream in main_agent.astream(
input_data,
config=current_config,
stream_mode=["updates", "messages", "custom"], # 确保包含 "values"
stream_mode=["updates", "messages", "custom"],
subgraphs=True
):
# logger.info(f"Received event: {event}")
_, mode, chunks = stream
if mode == "updates":
# TODO 补充
if mode == "updates": # 只做记录 不做事件返回
print(f"[updates] {chunks}")
update_model_messages = chunks.get("model", None)
update_tools_messages = chunks.get("tools", None)
payload_out = {
"node": "",
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
"is_delta": False,
"content": "",
"type": "updates"
}
if update_model_messages:
model_messages = update_model_messages.get("messages", [])
for model_token in model_messages:
if isinstance(model_token, AIMessage):
model_content_blocks = model_token.content_blocks[0]
model_name = model_token.name
payload_out.update({
"node": model_name if model_name else "main",
"tool_calls": model_token.tool_calls
})
logger.info(f"[updates] {model_name} -- {model_content_blocks} -- {model_token.tool_calls}")
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif update_tools_messages:
tools_messages = update_tools_messages.get("messages", [])
for tools_token in tools_messages:
if isinstance(tools_token, ToolMessage):
tool_content_blocks = tools_token.content_blocks[0]
tool_name = tools_token.name
logger.info(f"[updates] {tool_name} -- {tool_content_blocks}")
else:
logger.info(f"[updates] -- {chunks}")
elif mode == "messages":
token, metadata = chunks
subagent_name = metadata.get('lc_agent_name', None)
subagent_name = metadata.get('lc_agent_name', "main")
payload_out = {
"node": subagent_name,
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
"is_delta": False,
"content": "",
"suggestions": [],
"type": ""
}
@@ -213,29 +239,36 @@ async def chat_stream(request: ChatRequest):
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
text = [b for b in token.content_blocks if b["type"] == "text"]
if reasoning:
payload_out.update({
"type": "reasoning",
"is_delta": True,
"content": text,
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
if len(reasoning) == 1:
payload_out.update({
"type": "reasoning",
"is_delta": True,
"content": reasoning[0].get("reasoning", ""),
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
print(f"[reasoning] {reasoning}*************************************************************************************")
elif text:
payload_out.update({
"type": "text",
"is_delta": True,
"content": text,
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
if len(text) == 1:
payload_out.update({
"type": "text",
"is_delta": True,
"content": text[0].get("text", ""),
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
print(f"[text] {text}*************************************************************************************")
else:
payload_out.update({
"type": "tool_call",
"is_delta": True,
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(token, ToolMessageChunk): # 工具返回
text = [b for b in token.content_blocks if b["type"] == "text"]
payload_out.update({
"type": "tool_text",
"type": "tool_result",
"is_delta": False,
"content": text,
"tool_name": token.name,
@@ -244,7 +277,7 @@ async def chat_stream(request: ChatRequest):
elif isinstance(token, ToolMessage): # 工具返回
text = [b for b in token.content_blocks if b["type"] == "text"]
payload_out.update({
"type": "tool_text",
"type": "tool_result",
"is_delta": False,
"content": text,
"tool_name": token.name,
@@ -254,14 +287,11 @@ async def chat_stream(request: ChatRequest):
continue
elif mode == "custom":
token, metadata = chunks
subagent_name = metadata.get('lc_agent_name', None)
payload_out = {
"node": subagent_name,
"node": "research-agent",
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
"is_delta": False,
"content": "",
"suggestions": [],
"type": ""
}
delta = chunks.get("delta", "").strip()
@@ -273,40 +303,10 @@ async def chat_stream(request: ChatRequest):
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
# elif channel == "updates":
# # 处理 updates非 interrupt 的部分)
# if isinstance(payload, dict):
# for update_node, update_content in payload.items():
# # 处理 reducer 包裹的值
# if isinstance(update_content, dict):
# for k, v in update_content.items():
# if hasattr(v, "value"):
# update_content[k] = v.value
#
# # 序列化 messages
# if isinstance(update_content, dict) and "messages" in update_content:
# msgs = []
# for m in update_content["messages"]:
# msgs.append({
# "type": m.__class__.__name__,
# "content": getattr(m, "content", ""),
# "name": getattr(m, "name", None),
# "tool_calls": getattr(m, "tool_calls", None),
# })
# update_content["messages"] = msgs
#
# yield f"data: {json.dumps({
# "node": "Supervisor", # 或 update_node
# "type": "updates",
# "content": update_content,
# "is_delta": False,
# "checkpoint_id": current_cp_id,
# }, ensure_ascii=False)}\n\n"
#
# elif channel == "custom":
else:
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
suggested_questions = generate_suggested_questions(main_agent, target_thread_id)
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -357,6 +357,7 @@ async def get_chat_history(thread_id: str):
"""
config = {"configurable": {"thread_id": thread_id}, }
history_data = []
main_agent = build_main_agent(False)
async for state in main_agent.aget_state_history(config):
msg_content = "Initial"
if state.values and "messages" in state.values: