diff --git a/main.py b/main.py index f17207b..c938c96 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ from logging_env import LOGGER_CONFIG_DICT from src.routers import deep_agent_chat from src.routers import generate_3D from src.routers import flux2_gen_img +from src.routers import seg_furniture logging.config.dictConfig(LOGGER_CONFIG_DICT) @@ -29,6 +30,7 @@ app_server.add_middleware( app_server.include_router(deep_agent_chat.router) app_server.include_router(generate_3D.router) app_server.include_router(flux2_gen_img.router) +app_server.include_router(seg_furniture.router) @app_server.get("/") diff --git a/src/core/config.py b/src/core/config.py index 12af83d..2cb7a2d 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -37,6 +37,7 @@ class Settings(BaseSettings): # --- 本地服务器配置信息 --- IMAGE_TO_3D_MODEL_URL: str = Field(default='', description="") FLUX2_GEN_IMG_MODEL_URL: str = Field(default='', description="") + SEG_ANYTHING: str = Field(default='', description="") # --- 外部工具api配置信息 --- TAVILY_API_KEY: str = Field(default="", description="") diff --git a/src/routers/chat.py b/src/routers/chat.py deleted file mode 100644 index e4a72bf..0000000 --- a/src/routers/chat.py +++ /dev/null @@ -1,374 +0,0 @@ -import logging -import uuid -import json -from typing import AsyncGenerator - -from fastapi import APIRouter -from fastapi.responses import StreamingResponse -from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem -from src.server.agent.graph import app # 导入已经 compile 好的 graph -from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage - -router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"]) -logger = logging.getLogger(__name__) - - -@router.post("/stream") -async def chat_stream(request: ChatRequest): - """ - ### 家具设计流式对话接口 (SSE) - - 通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。 - - #### 1. 核心功能 - * **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。 - * **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。 - * **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。 - - #### 2. 请求参数 - * `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。 - * `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。 - * `checkpoint_id`: (可选) 历史快照 ID。 - * `config_params`: (可选) 对话配置参数 - * `need_suggestion`: (可选) 是否需要建议按钮,需要建议的频率,0-1的浮点数 - * `use_report`: (可选) 是否需要使用report功能 true/false - - - #### 3. 响应流说明 (Data Format) - 响应以 `data: ` 开头的 JSON 字符串流形式发送: - - **Session Start**: `{"thread_id": "...", "status": "start"}` - - **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}` - - **Session End**: `{"status": "end"}` - - - **is_delta**: False/True,表示这个消息不是完整内容,只是 AI 正在生成的一小段内容(一个字、一个词、一句话),需要前端把这些片段拼接起来才能得到完整的回答。 - - #### 4. 请求示例 - ``` - { - "message": "设计一款北欧风格的躺椅." - } - - { - "message": "就以上信息直接生成sketch.", - "thread_id": "187e58af" - } - - { - "message": "不要躺椅,要桌子", - "thread_id": "187e58af", - "checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a" - } - ``` - - ### 5. 响应流说明 - 所有响应均以 data: 开头,JSON 字符串格式,末尾以 \n\n 结束 - 响应流包含三种类型的事件:会话开始、节点消息、会话结束 - 会话开始: - ``` - { - "thread_id": "str", - "is_branch": "boolean", - "status": "start" - } - ``` - 节点消息: - ``` - { - "node": "节点名称(如Designer/Researcher/Main)", - "content": "消息内容", - "checkpoint_id": "快照ID", - "is_delta": "boolean", - "type": "消息类型", - "suggestions": "建议列表(可选)", - "tool_name": "工具名称(可选)", - "tool_call_chunk": "工具调用片段(可选)", - "tool_call_id": "工具调用ID(可选)" - } - - ``` - 报告增量消息: - ``` - { - "node": "Researcher", - "type": "report_delta", - "content": "报告内容增量", - "is_delta": true, - "checkpoint_id": "xxx" - } - ``` - AI 消息片段: - ``` - { - "node": "Designer", - "content": "设计建议内容", - "checkpoint_id": "xxx", - "is_delta": true, - "type": "delta", - "tool_call_chunk": {...} - } - ``` - 工具执行结果: - ``` - { - "node": "ToolExecutor", - "content": "工具执行结果", - "checkpoint_id": "xxx", - "is_delta": false, - "type": "tool_result", - "tool_name": "ImageGenerator", - "tool_call_id": "yyy" - } - ``` - - """ - logger.debug(f"chat request data: {request}") - source_thread_id = request.thread_id - checkpoint_id = request.checkpoint_id - - # 1. 確定目標 thread_id - is_branching = source_thread_id and checkpoint_id - target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8]) - - # 2. 配置參數 - temp = request.config_params.temperature if request.config_params else 0.7 - current_config = { - "recursion_limit": 100, - "configurable": { - "thread_id": target_thread_id, - "llm_temperature": temp, - "use_report": request.use_report, - } - } - - # 3. 初始化消息 + 系統提示 - initial_messages = [] - if not source_thread_id or is_branching: - if request.config_params: - cp = request.config_params - system_prompt = ( - f"Current furniture design background settings:\n" - f"- type: {cp.type}\n" - f"- space/region: {cp.region}\n" - f"- style tendency: {cp.style}\n" - f"Please strictly follow the above settings in subsequent conversations。" - ) - initial_messages.append(SystemMessage(content=system_prompt)) - - # 4. 處理分支(從歷史 checkpoint 複製狀態) - if is_branching: - source_config = { - "configurable": { - "thread_id": source_thread_id, - "checkpoint_id": checkpoint_id - } - } - older_state = await app.aget_state(source_config) - combined_values = older_state.values.copy() - if initial_messages: - combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages - await app.aupdate_state(current_config, combined_values) - - async def event_generator() -> AsyncGenerator[str, None]: - # 初始事件 - yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n" - - # 構造輸入(保持不變) - new_messages = initial_messages[:] if not source_thread_id else [] - new_messages.append(HumanMessage(content=request.message)) - - input_data = { - "messages": new_messages, - "require_suggestion": request.need_suggestion, - "use_report": request.use_report, - } - - # ─── 重點改這裡 ─────────────────────────────────────── - async for event in app.astream( - input_data, - config=current_config, - stream_mode=["custom", "updates", "messages"], # 推薦組合 - subgraphs=True - # 不再需要,行為已包含 - ): - logger.info(event) - # 取得 checkpoint_id(可選,視前端是否真的需要) - latest_state = await app.aget_state(current_config) - configurable = latest_state.config.get("configurable", {}) - current_cp_id = configurable.get("checkpoint_id", "") - if len(event) == 3: - namespace, channel, payload = event - # 路由更新 - if event[1] == "updates": - namespace, _, payload = event - - if isinstance(payload, dict): - for update_node, update_content in payload.items(): - - # 处理 reducer(Overwrite / Append) - if isinstance(update_content, dict): - for k, v in update_content.items(): - if hasattr(v, "value"): # Overwrite(...) - update_content[k] = v.value - - if isinstance(update_content, dict) and "messages" in update_content: - msgs = [] - for m in update_content["messages"]: - msgs.append({ - "type": m.__class__.__name__, - "content": getattr(m, "content", ""), - "name": getattr(m, "name", None), - "tool_calls": getattr(m, "tool_calls", None), - }) - update_content["messages"] = msgs - - yield f"data: {json.dumps({ - "node": "Supervisor", - "type": "updates", - "content": update_content, - "is_delta": False, - "checkpoint_id": current_cp_id, - }, ensure_ascii=False)}\n\n" - - # 自定义事件 - elif event[1] == "custom": - if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", "report_error", "report_save_warning", "report_complete"): - delta = payload.get("delta", "").strip() - if delta: - yield f"data: {json.dumps({ - 'node': 'Researcher', - 'type': 'report_delta', - 'content': delta, - 'is_delta': True, - 'checkpoint_id': current_cp_id, - }, ensure_ascii=False)}\n\n" - # 基础消息 - elif event[1] == "messages": - if namespace: - node_name = namespace[-1] if isinstance(namespace, tuple) else namespace - if ':' in node_name: - node_name = node_name.split(':')[0] - else: - node_name = "Main" - message, metadata = payload - is_not_research = node_name != 'Researcher' - node_name = metadata.get("langgraph_node", node_name) - # 3. 处理不同类型的 message - payload_out = { - "node": node_name, - "checkpoint_id": current_cp_id, # 你之前已经获取了 - "is_delta": False, - "content": "", - "suggestions": [], - "type": "unknown" - } - - if isinstance(message, AIMessageChunk): - # 节点不是research 并且 tool_call_chunks不为空的情况下,避免research的report工具使用custom发出的消息和message的消息重复了 - if is_not_research and node_name != 'Researcher' and message.tool_call_chunks: - payload_out.update({ - "type": "delta", - "is_delta": True, - "content": message.content, - # 如果有 tool call chunk,也可以在这里处理 - "tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None - }) - yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n" - elif isinstance(message, ToolMessage): - # 工具执行结果(完整的一次性输出) - payload_out.update({ - "type": "tool_result", - "is_delta": False, - "content": message.content, - "tool_name": message.name, - "tool_call_id": message.tool_call_id - }) - yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n" - - elif isinstance(message, AIMessage): - # 完整 AIMessage(不常见在 messages 模式下,但以防万一) - payload_out.update({ - "type": "complete_message", - "is_delta": False, - "content": message.content - }) - yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n" - - else: - # 其他未知类型,记录日志 - print(f"未知消息类型: {type(message)}", message) - continue - - # 流結束 - yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n" - - return StreamingResponse(event_generator(), media_type="text/event-stream") - - -@router.get("/history/{thread_id}", response_model=HistoryResponse) -async def get_chat_history(thread_id: str): - """ - ### 获取项目设计历史记录 - - 此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。 - - #### 1. 功能说明 - * **快照列表**: 返回该项目从启动至今的所有关键节点(Checkpoints)。 - * **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。 - * **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。 - - #### 2. 路径参数 - * `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。 - - #### 3. 返回字段定义 - * `thread_id`: 当前查询的项目ID。 - * `history`: 历史记录数组,包含: - - `checkpoint_id`: 必填,回溯时使用的关键凭证。 - - `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。 - - `node`: 产生该快照的节点名称(如 Designer, Visualizer)。 - - `timestamp`: 逻辑步骤序号。 - - #### 4. 响应示例 - ```json - { - "thread_id": "proj_001", - "history": [ - { - "checkpoint_id": "d82f3a12", - "last_message": "我想设计一款北欧风书架", - "node": "Supervisor", - "timestamp": 1 - }, - { - "checkpoint_id": "f4k92m1a", - "last_message": "建议使用浅色橡木材质,增加简约感...", - "node": "Designer", - "timestamp": 2 - } - ] - } - ``` - """ - config = {"configurable": {"thread_id": thread_id}, } - history_data = [] - async for state in app.aget_state_history(config): - msg_content = "Initial" - if state.values and "messages" in state.values: - msgs = state.values["messages"] - if msgs and len(msgs) > 0: - last_msg = msgs[-1] - # 获取内容并做摘要截断 - content = getattr(last_msg, "content", str(last_msg)) - msg_content = content - - history_data.append(HistoryItem( - checkpoint_id=state.config["configurable"]["checkpoint_id"], - last_message=msg_content, - node=state.metadata.get("source"), - timestamp=state.metadata.get("step") - )) - - return HistoryResponse(thread_id=thread_id, history=history_data) - # try: - - # except Exception as e: - # raise HTTPException(status_code=404, detail=f"History not found: {str(e)}") diff --git a/src/routers/seg_furniture.py b/src/routers/seg_furniture.py new file mode 100644 index 0000000..5c3f9ca --- /dev/null +++ b/src/routers/seg_furniture.py @@ -0,0 +1,58 @@ +import json +import logging + +import requests +from fastapi import APIRouter + +from src.core.config import settings +from src.schemas.response_template import ResponseModel +from src.schemas.san_furniture import SAMRequestModel + +router = APIRouter(prefix="/canvas", tags=["Furniture Canvas"]) +logger = logging.getLogger(__name__) + + +@router.post("/seg_anything") +async def seg_anything(request_data: SAMRequestModel): + """ + **Segment Anything 交互式分割接口** + + 通过传入图片路径和点击的点坐标,返回分割后的掩码数据。 + + ### 参数说明: + - **user_id**:用户id 用于存储分割图 + - **image_path**: 图片在服务器或云端的相对路径。 + - **type**: 推理类型 + - **box**: 框选矩形点位信息 + - **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。 + - **labels**: 坐标点的属性标签,必须与 points 长度一致: + - 1: **前景点** (代表想要分割出的区域) + - 0: **背景点** (代表想要排除的区域) + + ### 请求体示例: + ```json + point + { + "user_id": 1, + "image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png", + "type":"point", + "points": [[310, 403], [493, 375], [261, 266], [404, 484]], + "labels": [1, 1, 0, 1] + } + + box + { + "user_id": 1, + "image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png", + "type":"box", + "box": [350, 286, 544, 520] + } + ``` + """ + try: + logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}") + data = requests.post(f"http://{settings.SEG_ANYTHING}/predict", json=request_data.dict()) + logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}") + return ResponseModel(data=json.loads(data.content)) + except Exception as e: + logger.warning(f"seg_anything Run Exception @@@@@@:{e}") diff --git a/src/schemas/san_furniture.py b/src/schemas/san_furniture.py new file mode 100644 index 0000000..6ec54f3 --- /dev/null +++ b/src/schemas/san_furniture.py @@ -0,0 +1,12 @@ +from typing import Optional, List + +from pydantic import BaseModel, Field + + +class SAMRequestModel(BaseModel): + user_id: int = Field(..., description="用户id, 必填字段") + image_path: str = Field(..., description="图片路径,必填字段") + type: str = Field(..., description="推理类型,必填字段") + points: Optional[List[List[float]]] | None = None + labels: Optional[List[int]] | None = None + box: Optional[List[int]] | None = None diff --git a/src/server/deep_agent/init_prompt.py b/src/server/deep_agent/init_prompt.py index fba9a98..2e61c3f 100644 --- a/src/server/deep_agent/init_prompt.py +++ b/src/server/deep_agent/init_prompt.py @@ -73,6 +73,17 @@ def build_system_prompt(use_report): - **禁止** 把工具返回的原始 image_url 直接暴露给用户。 - 你的输出必须**简短**。 ======================== + + 图片生成工具只会返回 MinIO 内部路径(如 "bucket/folder/xxx.png")。 + 重要规则: + - 当工具返回包含 "MinIO path:" 或类似 "test/a/v/xxx.png" 的内容时,你必须理解这是一张图片。 + - 这张图片**不能直接用于多模态输入**,也不能直接发给用户查看。 + - 任何时候如果你(或用户)需要“看”这张图片、描述这张图片、或把图片作为 vision 模型的输入,你**必须先调用 get_presigned_image_url 工具**,传入该 MinIO path,获取 presigned http URL。 + - 之后用这个 presigned URL 进行多模态调用(例如把 URL 传给支持 image_url 的模型)。 + - 在对话历史中,优先记住 MinIO path,并在需要时主动转换。 + - 用户说“描述一下这张图”时,你应该先获取 presigned URL,再调用 vision 工具/模型。 + + 永远不要假设 MinIO path 是可直接访问的 http 地址。 """ return system_prompt