Files
FiDA_Python/src/routers/deep_agent_chat.py
2026-05-19 11:30:29 +08:00

482 lines
22 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import json
import random
import logging
from minio import Minio
from fastapi import APIRouter
from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
from langchain_core.messages import SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
from minio.commonconfig import CopySource
from src.core.config import PROJECT_ROOT, settings
from src.server.deep_agent.agents.main_agent import build_main_agent, Context
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
from src.server.utils.new_oss_client import get_presigned_url
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
logger = logging.getLogger(__name__)
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@router.post("/deep_agent_stream")
async def chat_stream(request: DeepAgentChatRequest):
"""
### 家具设计流式对话接口 (SSE)
通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。
#### 1. 核心功能
* **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。
* **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。
* **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。
#### 2. 请求参数
* `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。
* `enable_thinking`: 是否开启思考模式。
* `quote_image_path`: 用户引用图片地址 如:"fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"
* `input_image_paths`: 用户上传图片地址集合如:["fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"]。
* `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。
* `checkpoint_id`: (可选) 历史快照 ID。
* `config_params`: (可选) 对话配置参数
```json
{
"message": "你好",
"config_params": {
"type": "test",
"region": "test",
"style": "test"
}
}
```
* `need_suggestion`: (可选) 是否需要建议按钮,需要建议的频率0-1的浮点数
* `use_report`: (可选) 是否需要使用report功能 true/false
#### 3. 响应流说明 (Data Format)
响应以 `data: ` 开头的 JSON 字符串流形式发送:
- **Session Start**: `{"thread_id": "...", "status": "start"}`
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
- **Session End**: `{"status": "end"}`
- **is_delta**: False/True表示这个消息不是完整内容只是 AI 正在生成的一小段内容(一个字、一个词、一句话),需要前端把这些片段拼接起来才能得到完整的回答。
#### 4. 请求示例
```
{
"message": "设计一款北欧风格的躺椅."
}
{
"message": "就以上信息直接生成sketch.",
"thread_id": "187e58af"
}
{
"message": "不要躺椅,要桌子",
"thread_id": "187e58af",
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
}
用户上传:
{
"message": "合并两张图一边一半,左右拼",
"input_image_paths": [
"fida-test/furniture/sketches/218adbd2-c312-4298-9a82-5a92601ac9e2.png",
"fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"
]
}
用户引用:
{
"message": "描述这张图",
"quote_image_path":"fida-test/furniture/sketches/218adbd2-c312-4298-9a82-5a92601ac9e2.png"
}
```
### 5. 响应流说明
所有响应均以 data: 开头JSON 字符串格式,末尾以 \n\n 结束
响应流包含三种类型的事件:会话开始、节点消息、会话结束
"""
# ===================== 简洁优化版 =====================
# 1. 线程与标题标记
need_title = not request.thread_id
source_thread_id = request.thread_id
checkpoint_id = request.checkpoint_id
# 2. 目标线程 ID
is_branching = all([source_thread_id, checkpoint_id])
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
# 3. Agent 初始化
workspace_dir = os.path.join(PROJECT_ROOT, "agent_workspace", target_thread_id)
logger.info(f"chat request data: {request} | target_thread_id: {target_thread_id}, workspace_dir: {workspace_dir}")
main_agent = build_main_agent(workspace_dir, request.enable_thinking)
# 4. 配置
temp = request.config_params.temperature if request.config_params else 0.7
current_config = {
"recursion_limit": 120,
"configurable": {
"thread_id": target_thread_id,
"llm_temperature": temp,
"use_report": request.use_report,
}
}
# 5. 初始化系统消息
initial_messages = []
if not source_thread_id or is_branching:
cp = request.config_params
if cp:
config_items = [
("type", cp.type),
("space/region", cp.region),
("style tendency", cp.style)
]
valid_lines = [f"- {k}: {v}" for k, v in config_items if v]
if valid_lines:
system_prompt = (
"Current furniture design background settings\n"
+ "\n".join(valid_lines) + "\n"
"Please strictly follow the above settings in subsequent conversations。"
)
initial_messages.append(SystemMessage(content=system_prompt))
design_backend = f"""
<design_constraints>
Category: {cp.type or 'unspecified'}
region: {cp.region or 'unspecified'}
style: {cp.style or 'unspecified'}
</design_constraints>
"""
# 6. 分支处理
if is_branching:
source_config = {"configurable": {"thread_id": source_thread_id, "checkpoint_id": checkpoint_id}}
last_checkpoint_id = await get_branch_checkpoint_id(main_agent, source_config)
older_state = await main_agent.aget_state(source_config)
combined_values = older_state.values.copy()
if initial_messages:
combined_values["messages"] = combined_values.get("messages", []) + initial_messages
await main_agent.aupdate_state(current_config, combined_values)
else:
last_checkpoint_id = await get_checkpoint_id(main_agent, current_config)
# 7. 事件流生成
async def event_generator() -> AsyncGenerator[str, None]:
is_first = True
content = [{"type": "text", "text": request.message}]
files = {
"input_image": [],
"quote_image": "",
"current_image": ""
}
input_image_content = ""
# 处理上传图片
if request.input_image_paths:
input_image_content += "\n【附件上传图片路径】\n"
for i, path in enumerate(request.input_image_paths):
input_image_content += f"- 上传图片{i}: {path}\n"
bucket, obj = path.split("/", 1)
minio_client.copy_object("fida-public-bucket", path, CopySource(bucket, obj))
image_url = f"https://www.minio-api.aida.com.hk/fida-public-bucket/{path}"
content.append({"type": "image_url", "image_url": {"url": image_url}})
files["input_image"].append(path)
# 处理引用图片
if request.quote_image_path:
input_image_content += "\n【附件引用图片路径】\n"
input_image_content += f"- 引用图片: {request.quote_image_path}\n"
bucket, obj = request.quote_image_path.split("/", 1)
minio_client.copy_object("fida-public-bucket", request.quote_image_path, CopySource(bucket, obj))
image_url = f"https://www.minio-api.aida.com.hk/fida-public-bucket/{request.quote_image_path}"
content.append({"type": "image_url", "image_url": {"url": image_url}})
files["quote_image"] = request.quote_image_path
# 追加文本内容
if input_image_content:
content[0]["text"] += input_image_content
if design_backend:
content[0]["text"] += design_backend
message_list = [{"role": "user", "content": content}]
final_messages = {"messages": message_list, "files": files}
logger.info(final_messages)
config_content_type = f"- type: {request.config_params.type}\n" if request.config_params.type else ""
config_content_region = f"- region: {request.config_params.region}\n" if request.config_params.region else ""
config_content_style = f"- style: {request.config_params.style}\n" if request.config_params.style else ""
async for stream in main_agent.astream(
final_messages,
config=current_config,
stream_mode=["updates", "messages", "custom"],
subgraphs=True,
context=Context(use_report=request.use_report,
language=request.language,
type=request.config_params.type,
region=request.config_params.region,
style=request.config_params.style,
),
):
_, mode, chunks = stream
if is_first:
checkpoint_id = get_latest_checkpoint_id(main_agent, current_config)
if not checkpoint_id:
print("123")
main_agent.store.put(
("image_history",),
"checkpoint_id",
{
"current_checkpoint_id": checkpoint_id,
"last_checkpoint_id": last_checkpoint_id,
}
)
logger.info(f"*******************{checkpoint_id}**********************************")
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start', "checkpoint_id": checkpoint_id}, ensure_ascii=False)}\n\n"
is_first = False
if mode == "updates": # 只做记录 不做事件返回
logger.info(f"[updates] -- {chunks}")
update_model_messages = chunks.get("model", None)
update_tools_messages = chunks.get("tools", None)
payload_out = {
"node": "",
"is_delta": False,
"content": "",
"type": "updates"
}
if update_model_messages:
model_messages = update_model_messages.get("messages", [])
for model_token in model_messages:
if isinstance(model_token, AIMessage):
model_content_blocks = model_token.content_blocks[0]
model_name = model_token.name
payload_out.update({
"node": model_name if model_name else "main",
"tool_calls": model_token.tool_calls
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif update_tools_messages:
tools_messages = update_tools_messages.get("messages", [])
for tools_token in tools_messages:
if isinstance(tools_token, ToolMessage):
tool_content_blocks = tools_token.content_blocks[0]
tool_name = tools_token.name
logger.info(f"[updates] {tool_name} -- {tool_content_blocks}")
elif mode == "messages":
# logger.info(f"[messages] -- {chunks}")
token, metadata = chunks
subagent_name = metadata.get('lc_agent_name', "main")
payload_out = {
"node": subagent_name,
"is_delta": False,
"content": "",
"type": ""
}
if isinstance(token, AIMessageChunk): # 默认回复 思考内容
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
text = [b for b in token.content_blocks if b["type"] == "text"]
if reasoning:
if len(reasoning) == 1:
payload_out.update({
"type": "reasoning",
"is_delta": True,
"content": reasoning[0].get("reasoning", ""),
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
logger.info(f"[reasoning] {reasoning}*************************************************************************************")
elif text:
if len(text) == 1:
payload_out.update({
"type": "text",
"is_delta": True,
"content": text[0].get("text", ""),
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
logger.info(f"[text] {text}*************************************************************************************")
else:
payload_out.update({
"type": "tool_call",
"is_delta": True,
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(token, ToolMessageChunk): # 工具返回
text = [b for b in token.content_blocks if b["type"] == "text"]
payload_out.update({
"type": "tool_result",
"is_delta": False,
"content": text,
"tool_name": token.name,
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(token, ToolMessage): # 工具返回
text = [b for b in token.content_blocks if b["type"] == "text"]
payload_out.update({
"type": "tool_result",
"is_delta": False,
"content": text,
"tool_name": token.name,
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
else:
continue
elif mode == "custom":
logger.info(f"[custom] -- {chunks}")
payload_out = {
"node": "research-agent",
"is_delta": False,
"content": "",
"type": ""
}
delta = chunks.get("delta", "")
payload_out.update({
"type": chunks.get("type", ""),
"is_delta": True,
"content": delta,
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
# 获取建议消息
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
suggested_questions = await generate_suggested_questions(main_agent, target_thread_id)
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
# 获取标题
if need_title:
title = await conversation_title(agent=main_agent, config=current_config)
logger.info(f"[title] {title}")
yield f"data: {json.dumps({'title': title}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.get("/history/{thread_id}", response_model=HistoryResponse)
async def get_chat_history(thread_id: str):
"""
### 获取项目设计历史记录
此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。
#### 1. 功能说明
* **快照列表**: 返回该项目从启动至今的所有关键节点Checkpoints
* **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。
* **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。
#### 2. 路径参数
* `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。
#### 3. 返回字段定义
* `thread_id`: 当前查询的项目ID。
* `history`: 历史记录数组,包含:
- `checkpoint_id`: 必填,回溯时使用的关键凭证。
- `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。
- `node`: 产生该快照的节点名称(如 Designer, Visualizer
- `timestamp`: 逻辑步骤序号。
#### 4. 响应示例
```json
{
"thread_id": "proj_001",
"history": [
{
"checkpoint_id": "d82f3a12",
"last_message": "我想设计一款北欧风书架",
"node": "Supervisor",
"timestamp": 1
},
{
"checkpoint_id": "f4k92m1a",
"last_message": "建议使用浅色橡木材质,增加简约感...",
"node": "Designer",
"timestamp": 2
}
]
}
```
"""
config = {"configurable": {"thread_id": thread_id}, }
history_data = []
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{thread_id}")
main_agent = build_main_agent(workspace_dir, enable_thinking=False)
async for state in main_agent.aget_state_history(config):
msg_content = "Initial"
if state.values and "messages" in state.values:
msgs = state.values["messages"]
if msgs and len(msgs) > 0:
last_msg = msgs[-1]
# 获取内容并做摘要截断
content = getattr(last_msg, "content", str(last_msg))
msg_content = content
history_data.append(HistoryItem(
checkpoint_id=state.config["configurable"]["checkpoint_id"],
last_message=msg_content,
node=state.metadata.get("source"),
timestamp=state.metadata.get("step")
))
return HistoryResponse(thread_id=thread_id, history=history_data)
async def get_checkpoint_id(main_agent, current_config):
# 🔥 最优:边遍历边找,找到第一个就返回,不浪费内存
async for item in main_agent.aget_state_history(config=current_config):
if item.next == ("__start__",):
# 找到直接处理并返回
# if item.parent_config:
# return item.parent_config.get('configurable', {}).get('checkpoint_id')
return item.config.get('configurable', {}).get('checkpoint_id')
# 没找到
return None
async def get_branch_checkpoint_id(main_agent, current_config):
# 🔥 最优:边遍历边找,找到第一个就返回,不浪费内存
async for item in main_agent.aget_state_history(config=current_config):
current_id = current_config.get('configurable', {}).get('checkpoint_id')
if item.next == ("__start__",) and item.config.get('configurable', {}).get('checkpoint_id') != current_id:
if item.parent_config:
if item.parent_config.get('configurable', {}).get('checkpoint_id') != current_id:
return item.config.get('configurable', {}).get('checkpoint_id')
else:
return item.config.get('configurable', {}).get('checkpoint_id')
# 没找到
return None
def get_latest_checkpoint_id(agent, config):
# 先尝试直接 get_state
state = agent.get_state(config)
checkpoint_id = state.config.get("configurable", {}).get("checkpoint_id")
if checkpoint_id:
return checkpoint_id
# 如果是 None 或空,使用 history 取最新一条history[0] 永远是最新的)
print("checkpoint_id 为 None使用 get_state_history 兜底...")
history = list(agent.get_state_history(config))
if history:
checkpoint_id = history[0].config["configurable"]["checkpoint_id"]
print(f"从 history 获取到最新 checkpoint_id: {checkpoint_id}")
return checkpoint_id
return None