新增对话接口
This commit is contained in:
@@ -1,23 +0,0 @@
|
||||
from langchain_qwq import ChatQwen
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
llm = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
enable_thinking=False,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
title_llm = ChatQwen(
|
||||
model="qwen-plus",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
streaming=False,
|
||||
temperature=0.1,
|
||||
top_p=0.8,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
@@ -11,8 +11,8 @@ from src.core.config import MONGO_URI
|
||||
from src.server.deep_agent.agents.painter import painter_subagent
|
||||
from src.server.deep_agent.agents.researcher import research_subagent
|
||||
from src.server.deep_agent.agents.user_profile import user_profile_subagent
|
||||
from src.server.deep_agent.init_llm import main_llm
|
||||
from src.server.deep_agent.init_prompt import build_system_prompt
|
||||
from src.server.deep_agent.tools.report_generator_tool import llm
|
||||
|
||||
TOOL_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_ROOT = TOOL_DIR.parent
|
||||
@@ -32,7 +32,7 @@ subagents = [
|
||||
|
||||
def build_main_agent(use_report):
|
||||
main_agent = create_deep_agent(
|
||||
model=llm,
|
||||
model=main_llm,
|
||||
system_prompt=build_system_prompt(use_report=use_report),
|
||||
subagents=subagents,
|
||||
checkpointer=checkpointer,
|
||||
@@ -42,7 +42,7 @@ def build_main_agent(use_report):
|
||||
),
|
||||
middleware=[
|
||||
SummarizationMiddleware(
|
||||
model=llm,
|
||||
model=main_llm,
|
||||
trigger=("tokens", 3000),
|
||||
keep=("messages", 100),
|
||||
),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from langchain.agents.middleware import wrap_tool_call
|
||||
|
||||
from src.server.deep_agent.agents.init_llm import llm
|
||||
from src.server.deep_agent.init_llm import llm
|
||||
from src.server.deep_agent.init_prompt import build_painter_prompt
|
||||
from src.server.deep_agent.tools.generate_furniture_sketch import generate_furniture
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.server.deep_agent.agents.init_llm import llm
|
||||
from src.server.deep_agent.init_llm import llm
|
||||
from src.server.deep_agent.init_prompt import build_researcher_prompt
|
||||
from src.server.deep_agent.tools.crawl_tool import crawl4ai_batch
|
||||
from src.server.deep_agent.tools.report_generator_tool import report_generator
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.server.deep_agent.agents.init_llm import llm
|
||||
from src.server.deep_agent.init_llm import llm
|
||||
from src.server.deep_agent.init_prompt import build_user_persona_prompt
|
||||
from src.server.deep_agent.tools.user_persona_tool import query_report_profile, update_report_profile, check_profile_complete
|
||||
|
||||
|
||||
51
src/server/deep_agent/init_llm.py
Normal file
51
src/server/deep_agent/init_llm.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from langchain_qwq import ChatQwen
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
llm = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
enable_thinking=False,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
title_llm = ChatQwen(
|
||||
model="qwen-plus",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
streaming=False,
|
||||
temperature=0.1,
|
||||
top_p=0.8,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
main_llm = ChatQwen(
|
||||
model="qwen3.5-flash",
|
||||
temperature=0.2,
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
|
||||
suggested_llm = ChatQwen(
|
||||
model="qwen-plus",
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
streaming=False,
|
||||
temperature=0.1,
|
||||
top_p=0.8,
|
||||
api_key=settings.QWEN_API_KEY
|
||||
)
|
||||
|
||||
repoer_llm = ChatQwen(
|
||||
enable_thinking=False,
|
||||
model="qwen3.5-flash",
|
||||
temperature=0.2,
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
@@ -8,7 +8,7 @@ agent = build_main_agent(use_report=True)
|
||||
|
||||
|
||||
async def continuous_chat():
|
||||
thread_id = str(uuid.uuid4())
|
||||
thread_id = "c8e327fb-e208-4fab-83fd-b7b9c4d5fdd0"
|
||||
print("===== 家具设计助手(支持持续对话+记忆)=====")
|
||||
print("输入 'exit' 或 '退出' 结束对话\n")
|
||||
|
||||
@@ -25,13 +25,38 @@ async def continuous_chat():
|
||||
|
||||
print("\n助手:正在处理你的需求...\n")
|
||||
|
||||
current_config = {
|
||||
"recursion_limit": 120,
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
}
|
||||
source_config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_id": '1f11dc17-be49-65a1-8000-96139f7c89cb'
|
||||
}
|
||||
}
|
||||
initial_messages = []
|
||||
older_state = await agent.aget_state(source_config)
|
||||
combined_values = older_state.values.copy()
|
||||
if initial_messages:
|
||||
combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
|
||||
await agent.aupdate_state(current_config, combined_values)
|
||||
|
||||
# 现在可以安全使用 async for
|
||||
async for stream in agent.astream(
|
||||
{"messages": user_input},
|
||||
stream_mode=["updates", "messages", "custom"],
|
||||
subgraphs=True,
|
||||
version="v2",
|
||||
config={"configurable": {"thread_id": thread_id}}
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
'checkpoint_id': '1f11dc17-be49-65a1-8000-96139f7c89cb'
|
||||
}
|
||||
|
||||
}
|
||||
):
|
||||
|
||||
print(stream)
|
||||
@@ -61,7 +86,7 @@ async def continuous_chat():
|
||||
|
||||
elif mode == "custom":
|
||||
print(f"[report] {chunks.get('delta', '')}", end="")
|
||||
|
||||
print("end")
|
||||
# if chunk["type"] == "messages":
|
||||
# token, metadata = chunk["data"]
|
||||
# if not isinstance(token, AIMessageChunk):
|
||||
|
||||
0
src/server/deep_agent/tools/__init__.py
Normal file
0
src/server/deep_agent/tools/__init__.py
Normal file
@@ -1,6 +1,6 @@
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from src.server.deep_agent.agents.init_llm import title_llm
|
||||
from src.server.deep_agent.init_llm import title_llm
|
||||
|
||||
|
||||
def conversation_title(full_conversation):
|
||||
|
||||
75
src/server/deep_agent/tools/extract_suggested_questions.py
Normal file
75
src/server/deep_agent/tools/extract_suggested_questions.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from src.server.deep_agent.init_llm import suggested_llm
|
||||
|
||||
|
||||
def format_messages(messages, max_messages: int = 6) -> str:
|
||||
"""
|
||||
将 LangGraph messages 转换为 LLM prompt 文本
|
||||
"""
|
||||
messages = messages[-max_messages:]
|
||||
lines: List[str] = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
lines.append(f"User: {m.content}")
|
||||
elif isinstance(m, AIMessage):
|
||||
if m.content:
|
||||
lines.append(f"Assistant: {m.content}")
|
||||
elif isinstance(m, ToolMessage):
|
||||
# Tool结果建议简单化
|
||||
tool_output = str(m.content)
|
||||
if len(tool_output) > 200:
|
||||
tool_output = tool_output[:200] + "..."
|
||||
lines.append(f"Tool Result: {tool_output}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def generate_suggested_questions(
|
||||
agent,
|
||||
thread_id: str,
|
||||
max_messages: int = 6,
|
||||
) -> List[str]:
|
||||
"""
|
||||
根据当前对话生成3条用户可能继续提问的问题
|
||||
"""
|
||||
# 获取当前对话state
|
||||
state = agent.get_state(
|
||||
{"configurable": {"thread_id": thread_id}}
|
||||
)
|
||||
messages = state.values.get("messages", [])
|
||||
if not messages:
|
||||
return []
|
||||
conversation = format_messages(messages, max_messages)
|
||||
|
||||
prompt = f"""
|
||||
以下是用户与AI助手的对话:
|
||||
{conversation}
|
||||
请根据对话内容,生成3条用户可能继续提出的问题。
|
||||
要求:
|
||||
- 每条一句话
|
||||
- 语言自然
|
||||
- 不要解释
|
||||
- 返回JSON数组
|
||||
- 尽量与家具设计相关
|
||||
示例:
|
||||
["问题1", "问题2", "问题3"]
|
||||
"""
|
||||
result = await suggested_llm.ainvoke(prompt)
|
||||
|
||||
text = result.content.strip()
|
||||
|
||||
try:
|
||||
questions = json.loads(text)
|
||||
|
||||
if isinstance(questions, list):
|
||||
return questions[:3]
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
@@ -2,27 +2,11 @@ import os
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, List, Dict
|
||||
from langchain_qwq import ChatQwen
|
||||
from langgraph.config import get_stream_writer
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
# =========================
|
||||
# LLM 初始化
|
||||
# =========================
|
||||
|
||||
|
||||
llm = ChatQwen(
|
||||
enable_thinking=False,
|
||||
model="qwen3.5-flash",
|
||||
temperature=0.2,
|
||||
max_tokens=3_000,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=settings.QWEN_API_KEY)
|
||||
from src.server.deep_agent.init_llm import repoer_llm
|
||||
|
||||
|
||||
# =========================
|
||||
@@ -109,7 +93,7 @@ async def report_generator(
|
||||
|
||||
full_report = ""
|
||||
try:
|
||||
report_llm = llm.with_config(
|
||||
report_llm = repoer_llm.with_config(
|
||||
callbacks=[]
|
||||
)
|
||||
async for chunk in report_llm.astream(
|
||||
|
||||
Reference in New Issue
Block a user