Files
FiDA_Python/src/routers/chat.py
2026-02-04 17:57:49 +08:00

164 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import uuid
import json
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
from src.server.agent.graph import app # 导入已经 compile 好的 graph
from langchain_core.messages import HumanMessage
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
@router.post("/stream")
async def chat_stream(request: ChatRequest):
"""
### 家具设计流式对话接口 (SSE)
通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。
#### 1. 核心功能
* **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。
* **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。
* **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。
#### 2. 请求参数
* `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。
* `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。
* `checkpoint_id`: (可选) 历史快照 ID。
#### 3. 响应流说明 (Data Format)
响应以 `data: ` 开头的 JSON 字符串流形式发送:
- **Session Start**: `{"thread_id": "...", "status": "start"}`
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
- **Session End**: `{"status": "end"}`
#### 4. 请求示例
```
{
"message": "设计一款北欧风格的躺椅."
}
{
"message": "就以上信息直接生成sketch.",
"thread_id": "187e58af"
}
{
"message": "不要躺椅,要桌子",
"thread_id": "187e58af",
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
}
```
"""
source_thread_id = request.thread_id
checkpoint_id = request.checkpoint_id
# 1. 确定目标 thread_id
# 如果是回溯操作,我们生成一个新的 ID或者由前端传入一个新的 target_thread_id
is_branching = source_thread_id and checkpoint_id
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
# 2. 如果是分叉请求,我们需要先“搬家”状态
if is_branching:
# 获取旧状态
source_config = {"configurable": {"thread_id": source_thread_id, "checkpoint_id": checkpoint_id}}
older_state = await app.aget_state(source_config)
# 将旧状态的消息,作为新 thread 的初始值注入
# 注意:这里我们手动把旧消息塞给新 thread
new_config = {"configurable": {"thread_id": target_thread_id}}
await app.aupdate_state(new_config, older_state.values)
# 现在的 config 指向新 Thread
current_config = new_config
else:
current_config = {"configurable": {"thread_id": target_thread_id}}
async def event_generator():
# 告诉前端:现在是在哪个 Thread 上工作(如果是分叉,前端需要更新本地存储的 ID
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching}, ensure_ascii=False)}\n\n"
async for event in app.astream(
{"messages": [HumanMessage(content=request.message)]},
current_config,
stream_mode="updates"
):
# ... 发送流式内容的逻辑保持不变 ...
for node_name, output in event.items():
if "messages" in output:
msg = output["messages"][-1]
state = await app.aget_state(current_config)
yield f"data: {json.dumps({'node': node_name, 'content': msg.content, 'checkpoint_id': state.config['configurable']['checkpoint_id']}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.get("/history/{thread_id}", response_model=HistoryResponse)
async def get_chat_history(thread_id: str):
"""
### 获取项目设计历史记录
此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。
#### 1. 功能说明
* **快照列表**: 返回该项目从启动至今的所有关键节点Checkpoints
* **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。
* **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。
#### 2. 路径参数
* `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。
#### 3. 返回字段定义
* `thread_id`: 当前查询的项目ID。
* `history`: 历史记录数组,包含:
- `checkpoint_id`: 必填,回溯时使用的关键凭证。
- `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。
- `node`: 产生该快照的节点名称(如 Designer, Visualizer
- `timestamp`: 逻辑步骤序号。
#### 4. 响应示例
```json
{
"thread_id": "proj_001",
"history": [
{
"checkpoint_id": "d82f3a12",
"last_message": "我想设计一款北欧风书架",
"node": "Supervisor",
"timestamp": 1
},
{
"checkpoint_id": "f4k92m1a",
"last_message": "建议使用浅色橡木材质,增加简约感...",
"node": "Designer",
"timestamp": 2
}
]
}
```
"""
config = {"configurable": {"thread_id": thread_id}}
history_data = []
async for state in app.aget_state_history(config):
msg_content = "Initial"
if state.values and "messages" in state.values:
msgs = state.values["messages"]
if msgs and len(msgs) > 0:
last_msg = msgs[-1]
# 获取内容并做摘要截断
content = getattr(last_msg, "content", str(last_msg))
msg_content = content[:50] + ("..." if len(content) > 50 else "")
history_data.append(HistoryItem(
checkpoint_id=state.config["configurable"]["checkpoint_id"],
last_message=msg_content[:50],
node=state.metadata.get("source"),
timestamp=state.metadata.get("step")
))
return HistoryResponse(thread_id=thread_id, history=history_data)
# try:
# except Exception as e:
# raise HTTPException(status_code=404, detail=f"History not found: {str(e)}")