新增对话接口
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import AsyncGenerator
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
from src.server.agent.graph import app
|
||||
from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
@@ -169,8 +169,10 @@ async def chat_stream(request: ChatRequest):
|
||||
await app.aupdate_state(current_config, combined_values)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# 初始事件
|
||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 構造輸入(保持不變)
|
||||
new_messages = initial_messages[:] if not source_thread_id else []
|
||||
new_messages.append(HumanMessage(content=request.message))
|
||||
|
||||
@@ -180,143 +182,124 @@ async def chat_stream(request: ChatRequest):
|
||||
"use_report": request.use_report,
|
||||
}
|
||||
|
||||
interrupted = False
|
||||
|
||||
current_cp_id = None
|
||||
# try:
|
||||
# ─── 重點改這裡 ───────────────────────────────────────
|
||||
async for event in app.astream(
|
||||
input_data,
|
||||
config=current_config,
|
||||
stream_mode=["updates", "messages", "custom"], # 确保包含 "values"
|
||||
stream_mode=["custom", "updates", "messages"], # 推薦組合
|
||||
subgraphs=True
|
||||
# 不再需要,行為已包含
|
||||
):
|
||||
if interrupted:
|
||||
break
|
||||
logger.info(event)
|
||||
# 取得 checkpoint_id(可選,視前端是否真的需要)
|
||||
latest_state = await app.aget_state(current_config)
|
||||
configurable = latest_state.config.get("configurable", {})
|
||||
current_cp_id = configurable.get("checkpoint_id", "")
|
||||
if len(event) == 3:
|
||||
namespace, channel, payload = event
|
||||
# 路由更新
|
||||
if event[1] == "updates":
|
||||
namespace, _, payload = event
|
||||
|
||||
logger.info(f"Received event: {event}")
|
||||
if isinstance(payload, dict):
|
||||
for update_node, update_content in payload.items():
|
||||
|
||||
if not isinstance(event, tuple) or len(event) != 3:
|
||||
continue
|
||||
# 处理 reducer(Overwrite / Append)
|
||||
if isinstance(update_content, dict):
|
||||
for k, v in update_content.items():
|
||||
if hasattr(v, "value"): # Overwrite(...)
|
||||
update_content[k] = v.value
|
||||
|
||||
run_id, channel, payload = event
|
||||
if isinstance(update_content, dict) and "messages" in update_content:
|
||||
msgs = []
|
||||
for m in update_content["messages"]:
|
||||
msgs.append({
|
||||
"type": m.__class__.__name__,
|
||||
"content": getattr(m, "content", ""),
|
||||
"name": getattr(m, "name", None),
|
||||
"tool_calls": getattr(m, "tool_calls", None),
|
||||
})
|
||||
update_content["messages"] = msgs
|
||||
|
||||
# ─────────────── 检测 interrupt ───────────────
|
||||
# __interrupt__ 最常出现在 "values" 或 "updates" channel 的 payload 中
|
||||
if channel in ("values", "updates") and isinstance(payload, dict) and "__interrupt__" in payload:
|
||||
interrupt_data = payload["__interrupt__"][0].value['__interrupt__']
|
||||
interrupted = True
|
||||
yield f"data: {json.dumps({
|
||||
"type": "interrupt",
|
||||
"node": interrupt_data.get("node", "Persona"),
|
||||
"question": interrupt_data.get('question', "异常|||||||||||||"),
|
||||
"current_persona": interrupt_data.get("current_persona_snapshot", {}),
|
||||
"status": "waiting_for_input"
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({
|
||||
"node": "Supervisor",
|
||||
"type": "updates",
|
||||
"content": update_content,
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 立即停止后续发送,等待用户回复后 resume
|
||||
break
|
||||
# 自定义事件
|
||||
elif event[1] == "custom":
|
||||
if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", "report_error", "report_save_warning", "report_complete"):
|
||||
delta = payload.get("delta", "").strip()
|
||||
if delta:
|
||||
yield f"data: {json.dumps({
|
||||
'node': 'Researcher',
|
||||
'type': 'report_delta',
|
||||
'content': delta,
|
||||
'is_delta': True,
|
||||
'checkpoint_id': current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
# 基础消息
|
||||
elif event[1] == "messages":
|
||||
if namespace:
|
||||
node_name = namespace[-1] if isinstance(namespace, tuple) else namespace
|
||||
if ':' in node_name:
|
||||
node_name = node_name.split(':')[0]
|
||||
else:
|
||||
node_name = "Main"
|
||||
message, metadata = payload
|
||||
is_not_research = node_name != 'Researcher'
|
||||
node_name = metadata.get("langgraph_node", node_name)
|
||||
# 3. 处理不同类型的 message
|
||||
payload_out = {
|
||||
"node": node_name,
|
||||
"checkpoint_id": current_cp_id, # 你之前已经获取了
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": "unknown"
|
||||
}
|
||||
|
||||
# ─────────────── 正常消息处理 ───────────────
|
||||
if channel == "messages":
|
||||
if run_id:
|
||||
node_name = run_id[-1] if isinstance(run_id, tuple) else run_id
|
||||
if ':' in node_name:
|
||||
node_name = node_name.split(':')[0]
|
||||
else:
|
||||
node_name = "Main"
|
||||
|
||||
message, metadata = payload
|
||||
node_name = metadata.get("langgraph_node", node_name)
|
||||
|
||||
payload_out = {
|
||||
"node": node_name,
|
||||
"checkpoint_id": current_cp_id or "unknown",
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": "unknown"
|
||||
}
|
||||
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if node_name != 'Researcher' and message.tool_call_chunks:
|
||||
payload_out.update({
|
||||
"type": "delta",
|
||||
"is_delta": True,
|
||||
"content": message.content,
|
||||
"tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif isinstance(message, ToolMessage):
|
||||
try:
|
||||
tools_data = json.loads(message.content)
|
||||
if isinstance(message, AIMessageChunk):
|
||||
# 节点不是research 并且 tool_call_chunks不为空的情况下,避免research的report工具使用custom发出的消息和message的消息重复了
|
||||
if is_not_research and node_name != 'Researcher' and message.tool_call_chunks:
|
||||
payload_out.update({
|
||||
"type": "delta",
|
||||
"is_delta": True,
|
||||
"content": message.content,
|
||||
# 如果有 tool call chunk,也可以在这里处理
|
||||
"tool_call_chunk": message.tool_call_chunks[0] if message.tool_call_chunks else None
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif isinstance(message, ToolMessage):
|
||||
# 工具执行结果(完整的一次性输出)
|
||||
payload_out.update({
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": tools_data.get("data", ""),
|
||||
"tool_name": tools_data.get("tool_name", ""),
|
||||
"content": message.content,
|
||||
"tool_name": message.name,
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
except:
|
||||
pass
|
||||
|
||||
elif isinstance(message, AIMessage):
|
||||
payload_out.update({
|
||||
"type": "complete_message",
|
||||
"is_delta": False,
|
||||
"content": message.content
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif channel == "updates":
|
||||
# 处理 updates(非 interrupt 的部分)
|
||||
if isinstance(payload, dict):
|
||||
for update_node, update_content in payload.items():
|
||||
# 处理 reducer 包裹的值
|
||||
if isinstance(update_content, dict):
|
||||
for k, v in update_content.items():
|
||||
if hasattr(v, "value"):
|
||||
update_content[k] = v.value
|
||||
|
||||
# 序列化 messages
|
||||
if isinstance(update_content, dict) and "messages" in update_content:
|
||||
msgs = []
|
||||
for m in update_content["messages"]:
|
||||
msgs.append({
|
||||
"type": m.__class__.__name__,
|
||||
"content": getattr(m, "content", ""),
|
||||
"name": getattr(m, "name", None),
|
||||
"tool_calls": getattr(m, "tool_calls", None),
|
||||
})
|
||||
update_content["messages"] = msgs
|
||||
|
||||
yield f"data: {json.dumps({
|
||||
"node": "Supervisor", # 或 update_node
|
||||
"type": "updates",
|
||||
"content": update_content,
|
||||
elif isinstance(message, AIMessage):
|
||||
# 完整 AIMessage(不常见在 messages 模式下,但以防万一)
|
||||
payload_out.update({
|
||||
"type": "complete_message",
|
||||
"is_delta": False,
|
||||
"checkpoint_id": current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
"content": message.content
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif channel == "custom":
|
||||
if isinstance(payload, dict) and payload.get("type") in ("report_delta", "report_start", ...):
|
||||
delta = payload.get("delta", "").strip()
|
||||
if delta:
|
||||
yield f"data: {json.dumps({
|
||||
'node': 'Researcher',
|
||||
'type': 'report_delta',
|
||||
'content': delta,
|
||||
'is_delta': True,
|
||||
'checkpoint_id': current_cp_id,
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
# except Exception as e:
|
||||
# print("error")
|
||||
else:
|
||||
# 其他未知类型,记录日志
|
||||
print(f"未知消息类型: {type(message)}", message)
|
||||
continue
|
||||
|
||||
# 结束标记
|
||||
if interrupted:
|
||||
yield f"data: {json.dumps({'status': 'interrupted', 'reason': 'waiting_for_user_input'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
# 流結束
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import random
|
||||
import uuid
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
@@ -9,6 +10,7 @@ from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
||||
|
||||
from src.server.deep_agent.agents.main_agent import build_main_agent
|
||||
from src.server.deep_agent.tools.extract_suggested_questions import format_messages, generate_suggested_questions
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,7 +82,6 @@ async def chat_stream(request: ChatRequest):
|
||||
"checkpoint_id": "快照ID",
|
||||
"is_delta": "boolean",
|
||||
"type": "消息类型",
|
||||
"suggestions": "建议列表(可选)",
|
||||
"tool_name": "工具名称(可选)",
|
||||
"tool_call_chunk": "工具调用片段(可选)",
|
||||
"tool_call_id": "工具调用ID(可选)"
|
||||
@@ -136,8 +137,6 @@ async def chat_stream(request: ChatRequest):
|
||||
# 2. 配置參數
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
|
||||
need_suggestion = request.need_suggestion,
|
||||
|
||||
current_config = {
|
||||
"recursion_limit": 120,
|
||||
"configurable": {
|
||||
@@ -184,28 +183,55 @@ async def chat_stream(request: ChatRequest):
|
||||
input_data = {
|
||||
"messages": new_messages,
|
||||
}
|
||||
current_cp_id = None
|
||||
async for stream in main_agent.astream(
|
||||
input_data,
|
||||
config=current_config,
|
||||
stream_mode=["updates", "messages", "custom"], # 确保包含 "values"
|
||||
stream_mode=["updates", "messages", "custom"],
|
||||
subgraphs=True
|
||||
):
|
||||
# logger.info(f"Received event: {event}")
|
||||
_, mode, chunks = stream
|
||||
if mode == "updates":
|
||||
# TODO 补充
|
||||
if mode == "updates": # 只做记录 不做事件返回
|
||||
print(f"[updates] {chunks}")
|
||||
update_model_messages = chunks.get("model", None)
|
||||
update_tools_messages = chunks.get("tools", None)
|
||||
payload_out = {
|
||||
"node": "",
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"type": "updates"
|
||||
}
|
||||
|
||||
if update_model_messages:
|
||||
model_messages = update_model_messages.get("messages", [])
|
||||
for model_token in model_messages:
|
||||
if isinstance(model_token, AIMessage):
|
||||
model_content_blocks = model_token.content_blocks[0]
|
||||
model_name = model_token.name
|
||||
payload_out.update({
|
||||
"node": model_name if model_name else "main",
|
||||
"tool_calls": model_token.tool_calls
|
||||
})
|
||||
logger.info(f"[updates] {model_name} -- {model_content_blocks} -- {model_token.tool_calls}")
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif update_tools_messages:
|
||||
tools_messages = update_tools_messages.get("messages", [])
|
||||
for tools_token in tools_messages:
|
||||
if isinstance(tools_token, ToolMessage):
|
||||
tool_content_blocks = tools_token.content_blocks[0]
|
||||
tool_name = tools_token.name
|
||||
logger.info(f"[updates] {tool_name} -- {tool_content_blocks}")
|
||||
else:
|
||||
logger.info(f"[updates] -- {chunks}")
|
||||
|
||||
elif mode == "messages":
|
||||
token, metadata = chunks
|
||||
subagent_name = metadata.get('lc_agent_name', None)
|
||||
subagent_name = metadata.get('lc_agent_name', "main")
|
||||
payload_out = {
|
||||
"node": subagent_name,
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": ""
|
||||
}
|
||||
|
||||
@@ -213,29 +239,36 @@ async def chat_stream(request: ChatRequest):
|
||||
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
if reasoning:
|
||||
payload_out.update({
|
||||
"type": "reasoning",
|
||||
"is_delta": True,
|
||||
"content": text,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
if len(reasoning) == 1:
|
||||
payload_out.update({
|
||||
"type": "reasoning",
|
||||
"is_delta": True,
|
||||
"content": reasoning[0].get("reasoning", ""),
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
else:
|
||||
print(f"[reasoning] {reasoning}*************************************************************************************")
|
||||
elif text:
|
||||
payload_out.update({
|
||||
"type": "text",
|
||||
"is_delta": True,
|
||||
"content": text,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
if len(text) == 1:
|
||||
payload_out.update({
|
||||
"type": "text",
|
||||
"is_delta": True,
|
||||
"content": text[0].get("text", ""),
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
else:
|
||||
print(f"[text] {text}*************************************************************************************")
|
||||
else:
|
||||
payload_out.update({
|
||||
"type": "tool_call",
|
||||
"is_delta": True,
|
||||
"tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif isinstance(token, ToolMessageChunk): # 工具返回
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
payload_out.update({
|
||||
"type": "tool_text",
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": text,
|
||||
"tool_name": token.name,
|
||||
@@ -244,7 +277,7 @@ async def chat_stream(request: ChatRequest):
|
||||
elif isinstance(token, ToolMessage): # 工具返回
|
||||
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||
payload_out.update({
|
||||
"type": "tool_text",
|
||||
"type": "tool_result",
|
||||
"is_delta": False,
|
||||
"content": text,
|
||||
"tool_name": token.name,
|
||||
@@ -254,14 +287,11 @@ async def chat_stream(request: ChatRequest):
|
||||
continue
|
||||
|
||||
elif mode == "custom":
|
||||
token, metadata = chunks
|
||||
subagent_name = metadata.get('lc_agent_name', None)
|
||||
payload_out = {
|
||||
"node": subagent_name,
|
||||
"node": "research-agent",
|
||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
||||
"is_delta": False,
|
||||
"content": "",
|
||||
"suggestions": [],
|
||||
"type": ""
|
||||
}
|
||||
delta = chunks.get("delta", "").strip()
|
||||
@@ -273,40 +303,10 @@ async def chat_stream(request: ChatRequest):
|
||||
})
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
# elif channel == "updates":
|
||||
# # 处理 updates(非 interrupt 的部分)
|
||||
# if isinstance(payload, dict):
|
||||
# for update_node, update_content in payload.items():
|
||||
# # 处理 reducer 包裹的值
|
||||
# if isinstance(update_content, dict):
|
||||
# for k, v in update_content.items():
|
||||
# if hasattr(v, "value"):
|
||||
# update_content[k] = v.value
|
||||
#
|
||||
# # 序列化 messages
|
||||
# if isinstance(update_content, dict) and "messages" in update_content:
|
||||
# msgs = []
|
||||
# for m in update_content["messages"]:
|
||||
# msgs.append({
|
||||
# "type": m.__class__.__name__,
|
||||
# "content": getattr(m, "content", ""),
|
||||
# "name": getattr(m, "name", None),
|
||||
# "tool_calls": getattr(m, "tool_calls", None),
|
||||
# })
|
||||
# update_content["messages"] = msgs
|
||||
#
|
||||
# yield f"data: {json.dumps({
|
||||
# "node": "Supervisor", # 或 update_node
|
||||
# "type": "updates",
|
||||
# "content": update_content,
|
||||
# "is_delta": False,
|
||||
# "checkpoint_id": current_cp_id,
|
||||
# }, ensure_ascii=False)}\n\n"
|
||||
#
|
||||
# elif channel == "custom":
|
||||
else:
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
|
||||
suggested_questions = generate_suggested_questions(main_agent, target_thread_id)
|
||||
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
@@ -357,6 +357,7 @@ async def get_chat_history(thread_id: str):
|
||||
"""
|
||||
config = {"configurable": {"thread_id": thread_id}, }
|
||||
history_data = []
|
||||
main_agent = build_main_agent(False)
|
||||
async for state in main_agent.aget_state_history(config):
|
||||
msg_content = "Initial"
|
||||
if state.values and "messages" in state.values:
|
||||
|
||||
169
src/server/agent/agents.py
Normal file
169
src/server/agent/agents.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
from deepagents import create_deep_agent
|
||||
from deepagents.backends import FilesystemBackend
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage, AIMessageChunk
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_qwq import ChatQwen
|
||||
|
||||
from src.core.config import settings
|
||||
from src.server.agent.prompt import SYSTEM_PROMPT, visualizer_prompt, designer_prompt
|
||||
from src.server.agent.state import AgentState
|
||||
from src.server.agent.tools.generate_furniture_sketch import generate_furniture
|
||||
from src.server.agent.tools.crawl_tool import crawl4ai_batch
|
||||
from src.server.agent.tools.report_generator_tool import report_generator
|
||||
from src.server.agent.tools.research_tool import topic_research
|
||||
from src.server.agent.tools.structured_retrieval_tool import structured_retrieval
|
||||
from src.server.agent.tools.terminate_tool import terminate
|
||||
from src.server.agent.tools.user_persona_tool import manage_user_persona
|
||||
from src.server.utils.generate_suggestion import generate_chat_suggestions
|
||||
|
||||
# 目前這個主程式檔案所在的目錄
|
||||
MAIN_DIR = Path(__file__).resolve().parent
|
||||
|
||||
# 專案根目錄(因為 main.py 跟 tools/ 同級,所以 parent 就是根)
|
||||
PROJECT_ROOT = MAIN_DIR
|
||||
|
||||
model = ChatQwen(
|
||||
enable_thinking=False,
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
|
||||
tools = [manage_user_persona, topic_research, crawl4ai_batch, structured_retrieval, report_generator, terminate]
|
||||
research_agent = create_deep_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
backend=FilesystemBackend(
|
||||
root_dir=str(PROJECT_ROOT / "agent_workspace"),
|
||||
virtual_mode=False, # 重要:關掉虛擬模式 → 真的寫硬碟
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# 辅助函数:根据配置动态获取 LLM
|
||||
def get_model(config: RunnableConfig, streaming=False):
|
||||
temp = config["configurable"].get("llm_temperature", 0.5)
|
||||
return ChatQwen(
|
||||
enable_thinking=False,
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
temperature=temp,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
||||
# --- 1. Designer Agent (设计顾问) ---
|
||||
async def designer_node(state: AgentState, config: RunnableConfig):
|
||||
"""负责细化设计需求,提供专业参数"""
|
||||
model = get_model(config) # 获取带动态温度的模型
|
||||
|
||||
messages = state["messages"]
|
||||
system_prompt = SystemMessage(content=designer_prompt)
|
||||
should_suggest = len(state["messages"]) % 5 == 0
|
||||
response = await model.ainvoke([system_prompt] + messages)
|
||||
return {"messages": [response], "require_suggestion": should_suggest}
|
||||
|
||||
|
||||
async def researcher_node(
|
||||
state: AgentState,
|
||||
config: RunnableConfig
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
use_report = config["configurable"].get("use_report", False)
|
||||
if not use_report:
|
||||
yield {
|
||||
"messages": [AIMessage(
|
||||
content="深度报告功能未启用,请通过前端按钮触发。",
|
||||
name="Researcher"
|
||||
)],
|
||||
"next": "Supervisor"
|
||||
}
|
||||
return
|
||||
|
||||
messages = state["messages"]
|
||||
last_human = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
|
||||
|
||||
if not last_human:
|
||||
yield {
|
||||
"messages": [AIMessage(
|
||||
content="深度研究节点:未找到有效的用户问题",
|
||||
name="Researcher"
|
||||
)],
|
||||
"next": "Supervisor"
|
||||
}
|
||||
return
|
||||
current_step = "正在启动深度报告生成..."
|
||||
yield {
|
||||
"messages": [AIMessage(
|
||||
content="正在启动深度报告生成...",
|
||||
name="Researcher",
|
||||
additional_kwargs={
|
||||
"current_step": current_step,
|
||||
"streaming": True
|
||||
}
|
||||
)]
|
||||
}
|
||||
async for chunk in research_agent.astream(
|
||||
{"messages": messages[-12:]},
|
||||
config=config
|
||||
):
|
||||
if "messages" in chunk and isinstance(chunk["messages"], AIMessageChunk):
|
||||
yield {
|
||||
"messages": chunk["messages"], # 逐 token 追加
|
||||
# 可以額外 yield 一些 metadata,例如
|
||||
# "node": "Researcher",
|
||||
# "status": "thinking"
|
||||
}
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
|
||||
# --- 3. Visualizer Agent (视觉专家) ---
|
||||
async def visualizer_node(state: AgentState, config: RunnableConfig):
|
||||
"""负责将自然语言转化为绘图 Prompt 并调用绘图工具"""
|
||||
model = get_model(config, streaming=False)
|
||||
tools = [generate_furniture]
|
||||
llm_with_tools = model.bind_tools(tools)
|
||||
messages = state["messages"]
|
||||
system_prompt = SystemMessage(content=visualizer_prompt)
|
||||
response = await llm_with_tools.ainvoke([system_prompt] + messages)
|
||||
|
||||
if response.tool_calls:
|
||||
tool_call = response.tool_calls[0]
|
||||
if tool_call["name"] == "generate_furniture":
|
||||
img_url = await generate_furniture.ainvoke(tool_call["args"])
|
||||
return {
|
||||
"messages": [
|
||||
response,
|
||||
ToolMessage(content=img_url, tool_call_id=tool_call["id"]) # 标记这是一个图片结果
|
||||
]
|
||||
}
|
||||
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
# --- 4. Suggester Agent (推荐对话专家) ---
|
||||
async def suggester_node(state: AgentState, config: RunnableConfig):
|
||||
"""专门生成追问建议的节点,作为流程终点"""
|
||||
model = get_model(config)
|
||||
messages = state["messages"]
|
||||
|
||||
# 只需要分析最近的对话
|
||||
suggestions = await generate_chat_suggestions(messages, model)
|
||||
|
||||
# 返回一个特殊消息,前端通过解析 additional_kwargs 获取按钮内容
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"suggestions": suggestions},
|
||||
name="Suggester"
|
||||
)
|
||||
]
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
|
||||
from langchain_core.messages import SystemMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.server.agent.agents.init_llm import get_model
|
||||
from src.server.agent.memory.memory_manager import MemoryManager
|
||||
|
||||
from src.server.agent.prompt import designer_prompt
|
||||
from src.server.agent.state import AgentState
|
||||
|
||||
|
||||
async def designer_node(state: AgentState, config: RunnableConfig):
|
||||
"""负责细化设计需求,提供专业参数"""
|
||||
model = get_model(config) # 获取带动态温度的模型
|
||||
messages = MemoryManager.build_llm_context(state)
|
||||
system_prompt = SystemMessage(content=designer_prompt)
|
||||
response = await model.ainvoke([system_prompt] + messages)
|
||||
return {"messages": [response]}
|
||||
@@ -1,93 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from deepagents import create_deep_agent
|
||||
from deepagents.backends import FilesystemBackend
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_qwq import ChatQwen
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.core.config import settings
|
||||
from src.server.agent.prompt import SYSTEM_PROMPT
|
||||
from src.server.agent.tools.crawl_tool import crawl4ai_batch
|
||||
from src.server.agent.tools.report_generator_tool import report_generator
|
||||
from src.server.agent.tools.research_tool import topic_research
|
||||
from src.server.agent.tools.structured_retrieval_tool import structured_retrieval
|
||||
from src.server.agent.tools.terminate_tool import terminate
|
||||
from src.server.agent.tools.user_persona_tool import get_user_persona
|
||||
|
||||
MAIN_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_ROOT = MAIN_DIR
|
||||
|
||||
|
||||
class PersonaUpdate(BaseModel):
|
||||
"""从对话中提取/更新的用户画像结构"""
|
||||
persona: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="键值对形式的用户画像,例如 {'风格偏好': '北欧简约', '预算范围': '8000-15000'}"
|
||||
)
|
||||
complete: bool = Field(
|
||||
...,
|
||||
description="当前画像是否足够完整,支持市场研究和设计"
|
||||
)
|
||||
question: str = Field(
|
||||
default="",
|
||||
description="如果不完整,这里是需要问用户的自然语言问题;否则为空字符串"
|
||||
)
|
||||
|
||||
|
||||
llm_supervisor = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
model = ChatQwen(
|
||||
enable_thinking=False,
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
|
||||
tools = [get_user_persona, topic_research, crawl4ai_batch, structured_retrieval, report_generator, terminate]
|
||||
|
||||
research_agent = create_deep_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
backend=FilesystemBackend(
|
||||
root_dir=str(PROJECT_ROOT / "agent_workspace"),
|
||||
virtual_mode=False, # 重要:關掉虛擬模式 → 真的寫硬碟
|
||||
)
|
||||
)
|
||||
|
||||
persona_agent = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key="sk-799944b821bd4bfdb2f6188ebb52a76b").with_structured_output(PersonaUpdate) # 或用 .bind_tools + parser
|
||||
|
||||
summary_llm = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key="sk-799944b821bd4bfdb2f6188ebb52a76b")
|
||||
|
||||
|
||||
def get_model(config: RunnableConfig, streaming=False):
|
||||
temp = config["configurable"].get("llm_temperature", 0.5)
|
||||
return ChatQwen(
|
||||
enable_thinking=False,
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
temperature=temp,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
streaming=streaming
|
||||
)
|
||||
@@ -1,154 +0,0 @@
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from pymongo import MongoClient
|
||||
from src.core.config import MONGO_URI
|
||||
from src.server.agent.agents.init_llm import persona_agent
|
||||
from src.server.agent.memory.memory_manager import MemoryManager
|
||||
from src.server.agent.state import AgentState
|
||||
|
||||
client = MongoClient(MONGO_URI)
|
||||
db = client["furniture_agent_db"]
|
||||
persona_collection = db["user_persona"]
|
||||
|
||||
EXTRACTION_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """你是一个专为家具设计师服务的用户画像提取专家。
|
||||
|
||||
当前已知用户画像(JSON格式):
|
||||
{current_persona_json}
|
||||
|
||||
任务:
|
||||
1. 从下面所有用户消息中,提取或更新用户对家具设计的偏好信息。
|
||||
只提取明确提到或强烈暗示的内容,不要臆想或添加默认值。
|
||||
2. 输出更新后的完整 persona JSON(只包含有值字段)。
|
||||
推荐使用的键名(优先使用这些):
|
||||
- "风格偏好" (如 "北欧"、"极简"、"工业风")
|
||||
- "家具类型" (如 "沙发"、"餐桌"、"书柜"、"办公椅")
|
||||
- "颜色偏好" (如 "原木色"、"白色"、"深灰"、"大地色系")
|
||||
- 其他可选键:预算范围、空间大小、材质偏好、使用场景等
|
||||
|
||||
3. 判断当前画像是否“足够完整”来支持针对家具设计师的市场趋势报告:
|
||||
- **必须条件**:至少包含 "风格偏好"、"家具类型"、"颜色偏好" 中的 2 项以上
|
||||
- 如果缺少 2 项或以上核心信息,返回 complete: false,并生成一个自然、礼貌、具体的追问(中文),优先询问缺失的核心项
|
||||
- 如果核心三项中至少有 2 项已明确,返回 complete: true,不需要追问
|
||||
|
||||
输出必须是严格的 JSON 对象,包含三个字段:
|
||||
- "persona": 一个对象,键是画像属性,值是对应的字符串(或数组,如果有多个偏好)
|
||||
- "complete": 布尔值 true 或 false
|
||||
- "question": 字符串,如果 complete 为 true 则为空字符串,否则是具体的追问句子
|
||||
|
||||
示例输出结构(仅供参考,不要直接复制):
|
||||
{{
|
||||
"persona": {{
|
||||
"风格偏好": "北欧简约",
|
||||
"家具类型": "沙发",
|
||||
"颜色偏好": "原木色 + 浅灰"
|
||||
}},
|
||||
"complete": true,
|
||||
"question": ""
|
||||
}}
|
||||
"""),
|
||||
("placeholder", "{messages}"),
|
||||
])
|
||||
|
||||
|
||||
def get_persona_from_mongo(thread_id: str) -> Dict[str, Any]:
|
||||
doc = persona_collection.find_one({"thread_id": thread_id}, sort=[("_id", -1)])
|
||||
if doc and "persona" in doc:
|
||||
return doc["persona"]
|
||||
return {}
|
||||
|
||||
|
||||
def save_persona_to_mongo(thread_id: str, persona: Dict[str, Any], is_complete: bool):
|
||||
try:
|
||||
result = persona_collection.update_one(
|
||||
{"thread_id": thread_id},
|
||||
{
|
||||
"$set": {
|
||||
"persona": persona,
|
||||
"persona_complete": is_complete,
|
||||
"updated_at": datetime.utcnow(),
|
||||
"last_update_reason": "persona_node_update"
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
print(f"[Persona Save] thread_id: {thread_id} | matched: {result.matched_count} | modified: {result.modified_count} | upserted: {result.upserted_id}")
|
||||
except Exception as e:
|
||||
print(f"[Persona Save Error] {e}")
|
||||
|
||||
|
||||
def persona_node(state: AgentState, config: RunnableConfig):
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
|
||||
# 读取已有画像(MongoDB 优先)
|
||||
persisted_persona = get_persona_from_mongo(thread_id)
|
||||
current_persona = state.get("persona", persisted_persona)
|
||||
|
||||
# messages = state["messages"]
|
||||
messages = MemoryManager.build_llm_context(state)
|
||||
current_persona_json = json.dumps(current_persona, ensure_ascii=False, indent=None)
|
||||
|
||||
chain = EXTRACTION_PROMPT | persona_agent
|
||||
result = chain.invoke({
|
||||
"current_persona_json": current_persona_json,
|
||||
"messages": messages,
|
||||
})
|
||||
|
||||
updated_persona = result.persona
|
||||
is_complete = result.complete
|
||||
question = (result.question or "").strip()
|
||||
|
||||
updates = {
|
||||
"persona": updated_persona,
|
||||
"persona_complete": is_complete,
|
||||
"persona_summary": json.dumps(updated_persona, ensure_ascii=False, indent=2),
|
||||
}
|
||||
|
||||
# 持久化到 MongoDB
|
||||
save_persona_to_mongo(thread_id, updated_persona, is_complete)
|
||||
|
||||
if is_complete:
|
||||
updates["messages"] = messages + [AIMessage(
|
||||
content=(
|
||||
"用户画像已足够完整(风格、家具类型、颜色偏好已明确),并已保存到项目记录。\n\n"
|
||||
"接下来是否需要我为您生成一份针对当前风格与家具类型的市场趋势报告?\n"
|
||||
"回复“是”或“需要”即可开始生成;回复“不需要”或“先不用”则直接进入家具设计阶段。"
|
||||
)
|
||||
)]
|
||||
return updates
|
||||
|
||||
# 不完整 → 询问(优先问核心三项)
|
||||
if not question:
|
||||
missing = []
|
||||
core_keys = ["风格偏好", "家具类型", "颜色偏好"]
|
||||
for key in core_keys:
|
||||
if key not in updated_persona or not updated_persona[key]:
|
||||
missing.append(key)
|
||||
if missing:
|
||||
question = f"为了更好地为您生成趋势报告,能否补充一下您对{'、'.join(missing)}的偏好呢?"
|
||||
else:
|
||||
question = "您对家具的风格、类型或颜色有什么特别的偏好吗?可以多说一些~"
|
||||
|
||||
updated_messages = messages + [AIMessage(content=question)]
|
||||
|
||||
approved = interrupt({
|
||||
**updates,
|
||||
"messages": updated_messages,
|
||||
"persona_complete": False,
|
||||
"__interrupt__": {
|
||||
"type": "persona_question",
|
||||
"question": question,
|
||||
"node": "Persona",
|
||||
"wait_for": "human_response",
|
||||
# 可选:当前画像快照,便于前端显示或调试
|
||||
"current_persona_snapshot": updated_persona
|
||||
}
|
||||
})
|
||||
return approved
|
||||
@@ -1,88 +0,0 @@
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, AIMessageChunk, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.server.agent.agents.init_llm import research_agent
|
||||
from src.server.agent.memory.memory_manager import MemoryManager
|
||||
from src.server.agent.state import AgentState
|
||||
|
||||
|
||||
async def researcher_node(state: AgentState, config: RunnableConfig) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
use_report = config["configurable"].get("use_report", False)
|
||||
if not use_report:
|
||||
yield {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="深度报告功能未启用,请通过前端按钮触发。",
|
||||
name="Researcher"
|
||||
)
|
||||
],
|
||||
"next": "Supervisor"
|
||||
}
|
||||
return
|
||||
|
||||
# messages = state["messages"]
|
||||
messages = MemoryManager.build_llm_context(state)
|
||||
safe_messages = [
|
||||
m for m in messages
|
||||
if isinstance(m, (HumanMessage, AIMessage))
|
||||
]
|
||||
last_human = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
|
||||
|
||||
if not last_human:
|
||||
yield {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="深度研究节点:未找到有效的用户问题",
|
||||
name="Researcher"
|
||||
)
|
||||
],
|
||||
"next": "Supervisor"
|
||||
}
|
||||
return
|
||||
current_step = "正在启动深度报告生成..."
|
||||
yield {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="正在启动深度报告生成...",
|
||||
name="Researcher",
|
||||
additional_kwargs={
|
||||
"current_step": current_step,
|
||||
"streaming": True
|
||||
}
|
||||
)
|
||||
]
|
||||
}
|
||||
async for chunk in research_agent.astream(
|
||||
{"messages": safe_messages[-12:]},
|
||||
config=config
|
||||
):
|
||||
|
||||
if "messages" not in chunk:
|
||||
continue
|
||||
|
||||
msgs = chunk["messages"]
|
||||
|
||||
if not isinstance(msgs, list):
|
||||
continue
|
||||
|
||||
for m in msgs:
|
||||
|
||||
# 1️⃣ token streaming
|
||||
if isinstance(m, AIMessageChunk):
|
||||
yield {"messages": [m]}
|
||||
|
||||
# 2️⃣ tool result → 只 stream
|
||||
elif isinstance(m, ToolMessage):
|
||||
yield {
|
||||
"custom": {
|
||||
"type": "tool_result",
|
||||
"tool_name": m.name,
|
||||
"content": m.content
|
||||
}
|
||||
}
|
||||
|
||||
# 3️⃣ 最终 AI message → 写入 state
|
||||
elif isinstance(m, AIMessage):
|
||||
yield {"messages": [m]}
|
||||
@@ -1,29 +0,0 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from src.server.agent.agents.init_llm import get_model
|
||||
from src.server.agent.memory.memory_manager import MemoryManager
|
||||
|
||||
from src.server.agent.state import AgentState
|
||||
from src.server.utils.generate_suggestion import generate_chat_suggestions
|
||||
|
||||
|
||||
async def suggester_node(state: AgentState, config: RunnableConfig):
|
||||
"""专门生成追问建议的节点,作为流程终点"""
|
||||
model = get_model(config)
|
||||
# messages = state["messages"]
|
||||
messages = MemoryManager.build_llm_context(state)
|
||||
|
||||
|
||||
# 只需要分析最近的对话
|
||||
suggestions = await generate_chat_suggestions(messages, model)
|
||||
|
||||
# 返回一个特殊消息,前端通过解析 additional_kwargs 获取按钮内容
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"suggestions": suggestions},
|
||||
name="Suggester"
|
||||
)
|
||||
]
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from src.server.agent.agents.init_llm import summary_llm
|
||||
from src.server.agent.memory.memory_manager import MemoryManager
|
||||
from src.server.agent.state import AgentState
|
||||
|
||||
|
||||
async def summary_node(state: AgentState):
|
||||
messages = state["messages"]
|
||||
|
||||
if not MemoryManager.should_summarize(messages):
|
||||
return {}
|
||||
|
||||
# 只总结旧消息
|
||||
old_messages = messages[:-30]
|
||||
|
||||
text = "\n".join([m.content for m in old_messages if hasattr(m, "content")])
|
||||
|
||||
prompt = f"""
|
||||
Summarize the following conversation briefly.
|
||||
Focus on user goals, preferences and decisions.
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
summary = await summary_llm.ainvoke([HumanMessage(content=prompt)])
|
||||
|
||||
return {
|
||||
"conversation_summary": summary.content
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
import random
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.server.agent.agents.init_llm import llm_supervisor
|
||||
from src.server.agent.memory.memory_manager import MemoryManager
|
||||
from src.server.agent.state import AgentState
|
||||
|
||||
|
||||
class RouteResponse(BaseModel):
|
||||
next: Literal["Designer", "Persona", "Researcher", "Visualizer", "Suggester", "FINISH"]
|
||||
|
||||
|
||||
def supervisor_node(state: AgentState, config: RunnableConfig):
|
||||
if state.get("__end__", False):
|
||||
return {"next": "FINISH"}
|
||||
|
||||
# 如果最后一条是 terminate 的结果,也结束
|
||||
last_msg = state["messages"][-1] if state["messages"] else None
|
||||
if isinstance(last_msg, ToolMessage) and last_msg.name == "terminate":
|
||||
return {"next": "FINISH"}
|
||||
|
||||
configurable = config["configurable"]
|
||||
use_report = configurable.get("use_report", False)
|
||||
suggest_frequency = configurable.get("require_suggestion", 0.6)
|
||||
|
||||
messages = MemoryManager.build_llm_context(state)
|
||||
# 读取关键状态(必须有!)
|
||||
persona_complete = state.get("persona_complete", False)
|
||||
print(f"persona_complete : {persona_complete}")
|
||||
|
||||
# 第一次对话或无消息 → 强制去 Persona 收集画像
|
||||
if not messages:
|
||||
return {"next": "Persona"}
|
||||
|
||||
# ── system prompt ── 加强语气 + 明确优先级
|
||||
system_prompt = f"""
|
||||
你是家具设计主管,**必须严格按以下优先级顺序**决定下一个节点,**不允许有任何例外**。
|
||||
|
||||
当前状态(必须严格遵守):
|
||||
- persona_complete: {state.get("persona_complete", False)}
|
||||
- 是否需要市场研究报告 (use_report): {'是' if use_report else '否'}
|
||||
- 用户最新消息:请仔细阅读
|
||||
|
||||
绝对优先级规则(从高到低,必须逐条检查):
|
||||
1. 如果 persona_complete == False,**必须** 且 **只能** 选择 "Persona",其他节点一律不允许
|
||||
2. 当 persona_complete == True 时,且用户确认需要报告(或未明确拒绝),优先选择 Researcher。报告的重点是:当前风格、家具类型、颜色在市场上的流行趋势、材质搭配建议、设计师案例
|
||||
3. 如果 use_report == False,**绝对禁止** 选择 "Researcher"
|
||||
4. 其他常见情况:
|
||||
- 纯风格、尺寸、材质、功能、灵感讨论 → "Designer"
|
||||
- 用户说想看图、效果图、草图 → "Visualizer"
|
||||
- 对话自然结束或用户满意 → "FINISH"
|
||||
- 需要给用户选项/建议按钮 → "Suggester"
|
||||
|
||||
当前用户需求是否明显需要市场报告?请严格判断,不要主观臆断。
|
||||
如果条件 2 满足,**必须** 选 Researcher,不要选 Designer。
|
||||
|
||||
输出时只选择一个 next,不要多选或发明节点。
|
||||
"""
|
||||
|
||||
chain = llm_supervisor.with_structured_output(RouteResponse)
|
||||
try:
|
||||
decision = chain.invoke([{"role": "system", "content": system_prompt}] + messages)
|
||||
next_node = decision.next
|
||||
except Exception as e:
|
||||
# LLM 输出格式错误时的兜底
|
||||
print(f"Supervisor LLM 决策失败: {e}")
|
||||
next_node = "Persona" if not persona_complete else "Suggester"
|
||||
|
||||
# 强制安全阀:双重检查,防止 LLM 违规
|
||||
if not persona_complete:
|
||||
if next_node != "Persona":
|
||||
print(f"强制修正:persona 不完整,LLM 选了 {next_node},改为 Persona")
|
||||
next_node = "Persona"
|
||||
|
||||
elif next_node == "Researcher" and (not use_report or not persona_complete):
|
||||
print(f"警告:非法选择 Researcher,已强制改为 Suggester")
|
||||
next_node = "Suggester"
|
||||
|
||||
# FINISH 时随机插入 Suggester(保持原逻辑)
|
||||
if next_node == "FINISH" and suggest_frequency > 0 and random.random() < suggest_frequency:
|
||||
next_node = "Suggester"
|
||||
|
||||
return {"next": next_node}
|
||||
@@ -1,33 +0,0 @@
|
||||
# --- 3. Visualizer Agent (视觉专家) ---
|
||||
from langchain_core.messages import SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.server.agent.agents.init_llm import get_model
|
||||
from src.server.agent.memory.memory_manager import MemoryManager
|
||||
from src.server.agent.prompt import visualizer_prompt
|
||||
from src.server.agent.state import AgentState
|
||||
from src.server.agent.tools.generate_furniture_sketch import generate_furniture
|
||||
|
||||
|
||||
async def visualizer_node(state: AgentState, config: RunnableConfig):
|
||||
"""负责将自然语言转化为绘图 Prompt 并调用绘图工具"""
|
||||
model = get_model(config, streaming=False)
|
||||
tools = [generate_furniture]
|
||||
llm_with_tools = model.bind_tools(tools)
|
||||
# messages = state["messages"]
|
||||
messages = MemoryManager.build_llm_context(state)
|
||||
system_prompt = SystemMessage(content=visualizer_prompt)
|
||||
response = await llm_with_tools.ainvoke([system_prompt] + messages)
|
||||
|
||||
if response.tool_calls:
|
||||
tool_call = response.tool_calls[0]
|
||||
if tool_call["name"] == "generate_furniture":
|
||||
img_url = await generate_furniture.ainvoke(tool_call["args"])
|
||||
return {
|
||||
"messages": [
|
||||
response,
|
||||
ToolMessage(content=img_url, tool_call_id=tool_call["id"]) # 标记这是一个图片结果
|
||||
]
|
||||
}
|
||||
|
||||
return {"messages": [response]}
|
||||
@@ -1,56 +1,109 @@
|
||||
import random
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_qwq import ChatQwen
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
from langgraph.graph import StateGraph, END, START
|
||||
from pydantic import BaseModel
|
||||
from pymongo import MongoClient
|
||||
|
||||
from src.core.config import MONGO_URI, settings
|
||||
from src.server.agent.agents.designer import designer_node
|
||||
from src.server.agent.agents.persona import persona_node
|
||||
from src.server.agent.agents.researcher import researcher_node
|
||||
from src.server.agent.agents.suggester import suggester_node
|
||||
from src.server.agent.agents.summary import summary_node
|
||||
from src.server.agent.agents.supervisor import supervisor_node
|
||||
from src.server.agent.agents.visualizer import visualizer_node
|
||||
from src.server.agent.state import AgentState
|
||||
from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node
|
||||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||
|
||||
|
||||
# --- Supervisor (路由逻辑) ---
|
||||
# 定义路由的输出结构,强制 LLM 选择一个
|
||||
class RouteResponse(BaseModel):
|
||||
# 将 FINISH 替换或增加 Suggester
|
||||
next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"]
|
||||
|
||||
|
||||
llm_supervisor = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
|
||||
|
||||
def supervisor_node(state: AgentState, config: RunnableConfig):
|
||||
configurable = config["configurable"]
|
||||
use_report = configurable.get("use_report", False)
|
||||
suggest_frequency = configurable.get("require_suggestion", 0.6) # 0.0~1.0
|
||||
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return {"next": "Suggester"}
|
||||
|
||||
# ── system prompt 保持不变 ──
|
||||
system_prompt = f"""你是家具设计主管,负责分配任务。
|
||||
当前设定:
|
||||
- 是否需要市场研究报告:{'是' if use_report else '否'}
|
||||
|
||||
严格遵守以下规则:
|
||||
- 如果 **不需要** 市场研究报告(use_report = False),**绝对不能** 选择 Researcher
|
||||
- 只有在 **明确需要** 市场报告、竞争分析、材质趋势、价格区间等外部资讯时,才选择 Researcher,且 **必须** use_report = True
|
||||
- 常见分配:
|
||||
- 纯设计、风格、尺寸、材质建议 → Designer
|
||||
- 需要生成图片、渲染 → Visualizer
|
||||
- 需要产生建议按钮 → Suggester
|
||||
- 需要市场报告 → Researcher(但只有 use_report=True 时才允许)
|
||||
- 对话已完整、无需继续 → FINISH
|
||||
|
||||
用户最后说了什么?请根据实际需求决定下一步。
|
||||
"""
|
||||
|
||||
chain = llm_supervisor.with_structured_output(RouteResponse)
|
||||
decision = chain.invoke([{"role": "system", "content": system_prompt}] + messages)
|
||||
next_node = decision.next # 防空默认 FINISH
|
||||
|
||||
# 安全阀:禁止非法选择 Researcher
|
||||
if next_node == "Researcher" and not use_report:
|
||||
print("警告:LLM 违规选择了 Researcher,已强制改为 Suggester 或 FINISH")
|
||||
next_node = "Suggester" if state.get("require_suggestion", False) else "FINISH"
|
||||
|
||||
# 核心改动:只有 LLM 决定 FINISH 时,才掷骰子看是否插入 Suggester
|
||||
if next_node == "FINISH":
|
||||
# 满足概率条件 → 插入 Suggester
|
||||
if suggest_frequency > 0 and random.random() < suggest_frequency:
|
||||
next_node = "Suggester"
|
||||
|
||||
return {"next": next_node}
|
||||
|
||||
|
||||
# --- 构建 Graph ---
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
workflow.add_node("Supervisor", supervisor_node)
|
||||
workflow.add_node("Designer", designer_node)
|
||||
workflow.add_node("Persona", persona_node)
|
||||
workflow.add_node("Researcher", researcher_node)
|
||||
workflow.add_node("Visualizer", visualizer_node)
|
||||
workflow.add_node("Suggester", suggester_node)
|
||||
workflow.add_node("Summary", summary_node)
|
||||
|
||||
# workflow.add_edge(START, "Supervisor")
|
||||
workflow.add_edge(START, "Summary")
|
||||
workflow.add_edge("Summary", "Supervisor")
|
||||
workflow.add_edge("Designer", "Supervisor")
|
||||
workflow.add_edge("Persona", "Supervisor")
|
||||
workflow.add_edge("Researcher", "Supervisor")
|
||||
workflow.add_edge("Visualizer", "Supervisor")
|
||||
# 重点:Suggester 可以是整个流程的终点
|
||||
workflow.add_edge("Suggester", END)
|
||||
|
||||
workflow.add_edge(START, "Supervisor")
|
||||
|
||||
# 修改条件边映射
|
||||
workflow.add_conditional_edges(
|
||||
"Supervisor",
|
||||
lambda state: "FINISH" if state.get("__end__", False) or state["next"] == "FINISH" else state["next"],
|
||||
lambda state: state["next"],
|
||||
{
|
||||
"Designer": "Designer",
|
||||
"Persona": "Persona",
|
||||
"Researcher": "Researcher",
|
||||
"Visualizer": "Visualizer",
|
||||
"Suggester": "Suggester",
|
||||
"FINISH": END
|
||||
"Suggester": "Suggester", # 原本的 FINISH 现在指向 Suggester
|
||||
"FINISH": END # 直接结束,不给建议
|
||||
}
|
||||
)
|
||||
|
||||
# 专家执行完依然回到 Supervisor
|
||||
workflow.add_edge("Designer", "Supervisor")
|
||||
workflow.add_edge("Researcher", "Supervisor")
|
||||
workflow.add_edge("Visualizer", "Supervisor")
|
||||
# 重点:Suggester 可以是整个流程的终点
|
||||
workflow.add_edge("Suggester", END)
|
||||
|
||||
client = MongoClient(MONGO_URI)
|
||||
checkpointer = MongoDBSaver(
|
||||
client=client["furniture_agent_db"],
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import List
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
HumanMessage,
|
||||
AIMessage
|
||||
)
|
||||
|
||||
MAX_RECENT_MESSAGES = 25
|
||||
SUMMARY_TRIGGER = 120
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
|
||||
@staticmethod
|
||||
def split_messages(messages: List[BaseMessage]):
|
||||
"""
|
||||
分离 system / conversation
|
||||
"""
|
||||
system_msgs = []
|
||||
conversation_msgs = []
|
||||
|
||||
for m in messages:
|
||||
if isinstance(m, SystemMessage):
|
||||
system_msgs.append(m)
|
||||
else:
|
||||
conversation_msgs.append(m)
|
||||
|
||||
return system_msgs, conversation_msgs
|
||||
|
||||
@staticmethod
|
||||
def build_llm_context(state) -> List[BaseMessage]:
|
||||
"""
|
||||
构建发送给 LLM 的 context
|
||||
"""
|
||||
messages = state["messages"]
|
||||
|
||||
system_msgs, conversation_msgs = MemoryManager.split_messages(messages)
|
||||
|
||||
summary = state.get("conversation_summary")
|
||||
|
||||
recent_msgs = conversation_msgs[-MAX_RECENT_MESSAGES:]
|
||||
|
||||
context = []
|
||||
|
||||
context.extend(system_msgs)
|
||||
|
||||
if summary:
|
||||
context.append(SystemMessage(content=f"Conversation Summary:\n{summary}"))
|
||||
|
||||
context.extend(recent_msgs)
|
||||
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def should_summarize(messages: List[BaseMessage]) -> bool:
|
||||
"""
|
||||
判断是否需要 summary
|
||||
"""
|
||||
return len(messages) > SUMMARY_TRIGGER
|
||||
@@ -1,57 +1,78 @@
|
||||
SYSTEM_PROMPT = """
|
||||
你现在是 "TrendAgent",一个极度专注、高效的家具设计趋势分析代理。
|
||||
你的**唯一目标**:针对用户一次请求,产出一份高质量的 Markdown 趋势报告。
|
||||
你**只能**按照下面严格的执行流程工作,不允许有任何偏差。
|
||||
You are "TrendAgent" - a focused, efficient design trend analysis agent.
|
||||
Your ONLY goal: produce one high-quality Markdown trend report per user request.
|
||||
|
||||
核心铁律(违反即失效,必须严格执行):
|
||||
- 永远保持自然、亲切、像资深家具设计师的对话语气
|
||||
- 画像收集阶段**只靠自然对话**,**绝不调用任何工具**
|
||||
- 一旦你判断用户画像已足够(风格、家具类型、颜色、材质至少 3 项明确),**立即停止所有追问**,说一句类似“好的,我已经掌握您的核心需求,现在开始规划并生成报告~”,然后**必须立刻进入规划阶段**
|
||||
- **绝对禁止** 在画像完整后继续问任何问题(包括人群、预算、场景、尺寸、材质细节等)
|
||||
- **绝对禁止** 使用任何流程化、机械、状态相关的词语,如 STATUS、Phase、请先完成、按照流程、现在进入下一步等
|
||||
TOOL ORDER & DISCIPLINE IS MANDATORY - DO NOT INVENT STEPS
|
||||
|
||||
画像完整判断铁律(内心执行,永不告诉用户):
|
||||
- 必须同时满足以下至少 3 项:
|
||||
- 风格偏好(北欧、极简、现代、日式等)
|
||||
- 家具类型(沙发、餐桌、床、书柜、灯具等)
|
||||
- 颜色偏好(白色、原木色、深灰、莫兰迪色等)
|
||||
- 材质偏好(棉麻、实木、皮革、金属等)
|
||||
- 一旦达到 3 项,**立即**视为完整,**禁止**再问任何问题
|
||||
┌───────────────────────────────────────────────────────┐
|
||||
│ Phase 0 - Context & Persona (必须先完成) │
|
||||
└───────────────────────────────────────────────────────┘
|
||||
|
||||
执行流程(**必须严格按此顺序,不可跳跃、不可重复、不可插入任何额外对话**):
|
||||
Rules for Phase 0:
|
||||
1. ALWAYS start with manage_user_persona(command="get")
|
||||
2. If STATUS == "INCOMPLETE" or persona missing critical fields (Design Type, Style, Target Audience, Color Preference, etc.):
|
||||
→ MUST call manage_user_persona(command="ask") to collect missing info
|
||||
→ After user answers → call manage_user_persona(command="set", ...)
|
||||
→ Loop until STATUS == "READY"
|
||||
3. Only when STATUS == "READY" → proceed to Phase 1
|
||||
4. Never assume or fabricate persona details
|
||||
|
||||
1. 画像收集阶段:
|
||||
- 只用自然对话补全信息
|
||||
- 如果不足 3 项核心信息,顺势自然问 1-2 个最关键的问题
|
||||
- 一旦足够,立即说“好的,我已经掌握核心需求,现在开始规划并生成报告~”,然后**必须立刻**进入第 2 步
|
||||
┌───────────────────────────────────────────────────────┐
|
||||
│ Phase 1 - Planning (必须执行一次且只能一次) │
|
||||
└───────────────────────────────────────────────────────┘
|
||||
|
||||
2. 规划阶段(**必须且只能执行一次**):
|
||||
- **立即调用 get_user_persona 工具**,获取最新画像 JSON
|
||||
- 根据返回的 persona 字段构造关键词(例如:"2025-2026 北欧 沙发 白色 棉麻 趋势")
|
||||
- **立即调用 write_todos 工具一次**,生成严格的 ToDo 列表,内容**必须且只能**是以下顺序:
|
||||
1. topic_research:搜索上面构造的关键词,返回 3-5 个高质量网址
|
||||
2. crawl4ai_batch:批量爬取上面选定的网址
|
||||
3. structured_retrieval:对爬取内容进行结构化提取(重点:设计趋势、材质创新、颜色应用、代表案例、品牌参考)
|
||||
4. report_generator:基于提取内容生成完整 Markdown 报告
|
||||
5. terminate
|
||||
- **严禁** 添加任何其他步骤、询问用户、生成中间总结或额外对话
|
||||
When persona READY and user gave a clear trend request:
|
||||
1. Call write_todos EXACTLY ONCE with a strict plan containing:
|
||||
- 3–6 concrete steps (numbered)
|
||||
- Which URLs/topics to research
|
||||
- Expected output of each major tool
|
||||
- Final deliverable: one Markdown report
|
||||
2. After receiving todos, you MUST follow this exact sequence unless impossible
|
||||
3. Do NOT call any other tool until write_todos is done
|
||||
|
||||
3. 执行阶段:
|
||||
- **严格按 write_todos 返回的顺序逐一调用工具**
|
||||
- **不允许** 跳过任何步骤、重复调用、插入对话
|
||||
┌───────────────────────────────────────────────────────┐
|
||||
│ Phase 2 - Research & Collection │
|
||||
└───────────────────────────────────────────────────────┘
|
||||
|
||||
4. 报告生成后:
|
||||
- **直接调用 terminate** 结束流程
|
||||
Follow todos order:
|
||||
- Use topic_research → get 3–8 high-quality URLs (add persona [Style] [Type] in query)
|
||||
- Select best 3–6 URLs → call crawl4ai_batch ONCE with list
|
||||
- Get file paths → call structured_retrieval ONCE with file_paths list
|
||||
|
||||
报告要求(必须遵守):
|
||||
- 每部分先写 **Conclusion First** 的核心洞察
|
||||
- 在合适位置插入 [IMAGE_REF_xx] 占位符
|
||||
- 所有内容基于真实检索内容,**绝不虚构**
|
||||
┌───────────────────────────────────────────────────────┐
|
||||
│ Phase 3 - Synthesis & Delivery │
|
||||
└───────────────────────────────────────────────────────┘
|
||||
|
||||
现在开始:
|
||||
- 用自然亲切的语气直接回应用户消息
|
||||
- 如果画像不足 3 项核心信息,顺势自然问最关键的问题
|
||||
- 一旦足够,**立即**说一句过渡语,然后**必须**调用 write_todos
|
||||
After structured_retrieval summary received:
|
||||
- If extracted item count ≥ 8–12 AND covers main aspects in todos → ready to report
|
||||
- Call report_generator ONCE (it reads local JSON/DB)
|
||||
- After report_generator success → call terminate
|
||||
- If data obviously insufficient → call topic_research again (max 1 extra round)
|
||||
|
||||
┌───────────────────────────────────────────────────────┐
|
||||
│ HARD RULES - MUST OBEY │
|
||||
└───────────────────────────────────────────────────────┘
|
||||
|
||||
• Never load full JSON/markdown into context - trust local storage
|
||||
• Batch everything possible (crawl4ai_batch + structured_retrieval)
|
||||
• Call tools in PHASE ORDER - no jumping, no repetition
|
||||
• After report_generator → next action MUST be terminate
|
||||
• If stuck > 4 steps without progress → call terminate with note "Incomplete - insufficient data"
|
||||
• Never hallucinate trend data - base everything on retrieved content
|
||||
• Report must start each section with **Conclusion First** insight
|
||||
• Include [IMAGE_REF_xx] placeholders where visuals were extracted
|
||||
|
||||
Current status: Phase 0
|
||||
"""
|
||||
|
||||
designer_prompt = """
|
||||
你是家具设计团队的主管(Supervisor)。
|
||||
请根据用户的意图,选择最合适的专家:
|
||||
- Designer: 设计建议、参数细化、闲聊、问候。
|
||||
- Visualizer: 绘图、看草图。
|
||||
- Researcher: 市场报告、趋势。
|
||||
|
||||
只需输出专家名称。
|
||||
"""
|
||||
|
||||
visualizer_prompt = """
|
||||
@@ -83,7 +104,6 @@ Prompt 生成要求(仅供内部参考,必须全部做到):
|
||||
- 最后立即调用工具:generate_furniture,参数 prompt = 你刚才输出的完整内容
|
||||
- 不要做其他任何说明或聊天
|
||||
"""
|
||||
|
||||
designer_prompt = """
|
||||
你是一位资深的家具设计师,经验丰富、审美一流、沟通温暖且高效。
|
||||
你的核心目标:快速理解用户想法,并用最合适的方式推进设计。
|
||||
|
||||
@@ -5,7 +5,10 @@ from src.server.agent.graph import app
|
||||
|
||||
|
||||
async def async_main():
|
||||
config = {"configurable": {"thread_id": "project_alpha12345", "use_report": True}}
|
||||
config = {"configurable": {"thread_id": "project_alpha"}}
|
||||
|
||||
print("測試模式已啟動 (輸入 'exit' 離開,'history' 查看歷史並回溯)")
|
||||
use_report = input("是否启用深度报告?(y/n): ").lower() == 'y'
|
||||
while True:
|
||||
user_input = input("\n👤 輸入訊息: ").strip()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import operator
|
||||
from typing import Annotated, Sequence, TypedDict, Union, Optional, Dict, Any
|
||||
from typing import Annotated, Sequence, TypedDict, Union, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
@@ -8,8 +8,4 @@ class AgentState(TypedDict):
|
||||
messages: Annotated[Sequence[BaseMessage], operator.add]
|
||||
# next 存储 Supervisor 决定的下一步是谁
|
||||
next: str
|
||||
persona: Dict[str, Any] # 存储提取出的结构化画像,例如 {"风格偏好": "北欧", "预算": "8000-12000", ...}
|
||||
persona_summary: str # 可选:LLM 对当前画像的自然语言总结,便于 prompt 使用
|
||||
persona_complete: bool # Supervisor 用这个判断是否能去 Researcher
|
||||
require_suggestion: bool # 是否需要建议按钮
|
||||
__end__: bool # 新增这个字段,默认 False
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""
|
||||
DeepAgents Tool 返回标准结构
|
||||
"""
|
||||
|
||||
# 返回给 LLM / 用户的内容
|
||||
content: Optional[str] = None
|
||||
|
||||
# 是否出错
|
||||
success: bool = True
|
||||
|
||||
# 工具元信息(推荐放 tool_name / path / cost 等)
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 是否终止 agent
|
||||
terminate: bool = False
|
||||
@@ -1,189 +1,118 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
|
||||
import uuid
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
||||
from langchain_core.tools import tool
|
||||
|
||||
# ─────────────────────────────────────
|
||||
# 路径配置
|
||||
# ─────────────────────────────────────
|
||||
|
||||
# ─────────────── 重要:計算路徑 ───────────────
|
||||
# 目前這個檔案 (crawl4ai_batch.py) 所在的目錄
|
||||
TOOL_DIR = Path(__file__).resolve().parent
|
||||
|
||||
# 專案根目錄(假設 tools 資料夾與主程式同級)
|
||||
PROJECT_ROOT = TOOL_DIR.parent
|
||||
|
||||
# DeepAgents 推荐目录
|
||||
SAVE_DIR = PROJECT_ROOT / "agent_workspace" / "raw_data"
|
||||
# 儲存爬取結果的目錄(你可以自由決定放在哪裡)
|
||||
# 建議選項 A:放在專案根目錄下的 workspace/raw_data
|
||||
SAVE_DIR = PROJECT_ROOT / "workspace" / "raw_data"
|
||||
|
||||
# 建議選項 B:如果你打算讓 deep agent 直接讀取,建議放在 agent_workspace 底下
|
||||
# SAVE_DIR = PROJECT_ROOT / "agent_workspace" / "raw_data"
|
||||
|
||||
# 確保目錄存在
|
||||
SAVE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ─────────────────────────────────────
|
||||
# Browser 配置
|
||||
# ─────────────────────────────────────
|
||||
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
java_script_enabled=True,
|
||||
user_agent=(
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/118.0 Safari/537.36"
|
||||
),
|
||||
)
|
||||
|
||||
run_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
word_count_threshold=5,
|
||||
excluded_tags=["script", "style", "nav", "footer"],
|
||||
remove_overlay_elements=True,
|
||||
process_iframes=True,
|
||||
)
|
||||
# ────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ─────────────────────────────────────
|
||||
# URL → 文件名
|
||||
# ─────────────────────────────────────
|
||||
|
||||
def build_filename(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
|
||||
domain = parsed.netloc.replace("www.", "").replace(".", "_")
|
||||
path_part = parsed.path.strip("/").replace("/", "_")[:50] or "index"
|
||||
|
||||
ts = int(time.time())
|
||||
rand = uuid.uuid4().hex[:6]
|
||||
|
||||
return f"{ts}_{rand}_{domain}_{path_part}.md"
|
||||
|
||||
|
||||
# ─────────────────────────────────────
|
||||
# 单个 URL 抓取
|
||||
# ─────────────────────────────────────
|
||||
|
||||
async def crawl_one(crawler, url: str, sem: asyncio.Semaphore) -> Dict[str, Any]:
|
||||
async with sem:
|
||||
try:
|
||||
result = await crawler.arun(url=url, config=run_config)
|
||||
|
||||
if not result.success:
|
||||
return {
|
||||
"url": url,
|
||||
"success": False,
|
||||
"error": f"status={getattr(result, 'status_code', 'unknown')}"
|
||||
}
|
||||
|
||||
markdown = result.markdown or ""
|
||||
|
||||
if len(markdown) < 500:
|
||||
return {
|
||||
"url": url,
|
||||
"success": False,
|
||||
"error": "content too short"
|
||||
}
|
||||
|
||||
filename = build_filename(url)
|
||||
filepath = SAVE_DIR / filename
|
||||
|
||||
header = (
|
||||
f"<!-- Source: {url} -->\n"
|
||||
f"<!-- Saved: {time.strftime('%Y-%m-%d %H:%M:%S')} -->\n\n"
|
||||
)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write(header + markdown)
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"success": True,
|
||||
"file": str(filepath)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"url": url,
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# ─────────────────────────────────────
|
||||
# Async 主逻辑
|
||||
# ─────────────────────────────────────
|
||||
|
||||
async def _crawl4ai_batch(urls: List[str]) -> Dict[str, Any]:
|
||||
urls = list(set(urls)) # 去重
|
||||
|
||||
if not urls:
|
||||
return {"error": "no urls"}
|
||||
|
||||
sem = asyncio.Semaphore(5) # 并发限制
|
||||
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
|
||||
tasks = [
|
||||
crawl_one(crawler, url, sem)
|
||||
for url in urls
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
success_files = []
|
||||
summary = []
|
||||
|
||||
for r in results:
|
||||
|
||||
if r["success"]:
|
||||
success_files.append(r["file"])
|
||||
summary.append(f"✅ {r['url']}")
|
||||
else:
|
||||
summary.append(f"❌ {r['url']} ({r['error']})")
|
||||
|
||||
return {
|
||||
"saved_files": success_files,
|
||||
"count": len(success_files),
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
|
||||
# ─────────────────────────────────────
|
||||
# Tool(同步)
|
||||
# ─────────────────────────────────────
|
||||
@tool
|
||||
def crawl4ai_batch(urls: List[str]) -> str:
|
||||
async def crawl4ai_batch(urls: List[str]) -> str:
|
||||
"""
|
||||
Batch crawl webpages and save their content as markdown files.
|
||||
|
||||
Args:
|
||||
urls: List of webpage URLs to crawl.
|
||||
|
||||
Returns:
|
||||
A summary of crawling results and saved file paths.
|
||||
高性能网页爬虫,支持并行处理多个 URL。
|
||||
爬取后的 Markdown 内容将保存到本地 workspace/raw_data 目录中。
|
||||
返回执行结果摘要和保存的文件路径列表。
|
||||
"""
|
||||
if not urls:
|
||||
return "❌ 错误: 未提供任何 URL。"
|
||||
|
||||
# print(f"🕷️ 正在并行爬取 {len(urls)} 个 URL...")
|
||||
# print(f"儲存目錄: {SAVE_DIR}")
|
||||
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
java_script_enabled=True,
|
||||
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/118.0.5993.118 Safari/537.36",
|
||||
proxy=None, # 可选,如果需要代理填 "http://user:pass@ip:port"
|
||||
)
|
||||
|
||||
run_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
word_count_threshold=5,
|
||||
excluded_tags=["script", "style", "nav", "footer"],
|
||||
remove_overlay_elements=True,
|
||||
process_iframes=True,
|
||||
)
|
||||
|
||||
results_summary = []
|
||||
saved_files = []
|
||||
|
||||
try:
|
||||
result = asyncio.run(_crawl4ai_batch(urls))
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
tasks = [crawler.arun(url=url, config=run_config) for url in urls]
|
||||
crawl_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
if "error" in result:
|
||||
return f"❌ Error: {result['error']}"
|
||||
for i, result in enumerate(crawl_results):
|
||||
url = urls[i]
|
||||
|
||||
output = [
|
||||
"### 批量抓取完成 ###",
|
||||
f"成功保存文件: {result['count']}",
|
||||
f"保存目录: {SAVE_DIR}",
|
||||
"",
|
||||
"抓取详情:"
|
||||
]
|
||||
if isinstance(result, Exception):
|
||||
results_summary.append(f"❌ 抓取失败 {url}: {str(result)}")
|
||||
continue
|
||||
|
||||
output.extend(result["summary"])
|
||||
if result.success:
|
||||
markdown_content = result.markdown or ""
|
||||
|
||||
if result["saved_files"]:
|
||||
output.append("\n可读取文件:")
|
||||
output.extend(result["saved_files"])
|
||||
if len(markdown_content) < 500:
|
||||
results_summary.append(f"⏩ 跳过 {url} (内容过短)")
|
||||
continue
|
||||
|
||||
return "\n".join(output)
|
||||
# 生成檔名
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc.replace("www.", "").replace(".", "_")
|
||||
path_part = parsed.path.strip("/").replace("/", "_")[:50] or "index"
|
||||
filename = f"{int(time.time())}_{domain}_{path_part}.md"
|
||||
|
||||
# 完整檔案路徑
|
||||
filepath = SAVE_DIR / filename
|
||||
|
||||
# 寫入檔案
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
header = f"<!-- Source: {url} -->\n<!-- Saved: {time.strftime('%Y-%m-%d %H:%M:%S')} -->\n\n"
|
||||
f.write(header + markdown_content)
|
||||
|
||||
saved_files.append(str(filepath)) # 建議轉成字串
|
||||
results_summary.append(f"✅ 成功: {url} → {filepath}")
|
||||
|
||||
else:
|
||||
status = getattr(result, 'status_code', '未知错误')
|
||||
results_summary.append(f"❌ 失败: {url} (状态码: {status})")
|
||||
|
||||
except Exception as e:
|
||||
return f"🚨 爬虫系统异常: {str(e)}"
|
||||
return f"🚨 爬虫系统崩溃: {str(e)}"
|
||||
|
||||
# 回傳給 agent 的結果
|
||||
final_output = (
|
||||
f"### 批量抓取完成 ###\n"
|
||||
f"已成功保存 {len(saved_files)} 个文件。\n"
|
||||
f"儲存目錄: {SAVE_DIR}\n"
|
||||
f"详情:\n" + "\n".join(results_summary)
|
||||
)
|
||||
|
||||
if saved_files:
|
||||
final_output += "\n\n已保存的文件列表(可供後續讀取):\n" + "\n".join(saved_files)
|
||||
|
||||
return final_output
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from google.oauth2 import service_account
|
||||
from langchain_core.tools import tool
|
||||
@@ -11,7 +9,6 @@ from minio import Minio
|
||||
from src.core.config import settings
|
||||
from src.server.utils.new_oss_client import oss_upload_image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# 初始化全局凭证和客户端
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
||||
@@ -65,30 +62,8 @@ def generate_furniture(prompt: str) -> str:
|
||||
# 4. 构造访问链接 (如果是私有 bucket,需使用 presigned_get_object)
|
||||
# 这里简单示例为直接访问地址
|
||||
image_url = f"{bucket}/{object_name}"
|
||||
return json.dumps(
|
||||
{
|
||||
"tool_name": "generate_furniture",
|
||||
"data": image_url,
|
||||
"tool_status": "success"
|
||||
},
|
||||
ensure_ascii=False
|
||||
)
|
||||
return image_url
|
||||
else:
|
||||
return json.dumps(
|
||||
{
|
||||
"tool_name": "generate_furniture",
|
||||
"data": "图片生成成功,但上传至存储服务器失败。",
|
||||
"tool_status": "error"
|
||||
},
|
||||
ensure_ascii=False
|
||||
)
|
||||
return "图片生成成功,但上传至存储服务器失败。"
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return json.dumps(
|
||||
{
|
||||
"tool_name": "generate_furniture",
|
||||
"data": f"绘图流程异常",
|
||||
"tool_status": "error"
|
||||
},
|
||||
ensure_ascii=False
|
||||
)
|
||||
return f"绘图流程异常: {str(e)}"
|
||||
|
||||
@@ -105,7 +105,7 @@ Input Data:
|
||||
# =========================
|
||||
# 调用 LLM
|
||||
# =========================
|
||||
# writer({"type": "report_start", "topic": report_topic, "language": language})
|
||||
writer({"type": "report_start", "topic": report_topic, "language": language})
|
||||
|
||||
full_report = ""
|
||||
try:
|
||||
@@ -116,7 +116,7 @@ Input Data:
|
||||
if chunk.content: # Gemini 返回的 chunk.content
|
||||
delta = chunk.content
|
||||
full_report += delta
|
||||
# writer({"type": "report_delta", "delta": delta}) # ← 实时推送给前端
|
||||
writer({"type": "report_delta", "delta": delta}) # ← 实时推送给前端
|
||||
except Exception as e:
|
||||
error_msg = f"LLM generation failed: {str(e)}"
|
||||
writer({"type": "report_error", "message": error_msg})
|
||||
|
||||
@@ -12,7 +12,7 @@ TAVILY_API_KEY = settings.TAVILY_API_KEY
|
||||
|
||||
|
||||
@tool
|
||||
async def topic_research(topic: str, max_urls: int = 5) -> str:
|
||||
async def topic_research(topic: str, max_urls: int = 15) -> str:
|
||||
"""
|
||||
深度调研工具。该工具会利用 Tavily 搜索引擎针对特定主题进行多维度搜索。
|
||||
它会自动生成针对性的搜索词(包含年份和趋势),并返回去重后的高质量 URL 列表。
|
||||
|
||||
@@ -17,13 +17,22 @@ class TerminateInput(BaseModel):
|
||||
|
||||
|
||||
@tool(args_schema=TerminateInput)
|
||||
def terminate(status: str, reason: str = "") -> dict:
|
||||
def terminate(status: str, reason: str = "") -> str:
|
||||
"""
|
||||
终止本次互动。
|
||||
當任務完成、報告已生成,或無法繼續進行時,呼叫此工具來結束本次互動。
|
||||
|
||||
使用時機:
|
||||
- 已經成功產生最終報告(report_generator 已完成)
|
||||
- 遇到無法解決的錯誤或缺少關鍵資訊
|
||||
- 用戶需求已完全滿足
|
||||
|
||||
請在呼叫前確保所有必要步驟已完成,並在 reason 中簡單說明結束原因。
|
||||
"""
|
||||
return {
|
||||
"messages": [], # 清空追加消息
|
||||
"__end__": True, # 结束标记
|
||||
"status": status,
|
||||
"reason": reason
|
||||
}
|
||||
if status not in ("success", "failure"):
|
||||
status = "failure" # 防呆
|
||||
|
||||
msg = f"互動已終止,狀態:{status.upper()}"
|
||||
if reason:
|
||||
msg += f"\n原因:{reason}"
|
||||
|
||||
return msg
|
||||
|
||||
@@ -1,54 +1,96 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
import json
|
||||
import os
|
||||
from typing import List, Literal, Optional, Dict, Any
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from pymongo import MongoClient
|
||||
# 定义存储路径
|
||||
DB_PATH = os.path.join("workspace", "user_persona.json")
|
||||
|
||||
from src.core.config import MONGO_URI
|
||||
|
||||
client = MongoClient(MONGO_URI)
|
||||
db = client["furniture_agent_db"]
|
||||
persona_collection = db["user_persona"]
|
||||
def _load_store() -> Dict[str, Any]:
|
||||
"""从本地文件加载画像数据"""
|
||||
if os.path.exists(DB_PATH):
|
||||
try:
|
||||
with open(DB_PATH, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
def _save_store(data: Dict[str, Any]):
|
||||
"""将画像数据保存到本地文件"""
|
||||
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
||||
with open(DB_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@tool
|
||||
def get_user_persona(config: RunnableConfig) -> str:
|
||||
def manage_user_persona(
|
||||
command: Literal["set", "update", "get", "clear"],
|
||||
design_type: Optional[str] = None,
|
||||
style_preference: Optional[str] = None,
|
||||
budget_range: Optional[str] = None,
|
||||
color_palette: Optional[List[str]] = None,
|
||||
target_audience: Optional[str] = None,
|
||||
extra_requirements: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
获取当前对话线程的用户画像信息。
|
||||
|
||||
参数:
|
||||
- thread_id: 可选,当前线程ID。如果不传,默认使用当前会话的 thread_id
|
||||
|
||||
返回:JSON 字符串,包含以下字段:
|
||||
- persona: Dict,用户画像(风格偏好、家具类型、颜色偏好等)
|
||||
- persona_complete: bool,画像是否已足够完整
|
||||
- last_updated: str,最后更新时间
|
||||
用户画像与设计偏好管理工具。
|
||||
用于设定、更新、获取或重置用户的设计上下文(如风格、预算、颜色)。
|
||||
Agent 在开始调研前必须先调用 get 获取画像,若关键信息缺失需引导用户补充。
|
||||
"""
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
if thread_id is None:
|
||||
thread_id = "current_thread_id_placeholder"
|
||||
# 每次调用都重新读取,确保多进程或重启后数据一致
|
||||
store = _load_store()
|
||||
|
||||
doc = persona_collection.find_one(
|
||||
{"thread_id": thread_id},
|
||||
sort=[("_id", -1)] # 最新一条
|
||||
)
|
||||
if command == "clear":
|
||||
if os.path.exists(DB_PATH):
|
||||
os.remove(DB_PATH)
|
||||
return "✅ 用户个性化模板已从本地文件清空。"
|
||||
|
||||
if not doc or "persona" not in doc:
|
||||
return json.dumps({
|
||||
"persona": {},
|
||||
"persona_complete": False,
|
||||
"last_updated": None,
|
||||
"message": "当前线程暂无用户画像信息"
|
||||
}, ensure_ascii=False, indent=2)
|
||||
if command == "get":
|
||||
if not store:
|
||||
return "⚠️ [缺失信息] 当前尚未配置画像。请询问用户:设计类型(如沙发)、风格偏好(如极简)等。"
|
||||
|
||||
last_updated = doc.get("updated_at")
|
||||
if isinstance(last_updated, datetime):
|
||||
last_updated = last_updated.strftime('%Y-%m-%d %H:%M:%S')
|
||||
# 格式化输出供 Agent 阅读
|
||||
res = [
|
||||
"--- 👤 实时用户画像 (本地存储) ---",
|
||||
f"🎯 类型: {store.get('design_type', '未设定')}",
|
||||
f"🎨 风格: {store.get('style_preference', '未设定')}",
|
||||
f"💰 预算: {store.get('budget_range', '未设定')}",
|
||||
f"🌈 色系: {', '.join(store.get('color_palette', [])) or '未设定'}",
|
||||
f"👥 受众: {store.get('target_audience', '未设定')}",
|
||||
f"📝 需求: {store.get('extra_requirements', '未设定')}",
|
||||
"-----------------------"
|
||||
]
|
||||
|
||||
return json.dumps({
|
||||
"persona": doc["persona"],
|
||||
"persona_complete": doc.get("persona_complete", False),
|
||||
"last_updated": last_updated,
|
||||
}, ensure_ascii=False, indent=2)
|
||||
# 逻辑检查
|
||||
if not store.get('design_type') or not store.get('style_preference'):
|
||||
res.append("\n⚠️ 关键信息缺失,建议补充 '设计类型' 和 '风格偏好'。")
|
||||
return "\n".join(res)
|
||||
|
||||
if command in ["set", "update"]:
|
||||
if command == "set":
|
||||
store = {} # 重置内存中的字典
|
||||
|
||||
# 提取传入的非空参数
|
||||
update_data = {
|
||||
"design_type": design_type,
|
||||
"style_preference": style_preference,
|
||||
"budget_range": budget_range,
|
||||
"color_palette": color_palette,
|
||||
"target_audience": target_audience,
|
||||
"extra_requirements": extra_requirements
|
||||
}
|
||||
|
||||
# 更新有效字段
|
||||
for k, v in update_data.items():
|
||||
if v is not None:
|
||||
store[k] = v
|
||||
|
||||
# 保存到文件
|
||||
_save_store(store)
|
||||
|
||||
return f"✅ 本地画像已同步。当前配置:\n{json.dumps(store, ensure_ascii=False, indent=2)}"
|
||||
|
||||
return "❌ 错误:未知命令。"
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from langchain_qwq import ChatQwen
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
llm = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
enable_thinking=False,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
title_llm = ChatQwen(
|
||||
model="qwen-plus",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
streaming=False,
|
||||
temperature=0.1,
|
||||
top_p=0.8,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
@@ -11,8 +11,8 @@ from src.core.config import MONGO_URI
|
||||
from src.server.deep_agent.agents.painter import painter_subagent
|
||||
from src.server.deep_agent.agents.researcher import research_subagent
|
||||
from src.server.deep_agent.agents.user_profile import user_profile_subagent
|
||||
from src.server.deep_agent.init_llm import main_llm
|
||||
from src.server.deep_agent.init_prompt import build_system_prompt
|
||||
from src.server.deep_agent.tools.report_generator_tool import llm
|
||||
|
||||
TOOL_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_ROOT = TOOL_DIR.parent
|
||||
@@ -32,7 +32,7 @@ subagents = [
|
||||
|
||||
def build_main_agent(use_report):
|
||||
main_agent = create_deep_agent(
|
||||
model=llm,
|
||||
model=main_llm,
|
||||
system_prompt=build_system_prompt(use_report=use_report),
|
||||
subagents=subagents,
|
||||
checkpointer=checkpointer,
|
||||
@@ -42,7 +42,7 @@ def build_main_agent(use_report):
|
||||
),
|
||||
middleware=[
|
||||
SummarizationMiddleware(
|
||||
model=llm,
|
||||
model=main_llm,
|
||||
trigger=("tokens", 3000),
|
||||
keep=("messages", 100),
|
||||
),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from langchain.agents.middleware import wrap_tool_call
|
||||
|
||||
from src.server.deep_agent.agents.init_llm import llm
|
||||
from src.server.deep_agent.init_llm import llm
|
||||
from src.server.deep_agent.init_prompt import build_painter_prompt
|
||||
from src.server.deep_agent.tools.generate_furniture_sketch import generate_furniture
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.server.deep_agent.agents.init_llm import llm
|
||||
from src.server.deep_agent.init_llm import llm
|
||||
from src.server.deep_agent.init_prompt import build_researcher_prompt
|
||||
from src.server.deep_agent.tools.crawl_tool import crawl4ai_batch
|
||||
from src.server.deep_agent.tools.report_generator_tool import report_generator
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.server.deep_agent.agents.init_llm import llm
|
||||
from src.server.deep_agent.init_llm import llm
|
||||
from src.server.deep_agent.init_prompt import build_user_persona_prompt
|
||||
from src.server.deep_agent.tools.user_persona_tool import query_report_profile, update_report_profile, check_profile_complete
|
||||
|
||||
|
||||
51
src/server/deep_agent/init_llm.py
Normal file
51
src/server/deep_agent/init_llm.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from langchain_qwq import ChatQwen
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
llm = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
enable_thinking=False,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
title_llm = ChatQwen(
|
||||
model="qwen-plus",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
streaming=False,
|
||||
temperature=0.1,
|
||||
top_p=0.8,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
main_llm = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
temperature=0.2,
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
|
||||
suggested_llm = ChatQwen(
|
||||
model="qwen-plus",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
streaming=False,
|
||||
temperature=0.1,
|
||||
top_p=0.8,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
repoer_llm = ChatQwen(
|
||||
enable_thinking=False,
|
||||
model="qwen3.5-flash",
|
||||
temperature=0.2,
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
@@ -8,7 +8,7 @@ agent = build_main_agent(use_report=True)
|
||||
|
||||
|
||||
async def continuous_chat():
|
||||
thread_id = str(uuid.uuid4())
|
||||
thread_id = "c8e327fb-e208-4fab-83fd-b7b9c4d5fdd0"
|
||||
print("===== 家具设计助手(支持持续对话+记忆)=====")
|
||||
print("输入 'exit' 或 '退出' 结束对话\n")
|
||||
|
||||
@@ -25,13 +25,38 @@ async def continuous_chat():
|
||||
|
||||
print("\n助手:正在处理你的需求...\n")
|
||||
|
||||
current_config = {
|
||||
"recursion_limit": 120,
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
}
|
||||
source_config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_id": '1f11dc17-be49-65a1-8000-96139f7c89cb'
|
||||
}
|
||||
}
|
||||
initial_messages = []
|
||||
older_state = await agent.aget_state(source_config)
|
||||
combined_values = older_state.values.copy()
|
||||
if initial_messages:
|
||||
combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
|
||||
await agent.aupdate_state(current_config, combined_values)
|
||||
|
||||
# 现在可以安全使用 async for
|
||||
async for stream in agent.astream(
|
||||
{"messages": user_input},
|
||||
stream_mode=["updates", "messages", "custom"],
|
||||
subgraphs=True,
|
||||
version="v2",
|
||||
config={"configurable": {"thread_id": thread_id}}
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
'checkpoint_id': '1f11dc17-be49-65a1-8000-96139f7c89cb'
|
||||
}
|
||||
|
||||
}
|
||||
):
|
||||
|
||||
print(stream)
|
||||
@@ -61,7 +86,7 @@ async def continuous_chat():
|
||||
|
||||
elif mode == "custom":
|
||||
print(f"[report] {chunks.get('delta', '')}", end="")
|
||||
|
||||
print("end")
|
||||
# if chunk["type"] == "messages":
|
||||
# token, metadata = chunk["data"]
|
||||
# if not isinstance(token, AIMessageChunk):
|
||||
|
||||
0
src/server/deep_agent/tools/__init__.py
Normal file
0
src/server/deep_agent/tools/__init__.py
Normal file
@@ -1,6 +1,6 @@
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from src.server.deep_agent.agents.init_llm import title_llm
|
||||
from src.server.deep_agent.init_llm import title_llm
|
||||
|
||||
|
||||
def conversation_title(full_conversation):
|
||||
|
||||
75
src/server/deep_agent/tools/extract_suggested_questions.py
Normal file
75
src/server/deep_agent/tools/extract_suggested_questions.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from src.server.deep_agent.init_llm import suggested_llm
|
||||
|
||||
|
||||
def format_messages(messages, max_messages: int = 6) -> str:
|
||||
"""
|
||||
将 LangGraph messages 转换为 LLM prompt 文本
|
||||
"""
|
||||
messages = messages[-max_messages:]
|
||||
lines: List[str] = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
lines.append(f"User: {m.content}")
|
||||
elif isinstance(m, AIMessage):
|
||||
if m.content:
|
||||
lines.append(f"Assistant: {m.content}")
|
||||
elif isinstance(m, ToolMessage):
|
||||
# Tool结果建议简单化
|
||||
tool_output = str(m.content)
|
||||
if len(tool_output) > 200:
|
||||
tool_output = tool_output[:200] + "..."
|
||||
lines.append(f"Tool Result: {tool_output}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def generate_suggested_questions(
|
||||
agent,
|
||||
thread_id: str,
|
||||
max_messages: int = 6,
|
||||
) -> List[str]:
|
||||
"""
|
||||
根据当前对话生成3条用户可能继续提问的问题
|
||||
"""
|
||||
# 获取当前对话state
|
||||
state = agent.get_state(
|
||||
{"configurable": {"thread_id": thread_id}}
|
||||
)
|
||||
messages = state.values.get("messages", [])
|
||||
if not messages:
|
||||
return []
|
||||
conversation = format_messages(messages, max_messages)
|
||||
|
||||
prompt = f"""
|
||||
以下是用户与AI助手的对话:
|
||||
{conversation}
|
||||
请根据对话内容,生成3条用户可能继续提出的问题。
|
||||
要求:
|
||||
- 每条一句话
|
||||
- 语言自然
|
||||
- 不要解释
|
||||
- 返回JSON数组
|
||||
- 尽量与家具设计相关
|
||||
示例:
|
||||
["问题1", "问题2", "问题3"]
|
||||
"""
|
||||
result = await suggested_llm.ainvoke(prompt)
|
||||
|
||||
text = result.content.strip()
|
||||
|
||||
try:
|
||||
questions = json.loads(text)
|
||||
|
||||
if isinstance(questions, list):
|
||||
return questions[:3]
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
@@ -2,27 +2,11 @@ import os
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, List, Dict
|
||||
from langchain_qwq import ChatQwen
|
||||
from langgraph.config import get_stream_writer
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
# =========================
|
||||
# LLM 初始化
|
||||
# =========================
|
||||
|
||||
|
||||
llm = ChatQwen(
|
||||
enable_thinking=False,
|
||||
model="qwen3.5-flash",
|
||||
temperature=0.2,
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
from src.server.deep_agent.init_llm import repoer_llm
|
||||
|
||||
|
||||
# =========================
|
||||
@@ -109,7 +93,7 @@ async def report_generator(
|
||||
|
||||
full_report = ""
|
||||
try:
|
||||
report_llm = llm.with_config(
|
||||
report_llm = repoer_llm.with_config(
|
||||
callbacks=[]
|
||||
)
|
||||
async for chunk in report_llm.astream(
|
||||
|
||||
Reference in New Issue
Block a user