新增对话接口
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import random
|
||||
import uuid
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
@@ -9,6 +10,7 @@ from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
||||
|
||||
from src.server.deep_agent.agents.main_agent import build_main_agent
|
||||
from src.server.deep_agent.tools.extract_suggested_questions import format_messages, generate_suggested_questions
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,7 +82,6 @@ async def chat_stream(request: ChatRequest):
|
||||
"checkpoint_id": "快照ID",
|
||||
"is_delta": "boolean",
|
||||
"type": "消息类型",
|
||||
"suggestions": "建议列表(可选)",
|
||||
"tool_name": "工具名称(可选)",
|
||||
"tool_call_chunk": "工具调用片段(可选)",
|
||||
"tool_call_id": "工具调用ID(可选)"
|
||||
@@ -136,8 +137,6 @@ async def chat_stream(request: ChatRequest):
|
||||
# 2. 配置參數
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
|
||||
need_suggestion = request.need_suggestion,
|
||||
|
||||
current_config = {
|
||||
"recursion_limit": 120,
|
||||
"configurable": {
|
||||
@@ -184,28 +183,55 @@ async def chat_stream(request: ChatRequest):
|
||||
input_data = {
|
||||
"messages": new_messages,
|
||||
}
|
||||
current_cp_id = None
|
||||
async for stream in main_agent.astream(
|
||||
input_data,
|
||||
config=current_config,
|
||||
stream_mode=["updates", "messages", "custom"], # 确保包含 "values"
|
||||
stream_mode=["updates", "messages", "custom"],
|
||||
subgraphs=True
|
||||
):
|
||||
# logger.info(f"Received event: {event}")
|
||||
_, mode, chunks = stream
|
||||
if mode == "updates":
|
||||
# TODO 补充
|
||||
if mode == "updates": # 只做记录 不做事件返回
|
||||
print(f"[updates] {chunks}")
|
||||
update_model_messages = chunks.get("model", None)
|
||||
update_tools_messages = chunks.get("tools", None)
|
||||
payload_out = {
|
||||
"node": "",
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"type": "updates"
|
||||
}
|
||||
|
||||
if update_model_messages:
|
||||
model_messages = update_model_messages.get("messages", [])
|
||||
for model_token in model_messages:
|
||||
if isinstance(model_token, AIMessage):
|
||||
model_content_blocks = model_token.content_blocks[0]
|
||||
model_name = model_token.name
|
||||
payload_out.update({
|
||||
"node": model_name if model_name else "main",
|
||||
"tool_calls": model_token.tool_calls
|
||||
})
|
||||
logger.info(f"[updates] {model_name} -- {model_content_blocks} -- {model_token.tool_calls}")
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif update_tools_messages:
|
||||
tools_messages = update_tools_messages.get("messages", [])
|
||||
for tools_token in tools_messages:
|
||||
if isinstance(tools_token, ToolMessage):
|
||||
tool_content_blocks = tools_token.content_blocks[0]
|
||||
tool_name = tools_token.name
|
||||
logger.info(f"[updates] {tool_name} -- {tool_content_blocks}")
|
||||
else:
|
||||
logger.info(f"[updates] -- {chunks}")
|
||||
|
||||
elif mode == "messages":
|
||||
token, metadata = chunks
|
||||
subagent_name = metadata.get('lc_agent_name', None)
|
||||
subagent_name = metadata.get('lc_agent_name', "main")
|
||||
payload_out = {
|
||||
"node": subagent_name,
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": ""
|
||||
}
|
||||
|
||||
@@ -213,29 +239,36 @@ async def chat_stream(request: ChatRequest):
|
||||
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
if reasoning:
|
||||
payload_out.update({
|
||||
"type": "reasoning",
|
||||
"is_delta": True,
|
||||
"content": text,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
if len(reasoning) == 1:
|
||||
payload_out.update({
|
||||
"type": "reasoning",
|
||||
"is_delta": True,
|
||||
"content": reasoning[0].get("reasoning", ""),
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
else:
|
||||
print(f"[reasoning] {reasoning}*************************************************************************************")
|
||||
elif text:
|
||||
payload_out.update({
|
||||
"type": "text",
|
||||
"is_delta": True,
|
||||
"content": text,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
if len(text) == 1:
|
||||
payload_out.update({
|
||||
"type": "text",
|
||||
"is_delta": True,
|
||||
"content": text[0].get("text", ""),
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
else:
|
||||
print(f"[text] {text}*************************************************************************************")
|
||||
else:
|
||||
payload_out.update({
|
||||
"type": "tool_call",
|
||||
"is_delta": True,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif isinstance(token, ToolMessageChunk): # 工具返回
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
payload_out.update({
|
||||
"type": "tool_text",
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": text,
|
||||
"tool_name": token.name,
|
||||
@@ -244,7 +277,7 @@ async def chat_stream(request: ChatRequest):
|
||||
elif isinstance(token, ToolMessage): # 工具返回
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
payload_out.update({
|
||||
"type": "tool_text",
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": text,
|
||||
"tool_name": token.name,
|
||||
@@ -254,14 +287,11 @@ async def chat_stream(request: ChatRequest):
|
||||
continue
|
||||
|
||||
elif mode == "custom":
|
||||
token, metadata = chunks
|
||||
subagent_name = metadata.get('lc_agent_name', None)
|
||||
payload_out = {
|
||||
"node": subagent_name,
|
||||
"node": "research-agent",
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": ""
|
||||
}
|
||||
delta = chunks.get("delta", "").strip()
|
||||
@@ -273,40 +303,10 @@ async def chat_stream(request: ChatRequest):
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
# elif channel == "updates":
|
||||
# # 处理 updates(非 interrupt 的部分)
|
||||
# if isinstance(payload, dict):
|
||||
# for update_node, update_content in payload.items():
|
||||
# # 处理 reducer 包裹的值
|
||||
# if isinstance(update_content, dict):
|
||||
# for k, v in update_content.items():
|
||||
# if hasattr(v, "value"):
|
||||
# update_content[k] = v.value
|
||||
#
|
||||
# # 序列化 messages
|
||||
# if isinstance(update_content, dict) and "messages" in update_content:
|
||||
# msgs = []
|
||||
# for m in update_content["messages"]:
|
||||
# msgs.append({
|
||||
# "type": m.__class__.__name__,
|
||||
# "content": getattr(m, "content", ""),
|
||||
# "name": getattr(m, "name", None),
|
||||
# "tool_calls": getattr(m, "tool_calls", None),
|
||||
# })
|
||||
# update_content["messages"] = msgs
|
||||
#
|
||||
# yield f"data: {json.dumps({
|
||||
# "node": "Supervisor", # 或 update_node
|
||||
# "type": "updates",
|
||||
# "content": update_content,
|
||||
# "is_delta": False,
|
||||
# "checkpoint_id": current_cp_id,
|
||||
# }, ensure_ascii=False)}\n\n"
|
||||
#
|
||||
# elif channel == "custom":
|
||||
else:
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
|
||||
suggested_questions = generate_suggested_questions(main_agent, target_thread_id)
|
||||
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
@@ -357,6 +357,7 @@ async def get_chat_history(thread_id: str):
|
||||
"""
|
||||
config = {"configurable": {"thread_id": thread_id}, }
|
||||
history_data = []
|
||||
main_agent = build_main_agent(False)
|
||||
async for state in main_agent.aget_state_history(config):
|
||||
msg_content = "Initial"
|
||||
if state.values and "messages" in state.values:
|
||||
|
||||
Reference in New Issue
Block a user