Files
FiDA_Python/src/routers/deep_agent_chat.py
2026-03-23 17:40:47 +08:00

401 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import json
import random
import logging
from minio import Minio
from fastapi import APIRouter
from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
from langchain_core.messages import SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
from src.core.config import PROJECT_ROOT, settings, MONGO_URI
from src.server.deep_agent.agents.main_agent import build_main_agent
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
from src.server.deep_agent.tools.generate_furniture_sketch import is_image_path_exist
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
from src.server.deep_agent.utils.mongodb_util import ThreadImageMinIOStore
from src.server.utils.new_oss_client import is_minio_file_exist, oss_upload_image_file, oss_get_image, get_presigned_url
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
logger = logging.getLogger(__name__)
image_store = ThreadImageMinIOStore(MONGO_URI, "agent_tool_generate_db")
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@router.post("/deep_agent_stream")
async def chat_stream(request: DeepAgentChatRequest):
"""
### 家具设计流式对话接口 (SSE)
通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。
#### 1. 核心功能
* **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。
* **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。
* **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。
#### 2. 请求参数
* `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。
* `enable_thinking`: 是否开启思考模式。
* `quote_image_path`: 用户引用图片地址 如:"fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"
* `input_image_paths`: 用户上传图片地址集合如:["fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"]。
* `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。
* `checkpoint_id`: (可选) 历史快照 ID。
* `config_params`: (可选) 对话配置参数
* `need_suggestion`: (可选) 是否需要建议按钮,需要建议的频率0-1的浮点数
* `use_report`: (可选) 是否需要使用report功能 true/false
#### 3. 响应流说明 (Data Format)
响应以 `data: ` 开头的 JSON 字符串流形式发送:
- **Session Start**: `{"thread_id": "...", "status": "start"}`
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
- **Session End**: `{"status": "end"}`
- **is_delta**: False/True表示这个消息不是完整内容只是 AI 正在生成的一小段内容(一个字、一个词、一句话),需要前端把这些片段拼接起来才能得到完整的回答。
#### 4. 请求示例
```
{
"message": "设计一款北欧风格的躺椅."
}
{
"message": "就以上信息直接生成sketch.",
"thread_id": "187e58af"
}
{
"message": "不要躺椅,要桌子",
"thread_id": "187e58af",
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
}
用户上传:
{
"message": "合并两张图一边一半,左右拼",
"input_image_paths": [
"fida-test/furniture/sketches/218adbd2-c312-4298-9a82-5a92601ac9e2.png",
"fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"
]
}
用户引用:
{
"message": "描述这张图",
"quote_image_path":"fida-test/furniture/sketches/218adbd2-c312-4298-9a82-5a92601ac9e2.png"
}
```
### 5. 响应流说明
所有响应均以 data: 开头JSON 字符串格式,末尾以 \n\n 结束
响应流包含三种类型的事件:会话开始、节点消息、会话结束
"""
if request.thread_id:
need_title = False
else:
need_title = True
source_thread_id = request.thread_id
checkpoint_id = request.checkpoint_id
# 1. 確定目標 thread_id
is_branching = source_thread_id and checkpoint_id
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
# 构建主agent
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{target_thread_id}")
logger.info(f"chat request data: {request} | target_thread_id : workspace_dir: {workspace_dir}")
main_agent = build_main_agent(request.use_report, workspace_dir, request.enable_thinking)
# 2. 配置參數
temp = request.config_params.temperature if request.config_params else 0.7
current_config = {
"recursion_limit": 120,
"configurable": {
"thread_id": target_thread_id,
"llm_temperature": temp,
"use_report": request.use_report,
}
}
# 3. 初始化消息 + 系統提示 TODO 写入数据库
initial_messages = []
if not source_thread_id or is_branching:
if request.config_params:
cp = request.config_params
system_prompt = (
f"Current furniture design background settings\n"
f"- type: {cp.type}\n"
f"- space/region: {cp.region}\n"
f"- style tendency: {cp.style}\n"
f"Please strictly follow the above settings in subsequent conversations。"
)
initial_messages.append(SystemMessage(content=system_prompt))
# 4. 處理分支(從歷史 checkpoint 複製狀態)
if is_branching:
source_config = {
"configurable": {
"thread_id": source_thread_id,
"checkpoint_id": checkpoint_id
}
}
older_state = await main_agent.aget_state(source_config)
combined_values = older_state.values.copy()
if initial_messages:
combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
await main_agent.aupdate_state(current_config, combined_values)
async def event_generator() -> AsyncGenerator[str, None]:
is_first = True
content = [{"type": "text", "text": request.message}]
files = {
"input_image": [],
"quote_image": "",
"current_image": ""
}
# 用户上传图片
if request.input_image_paths:
for path in request.input_image_paths:
bucket, object_name = path.split('/', 1)
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
content.append({"type": "image_url", "image_url": {"url": image_url}})
files["input_image"].append(path)
# 用户引用图片
if request.quote_image_path:
bucket, object_name = request.quote_image_path.split('/', 1)
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
content.append({"type": "image_url", "image_url": {"url": image_url}})
files["quote_image"] = request.quote_image_path
# 用户最近生成图片
if image_store.get_image_path(target_thread_id):
current_image_path = image_store.get_image_path(target_thread_id).get("current_image_path", False)
if current_image_path:
bucket, object_name = current_image_path.split('/', 1)
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
content.append({"type": "image_url", "image_url": {"url": image_url}})
final_messages = {
"messages": [
{
"role": "user",
"content": content
},
],
"files": files
}
async for stream in main_agent.astream(
final_messages,
config=current_config,
stream_mode=["updates", "messages", "custom"],
subgraphs=True
):
if is_first:
checkpoint_id = main_agent.get_state(current_config).config.get("configurable").get("checkpoint_id")
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start', "checkpoint_id": checkpoint_id}, ensure_ascii=False)}\n\n"
is_first = False
_, mode, chunks = stream
if mode == "updates": # 只做记录 不做事件返回
logger.info(f"[updates] -- {chunks}")
update_model_messages = chunks.get("model", None)
update_tools_messages = chunks.get("tools", None)
payload_out = {
"node": "",
"is_delta": False,
"content": "",
"type": "updates"
}
if update_model_messages:
model_messages = update_model_messages.get("messages", [])
for model_token in model_messages:
if isinstance(model_token, AIMessage):
model_content_blocks = model_token.content_blocks[0]
model_name = model_token.name
payload_out.update({
"node": model_name if model_name else "main",
"tool_calls": model_token.tool_calls
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif update_tools_messages:
tools_messages = update_tools_messages.get("messages", [])
for tools_token in tools_messages:
if isinstance(tools_token, ToolMessage):
tool_content_blocks = tools_token.content_blocks[0]
tool_name = tools_token.name
logger.info(f"[updates] {tool_name} -- {tool_content_blocks}")
else:
logger.info(f"[updates] -- {chunks}")
elif mode == "messages":
# logger.info(f"[messages] -- {chunks}")
token, metadata = chunks
subagent_name = metadata.get('lc_agent_name', "main")
payload_out = {
"node": subagent_name,
"is_delta": False,
"content": "",
"type": ""
}
if isinstance(token, AIMessageChunk): # 默认回复 思考内容
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
text = [b for b in token.content_blocks if b["type"] == "text"]
if reasoning:
if len(reasoning) == 1:
payload_out.update({
"type": "reasoning",
"is_delta": True,
"content": reasoning[0].get("reasoning", ""),
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
print(f"[reasoning] {reasoning}*************************************************************************************")
elif text:
if len(text) == 1:
payload_out.update({
"type": "text",
"is_delta": True,
"content": text[0].get("text", ""),
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
print(f"[text] {text}*************************************************************************************")
else:
payload_out.update({
"type": "tool_call",
"is_delta": True,
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(token, ToolMessageChunk): # 工具返回
text = [b for b in token.content_blocks if b["type"] == "text"]
payload_out.update({
"type": "tool_result",
"is_delta": False,
"content": text,
"tool_name": token.name,
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(token, ToolMessage): # 工具返回
text = [b for b in token.content_blocks if b["type"] == "text"]
payload_out.update({
"type": "tool_result",
"is_delta": False,
"content": text,
"tool_name": token.name,
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
else:
continue
elif mode == "custom":
logger.info(f"[custom] -- {chunks}")
payload_out = {
"node": "research-agent",
"is_delta": False,
"content": "",
"type": ""
}
delta = chunks.get("delta", "")
payload_out.update({
"type": chunks.get("type", ""),
"is_delta": True,
"content": delta,
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
# 获取建议消息
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
suggested_questions = await generate_suggested_questions(main_agent, target_thread_id)
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
# 获取标题
if need_title:
title = await conversation_title(agent=main_agent, config=current_config)
logger.info(f"[title] {title}")
yield f"data: {json.dumps({'title': title}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.get("/history/{thread_id}", response_model=HistoryResponse)
async def get_chat_history(thread_id: str):
"""
### 获取项目设计历史记录
此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。
#### 1. 功能说明
* **快照列表**: 返回该项目从启动至今的所有关键节点Checkpoints
* **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。
* **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。
#### 2. 路径参数
* `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。
#### 3. 返回字段定义
* `thread_id`: 当前查询的项目ID。
* `history`: 历史记录数组,包含:
- `checkpoint_id`: 必填,回溯时使用的关键凭证。
- `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。
- `node`: 产生该快照的节点名称(如 Designer, Visualizer
- `timestamp`: 逻辑步骤序号。
#### 4. 响应示例
```json
{
"thread_id": "proj_001",
"history": [
{
"checkpoint_id": "d82f3a12",
"last_message": "我想设计一款北欧风书架",
"node": "Supervisor",
"timestamp": 1
},
{
"checkpoint_id": "f4k92m1a",
"last_message": "建议使用浅色橡木材质,增加简约感...",
"node": "Designer",
"timestamp": 2
}
]
}
```
"""
config = {"configurable": {"thread_id": thread_id}, }
history_data = []
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{thread_id}")
main_agent = build_main_agent(False, workspace_dir, enable_thinking=False)
async for state in main_agent.aget_state_history(config):
msg_content = "Initial"
if state.values and "messages" in state.values:
msgs = state.values["messages"]
if msgs and len(msgs) > 0:
last_msg = msgs[-1]
# 获取内容并做摘要截断
content = getattr(last_msg, "content", str(last_msg))
msg_content = content
history_data.append(HistoryItem(
checkpoint_id=state.config["configurable"]["checkpoint_id"],
last_message=msg_content,
node=state.metadata.get("source"),
timestamp=state.metadata.get("step")
))
return HistoryResponse(thread_id=thread_id, history=history_data)
# try:
# except Exception as e:
# raise HTTPException(status_code=404, detail=f"History not found: {str(e)}")