新增标题提取
This commit is contained in:
@@ -6,18 +6,19 @@ from typing import AsyncGenerator
|
|||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
||||||
|
|
||||||
from src.server.deep_agent.agents.main_agent import build_main_agent
|
from src.server.deep_agent.agents.main_agent import build_main_agent
|
||||||
from src.server.deep_agent.tools.extract_suggested_questions import format_messages, generate_suggested_questions
|
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
|
||||||
|
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/deep_agent_stream")
|
@router.post("/deep_agent_stream")
|
||||||
async def chat_stream(request: ChatRequest):
|
async def chat_stream(request: DeepAgentChatRequest):
|
||||||
"""
|
"""
|
||||||
### 家具设计流式对话接口 (SSE)
|
### 家具设计流式对话接口 (SSE)
|
||||||
|
|
||||||
@@ -34,6 +35,7 @@ async def chat_stream(request: ChatRequest):
|
|||||||
* `checkpoint_id`: (可选) 历史快照 ID。
|
* `checkpoint_id`: (可选) 历史快照 ID。
|
||||||
* `config_params`: (可选) 对话配置参数
|
* `config_params`: (可选) 对话配置参数
|
||||||
* `need_suggestion`: (可选) 是否需要建议按钮,需要建议的频率,0-1的浮点数
|
* `need_suggestion`: (可选) 是否需要建议按钮,需要建议的频率,0-1的浮点数
|
||||||
|
* `need_title`: (可选) 是否需要标题
|
||||||
* `use_report`: (可选) 是否需要使用report功能 true/false
|
* `use_report`: (可选) 是否需要使用report功能 true/false
|
||||||
|
|
||||||
|
|
||||||
@@ -66,61 +68,6 @@ async def chat_stream(request: ChatRequest):
|
|||||||
### 5. 响应流说明
|
### 5. 响应流说明
|
||||||
所有响应均以 data: 开头,JSON 字符串格式,末尾以 \n\n 结束
|
所有响应均以 data: 开头,JSON 字符串格式,末尾以 \n\n 结束
|
||||||
响应流包含三种类型的事件:会话开始、节点消息、会话结束
|
响应流包含三种类型的事件:会话开始、节点消息、会话结束
|
||||||
会话开始:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"thread_id": "str",
|
|
||||||
"is_branch": "boolean",
|
|
||||||
"status": "start"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
节点消息:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"node": "节点名称(如Designer/Researcher/Main)",
|
|
||||||
"content": "消息内容",
|
|
||||||
"checkpoint_id": "快照ID",
|
|
||||||
"is_delta": "boolean",
|
|
||||||
"type": "消息类型",
|
|
||||||
"tool_name": "工具名称(可选)",
|
|
||||||
"tool_call_chunk": "工具调用片段(可选)",
|
|
||||||
"tool_call_id": "工具调用ID(可选)"
|
|
||||||
}
|
|
||||||
|
|
||||||
```
|
|
||||||
报告增量消息:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"node": "Researcher",
|
|
||||||
"type": "report_delta",
|
|
||||||
"content": "报告内容增量",
|
|
||||||
"is_delta": true,
|
|
||||||
"checkpoint_id": "xxx"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
AI 消息片段:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"node": "Designer",
|
|
||||||
"content": "设计建议内容",
|
|
||||||
"checkpoint_id": "xxx",
|
|
||||||
"is_delta": true,
|
|
||||||
"type": "delta",
|
|
||||||
"tool_call_chunk": {...}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
工具执行结果:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"node": "ToolExecutor",
|
|
||||||
"content": "工具执行结果",
|
|
||||||
"checkpoint_id": "xxx",
|
|
||||||
"is_delta": false,
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_name": "ImageGenerator",
|
|
||||||
"tool_call_id": "yyy"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logger.info(f"chat request data: {request}")
|
logger.info(f"chat request data: {request}")
|
||||||
@@ -306,9 +253,16 @@ async def chat_stream(request: ChatRequest):
|
|||||||
})
|
})
|
||||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 获取建议消息
|
||||||
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
|
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
|
||||||
suggested_questions = await generate_suggested_questions(main_agent, target_thread_id)
|
suggested_questions = await generate_suggested_questions(main_agent, target_thread_id)
|
||||||
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 获取标题
|
||||||
|
if request.need_title:
|
||||||
|
title = await conversation_title(agent=main_agent, config=current_config)
|
||||||
|
yield f"data: {json.dumps({'title': title}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||||
|
|||||||
37
src/schemas/deep_agent_chat.py
Normal file
37
src/schemas/deep_agent_chat.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from pydantic import BaseModel, Field, confloat
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
type: str = Field(..., description="家具类型,如:沙发、餐桌")
|
||||||
|
region: str = Field(..., description="地区/空间,如:客厅、卧室、户外")
|
||||||
|
style: str = Field(..., description="设计风格,如:极简、工业风、中式")
|
||||||
|
temperature: confloat(ge=0, le=2.0) = Field(default=0.7, description="模型温度")
|
||||||
|
|
||||||
|
|
||||||
|
class DeepAgentChatRequest(BaseModel):
|
||||||
|
message: str = Field(..., description="用户的输入指令")
|
||||||
|
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
||||||
|
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||||
|
config_params: Optional[AgentConfig] = None
|
||||||
|
need_suggestion: float = 0
|
||||||
|
need_title: bool = False
|
||||||
|
use_report: bool = False # ← 新增:是否使用深度报告
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryItem(BaseModel):
|
||||||
|
checkpoint_id: str
|
||||||
|
last_message: str
|
||||||
|
node: Optional[str]
|
||||||
|
timestamp: Any
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryResponse(BaseModel):
|
||||||
|
thread_id: str
|
||||||
|
history: List[HistoryItem]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamChunk(BaseModel):
|
||||||
|
node: str
|
||||||
|
content: str
|
||||||
|
checkpoint_id: str
|
||||||
@@ -1,27 +1,44 @@
|
|||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
from src.server.deep_agent.init_llm import title_llm
|
from src.server.deep_agent.init_llm import title_llm
|
||||||
|
|
||||||
|
|
||||||
def conversation_title(full_conversation):
|
async def conversation_title(agent, config):
|
||||||
title_prompt = PromptTemplate(
|
state = agent.get_state(config)
|
||||||
input_variables=["full_conversation"],
|
messages = state.values.get("messages", [])
|
||||||
template="""
|
if len(messages) < 2:
|
||||||
请严格按照以下要求生成对话标题:
|
return None
|
||||||
1. 标题长度:8-15个字,纯中文,无标点、无特殊符号、无换行
|
|
||||||
2. 标题内容:基于完整对话,精准概括核心主题(兼顾用户需求和助手回复)
|
|
||||||
3. 标题风格:自然口语化,符合中文表达习惯,不冗余
|
|
||||||
|
|
||||||
完整对话内容:
|
|
||||||
{full_conversation}
|
|
||||||
|
|
||||||
仅输出标题,不要输出任何额外解释、说明或标点符号。
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
title_chain = title_prompt | title_llm
|
|
||||||
response = title_chain.invoke({"full_conversation": full_conversation})
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
user_msg = None
|
||||||
|
ai_msg = None
|
||||||
|
|
||||||
if __name__ == '__main__':
|
for m in messages:
|
||||||
print(conversation_title("你好"))
|
if isinstance(m, HumanMessage) and user_msg is None:
|
||||||
|
user_msg = m.content
|
||||||
|
|
||||||
|
if isinstance(m, AIMessage) and ai_msg is None:
|
||||||
|
ai_msg = m.content
|
||||||
|
|
||||||
|
if user_msg and ai_msg:
|
||||||
|
break
|
||||||
|
if not user_msg or not ai_msg:
|
||||||
|
return None
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
请根据以下首次对话内容,生成一个简洁、精准的标题(2-15个字):
|
||||||
|
|
||||||
|
用户:{user_msg}
|
||||||
|
助手:{ai_msg}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 标题需概括对话核心内容
|
||||||
|
2. 语言简洁、符合中文表达习惯
|
||||||
|
3. 仅返回标题,不要额外解释
|
||||||
|
"""
|
||||||
|
response = await title_llm.ainvoke(prompt)
|
||||||
|
title = response.content.strip()
|
||||||
|
|
||||||
|
# 去掉可能的符号
|
||||||
|
title = title.replace("《", "").replace("》", "")
|
||||||
|
return title
|
||||||
|
|||||||
Reference in New Issue
Block a user