feat 1.增加 建议词 机制 2.对话生图实现
This commit is contained in:
@@ -4,7 +4,7 @@ from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
|
||||
@@ -57,38 +57,114 @@ async def chat_stream(request: ChatRequest):
|
||||
# 如果是回溯操作,我们生成一个新的 ID,或者由前端传入一个新的 target_thread_id
|
||||
is_branching = source_thread_id and checkpoint_id
|
||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||
# 2. 获取配置参数
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
|
||||
# 2. 如果是分叉请求,我们需要先“搬家”状态
|
||||
# 构建基础 Config
|
||||
current_config = {
|
||||
"configurable": {
|
||||
"thread_id": target_thread_id,
|
||||
"llm_temperature": temp
|
||||
}
|
||||
}
|
||||
# 3. 处理状态初始化与分支
|
||||
initial_messages = []
|
||||
|
||||
# 如果是全新的对话(没有 source_thread_id),或者明确要求分叉
|
||||
if not source_thread_id or is_branching:
|
||||
# 如果用户传了标签,构造 SystemMessage 注入上下文
|
||||
if request.config_params:
|
||||
cp = request.config_params
|
||||
system_prompt = (
|
||||
f"Current furniture design background settings:\n"
|
||||
f"- type: {cp.type}\n"
|
||||
f"- space/region: {cp.region}\n"
|
||||
f"- style tendency: {cp.style}\n"
|
||||
f"Please strictly follow the above settings in subsequent conversations。"
|
||||
)
|
||||
initial_messages.append(SystemMessage(content=system_prompt))
|
||||
|
||||
# 4. 执行分叉逻辑(搬运旧数据)
|
||||
if is_branching:
|
||||
# 获取旧状态
|
||||
source_config = {"configurable": {"thread_id": source_thread_id, "checkpoint_id": checkpoint_id}}
|
||||
source_config = {
|
||||
"configurable": {
|
||||
"thread_id": source_thread_id,
|
||||
"checkpoint_id": checkpoint_id
|
||||
}
|
||||
}
|
||||
older_state = await app.aget_state(source_config)
|
||||
|
||||
# 将旧状态的消息,作为新 thread 的初始值注入
|
||||
# 注意:这里我们手动把旧消息塞给新 thread
|
||||
new_config = {"configurable": {"thread_id": target_thread_id}}
|
||||
await app.aupdate_state(new_config, older_state.values)
|
||||
# 将旧消息和我们新定义的 SystemMessage 合并
|
||||
# update_state 会将这些消息推送到新 thread 的存储中
|
||||
combined_values = older_state.values.copy()
|
||||
if initial_messages:
|
||||
combined_values["messages"] = list(combined_values["messages"]) + initial_messages
|
||||
|
||||
# 现在的 config 指向新 Thread
|
||||
current_config = new_config
|
||||
else:
|
||||
current_config = {"configurable": {"thread_id": target_thread_id}}
|
||||
await app.aupdate_state(current_config, combined_values)
|
||||
|
||||
async def event_generator():
|
||||
# 告诉前端:现在是在哪个 Thread 上工作(如果是分叉,前端需要更新本地存储的 ID)
|
||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching}, ensure_ascii=False)}\n\n"
|
||||
# 初始推送状态信息
|
||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构造本次请求的输入
|
||||
# 如果是第一次开始,且有 initial_messages,则连同 user message 一起发送
|
||||
# --- 核心逻辑:构造本次请求的消息列表 ---
|
||||
new_messages = []
|
||||
if not source_thread_id and initial_messages:
|
||||
new_messages.extend(initial_messages)
|
||||
# 添加用户消息
|
||||
new_messages.append(HumanMessage(content=request.message))
|
||||
|
||||
# --- 新增:强制绘图指令注入 ---
|
||||
# if request.force_sketch:
|
||||
# force_instruction = HumanMessage(
|
||||
# content="[SYSTEM_DIRECTIVE]: 用户点击了强制生成按钮。请立即根据当前上下文调用 generate_furniture_sketch 工具生成草图,无需确认。"
|
||||
# )
|
||||
# new_messages.append(force_instruction)
|
||||
|
||||
input_data = {"messages": new_messages}
|
||||
|
||||
async for event in app.astream(
|
||||
{"messages": [HumanMessage(content=request.message)]},
|
||||
input_data,
|
||||
current_config,
|
||||
stream_mode="updates"
|
||||
):
|
||||
# ... 发送流式内容的逻辑保持不变 ...
|
||||
for node_name, output in event.items():
|
||||
if "messages" in output:
|
||||
msg = output["messages"][-1]
|
||||
# 获取最新 state 以获取 checkpoint_id
|
||||
state = await app.aget_state(current_config)
|
||||
yield f"data: {json.dumps({'node': node_name, 'content': msg.content, 'checkpoint_id': state.config['configurable']['checkpoint_id']}, ensure_ascii=False)}\n\n"
|
||||
current_cp_id = state.config["configurable"].get("checkpoint_id")
|
||||
|
||||
# 遍历本次 update 产生的所有消息
|
||||
for msg in output["messages"]:
|
||||
payload = {
|
||||
"node": node_name,
|
||||
"content": "",
|
||||
"image_url": None,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"suggestions": []
|
||||
}
|
||||
|
||||
# --- 核心改动:提取建议按钮 ---
|
||||
# 无论是不是 Suggester 节点,只要消息里带了建议就提取
|
||||
if hasattr(msg, "additional_kwargs") and "suggestions" in msg.additional_kwargs:
|
||||
payload["suggestions"] = msg.additional_kwargs["suggestions"]
|
||||
|
||||
content = msg.content
|
||||
# 逻辑判断:MinIO 图片处理
|
||||
if node_name == "Visualizer" and str(content).endswith("png") and "furniture/sketches" in str(content):
|
||||
payload["image_url"] = content
|
||||
payload["content"] = "已为您生成设计草图"
|
||||
else:
|
||||
payload["content"] = content
|
||||
|
||||
# 如果消息既没有文本、也没有图片、也没有建议(比如中间的 ToolCall 消息),则跳过
|
||||
if not payload["content"] and not payload["image_url"] and not payload["suggestions"]:
|
||||
continue
|
||||
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
@@ -147,11 +223,11 @@ async def get_chat_history(thread_id: str):
|
||||
last_msg = msgs[-1]
|
||||
# 获取内容并做摘要截断
|
||||
content = getattr(last_msg, "content", str(last_msg))
|
||||
msg_content = content[:50] + ("..." if len(content) > 50 else "")
|
||||
msg_content = content
|
||||
|
||||
history_data.append(HistoryItem(
|
||||
checkpoint_id=state.config["configurable"]["checkpoint_id"],
|
||||
last_message=msg_content[:50],
|
||||
last_message=msg_content,
|
||||
node=state.metadata.get("source"),
|
||||
timestamp=state.metadata.get("step")
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user