feat 接入report

This commit is contained in:
zcr
2026-03-03 17:33:51 +08:00
parent 1ecb02d706
commit 1ade907828
23 changed files with 4079 additions and 516 deletions

1
.gitignore vendored
View File

@@ -147,3 +147,4 @@ app/logs/*
*.json *.json
*.env* *.env*
config.backup.py config.backup.py
*.md

View File

@@ -16,20 +16,41 @@ agents:
designer: designer:
prompt_template: | prompt_template: |
你是一位资深的家具设计师。你的职责是: 你是一位资深的家具设计师,经验丰富、审美一流、沟通温暖且高效。
1. 从用户的模糊描述中提取或补充具体的设计参数(尺寸、材质、人体工学数据)。
2. 如果用户想画图,不要直接画,而是先描述清楚细节,然后让 Visualizer 去画 你的核心目标:快速理解用户想法,并用最合适的方式推进设计
请以专业的口吻回复。
你可以:
1. 用户描述模糊时,可以自然地询问或给出建议,但**绝不强迫补充**尺寸、材质、人体工学等细节,除非用户自己关心或需要明确这些参数。
2. 如果用户提到想看图、想出效果图、想画草图、想渲染等,**直接同意并推动**
- 用一句话确认或赞美用户的想法
- 主动说“我这就帮你把当前设计转给视觉专家生成效果图/草图”
- 然后让 Visualizer 节点去处理(不需要你先写一大段细节描述)
3. 回复时像和懂设计的客户聊天一样:专业、亲切、有创意,偶尔带点热情或幽默,但始终围绕家具设计。
永远不要用“必须”“请先描述清楚”“按照流程”等强硬的流程化语言。
visualizer: visualizer:
prompt_template: | prompt_template: |
你是视觉专家。你的目标是生成高质量的家具草图。 你是专业的家具工业设计视觉专家,擅长将文字描述转化为高质量、清晰、专业的家具设计草图。
步骤:
1. 根据上下文,编写一个详细的 Stable Diffusion 风格的英文 Prompt。
2. 必须调用 generate_furniture_sketch 工具来生成图片。
注意:如果对话中出现 [SYSTEM_DIRECTIVE] 要求直接绘图,请立即根据已知信息编写 Prompt 并调用 generate_furniture_sketch 工具,不要进行多余的询问。 你的唯一任务:
- 基于当前全部对话上下文(包括用户描述、已有设计要点、风格要求、材质、尺寸、功能等),在内部生成一个详细的英文 Stable Diffusion Prompt。
- 然后**立即调用 generate_furniture 工具**来生成图片。
- **绝对不要**把生成的 Prompt 文本、任何代码块、任何解释、任何思考过程输出给用户。
- 只通过工具调用返回结果,工具执行完成后自然结束。
researcher: Prompt 内部生成要求(仅供你自己参考,不要输出):
prompt_template: | 1. 语言:全程英文
你是情报专家,负责检索与整理参考资料并生成报告。 2. 结构:主体描述 + 风格 + 视角 + 细节 + 材质 + 照明 + 质量修饰词
3. 必须包含高质量关键词highly detailed, sharp focus, professional industrial design sketch, clean line art 或 photorealistic product render
4. 背景pure white background, studio lighting, no people, no text, no watermark
5. 避免blurry, deformed, low quality, cartoon, extra limbs, bad anatomy
6. 如果上下文有明确风格、视角,必须强烈体现
7. 长度80-150 个词左右
现在开始工作:
- 直接分析上下文
- 内部构建 Prompt
- 立即调用 generate_furniture 工具
- 不要输出任何文字

56
logging_env.py Normal file
View File

@@ -0,0 +1,56 @@
import os
from src.core.config import settings
LOGGER_CONFIG_DICT = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'simple': {
'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S' # 补充日期格式,日志更易读
}
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'INFO',
'formatter': 'simple',
'stream': 'ext://sys.stdout',
},
'info_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'INFO',
'formatter': 'simple',
'filename': os.path.join(settings.LOGS_PATH, 'info.log'),
'maxBytes': 10485760,
'backupCount': 50,
'encoding': 'utf8',
},
'error_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'ERROR',
'formatter': 'simple',
'filename': os.path.join(settings.LOGS_PATH, 'error.log'),
'maxBytes': 10485760,
'backupCount': 20,
'encoding': 'utf8',
},
'debug_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'DEBUG',
'formatter': 'simple',
'filename': os.path.join(settings.LOGS_PATH, 'debug.log'),
'maxBytes': 10485760,
'backupCount': 50,
'encoding': 'utf8',
},
},
'loggers': {
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
},
'root': {
'level': 'DEBUG',
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
},
}

View File

@@ -1,8 +1,14 @@
import logging
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from logging_env import LOGGER_CONFIG_DICT
from src.routers import chat from src.routers import chat
logging.config.dictConfig(LOGGER_CONFIG_DICT)
app_server = FastAPI( app_server = FastAPI(
title="Gemini Furniture Designer API", title="Gemini Furniture Designer API",
description="基于 LangGraph + Gemini 2.0 Flash 的家具设计 Agent 接口", description="基于 LangGraph + Gemini 2.0 Flash 的家具设计 Agent 接口",

View File

@@ -4,20 +4,36 @@ version = "0.1.0"
description = "Add your description here" description = "Add your description here"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"crawl4ai>=0.8.0",
"deepagents>=0.4.3",
"fastapi[standard]>=0.128.0", "fastapi[standard]>=0.128.0",
"gunicorn>=25.0.1", "gunicorn>=25.0.1",
"image>=1.5.33", "image>=1.5.33",
"langchain-community>=0.4.1",
"langchain-core>=1.2.8", "langchain-core>=1.2.8",
"langchain-google-genai>=4.2.0", "langchain-google-genai>=4.2.0",
"langgraph>=1.0.7", "langgraph[postgres]>=1.0.7",
"langgraph-checkpoint-mongodb>=0.3.1", "langgraph-checkpoint-mongodb>=0.3.1",
"minio>=7.2.20", "minio>=7.2.20",
"modality>=0.1.0", "modality>=0.1.0",
"motor>=3.7.1", "motor>=3.7.1",
"playwright>=1.58.0",
"pydantic>=2.12.5", "pydantic>=2.12.5",
"pydantic-settings>=2.12.0", "pydantic-settings>=2.12.0",
"pymongo[srv]>=4.15.5", "pymongo[srv]>=4.15.5",
"python-dotenv>=1.2.1", "python-dotenv>=1.2.1",
"tavily-python>=0.7.21",
"uuid>=1.30", "uuid>=1.30",
"uvicorn>=0.40.0", "uvicorn>=0.40.0",
"psycopg[binary]>=3.3.3",
"postgres>=4.0",
"langchain-huggingface>=1.2.0",
"sentence-transformers>=5.2.3",
"rank-bm25>=0.2.2",
"torch>=2.10.0",
"faiss-cpu>=1.13.2",
"terminate>=0.0.9",
"report-generator>=0.1.10",
"dashscope>=1.25.13",
"prompt>=0.4.1",
] ]

View File

@@ -17,6 +17,9 @@ class Settings(BaseSettings):
GOOGLE_CLOUD_PROJECT: str = Field(default="", description="") GOOGLE_CLOUD_PROJECT: str = Field(default="", description="")
GOOGLE_CLOUD_LOCATION: str = Field(default="", description="") GOOGLE_CLOUD_LOCATION: str = Field(default="", description="")
# --- google api 配置信息 ---
QWEN_API_KEY: str = Field(default="", description="")
# --- minio 配置信息 --- # --- minio 配置信息 ---
MINIO_URL: str = Field(default='', description="") MINIO_URL: str = Field(default='', description="")
MINIO_ACCESS: str = Field(default='', description="") MINIO_ACCESS: str = Field(default='', description="")
@@ -29,6 +32,11 @@ class Settings(BaseSettings):
MONGODB_HOST: str = Field(default="localhost", description="") MONGODB_HOST: str = Field(default="localhost", description="")
MONGODB_PORT: int = Field(default=27017, description="") MONGODB_PORT: int = Field(default=27017, description="")
# --- 外部工具api配置信息 ---
TAVILY_API_KEY: str = Field(default="", description="")
LOGS_PATH: str = Field(default="/mnt/data/FiDA/logs", description="")
settings = Settings() settings = Settings()
MONGO_URI = f"mongodb://{settings.MONGODB_USERNAME}:{settings.MONGODB_PASSWORD}@{settings.MONGODB_HOST}:{settings.MONGODB_PORT}" MONGO_URI = f"mongodb://{settings.MONGODB_USERNAME}:{settings.MONGODB_PASSWORD}@{settings.MONGODB_HOST}:{settings.MONGODB_PORT}"

View File

@@ -1,5 +1,8 @@
import logging
import uuid import uuid
import json import json
from typing import AsyncGenerator
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
@@ -7,6 +10,7 @@ from src.server.agent.graph import app # 导入已经 compile 好的 graph
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"]) router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
logger = logging.getLogger(__name__)
@router.post("/stream") @router.post("/stream")
@@ -52,29 +56,28 @@ async def chat_stream(request: ChatRequest):
} }
``` ```
""" """
logger.debug(f"chat request data: {request}")
source_thread_id = request.thread_id source_thread_id = request.thread_id
checkpoint_id = request.checkpoint_id checkpoint_id = request.checkpoint_id
# 1. 确定目标 thread_id # 1. 确定目标 thread_id
# 如果是回溯操作,我们生成一个新的 ID或者由前端传入一个新的 target_thread_id
is_branching = source_thread_id and checkpoint_id is_branching = source_thread_id and checkpoint_id
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8]) target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
# 2. 获取配置参数
temp = request.config_params.temperature if request.config_params else 0.7
# 构建基础 Config # 2. 配置参数
temp = request.config_params.temperature if request.config_params else 0.7
current_config = { current_config = {
"recursion_limit": 100,
"configurable": { "configurable": {
"thread_id": target_thread_id, "thread_id": target_thread_id,
"llm_temperature": temp "llm_temperature": temp,
"use_report": request.use_report,
} }
} }
# 3. 处理状态初始化与分支
initial_messages = []
# 如果是全新的对话(没有 source_thread_id或者明确要求分叉 # 3. 初始化消息 + 系统提示
initial_messages = []
if not source_thread_id or is_branching: if not source_thread_id or is_branching:
# 如果用户传了标签,构造 SystemMessage 注入上下文
if request.config_params: if request.config_params:
cp = request.config_params cp = request.config_params
system_prompt = ( system_prompt = (
@@ -86,7 +89,7 @@ async def chat_stream(request: ChatRequest):
) )
initial_messages.append(SystemMessage(content=system_prompt)) initial_messages.append(SystemMessage(content=system_prompt))
# 4. 执行分叉逻辑(搬运旧数据 # 4. 处理分支(从历史 checkpoint 复制状态
if is_branching: if is_branching:
source_config = { source_config = {
"configurable": { "configurable": {
@@ -95,80 +98,149 @@ async def chat_stream(request: ChatRequest):
} }
} }
older_state = await app.aget_state(source_config) older_state = await app.aget_state(source_config)
# 将旧消息和我们新定义的 SystemMessage 合并
# update_state 会将这些消息推送到新 thread 的存储中
combined_values = older_state.values.copy() combined_values = older_state.values.copy()
if initial_messages: if initial_messages:
combined_values["messages"] = list(combined_values["messages"]) + initial_messages combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
await app.aupdate_state(current_config, combined_values) await app.aupdate_state(current_config, combined_values)
async def event_generator(): async def event_generator() -> AsyncGenerator[str, None]:
# 初始推送状态信息 # 初始事件
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
# 构造本次请求的输入 # 构造输入
# 如果是第一次开始,且有 initial_messages则连同 user message 一起发送 new_messages = initial_messages[:] if not source_thread_id else []
# --- 核心逻辑:构造本次请求的消息列表 ---
new_messages = []
if not source_thread_id and initial_messages:
new_messages.extend(initial_messages)
# 添加用户消息
new_messages.append(HumanMessage(content=request.message)) new_messages.append(HumanMessage(content=request.message))
# --- 新增:强制绘图指令注入 ---
# if request.force_sketch:
# force_instruction = HumanMessage(
# content="[SYSTEM_DIRECTIVE]: 用户点击了强制生成按钮。请立即根据当前上下文调用 generate_furniture_sketch 工具生成草图,无需确认。"
# )
# new_messages.append(force_instruction)
input_data = { input_data = {
"messages": new_messages, "messages": new_messages,
"require_suggestion": request.need_suggestion # 初始由前端决定 "require_suggestion": request.need_suggestion,
"use_report": request.use_report,
} }
async for event in app.astream( # 使用 astream_events v2 + stream_subgraphs=True 来捕获 DeepAgents 内部流式事件
async for event in app.astream_events(
input_data, input_data,
current_config, version="v2",
stream_mode="updates" config=current_config,
stream_subgraphs=True,
): ):
for node_name, output in event.items(): event_kind = event["event"]
if "messages" in output:
# 获取最新 state 以获取 checkpoint_id
state = await app.aget_state(current_config)
current_cp_id = state.config["configurable"].get("checkpoint_id")
# 遍历本次 update 产生的所有消息 # 获取当前 checkpoint_id安全方式避免 KeyError
for msg in output["messages"]: latest_state = await app.aget_state(current_config)
configurable = latest_state.config.get("configurable", {})
current_cp_id = configurable.get("checkpoint_id", "") # 如果没有,返回空字符串
# ────────────────────────────────────────────────
# 1. LLM token 流式输出(主图或子图的逐 token
# ────────────────────────────────────────────────
if event_kind == "on_chat_model_stream":
chunk = event["data"].get("chunk")
if chunk and chunk.content:
node_name = event.get("name", "Unknown")
# 判断是否来自 Researcher 子图
namespace = event.get("parent_ids", []) or event.get("namespace", [])
if any("Researcher" in str(ns) for ns in namespace):
node_name = "Researcher"
payload = {
"node": node_name,
"content": chunk.content,
"is_delta": True,
"checkpoint_id": current_cp_id,
"image_url": None,
"suggestions": []
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
# ────────────────────────────────────────────────
# 2. 自定义事件report_delta 等)
# ────────────────────────────────────────────────
elif event_kind == "on_custom_event":
custom_data = event["data"]
if isinstance(custom_data, dict):
if custom_data.get("type") == "report_delta":
payload = { payload = {
"node": node_name, "node": "Researcher",
"content": "", "content": custom_data.get("delta", ""),
"is_delta": True,
"checkpoint_id": current_cp_id,
"image_url": None, "image_url": None,
"suggestions": []
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
# 可选:报告开始/完成/错误等状态提示
elif custom_data.get("type") in ("report_start", "report_complete", "report_error"):
status_msg = {
"report_start": "Start generating reports...",
"report_complete": "Report generation completed",
"report_error": f"Report generation failed: {custom_data.get('message', '')}"
}.get(custom_data["type"], "")
payload = {
"node": "Researcher",
"content": status_msg,
"is_delta": False,
"checkpoint_id": current_cp_id,
"image_url": None,
"suggestions": []
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
# ────────────────────────────────────────────────
# 3. 节点启动 / 工具启动(进度提示)
# ────────────────────────────────────────────────
elif event_kind in {"on_tool_start", "on_tool_end"}:
tool_name = event.get("name", "unknown_tool")
tool_data = event.get("data", {})
tool_input = tool_data.get("input", "")
tool_output = tool_data.get("output", "")
if event_kind == "on_tool_start":
payload = {
"node": tool_name,
"content": tool_input,
"is_delta": False,
"checkpoint_id": current_cp_id,
"image_url": None,
"suggestions": []
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
else:
if tool_name == "generate_furniture" and isinstance(tool_output, str):
payload = {
"node": tool_name,
"content": "Design sketch has been generated for you.", # 给用户友好的文字提示
"image_url": tool_output, # 直接传 URL 给前端显示
"is_delta": False, # 这是一个完整事件,不是增量
"checkpoint_id": current_cp_id, "checkpoint_id": current_cp_id,
"suggestions": [] "suggestions": []
} }
# --- 核心改动:提取建议按钮 ---
# 无论是不是 Suggester 节点,只要消息里带了建议就提取
if hasattr(msg, "additional_kwargs") and "suggestions" in msg.additional_kwargs:
payload["suggestions"] = msg.additional_kwargs["suggestions"]
content = msg.content
# 逻辑判断MinIO 图片处理
if node_name == "Visualizer" and str(content).endswith("png") and "furniture/sketches" in str(content):
payload["image_url"] = content
payload["content"] = "已为您生成设计草图"
else:
payload["content"] = content
# 如果消息既没有文本、也没有图片、也没有建议(比如中间的 ToolCall 消息),则跳过
if not payload["content"] and not payload["image_url"] and not payload["suggestions"]:
continue
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
elif tool_name == "topic_research":
payload = {
"node": tool_name,
"content": "Visiting...", # 给用户友好的文字提示
"image_url": None, # 直接传 URL 给前端显示
"search_list": tool_output.content,
"is_delta": False, # 这是一个完整事件,不是增量
"checkpoint_id": current_cp_id,
"suggestions": []
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
else:
# 可选其他工具的通用处理debug 或显示结果)
if tool_output:
payload = {
"node": tool_name,
"content": f"tool {tool_name} Execution completed{str(tool_output)[:200]}...", # 截断避免过长
"is_delta": False,
"checkpoint_id": current_cp_id,
"image_url": None,
"suggestions": []
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
# 流结束
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream") return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -218,7 +290,7 @@ async def get_chat_history(thread_id: str):
} }
``` ```
""" """
config = {"configurable": {"thread_id": thread_id}} config = {"configurable": {"thread_id": thread_id}, }
history_data = [] history_data = []
async for state in app.aget_state_history(config): async for state in app.aget_state_history(config):
msg_content = "Initial" msg_content = "Initial"

View File

@@ -15,6 +15,7 @@ class ChatRequest(BaseModel):
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID用于从历史点开启新对话") checkpoint_id: Optional[str] = Field(None, description="回溯点的ID用于从历史点开启新对话")
config_params: Optional[AgentConfig] = None config_params: Optional[AgentConfig] = None
need_suggestion: bool = False need_suggestion: bool = False
use_report: bool = False # ← 新增:是否使用深度报告
class HistoryItem(BaseModel): class HistoryItem(BaseModel):

View File

@@ -1,36 +1,59 @@
import os from pathlib import Path
from typing import AsyncGenerator, Dict, Any
from google.oauth2 import service_account from deepagents import create_deep_agent
from deepagents.backends import FilesystemBackend
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_qwq import ChatQwen
from src.server.agent.state import AgentState
from src.server.agent.tools import generate_2025_report_tool, generate_furniture_sketch
from src.server.agent.config_loader import get_agent_prompt
from src.core.config import settings
from src.server.utils.generate_suggestion import generate_chat_suggestions
creds = service_account.Credentials.from_service_account_file( from src.core.config import settings
settings.GOOGLE_GENAI_USE_VERTEXAI, from src.server.agent.prompt import SYSTEM_PROMPT
scopes=["https://www.googleapis.com/auth/cloud-platform"], from src.server.agent.state import AgentState
from src.server.agent.tools.generate_furniture_sketch import generate_furniture
from src.server.agent.config_loader import get_agent_prompt
from src.server.agent.tools.crawl_tool import crawl4ai_batch
from src.server.agent.tools.research_tool import topic_research
from src.server.agent.tools.structured_retrieval_tool import structured_retrieval
from src.server.agent.tools.terminate_tool import terminate
from src.server.agent.tools.user_persona_tool import manage_user_persona
from src.server.utils.generate_suggestion import generate_chat_suggestions
from test.report.tools.report_generator_tool import report_generator
# 目前這個主程式檔案所在的目錄
MAIN_DIR = Path(__file__).resolve().parent
# 專案根目錄(因為 main.py 跟 tools/ 同級,所以 parent 就是根)
PROJECT_ROOT = MAIN_DIR
model = ChatQwen(
model="qwen3.5-flash",
max_tokens=3_000,
timeout=None,
max_retries=2,
api_key=settings.QWEN_API_KEY)
tools = [manage_user_persona, topic_research, crawl4ai_batch, structured_retrieval, report_generator, terminate]
research_agent = create_deep_agent(
model=model,
tools=tools,
system_prompt=SYSTEM_PROMPT,
backend=FilesystemBackend(
root_dir=str(PROJECT_ROOT / "agent_workspace"),
virtual_mode=False, # 重要:關掉虛擬模式 → 真的寫硬碟
)
) )
# 辅助函数:根据配置动态获取 LLM # 辅助函数:根据配置动态获取 LLM
def get_model(config: RunnableConfig): def get_model(config: RunnableConfig):
# 从 configurable 中获取温度,默认为 0.5 (对应你之前的设置)
# 这个 key 必须与你在 chat_stream 路由里定义的 "llm_temperature" 一致
temp = config["configurable"].get("llm_temperature", 0.5) temp = config["configurable"].get("llm_temperature", 0.5)
return ChatQwen(
return ChatGoogleGenerativeAI( model="qwen3.5-flash",
model="gemini-2.0-flash", max_tokens=3_000,
timeout=None,
max_retries=2,
temperature=temp, temperature=temp,
credentials=creds, api_key=settings.QWEN_API_KEY)
project=settings.GOOGLE_CLOUD_PROJECT,
location=settings.GOOGLE_CLOUD_LOCATION,
vertexai=True,
api_key=settings.GOOGLE_API_KEY
)
# --- 1. Designer Agent (设计顾问) --- # --- 1. Designer Agent (设计顾问) ---
@@ -48,33 +71,158 @@ async def designer_node(state: AgentState, config: RunnableConfig):
return {"messages": [response], "require_suggestion": should_suggest} return {"messages": [response], "require_suggestion": should_suggest}
# --- 2. Researcher Agent (情报专家) --- async def researcher_node(
async def researcher_node(state: AgentState, config: RunnableConfig): state: AgentState,
"""负责调用报告生成工具""" config: RunnableConfig
model = get_model(config) ) -> AsyncGenerator[Dict[str, Any], None]:
tools = [generate_2025_report_tool] use_report = config["configurable"].get("use_report", False)
llm_with_tools = model.bind_tools(tools) if not use_report:
yield {
"messages": [AIMessage(
content="深度报告功能未启用,请通过前端按钮触发。",
name="Researcher"
)],
"next": "Supervisor"
}
return
messages = state["messages"] messages = state["messages"]
system_text = get_agent_prompt("researcher") last_human = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
system_prompt = SystemMessage(content=system_text)
response = await llm_with_tools.ainvoke([system_prompt] + messages)
if response.tool_calls: if not last_human:
tool_call = response.tool_calls[0] yield {
if tool_call["name"] == "generate_2025_report_tool": "messages": [AIMessage(
# 这里的工具调用如果也是异步的,建议加 await content="深度研究节点:未找到有效的用户问题",
result = await generate_2025_report_tool.ainvoke(tool_call["args"]) name="Researcher"
return {"messages": [response, HumanMessage(content=str(result))]} )],
"next": "Supervisor"
}
return
return {"messages": [response]} full_content = ""
current_step = "正在启动深度报告生成..."
# 初始提示
yield {
"messages": [AIMessage(
content="正在启动深度报告生成...",
name="Researcher",
additional_kwargs={
"current_step": current_step,
"streaming": True
}
)]
}
async for event in research_agent.astream_events(
{"messages": messages[-12:]},
version="v2",
config=config,
stream_subgraphs=True
):
event_type = event["event"]
name = event.get("name", "未知")
if event["event"] == "on_custom_event":
custom_data = event["data"]
# 你的 writer 发的是 dict所以这里 custom_data 就是你写的 {"type": "report_delta", "delta": "..."}
if isinstance(custom_data, dict) and custom_data.get("type") == "report_delta":
delta = custom_data.get("delta", "")
print(delta, end="", flush=True) # 实时打印,不换行
# ────────────── 工具结束事件:重点处理并 yield 输出 ──────────────
if event["event"] in {"on_tool_start", "on_tool_end"}:
tool_name = event.get("name", "未知")
is_start = event["event"] == "on_tool_start"
if is_start:
tool_input = event["data"].get("input", {})
current_step = f"正在執行工具:{tool_name}"
print(f"| {current_step} | {tool_input}")
yield {
"messages": [AIMessage(
content=full_content,
name="Researcher",
additional_kwargs={
"current_step": current_step,
"tool_name": tool_name,
"tool_input": tool_input,
"tool_status": "start",
"streaming": True
}
)]
}
else: # on_tool_end
tool_output = event["data"].get("output", "")
current_step = f"工具 {tool_name} 已完成"
print(f"| {current_step} | {tool_output}")
yield {
"messages": [AIMessage(
content=full_content,
name="Researcher",
additional_kwargs={
"current_step": current_step,
"tool_name": tool_name,
"tool_output": tool_output,
"tool_status": "end",
"streaming": True
}
)]
}
# ────────────── LLM 内容生成(保持原有逻辑) ──────────────
elif event_type == "on_chat_model_stream":
chunk = event["data"]["chunk"].content or ""
if chunk:
print(chunk, end="", flush=True)
full_content += chunk
if "\n" in chunk or len(full_content) % 4 == 0:
yield {
"messages": [AIMessage(
content=full_content,
name="Researcher",
additional_kwargs={
"current_step": current_step,
"streaming": True
}
)]
}
# ────────────── 其他链路事件(可选补充) ──────────────
elif event_type in ("on_chain_start", "on_chain_end"):
status = "开始" if event_type == "on_chain_start" else "完成"
current_step = f"[{status}] {name.upper()}"
yield {
"messages": [AIMessage(
content=full_content,
name="Researcher",
additional_kwargs={
"current_step": current_step,
"streaming": True
}
)]
}
# 最终输出
yield {
"messages": [AIMessage(
content=full_content.strip() or "报告生成完成",
name="Researcher",
additional_kwargs={
"current_step": "报告已完成",
"streaming": False
}
)],
"next": "Suggester"
}
# --- 3. Visualizer Agent (视觉专家) --- # --- 3. Visualizer Agent (视觉专家) ---
async def visualizer_node(state: AgentState, config: RunnableConfig): async def visualizer_node(state: AgentState, config: RunnableConfig):
"""负责将自然语言转化为绘图 Prompt 并调用绘图工具""" """负责将自然语言转化为绘图 Prompt 并调用绘图工具"""
model = get_model(config) model = get_model(config)
tools = [generate_furniture_sketch] tools = [generate_furniture]
llm_with_tools = model.bind_tools(tools) llm_with_tools = model.bind_tools(tools)
messages = state["messages"] messages = state["messages"]
@@ -85,8 +233,8 @@ async def visualizer_node(state: AgentState, config: RunnableConfig):
if response.tool_calls: if response.tool_calls:
tool_call = response.tool_calls[0] tool_call = response.tool_calls[0]
if tool_call["name"] == "generate_furniture_sketch": if tool_call["name"] == "generate_furniture":
img_url = await generate_furniture_sketch.ainvoke(tool_call["args"]) img_url = await generate_furniture.ainvoke(tool_call["args"])
return { return {
"messages": [ "messages": [
response, response,

View File

@@ -1,14 +1,12 @@
import os
from typing import Literal from typing import Literal
from google.oauth2 import service_account
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.runnables import RunnableConfig
from langchain_qwq import ChatQwen
from langgraph.graph import StateGraph, END, START from langgraph.graph import StateGraph, END, START
from pydantic import BaseModel from pydantic import BaseModel
from pymongo import MongoClient from pymongo import MongoClient
from src.core.config import settings, MONGO_URI from src.core.config import MONGO_URI, settings
from src.server.agent.state import AgentState from src.server.agent.state import AgentState
from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node
from langgraph.checkpoint.mongodb import MongoDBSaver from langgraph.checkpoint.mongodb import MongoDBSaver
@@ -21,18 +19,16 @@ class RouteResponse(BaseModel):
next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"] next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"]
creds = service_account.Credentials.from_service_account_file( llm_supervisor = ChatQwen(
settings.GOOGLE_GENAI_USE_VERTEXAI, model="qwen3.5-flash",
scopes=["https://www.googleapis.com/auth/cloud-platform"], max_tokens=3_000,
) timeout=None,
max_retries=2,
llm_supervisor = ChatGoogleGenerativeAI( api_key=settings.QWEN_API_KEY)
model="gemini-2.0-flash", credentials=creds,
project="aida-461108", location='us-central1', vertexai=True, api_key=settings.GOOGLE_API_KEY
)
def supervisor_node(state: AgentState): def supervisor_node(state: AgentState, config: RunnableConfig):
use_report = config["configurable"].get("use_report", False)
messages = state["messages"] messages = state["messages"]
if not messages: if not messages:
return {"next": "Suggester"} return {"next": "Suggester"}
@@ -69,7 +65,6 @@ workflow.add_node("Designer", designer_node)
workflow.add_node("Researcher", researcher_node) workflow.add_node("Researcher", researcher_node)
workflow.add_node("Visualizer", visualizer_node) workflow.add_node("Visualizer", visualizer_node)
workflow.add_node("Suggester", suggester_node) # 新增节点 workflow.add_node("Suggester", suggester_node) # 新增节点
workflow.add_edge(START, "Supervisor") workflow.add_edge(START, "Supervisor")
# 修改条件边映射 # 修改条件边映射

View File

@@ -0,0 +1,66 @@
SYSTEM_PROMPT = """
You are "TrendAgent" - a focused, efficient design trend analysis agent.
Your ONLY goal: produce one high-quality Markdown trend report per user request.
TOOL ORDER & DISCIPLINE IS MANDATORY - DO NOT INVENT STEPS
┌───────────────────────────────────────────────────────┐
│ Phase 0 - Context & Persona (必须先完成) │
└───────────────────────────────────────────────────────┘
Rules for Phase 0:
1. ALWAYS start with manage_user_persona(command="get")
2. If STATUS == "INCOMPLETE" or persona missing critical fields (Design Type, Style, Target Audience, Color Preference, etc.):
→ MUST call manage_user_persona(command="ask") to collect missing info
→ After user answers → call manage_user_persona(command="set", ...)
→ Loop until STATUS == "READY"
3. Only when STATUS == "READY" → proceed to Phase 1
4. Never assume or fabricate persona details
┌───────────────────────────────────────────────────────┐
│ Phase 1 - Planning (必须执行一次且只能一次) │
└───────────────────────────────────────────────────────┘
When persona READY and user gave a clear trend request:
1. Call write_todos EXACTLY ONCE with a strict plan containing:
- 36 concrete steps (numbered)
- Which URLs/topics to research
- Expected output of each major tool
- Final deliverable: one Markdown report
2. After receiving todos, you MUST follow this exact sequence unless impossible
3. Do NOT call any other tool until write_todos is done
┌───────────────────────────────────────────────────────┐
│ Phase 2 - Research & Collection │
└───────────────────────────────────────────────────────┘
Follow todos order:
- Use topic_research → get 38 high-quality URLs (add persona [Style] [Type] in query)
- Select best 36 URLs → call crawl4ai_batch ONCE with list
- Get file paths → call structured_retrieval ONCE with file_paths list
┌───────────────────────────────────────────────────────┐
│ Phase 3 - Synthesis & Delivery │
└───────────────────────────────────────────────────────┘
After structured_retrieval summary received:
- If extracted item count ≥ 812 AND covers main aspects in todos → ready to report
- Call report_generator ONCE (it reads local JSON/DB)
- After report_generator success → call terminate
- If data obviously insufficient → call topic_research again (max 1 extra round)
┌───────────────────────────────────────────────────────┐
│ HARD RULES - MUST OBEY │
└───────────────────────────────────────────────────────┘
• Never load full JSON/markdown into context - trust local storage
• Batch everything possible (crawl4ai_batch + structured_retrieval)
• Call tools in PHASE ORDER - no jumping, no repetition
• After report_generator → next action MUST be terminate
• If stuck > 4 steps without progress → call terminate with note "Incomplete - insufficient data"
• Never hallucinate trend data - base everything on retrieved content
• Report must start each section with **Conclusion First** insight
• Include [IMAGE_REF_xx] placeholders where visuals were extracted
Current status: Phase 0
"""

View File

@@ -1,49 +1,74 @@
from langchain_core.messages import HumanMessage, AIMessage import asyncio
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from src.server.agent.graph import app from src.server.agent.graph import app
def main(): async def async_main():
# 模拟 thread_id 区分不同用户或项目
config = {"configurable": {"thread_id": "project_alpha"}} config = {"configurable": {"thread_id": "project_alpha"}}
print("測試模式已啟動 (輸入 'exit' 離開,'history' 查看歷史並回溯)")
use_report = input("是否启用深度报告?(y/n): ").lower() == 'y'
while True: while True:
user_input = input("\n👤 设计师 (输入 'history' 定位轮次): ") user_input = input("\n👤 輸入訊息: ").strip()
if user_input.lower() in ["exit", "quit", "結束"]:
print("測試結束")
break
# --- 官方推荐的异步回溯逻辑 ---
if user_input.lower() == "history": if user_input.lower() == "history":
print("\n--- 历史记录 ---") # 你的 history 邏輯(這裡不變)
for state in app.get_state_history(config): print("\n=== 歷史檢查點 ===")
# 每一个 state 都是一个 CheckpointTuple states = [s async for s in app.aget_state_history(config)]
cp_id = state.config["configurable"]["checkpoint_id"] for idx, state_tuple in enumerate(states):
msg = state.values["messages"][-1].content[:30] if state.values.get("messages") else "Initial" cp_id = state_tuple.config["configurable"].get("checkpoint_id", "N/A")
print(f"ID: {cp_id} | 内容: {msg}...") messages = state_tuple.values.get("messages", [])
if messages:
target_id = input("\n请输入想要回溯的 Checkpoint ID (直接回车取消): ") last_msg = messages[-1]
if target_id: msg_type = type(last_msg).__name__
# 重新配置 config指向特定的 checkpoint_id 实现分支 content_preview = str(last_msg.content)[:60].replace("\n", " ")
config = {"configurable": {"thread_id": "project_alpha", "checkpoint_id": target_id}} node = getattr(last_msg, "name", "無節點名")
print(f"✅ 已定位到节点 {target_id},后续对话将从此分叉。") print(f"[{idx}] {cp_id[:12]}... | {node} | {msg_type} | {content_preview}...")
target = input("\n輸入要回溯的 checkpoint ID (或 Enter 取消): ").strip()
if target:
config["configurable"]["checkpoint_id"] = target
print(f"已切換到 checkpoint {target}")
continue continue
# --- 官方推荐的 astream 异步流式调用 --- if not user_input:
print("🤖 Agent 思考中...") continue
for event in app.stream(
{"messages": [HumanMessage(content=user_input)]},
config,
stream_mode="values" # 这里设为 values 可以直接获取当前状态的消息列表
):
# 获取当前节点处理后的最新消息
if "messages" in event:
last_msg = event["messages"][-1]
if isinstance(last_msg, AIMessage):
# 为了极致流式体验,可以在此处对 content 进行打印
pass
# 运行结束后,最新的状态已经自动持久化到 MongoDB print("\n🤖 開始處理...")
# 我们可以通过 app.get_state(config) 验证
final_state = app.get_state(config) try:
print(f"\n✅ 最终回复: {final_state.values['messages'][-1].content}") last_output = ""
async for event in app.astream(
{"messages": [HumanMessage(content=user_input)]},
config,
stream_mode="updates"
):
for node_name, update in event.items():
if "messages" in update:
for msg in update["messages"]:
if isinstance(msg, AIMessage):
content = msg.content.strip()
if content and content != last_output:
print(f"\n[{node_name}] {msg.name or 'AI'}: {content}")
last_output = content
elif isinstance(msg, ToolMessage):
print(f" → 工具 {msg.name}: {msg.content[:120]}{'...' if len(msg.content) > 120 else ''}")
else:
print(f" ({node_name}) {type(msg).__name__}")
final_state = await app.aget_state(config)
final_msg = final_state.values["messages"][-1]
print(f"\n=== 完成 ===\n最終訊息: {final_msg.content[:300]}{'...' if len(final_msg.content) > 300 else ''}")
except Exception as e:
print(f"錯誤:{str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__": if __name__ == "__main__":
main() asyncio.run(async_main())

View File

@@ -1,7 +1,8 @@
import operator import operator
from typing import Annotated, Sequence, TypedDict, Union from typing import Annotated, Sequence, TypedDict, Union, Optional
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
class AgentState(TypedDict): class AgentState(TypedDict):
# messages 存储完整的对话历史operator.add 表示新消息是追加而不是覆盖 # messages 存储完整的对话历史operator.add 表示新消息是追加而不是覆盖
messages: Annotated[Sequence[BaseMessage], operator.add] messages: Annotated[Sequence[BaseMessage], operator.add]

View File

@@ -0,0 +1,121 @@
import time
import asyncio
from typing import List
from urllib.parse import urlparse
from pathlib import Path
from langchain_core.tools import tool
# ─────────────── 重要:計算路徑 ───────────────
# 目前這個檔案 (crawl4ai_batch.py) 所在的目錄
TOOL_DIR = Path(__file__).resolve().parent
# 專案根目錄(假設 tools 資料夾與主程式同級)
PROJECT_ROOT = TOOL_DIR.parent
# 儲存爬取結果的目錄(你可以自由決定放在哪裡)
# 建議選項 A放在專案根目錄下的 workspace/raw_data
SAVE_DIR = PROJECT_ROOT / "workspace" / "raw_data"
# 建議選項 B如果你打算讓 deep agent 直接讀取,建議放在 agent_workspace 底下
# SAVE_DIR = PROJECT_ROOT / "agent_workspace" / "raw_data"
# 確保目錄存在
SAVE_DIR.mkdir(parents=True, exist_ok=True)
# ────────────────────────────────────────────────
@tool
async def crawl4ai_batch(urls: List[str]) -> str:
"""
高性能网页爬虫,支持并行处理多个 URL。
爬取后的 Markdown 内容将保存到本地 workspace/raw_data 目录中。
返回执行结果摘要和保存的文件路径列表。
"""
if not urls:
return "❌ 错误: 未提供任何 URL。"
# print(f"🕷️ 正在并行爬取 {len(urls)} 个 URL...")
# print(f"儲存目錄: {SAVE_DIR}")
# Crawl4AI 配置(保持原樣)
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
browser_config = BrowserConfig(
headless=True,
verbose=False,
java_script_enabled=True,
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/118.0.5993.118 Safari/537.36",
proxy=None, # 可选,如果需要代理填 "http://user:pass@ip:port"
)
run_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
word_count_threshold=5,
excluded_tags=["script", "style", "nav", "footer"],
remove_overlay_elements=True,
process_iframes=True,
)
results_summary = []
saved_files = []
try:
async with AsyncWebCrawler(config=browser_config) as crawler:
tasks = [crawler.arun(url=url, config=run_config) for url in urls]
crawl_results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(crawl_results):
url = urls[i]
if isinstance(result, Exception):
results_summary.append(f"❌ 抓取失败 {url}: {str(result)}")
continue
if result.success:
markdown_content = result.markdown or ""
if len(markdown_content) < 500:
results_summary.append(f"⏩ 跳过 {url} (内容过短)")
continue
# 生成檔名
parsed = urlparse(url)
domain = parsed.netloc.replace("www.", "").replace(".", "_")
path_part = parsed.path.strip("/").replace("/", "_")[:50] or "index"
filename = f"{int(time.time())}_{domain}_{path_part}.md"
# 完整檔案路徑
filepath = SAVE_DIR / filename
# 寫入檔案
with open(filepath, "w", encoding="utf-8") as f:
header = f"<!-- Source: {url} -->\n<!-- Saved: {time.strftime('%Y-%m-%d %H:%M:%S')} -->\n\n"
f.write(header + markdown_content)
saved_files.append(str(filepath)) # 建議轉成字串
results_summary.append(f"✅ 成功: {url}{filepath}")
else:
status = getattr(result, 'status_code', '未知错误')
results_summary.append(f"❌ 失败: {url} (状态码: {status})")
except Exception as e:
return f"🚨 爬虫系统崩溃: {str(e)}"
# 回傳給 agent 的結果
final_output = (
f"### 批量抓取完成 ###\n"
f"已成功保存 {len(saved_files)} 个文件。\n"
f"儲存目錄: {SAVE_DIR}\n"
f"详情:\n" + "\n".join(results_summary)
)
if saved_files:
final_output += "\n\n已保存的文件列表(可供後續讀取):\n" + "\n".join(saved_files)
return final_output

View File

@@ -1,11 +1,8 @@
import base64
import uuid import uuid
from google.oauth2 import service_account from google.oauth2 import service_account
from langchain_core.tools import tool from langchain_core.tools import tool
from google import genai from google import genai
from google.genai.types import GenerateContentConfig, Modality from google.genai.types import GenerateContentConfig, Modality
from PIL import Image
from io import BytesIO
from minio import Minio from minio import Minio
@@ -27,21 +24,8 @@ client = genai.Client(
) )
# --- 模拟你已经开发好的报告生成功能 ---
@tool @tool
def generate_2025_report_tool(topic: str) -> str: def generate_furniture(prompt: str) -> str:
"""
专门用于收集信息并生成报告
当用户询问关于趋势市场分析年度报告如2025家具报告时调用此工具
"""
print(f"\n[系统日志] 正在调用外部模块生成关于 '{topic}' 的报告...")
# 这里对接你实际的代码比如return my_existing_module.run(topic)
return f"【报告生成成功】已生成关于 {topic} 的 PDF 报告。核心洞察2025年趋势倾向于生物嗜好设计(Biophilic Design)和可持续软木材质。"
# --- 2. 绘图工具 (接入 Nano Banana 逻辑) ---
@tool
def generate_furniture_sketch(prompt: str) -> str:
""" """
使用 Gemini 图像生成模型根据详细的英文提示词生成家具设计草图 使用 Gemini 图像生成模型根据详细的英文提示词生成家具设计草图
""" """
@@ -83,32 +67,3 @@ def generate_furniture_sketch(prompt: str) -> str:
return "图片生成成功,但上传至存储服务器失败。" return "图片生成成功,但上传至存储服务器失败。"
except Exception as e: except Exception as e:
return f"绘图流程异常: {str(e)}" return f"绘图流程异常: {str(e)}"
if __name__ == '__main__':
print(generate_furniture_sketch("椅子"))
# creds = service_account.Credentials.from_service_account_file(
# settings.GOOGLE_GENAI_USE_VERTEXAI,
# scopes=["https://www.googleapis.com/auth/cloud-platform"],
# )
# client = genai.Client(
# credentials=creds,
# project=settings.GOOGLE_CLOUD_PROJECT,
# location=settings.GOOGLE_CLOUD_LOCATION,
# vertexai=True
# )
#
# response = client.models.generate_content(
# model="gemini-2.5-flash-image",
# contents=("Generate an image of the Eiffel tower with fireworks in the background."),
# config=GenerateContentConfig(
# response_modalities=[Modality.TEXT, Modality.IMAGE],
# ),
# )
#
# for part in response.candidates[0].content.parts:
# if part.text:
# print(part.text)
# elif part.inline_data:
# image = Image.open(BytesIO((part.inline_data.data)))
# image.save("example-image-eiffel-tower.png")

View File

@@ -0,0 +1,36 @@
import os
from langchain_core.tools import tool
@tool
def read_file(file_path: str) -> str:
"""
读取本地文件的万能工具。支持绝对路径和相对路径。
"""
# 1. 极端清洗:去掉 Agent 可能误加的引号、空格或转义符
path = file_path.strip().strip("'").strip('"').replace("\\", "/")
# 2. 打印当前环境真相(在你的 Python 控制台可见)
print(f"\n--- 🛠️ READ_FILE 调试信息 ---")
print(f"待读路径: {path}")
print(f"当前工作目录 (CWD): {os.getcwd()}")
print(f"是否存在: {os.path.exists(path)}")
# 3. 尝试直接读取(跳过任何沙箱逻辑)
try:
# 如果是相对路径,尝试转为绝对路径再读
abs_path = os.path.abspath(path)
if os.path.exists(abs_path):
with open(abs_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
else:
# 如果读不到,列出父目录内容作为线索
parent = os.path.dirname(abs_path)
if os.path.exists(parent):
files = os.listdir(parent)
return f"错误:文件不存在。该目录下现有的文件有: {files[:5]}..."
return f"错误:路径不存在,且连父目录 {parent} 都找不到。"
except Exception as e:
return f"读取失败,系统异常: {str(e)}"

View File

@@ -0,0 +1,157 @@
import os
import json
import re
from typing import Optional, List, Dict
from langchain_qwq import ChatQwen
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage
from src.core.config import settings
# =========================
# LLM 初始化
# =========================
llm = ChatQwen(
model="qwen3.5-flash",
temperature=0.2,
max_tokens=3_000,
timeout=None,
max_retries=2,
api_key=settings.QWEN_API_KEY)
# =========================
# Tool 输入 Schema
# =========================
class ReportInput(BaseModel):
report_topic: str = Field(
...,
description="Main topic of the report, e.g. '2026 Sofa Design Trends'"
)
structured_data: List[Dict] = Field(
...,
description="Structured retrieval result items"
)
language: Optional[str] = Field(
default="English",
description="Output language"
)
# =========================
# LangGraph Tool
# =========================
@tool("report_generator", args_schema=ReportInput)
async def report_generator(
report_topic: str,
structured_data: List[Dict],
language: str = "English"
) -> dict:
"""
Generate a professional design/market report
directly from structured retrieval results.
"""
if not structured_data:
return {
"status": "error",
"message": "No structured data provided."
}
collected_data_str = json.dumps(
structured_data,
ensure_ascii=False,
indent=2
)
# =========================
# Prompt
# =========================
system_prompt = f"""
You are a professional design trend analyst.
Generate a long, structured Markdown report.
REQUIREMENTS:
1. Follow MECE principle.
2. Embed images ONLY if they start with https://
using: ![alt](url)
3. Insert images inline.
4. Every key insight must cite source:
[Website Name](url)
5. Use Markdown headings.
6. Start directly with title.
7. Be detailed and analytical.
Output Language: {language}
"""
user_prompt = f"""
Topic: {report_topic}
Input Data:
{collected_data_str}
"""
# =========================
# 调用 LLM
# =========================
try:
response = await llm.ainvoke([
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt)
])
report_content = response.content.strip()
# 清理 markdown block 包裹
report_content = (
report_content
.replace("```markdown", "")
.replace("```", "")
.strip()
)
except Exception as e:
return {
"status": "error",
"message": f"LLM generation failed: {str(e)}"
}
# =========================
# 保存报告
# =========================
output_dir = "workspace/reports"
os.makedirs(output_dir, exist_ok=True)
safe_topic = re.sub(
r'[\\/*?:"<>|]',
"",
report_topic.replace(" ", "_")
)
filename = f"{output_dir}/{safe_topic}.md"
try:
with open(filename, "w", encoding="utf-8") as f:
f.write(report_content)
except Exception as e:
return {
"status": "error",
"message": f"Failed to save report: {str(e)}"
}
return {
"status": "success",
"file_path": filename,
"message": "Report generated successfully."
}

View File

@@ -0,0 +1,74 @@
import asyncio
import json
from datetime import datetime
from typing import List, Set, Optional
from langchain_core.tools import tool
from tavily import TavilyClient
from src.core.config import settings
# 模拟配置加载
TAVILY_API_KEY = settings.TAVILY_API_KEY
@tool
async def topic_research(topic: str, max_urls: int = 15) -> str:
"""
深度调研工具。该工具会利用 Tavily 搜索引擎针对特定主题进行多维度搜索。
它会自动生成针对性的搜索词(包含年份和趋势),并返回去重后的高质量 URL 列表。
"""
if not TAVILY_API_KEY:
return "❌ 错误: 未配置 TAVILY_API_KEY。"
client = TavilyClient(api_key=TAVILY_API_KEY)
# 1. 自动生成多维度搜索词 (在工具内部快速生成)
current_year = datetime.now().strftime("%Y")
queries = [
f"{topic} trends {current_year}",
f"{topic} market analysis {current_year}",
f"top selling {topic} styles {current_year}",
f"best {topic} materials and colors {current_year}"
]
# 2. 并行执行搜索
async def perform_search(q: str):
# 使用 asyncio.to_thread 运行同步的 Tavily SDK
def sync_search():
try:
response = client.search(
query=q,
search_depth="advanced",
max_results=5,
include_answer=False
)
return response.get('results', [])
except Exception as e:
print(f"Search error: {e}")
return []
return await asyncio.to_thread(sync_search)
search_tasks = [perform_search(q) for q in queries]
search_results_list = await asyncio.gather(*search_tasks)
# 3. 结果去重与过滤
seen_urls: Set[str] = set()
final_urls = []
# 常见的非内容页面过滤
skip_extensions = ('.pdf', '.jpg', '.png', '.zip', '.exe')
for results in search_results_list:
for item in results:
url = item.get('url')
if url and url not in seen_urls:
if not url.lower().endswith(skip_extensions):
seen_urls.add(url)
final_urls.append(url)
# 4. 结果截断
selected_urls = final_urls[:max_urls]
# 返回 JSON 字符串,便于 Agent 下一步调用批量爬虫 (Crawl4ai)
return json.dumps(selected_urls, ensure_ascii=False)

View File

@@ -0,0 +1,27 @@
import os
from langchain_core.tools import tool
# 定义本地保存路径
OUTPUT_DIR = "./research_reports"
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
@tool
def save_to_local_disk(filename: str, content: str) -> str:
"""
将内容保存到本地物理磁盘。
filename: 文件名(例如 'sofa_report.md'
content: 调研报告或数据的文本内容
"""
try:
# 移除非法路径字符,确保安全
safe_filename = os.path.basename(filename)
file_path = os.path.join(OUTPUT_DIR, safe_filename)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
return f"✅ 成功!文件已保存至本地物理路径: {os.path.abspath(file_path)}"
except Exception as e:
return f"❌ 保存失败,错误原因: {str(e)}"

View File

@@ -0,0 +1,225 @@
import os
import re
import json
from datetime import datetime
from typing import List, Dict, Optional
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from langchain_core.documents import Document
# RAG
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
# =========================
# 全局模型(单例)
# =========================
_EMBEDDING_MODEL = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
_RERANK_MODEL = CrossEncoder(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
)
class StructuredRetrievalInput(BaseModel):
file_paths: List[str] = Field(..., description="List of local markdown file paths.")
query: str = Field(..., description="Extraction query")
source_url: Optional[str] = Field(None, description="Optional global source URL")
@tool("structured_retrieval", args_schema=StructuredRetrievalInput)
def structured_retrieval(
file_paths: List[str],
query: str,
source_url: Optional[str] = None
) -> Dict:
"""
Batch structured extraction from markdown files.
- Performs vector search + re-ranking
- Saves extracted structured data as JSON file to disk
- Returns ONLY summary (status, count, file path)
"""
# ── 1. 收集所有文件內容 ──────────────────────────────────────
all_docs_pool: List[Document] = []
for path in file_paths:
if not os.path.exists(path) or not path.endswith((".md", ".markdown")):
continue
file_name = os.path.basename(path)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
current_source = source_url or _extract_source_from_md(content) or "unknown"
sections = _split_markdown_by_headers(content)
for sec in sections:
all_docs_pool.append(
Document(
page_content=sec,
metadata={"source_url": current_source, "file_name": file_name}
)
)
if not all_docs_pool:
return {"status": "no_documents_found", "items_count": 0, "json_path": None}
# ── 2. Vector search ────────────────────────────────────────────
vector_store = FAISS.from_documents(all_docs_pool, _EMBEDDING_MODEL)
retrieved = vector_store.similarity_search(query, k=200)
# ── 3. 提取結構化片段 ──────────────────────────────────────────
structured_items = []
for doc in retrieved:
text = doc.page_content.strip()
if len(text) < 30:
continue
images = list(set(re.findall(r"!\[.*?\]\((.*?)\)", text)))
structured_items.append(
{
"text": text,
"images": images,
"source_url": doc.metadata.get("source_url"),
"file_name": doc.metadata.get("file_name")
}
)
# ── 4. Re-rank ──────────────────────────────────────────────────
if structured_items:
unique_items = {item["text"]: item for item in structured_items}.values()
pairs = [[query, item["text"]] for item in unique_items]
scores = _RERANK_MODEL.predict(pairs)
sorted_items = sorted(
zip(scores, unique_items),
key=lambda x: x[0],
reverse=True
)
top_items = [item for _, item in sorted_items[:50]]
else:
top_items = []
# ── 5. 寫入 JSON 文件 ──────────────────────────────────────────
if not top_items:
return {"status": "no_relevant_content", "items_count": 0, "json_path": None}
# 產生有意義的檔名
safe_query = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '_', query)[:40]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
json_filename = f"extracted_{safe_query}_{timestamp}.json"
# 建議的儲存目錄(與 crawl4ai_batch 對齊)
output_dir = os.path.join(os.path.dirname(file_paths[0]), "..", "extracted")
os.makedirs(output_dir, exist_ok=True)
json_path = os.path.join(output_dir, json_filename)
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{
"query": query,
"extracted_at": timestamp,
"item_count": len(top_items),
"items": top_items
},
f,
ensure_ascii=False,
indent=2
)
# ── 6. 只回傳摘要 ──────────────────────────────────────────────
return {
"status": "success",
"items_count": len(top_items),
"json_path": json_path,
"summary": f"已提取 {len(top_items)} 個高相關片段,儲存於 {json_path}"
}
def _extract_source_from_md(content: str) -> Optional[str]:
match = re.search(r"<!--\s*Source:\s*(.*?)\s*-->", content)
return match.group(1).strip() if match else None
# =========================
# Markdown Header Split
# =========================
def _split_markdown_by_headers(
content: str,
max_chars: int = 2000,
overlap: int = 150,
):
header_re = re.compile(
r'^(#{1,6})\s+(.+?)\s*$',
re.MULTILINE
)
matches = list(header_re.finditer(content))
if not matches:
return _chunk_text(content, max_chars, overlap)
sections = []
for i, m in enumerate(matches):
start = m.start()
end = (
matches[i + 1].start()
if i + 1 < len(matches)
else len(content)
)
block = content[start:end].strip()
if block:
sections.append(block)
final_sections = []
for s in sections:
if len(s) > max_chars:
final_sections.extend(
_chunk_text(s, max_chars, overlap)
)
else:
final_sections.append(s)
return final_sections
def _chunk_text(
text: str,
max_chars: int = 2000,
overlap: int = 150
):
text = text.strip()
if len(text) <= max_chars:
return [text]
chunks = []
start = 0
while start < len(text):
end = min(len(text), start + max_chars)
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
if end == len(text):
break
start = max(0, end - overlap)
return chunks

View File

@@ -0,0 +1,38 @@
from typing import Literal
from langchain_core.tools import tool
from pydantic import BaseModel, Field
class TerminateInput(BaseModel):
"""終止對話的輸入參數"""
status: Literal["success", "failure"] = Field(
description="互動結束的狀態:'success' 表示任務完成,'failure' 表示無法繼續",
examples=["success", "failure"]
)
reason: str = Field(
default="",
description="可選:簡單說明為什麼結束(例如 '報告已生成''缺少關鍵資訊'",
examples=["報告已成功生成", "無法取得足夠資料"]
)
@tool(args_schema=TerminateInput)
def terminate(status: str, reason: str = "") -> str:
"""
當任務完成、報告已生成,或無法繼續進行時,呼叫此工具來結束本次互動。
使用時機:
- 已經成功產生最終報告report_generator 已完成)
- 遇到無法解決的錯誤或缺少關鍵資訊
- 用戶需求已完全滿足
請在呼叫前確保所有必要步驟已完成,並在 reason 中簡單說明結束原因。
"""
if status not in ("success", "failure"):
status = "failure" # 防呆
msg = f"互動已終止,狀態:{status.upper()}"
if reason:
msg += f"\n原因:{reason}"
return msg

View File

@@ -0,0 +1,96 @@
import json
import os
from typing import List, Literal, Optional, Dict, Any
from langchain_core.tools import tool
# 定义存储路径
DB_PATH = os.path.join("workspace", "user_persona.json")
def _load_store() -> Dict[str, Any]:
"""从本地文件加载画像数据"""
if os.path.exists(DB_PATH):
try:
with open(DB_PATH, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
return {}
def _save_store(data: Dict[str, Any]):
"""将画像数据保存到本地文件"""
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
with open(DB_PATH, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
@tool
def manage_user_persona(
command: Literal["set", "update", "get", "clear"],
design_type: Optional[str] = None,
style_preference: Optional[str] = None,
budget_range: Optional[str] = None,
color_palette: Optional[List[str]] = None,
target_audience: Optional[str] = None,
extra_requirements: Optional[str] = None
) -> str:
"""
用户画像与设计偏好管理工具。
用于设定、更新、获取或重置用户的设计上下文(如风格、预算、颜色)。
Agent 在开始调研前必须先调用 get 获取画像,若关键信息缺失需引导用户补充。
"""
# 每次调用都重新读取,确保多进程或重启后数据一致
store = _load_store()
if command == "clear":
if os.path.exists(DB_PATH):
os.remove(DB_PATH)
return "✅ 用户个性化模板已从本地文件清空。"
if command == "get":
if not store:
return "⚠️ [缺失信息] 当前尚未配置画像。请询问用户:设计类型(如沙发)、风格偏好(如极简)等。"
# 格式化输出供 Agent 阅读
res = [
"--- 👤 实时用户画像 (本地存储) ---",
f"🎯 类型: {store.get('design_type', '未设定')}",
f"🎨 风格: {store.get('style_preference', '未设定')}",
f"💰 预算: {store.get('budget_range', '未设定')}",
f"🌈 色系: {', '.join(store.get('color_palette', [])) or '未设定'}",
f"👥 受众: {store.get('target_audience', '未设定')}",
f"📝 需求: {store.get('extra_requirements', '未设定')}",
"-----------------------"
]
# 逻辑检查
if not store.get('design_type') or not store.get('style_preference'):
res.append("\n⚠️ 关键信息缺失,建议补充 '设计类型''风格偏好'")
return "\n".join(res)
if command in ["set", "update"]:
if command == "set":
store = {} # 重置内存中的字典
# 提取传入的非空参数
update_data = {
"design_type": design_type,
"style_preference": style_preference,
"budget_range": budget_range,
"color_palette": color_palette,
"target_audience": target_audience,
"extra_requirements": extra_requirements
}
# 更新有效字段
for k, v in update_data.items():
if v is not None:
store[k] = v
# 保存到文件
_save_store(store)
return f"✅ 本地画像已同步。当前配置:\n{json.dumps(store, ensure_ascii=False, indent=2)}"
return "❌ 错误:未知命令。"

3010
uv.lock generated

File diff suppressed because it is too large Load Diff