164 lines
6.9 KiB
Python
164 lines
6.9 KiB
Python
|
|
import uuid
|
|||
|
|
import json
|
|||
|
|
from fastapi import APIRouter
|
|||
|
|
from fastapi.responses import StreamingResponse
|
|||
|
|
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
|||
|
|
from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
|||
|
|
from langchain_core.messages import HumanMessage
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/stream")
|
|||
|
|
async def chat_stream(request: ChatRequest):
|
|||
|
|
"""
|
|||
|
|
### 家具设计流式对话接口 (SSE)
|
|||
|
|
|
|||
|
|
通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。
|
|||
|
|
|
|||
|
|
#### 1. 核心功能
|
|||
|
|
* **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。
|
|||
|
|
* **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。
|
|||
|
|
* **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。
|
|||
|
|
|
|||
|
|
#### 2. 请求参数
|
|||
|
|
* `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。
|
|||
|
|
* `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。
|
|||
|
|
* `checkpoint_id`: (可选) 历史快照 ID。
|
|||
|
|
|
|||
|
|
#### 3. 响应流说明 (Data Format)
|
|||
|
|
响应以 `data: ` 开头的 JSON 字符串流形式发送:
|
|||
|
|
- **Session Start**: `{"thread_id": "...", "status": "start"}`
|
|||
|
|
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
|
|||
|
|
- **Session End**: `{"status": "end"}`
|
|||
|
|
|
|||
|
|
#### 4. 请求示例
|
|||
|
|
```
|
|||
|
|
{
|
|||
|
|
"message": "设计一款北欧风格的躺椅."
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
{
|
|||
|
|
"message": "就以上信息直接生成sketch.",
|
|||
|
|
"thread_id": "187e58af"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
{
|
|||
|
|
"message": "不要躺椅,要桌子",
|
|||
|
|
"thread_id": "187e58af",
|
|||
|
|
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
"""
|
|||
|
|
source_thread_id = request.thread_id
|
|||
|
|
checkpoint_id = request.checkpoint_id
|
|||
|
|
|
|||
|
|
# 1. 确定目标 thread_id
|
|||
|
|
# 如果是回溯操作,我们生成一个新的 ID,或者由前端传入一个新的 target_thread_id
|
|||
|
|
is_branching = source_thread_id and checkpoint_id
|
|||
|
|
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
|||
|
|
|
|||
|
|
# 2. 如果是分叉请求,我们需要先“搬家”状态
|
|||
|
|
if is_branching:
|
|||
|
|
# 获取旧状态
|
|||
|
|
source_config = {"configurable": {"thread_id": source_thread_id, "checkpoint_id": checkpoint_id}}
|
|||
|
|
older_state = await app.aget_state(source_config)
|
|||
|
|
|
|||
|
|
# 将旧状态的消息,作为新 thread 的初始值注入
|
|||
|
|
# 注意:这里我们手动把旧消息塞给新 thread
|
|||
|
|
new_config = {"configurable": {"thread_id": target_thread_id}}
|
|||
|
|
await app.aupdate_state(new_config, older_state.values)
|
|||
|
|
|
|||
|
|
# 现在的 config 指向新 Thread
|
|||
|
|
current_config = new_config
|
|||
|
|
else:
|
|||
|
|
current_config = {"configurable": {"thread_id": target_thread_id}}
|
|||
|
|
|
|||
|
|
async def event_generator():
|
|||
|
|
# 告诉前端:现在是在哪个 Thread 上工作(如果是分叉,前端需要更新本地存储的 ID)
|
|||
|
|
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching}, ensure_ascii=False)}\n\n"
|
|||
|
|
|
|||
|
|
async for event in app.astream(
|
|||
|
|
{"messages": [HumanMessage(content=request.message)]},
|
|||
|
|
current_config,
|
|||
|
|
stream_mode="updates"
|
|||
|
|
):
|
|||
|
|
# ... 发送流式内容的逻辑保持不变 ...
|
|||
|
|
for node_name, output in event.items():
|
|||
|
|
if "messages" in output:
|
|||
|
|
msg = output["messages"][-1]
|
|||
|
|
state = await app.aget_state(current_config)
|
|||
|
|
yield f"data: {json.dumps({'node': node_name, 'content': msg.content, 'checkpoint_id': state.config['configurable']['checkpoint_id']}, ensure_ascii=False)}\n\n"
|
|||
|
|
|
|||
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/history/{thread_id}", response_model=HistoryResponse)
|
|||
|
|
async def get_chat_history(thread_id: str):
|
|||
|
|
"""
|
|||
|
|
### 获取项目设计历史记录
|
|||
|
|
|
|||
|
|
此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。
|
|||
|
|
|
|||
|
|
#### 1. 功能说明
|
|||
|
|
* **快照列表**: 返回该项目从启动至今的所有关键节点(Checkpoints)。
|
|||
|
|
* **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。
|
|||
|
|
* **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。
|
|||
|
|
|
|||
|
|
#### 2. 路径参数
|
|||
|
|
* `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。
|
|||
|
|
|
|||
|
|
#### 3. 返回字段定义
|
|||
|
|
* `thread_id`: 当前查询的项目ID。
|
|||
|
|
* `history`: 历史记录数组,包含:
|
|||
|
|
- `checkpoint_id`: 必填,回溯时使用的关键凭证。
|
|||
|
|
- `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。
|
|||
|
|
- `node`: 产生该快照的节点名称(如 Designer, Visualizer)。
|
|||
|
|
- `timestamp`: 逻辑步骤序号。
|
|||
|
|
|
|||
|
|
#### 4. 响应示例
|
|||
|
|
```json
|
|||
|
|
{
|
|||
|
|
"thread_id": "proj_001",
|
|||
|
|
"history": [
|
|||
|
|
{
|
|||
|
|
"checkpoint_id": "d82f3a12",
|
|||
|
|
"last_message": "我想设计一款北欧风书架",
|
|||
|
|
"node": "Supervisor",
|
|||
|
|
"timestamp": 1
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"checkpoint_id": "f4k92m1a",
|
|||
|
|
"last_message": "建议使用浅色橡木材质,增加简约感...",
|
|||
|
|
"node": "Designer",
|
|||
|
|
"timestamp": 2
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
"""
|
|||
|
|
config = {"configurable": {"thread_id": thread_id}}
|
|||
|
|
history_data = []
|
|||
|
|
async for state in app.aget_state_history(config):
|
|||
|
|
msg_content = "Initial"
|
|||
|
|
if state.values and "messages" in state.values:
|
|||
|
|
msgs = state.values["messages"]
|
|||
|
|
if msgs and len(msgs) > 0:
|
|||
|
|
last_msg = msgs[-1]
|
|||
|
|
# 获取内容并做摘要截断
|
|||
|
|
content = getattr(last_msg, "content", str(last_msg))
|
|||
|
|
msg_content = content[:50] + ("..." if len(content) > 50 else "")
|
|||
|
|
|
|||
|
|
history_data.append(HistoryItem(
|
|||
|
|
checkpoint_id=state.config["configurable"]["checkpoint_id"],
|
|||
|
|
last_message=msg_content[:50],
|
|||
|
|
node=state.metadata.get("source"),
|
|||
|
|
timestamp=state.metadata.get("step")
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
return HistoryResponse(thread_id=thread_id, history=history_data)
|
|||
|
|
# try:
|
|||
|
|
|
|||
|
|
# except Exception as e:
|
|||
|
|
# raise HTTPException(status_code=404, detail=f"History not found: {str(e)}")
|