feat 1.增加 建议词 机制 2.对话生图实现

This commit is contained in:
zcr
2026-02-06 11:55:11 +08:00
parent ec195d17e1
commit 3248c45cd4
12 changed files with 655 additions and 85 deletions

View File

@@ -4,7 +4,7 @@ from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
from src.server.agent.graph import app # 导入已经 compile 好的 graph
from langchain_core.messages import HumanMessage
from langchain_core.messages import HumanMessage, SystemMessage
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
@@ -57,38 +57,114 @@ async def chat_stream(request: ChatRequest):
# 如果是回溯操作,我们生成一个新的 ID或者由前端传入一个新的 target_thread_id
is_branching = source_thread_id and checkpoint_id
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
# 2. 获取配置参数
temp = request.config_params.temperature if request.config_params else 0.7
# 2. 如果是分叉请求,我们需要先“搬家”状态
# 构建基础 Config
current_config = {
"configurable": {
"thread_id": target_thread_id,
"llm_temperature": temp
}
}
# 3. 处理状态初始化与分支
initial_messages = []
# 如果是全新的对话(没有 source_thread_id或者明确要求分叉
if not source_thread_id or is_branching:
# 如果用户传了标签,构造 SystemMessage 注入上下文
if request.config_params:
cp = request.config_params
system_prompt = (
f"Current furniture design background settings\n"
f"- type: {cp.type}\n"
f"- space/region: {cp.region}\n"
f"- style tendency: {cp.style}\n"
f"Please strictly follow the above settings in subsequent conversations。"
)
initial_messages.append(SystemMessage(content=system_prompt))
# 4. 执行分叉逻辑(搬运旧数据)
if is_branching:
# 获取旧状态
source_config = {"configurable": {"thread_id": source_thread_id, "checkpoint_id": checkpoint_id}}
source_config = {
"configurable": {
"thread_id": source_thread_id,
"checkpoint_id": checkpoint_id
}
}
older_state = await app.aget_state(source_config)
# 将旧状态的消息,作为新 thread 的初始值注入
# 注意:这里我们手动把旧消息塞给新 thread
new_config = {"configurable": {"thread_id": target_thread_id}}
await app.aupdate_state(new_config, older_state.values)
# 将旧消息和我们新定义的 SystemMessage 合并
# update_state 会将这些消息推送到新 thread 的存储中
combined_values = older_state.values.copy()
if initial_messages:
combined_values["messages"] = list(combined_values["messages"]) + initial_messages
# 现在的 config 指向新 Thread
current_config = new_config
else:
current_config = {"configurable": {"thread_id": target_thread_id}}
await app.aupdate_state(current_config, combined_values)
async def event_generator():
# 告诉前端:现在是在哪个 Thread 上工作(如果是分叉,前端需要更新本地存储的 ID
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching}, ensure_ascii=False)}\n\n"
# 初始推送状态信息
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
# 构造本次请求的输入
# 如果是第一次开始,且有 initial_messages则连同 user message 一起发送
# --- 核心逻辑:构造本次请求的消息列表 ---
new_messages = []
if not source_thread_id and initial_messages:
new_messages.extend(initial_messages)
# 添加用户消息
new_messages.append(HumanMessage(content=request.message))
# --- 新增:强制绘图指令注入 ---
# if request.force_sketch:
# force_instruction = HumanMessage(
# content="[SYSTEM_DIRECTIVE]: 用户点击了强制生成按钮。请立即根据当前上下文调用 generate_furniture_sketch 工具生成草图,无需确认。"
# )
# new_messages.append(force_instruction)
input_data = {"messages": new_messages}
async for event in app.astream(
{"messages": [HumanMessage(content=request.message)]},
input_data,
current_config,
stream_mode="updates"
):
# ... 发送流式内容的逻辑保持不变 ...
for node_name, output in event.items():
if "messages" in output:
msg = output["messages"][-1]
# 获取最新 state 以获取 checkpoint_id
state = await app.aget_state(current_config)
yield f"data: {json.dumps({'node': node_name, 'content': msg.content, 'checkpoint_id': state.config['configurable']['checkpoint_id']}, ensure_ascii=False)}\n\n"
current_cp_id = state.config["configurable"].get("checkpoint_id")
# 遍历本次 update 产生的所有消息
for msg in output["messages"]:
payload = {
"node": node_name,
"content": "",
"image_url": None,
"checkpoint_id": current_cp_id,
"suggestions": []
}
# --- 核心改动:提取建议按钮 ---
# 无论是不是 Suggester 节点,只要消息里带了建议就提取
if hasattr(msg, "additional_kwargs") and "suggestions" in msg.additional_kwargs:
payload["suggestions"] = msg.additional_kwargs["suggestions"]
content = msg.content
# 逻辑判断MinIO 图片处理
if node_name == "Visualizer" and str(content).endswith("png") and "furniture/sketches" in str(content):
payload["image_url"] = content
payload["content"] = "已为您生成设计草图"
else:
payload["content"] = content
# 如果消息既没有文本、也没有图片、也没有建议(比如中间的 ToolCall 消息),则跳过
if not payload["content"] and not payload["image_url"] and not payload["suggestions"]:
continue
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -147,11 +223,11 @@ async def get_chat_history(thread_id: str):
last_msg = msgs[-1]
# 获取内容并做摘要截断
content = getattr(last_msg, "content", str(last_msg))
msg_content = content[:50] + ("..." if len(content) > 50 else "")
msg_content = content
history_data.append(HistoryItem(
checkpoint_id=state.config["configurable"]["checkpoint_id"],
last_message=msg_content[:50],
last_message=msg_content,
node=state.metadata.get("source"),
timestamp=state.metadata.get("step")
))