新增家具分割接口

This commit is contained in:
zcr
2026-03-27 14:41:13 +08:00
parent 1c672afd2d
commit d9acdf593d
6 changed files with 84 additions and 374 deletions

View File

@@ -1,374 +0,0 @@
import logging
import uuid
import json
from typing import AsyncGenerator
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
from src.server.agent.graph import app # 导入已经 compile 好的 graph
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
logger = logging.getLogger(__name__)
@router.post("/stream")
async def chat_stream(request: ChatRequest):
"""
### 家具设计流式对话接口 (SSE)
通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。
#### 1. 核心功能
* **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。
* **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。
* **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。
#### 2. 请求参数
* `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。
* `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。
* `checkpoint_id`: (可选) 历史快照 ID。
* `config_params`: (可选) 对话配置参数
* `need_suggestion`: (可选) 是否需要建议按钮,需要建议的频率0-1的浮点数
* `use_report`: (可选) 是否需要使用report功能 true/false
#### 3. 响应流说明 (Data Format)
响应以 `data: ` 开头的 JSON 字符串流形式发送:
- **Session Start**: `{"thread_id": "...", "status": "start"}`
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
- **Session End**: `{"status": "end"}`
- **is_delta**: False/True表示这个消息不是完整内容只是 AI 正在生成的一小段内容(一个字、一个词、一句话),需要前端把这些片段拼接起来才能得到完整的回答。
#### 4. 请求示例
```
{
"message": "设计一款北欧风格的躺椅."
}
{
"message": "就以上信息直接生成sketch.",
"thread_id": "187e58af"
}
{
"message": "不要躺椅,要桌子",
"thread_id": "187e58af",
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
}
```
### 5. 响应流说明
所有响应均以 data: 开头JSON 字符串格式,末尾以 \n\n 结束
响应流包含三种类型的事件:会话开始、节点消息、会话结束
会话开始:
```
{
"thread_id": "str",
"is_branch": "boolean",
"status": "start"
}
```
节点消息:
```
{
"node": "节点名称如Designer/Researcher/Main",
"content": "消息内容",
"checkpoint_id": "快照ID",
"is_delta": "boolean",
"type": "消息类型",
"suggestions": "建议列表(可选)",
"tool_name": "工具名称(可选)",
"tool_call_chunk": "工具调用片段(可选)",
"tool_call_id": "工具调用ID可选"
}
```
报告增量消息:
```
{
"node": "Researcher",
"type": "report_delta",
"content": "报告内容增量",
"is_delta": true,
"checkpoint_id": "xxx"
}
```
AI 消息片段:
```
{
"node": "Designer",
"content": "设计建议内容",
"checkpoint_id": "xxx",
"is_delta": true,
"type": "delta",
"tool_call_chunk": {...}
}
```
工具执行结果:
```
{
"node": "ToolExecutor",
"content": "工具执行结果",
"checkpoint_id": "xxx",
"is_delta": false,
"type": "tool_result",
"tool_name": "ImageGenerator",
"tool_call_id": "yyy"
}
```
"""
logger.debug(f"chat request data: {request}")
source_thread_id = request.thread_id
checkpoint_id = request.checkpoint_id
# 1. 確定目標 thread_id
is_branching = source_thread_id and checkpoint_id
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
# 2. 配置參數
temp = request.config_params.temperature if request.config_params else 0.7
current_config = {
"recursion_limit": 100,
"configurable": {
"thread_id": target_thread_id,
"llm_temperature": temp,
"use_report": request.use_report,
}
}
# 3. 初始化消息 + 系統提示
initial_messages = []
if not source_thread_id or is_branching:
if request.config_params:
cp = request.config_params
system_prompt = (
f"Current furniture design background settings\n"
f"- type: {cp.type}\n"
f"- space/region: {cp.region}\n"
f"- style tendency: {cp.style}\n"
f"Please strictly follow the above settings in subsequent conversations。"
)
initial_messages.append(SystemMessage(content=system_prompt))
# 4. 處理分支(從歷史 checkpoint 複製狀態)
if is_branching:
source_config = {
"configurable": {
"thread_id": source_thread_id,
"checkpoint_id": checkpoint_id
}
}
older_state = await app.aget_state(source_config)
combined_values = older_state.values.copy()
if initial_messages:
combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
await app.aupdate_state(current_config, combined_values)
async def event_generator() -> AsyncGenerator[str, None]:
# 初始事件
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
# 構造輸入(保持不變)
new_messages = initial_messages[:] if not source_thread_id else []
new_messages.append(HumanMessage(content=request.message))
input_data = {
"messages": new_messages,
"require_suggestion": request.need_suggestion,
"use_report": request.use_report,
}
# ─── 重點改這裡 ───────────────────────────────────────
async for event in app.astream(
input_data,
config=current_config,
stream_mode=["custom", "updates", "messages"], # 推薦組合
subgraphs=True
# 不再需要,行為已包含
):
logger.info(event)
# 取得 checkpoint_id可選視前端是否真的需要
latest_state = await app.aget_state(current_config)
configurable = latest_state.config.get("configurable", {})
current_cp_id = configurable.get("checkpoint_id", "")
if len(event) == 3:
namespace, channel, payload = event
# 路由更新
if event[1] == "updates":
namespace, _, payload = event
if isinstance(payload, dict):
for update_node, update_content in payload.items():
# 处理 reducerOverwrite / Append
if isinstance(update_content, dict):
for k, v in update_content.items():
if hasattr(v, "value"): # Overwrite(...)
update_content[k] = v.value
if isinstance(update_content, dict) and "messages" in update_content:
msgs = []
for m in update_content["messages"]:
msgs.append({
"type": m.__class__.__name__,
"content": getattr(m, "content", ""),
"name": getattr(m, "name", None),
"tool_calls": getattr(m, "tool_calls", None),
})
update_content["messages"] = msgs
yield f"data: {json.dumps({
"node": "Supervisor",
"type": "updates",
"content": update_content,
"is_delta": False,
"checkpoint_id": current_cp_id,
}, ensure_ascii=False)}\n\n"
# 自定义事件
elif event[1] == "custom":
if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", "report_error", "report_save_warning", "report_complete"):
delta = payload.get("delta", "").strip()
if delta:
yield f"data: {json.dumps({
'node': 'Researcher',
'type': 'report_delta',
'content': delta,
'is_delta': True,
'checkpoint_id': current_cp_id,
}, ensure_ascii=False)}\n\n"
# 基础消息
elif event[1] == "messages":
if namespace:
node_name = namespace[-1] if isinstance(namespace, tuple) else namespace
if ':' in node_name:
node_name = node_name.split(':')[0]
else:
node_name = "Main"
message, metadata = payload
is_not_research = node_name != 'Researcher'
node_name = metadata.get("langgraph_node", node_name)
# 3. 处理不同类型的 message
payload_out = {
"node": node_name,
"checkpoint_id": current_cp_id, # 你之前已经获取了
"is_delta": False,
"content": "",
"suggestions": [],
"type": "unknown"
}
if isinstance(message, AIMessageChunk):
# 节点不是research 并且 tool_call_chunks不为空的情况下避免research的report工具使用custom发出的消息和message的消息重复了
if is_not_research and node_name != 'Researcher' and message.tool_call_chunks:
payload_out.update({
"type": "delta",
"is_delta": True,
"content": message.content,
# 如果有 tool call chunk也可以在这里处理
"tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(message, ToolMessage):
# 工具执行结果(完整的一次性输出)
payload_out.update({
"type": "tool_result",
"is_delta": False,
"content": message.content,
"tool_name": message.name,
"tool_call_id": message.tool_call_id
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
elif isinstance(message, AIMessage):
# 完整 AIMessage不常见在 messages 模式下,但以防万一)
payload_out.update({
"type": "complete_message",
"is_delta": False,
"content": message.content
})
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
else:
# 其他未知类型,记录日志
print(f"未知消息类型: {type(message)}", message)
continue
# 流結束
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.get("/history/{thread_id}", response_model=HistoryResponse)
async def get_chat_history(thread_id: str):
"""
### 获取项目设计历史记录
此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。
#### 1. 功能说明
* **快照列表**: 返回该项目从启动至今的所有关键节点Checkpoints
* **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。
* **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。
#### 2. 路径参数
* `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。
#### 3. 返回字段定义
* `thread_id`: 当前查询的项目ID。
* `history`: 历史记录数组,包含:
- `checkpoint_id`: 必填,回溯时使用的关键凭证。
- `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。
- `node`: 产生该快照的节点名称(如 Designer, Visualizer
- `timestamp`: 逻辑步骤序号。
#### 4. 响应示例
```json
{
"thread_id": "proj_001",
"history": [
{
"checkpoint_id": "d82f3a12",
"last_message": "我想设计一款北欧风书架",
"node": "Supervisor",
"timestamp": 1
},
{
"checkpoint_id": "f4k92m1a",
"last_message": "建议使用浅色橡木材质,增加简约感...",
"node": "Designer",
"timestamp": 2
}
]
}
```
"""
config = {"configurable": {"thread_id": thread_id}, }
history_data = []
async for state in app.aget_state_history(config):
msg_content = "Initial"
if state.values and "messages" in state.values:
msgs = state.values["messages"]
if msgs and len(msgs) > 0:
last_msg = msgs[-1]
# 获取内容并做摘要截断
content = getattr(last_msg, "content", str(last_msg))
msg_content = content
history_data.append(HistoryItem(
checkpoint_id=state.config["configurable"]["checkpoint_id"],
last_message=msg_content,
node=state.metadata.get("source"),
timestamp=state.metadata.get("step")
))
return HistoryResponse(thread_id=thread_id, history=history_data)
# try:
# except Exception as e:
# raise HTTPException(status_code=404, detail=f"History not found: {str(e)}")

View File

@@ -0,0 +1,58 @@
import json
import logging
import requests
from fastapi import APIRouter
from src.core.config import settings
from src.schemas.response_template import ResponseModel
from src.schemas.san_furniture import SAMRequestModel
router = APIRouter(prefix="/canvas", tags=["Furniture Canvas"])
logger = logging.getLogger(__name__)
@router.post("/seg_anything")
async def seg_anything(request_data: SAMRequestModel):
"""
**Segment Anything 交互式分割接口**
通过传入图片路径和点击的点坐标,返回分割后的掩码数据。
### 参数说明:
- **user_id**:用户id 用于存储分割图
- **image_path**: 图片在服务器或云端的相对路径。
- **type**: 推理类型
- **box**: 框选矩形点位信息
- **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。
- **labels**: 坐标点的属性标签,必须与 points 长度一致:
- 1: **前景点** (代表想要分割出的区域)
- 0: **背景点** (代表想要排除的区域)
### 请求体示例:
```json
point
{
"user_id": 1,
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"point",
"points": [[310, 403], [493, 375], [261, 266], [404, 484]],
"labels": [1, 1, 0, 1]
}
box
{
"user_id": 1,
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"box",
"box": [350, 286, 544, 520]
}
```
"""
try:
logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = requests.post(f"http://{settings.SEG_ANYTHING}/predict", json=request_data.dict())
logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
return ResponseModel(data=json.loads(data.content))
except Exception as e:
logger.warning(f"seg_anything Run Exception @@@@@@:{e}")