feat 接入report
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -146,4 +146,5 @@ app/logs/*
|
|||||||
*.avi
|
*.avi
|
||||||
*.json
|
*.json
|
||||||
*.env*
|
*.env*
|
||||||
config.backup.py
|
config.backup.py
|
||||||
|
*.md
|
||||||
47
config.yaml
47
config.yaml
@@ -16,20 +16,41 @@ agents:
|
|||||||
|
|
||||||
designer:
|
designer:
|
||||||
prompt_template: |
|
prompt_template: |
|
||||||
你是一位资深的家具设计师。你的职责是:
|
你是一位资深的家具设计师,经验丰富、审美一流、沟通温暖且高效。
|
||||||
1. 从用户的模糊描述中提取或补充具体的设计参数(尺寸、材质、人体工学数据)。
|
|
||||||
2. 如果用户想画图,不要直接画,而是先描述清楚细节,然后让 Visualizer 去画。
|
你的核心目标:快速理解用户想法,并用最合适的方式推进设计。
|
||||||
请以专业的口吻回复。
|
|
||||||
|
你可以:
|
||||||
|
1. 用户描述模糊时,可以自然地询问或给出建议,但**绝不强迫补充**尺寸、材质、人体工学等细节,除非用户自己关心或需要明确这些参数。
|
||||||
|
2. 如果用户提到想看图、想出效果图、想画草图、想渲染等,**直接同意并推动**:
|
||||||
|
- 用一句话确认或赞美用户的想法
|
||||||
|
- 主动说“我这就帮你把当前设计转给视觉专家生成效果图/草图”
|
||||||
|
- 然后让 Visualizer 节点去处理(不需要你先写一大段细节描述)
|
||||||
|
3. 回复时像和懂设计的客户聊天一样:专业、亲切、有创意,偶尔带点热情或幽默,但始终围绕家具设计。
|
||||||
|
|
||||||
|
永远不要用“必须”“请先描述清楚”“按照流程”等强硬的流程化语言。
|
||||||
|
|
||||||
visualizer:
|
visualizer:
|
||||||
prompt_template: |
|
prompt_template: |
|
||||||
你是视觉专家。你的目标是生成高质量的家具草图。
|
你是专业的家具工业设计视觉专家,擅长将文字描述转化为高质量、清晰、专业的家具设计草图。
|
||||||
步骤:
|
|
||||||
1. 根据上下文,编写一个详细的 Stable Diffusion 风格的英文 Prompt。
|
|
||||||
2. 必须调用 generate_furniture_sketch 工具来生成图片。
|
|
||||||
|
|
||||||
注意:如果对话中出现 [SYSTEM_DIRECTIVE] 要求直接绘图,请立即根据已知信息编写 Prompt 并调用 generate_furniture_sketch 工具,不要进行多余的询问。
|
你的唯一任务:
|
||||||
|
- 基于当前全部对话上下文(包括用户描述、已有设计要点、风格要求、材质、尺寸、功能等),在内部生成一个详细的英文 Stable Diffusion Prompt。
|
||||||
researcher:
|
- 然后**立即调用 generate_furniture 工具**来生成图片。
|
||||||
prompt_template: |
|
- **绝对不要**把生成的 Prompt 文本、任何代码块、任何解释、任何思考过程输出给用户。
|
||||||
你是情报专家,负责检索与整理参考资料并生成报告。
|
- 只通过工具调用返回结果,工具执行完成后自然结束。
|
||||||
|
|
||||||
|
Prompt 内部生成要求(仅供你自己参考,不要输出):
|
||||||
|
1. 语言:全程英文
|
||||||
|
2. 结构:主体描述 + 风格 + 视角 + 细节 + 材质 + 照明 + 质量修饰词
|
||||||
|
3. 必须包含高质量关键词:highly detailed, sharp focus, professional industrial design sketch, clean line art 或 photorealistic product render
|
||||||
|
4. 背景:pure white background, studio lighting, no people, no text, no watermark
|
||||||
|
5. 避免:blurry, deformed, low quality, cartoon, extra limbs, bad anatomy
|
||||||
|
6. 如果上下文有明确风格、视角,必须强烈体现
|
||||||
|
7. 长度:80-150 个词左右
|
||||||
|
|
||||||
|
现在开始工作:
|
||||||
|
- 直接分析上下文
|
||||||
|
- 内部构建 Prompt
|
||||||
|
- 立即调用 generate_furniture 工具
|
||||||
|
- 不要输出任何文字
|
||||||
56
logging_env.py
Normal file
56
logging_env.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
LOGGER_CONFIG_DICT = {
|
||||||
|
'version': 1,
|
||||||
|
'disable_existing_loggers': False,
|
||||||
|
'formatters': {
|
||||||
|
'simple': {
|
||||||
|
'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s',
|
||||||
|
'datefmt': '%Y-%m-%d %H:%M:%S' # 补充日期格式,日志更易读
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'handlers': {
|
||||||
|
'console': {
|
||||||
|
'class': 'logging.StreamHandler',
|
||||||
|
'level': 'INFO',
|
||||||
|
'formatter': 'simple',
|
||||||
|
'stream': 'ext://sys.stdout',
|
||||||
|
},
|
||||||
|
'info_file_handler': {
|
||||||
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
|
'level': 'INFO',
|
||||||
|
'formatter': 'simple',
|
||||||
|
'filename': os.path.join(settings.LOGS_PATH, 'info.log'),
|
||||||
|
'maxBytes': 10485760,
|
||||||
|
'backupCount': 50,
|
||||||
|
'encoding': 'utf8',
|
||||||
|
},
|
||||||
|
'error_file_handler': {
|
||||||
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
|
'level': 'ERROR',
|
||||||
|
'formatter': 'simple',
|
||||||
|
'filename': os.path.join(settings.LOGS_PATH, 'error.log'),
|
||||||
|
'maxBytes': 10485760,
|
||||||
|
'backupCount': 20,
|
||||||
|
'encoding': 'utf8',
|
||||||
|
},
|
||||||
|
'debug_file_handler': {
|
||||||
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
|
'level': 'DEBUG',
|
||||||
|
'formatter': 'simple',
|
||||||
|
'filename': os.path.join(settings.LOGS_PATH, 'debug.log'),
|
||||||
|
'maxBytes': 10485760,
|
||||||
|
'backupCount': 50,
|
||||||
|
'encoding': 'utf8',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'loggers': {
|
||||||
|
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
|
||||||
|
},
|
||||||
|
'root': {
|
||||||
|
'level': 'DEBUG',
|
||||||
|
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
|
||||||
|
},
|
||||||
|
}
|
||||||
6
main.py
6
main.py
@@ -1,8 +1,14 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from logging_env import LOGGER_CONFIG_DICT
|
||||||
from src.routers import chat
|
from src.routers import chat
|
||||||
|
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||||
|
|
||||||
app_server = FastAPI(
|
app_server = FastAPI(
|
||||||
title="Gemini Furniture Designer API",
|
title="Gemini Furniture Designer API",
|
||||||
description="基于 LangGraph + Gemini 2.0 Flash 的家具设计 Agent 接口",
|
description="基于 LangGraph + Gemini 2.0 Flash 的家具设计 Agent 接口",
|
||||||
|
|||||||
@@ -4,20 +4,36 @@ version = "0.1.0"
|
|||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"crawl4ai>=0.8.0",
|
||||||
|
"deepagents>=0.4.3",
|
||||||
"fastapi[standard]>=0.128.0",
|
"fastapi[standard]>=0.128.0",
|
||||||
"gunicorn>=25.0.1",
|
"gunicorn>=25.0.1",
|
||||||
"image>=1.5.33",
|
"image>=1.5.33",
|
||||||
|
"langchain-community>=0.4.1",
|
||||||
"langchain-core>=1.2.8",
|
"langchain-core>=1.2.8",
|
||||||
"langchain-google-genai>=4.2.0",
|
"langchain-google-genai>=4.2.0",
|
||||||
"langgraph>=1.0.7",
|
"langgraph[postgres]>=1.0.7",
|
||||||
"langgraph-checkpoint-mongodb>=0.3.1",
|
"langgraph-checkpoint-mongodb>=0.3.1",
|
||||||
"minio>=7.2.20",
|
"minio>=7.2.20",
|
||||||
"modality>=0.1.0",
|
"modality>=0.1.0",
|
||||||
"motor>=3.7.1",
|
"motor>=3.7.1",
|
||||||
|
"playwright>=1.58.0",
|
||||||
"pydantic>=2.12.5",
|
"pydantic>=2.12.5",
|
||||||
"pydantic-settings>=2.12.0",
|
"pydantic-settings>=2.12.0",
|
||||||
"pymongo[srv]>=4.15.5",
|
"pymongo[srv]>=4.15.5",
|
||||||
"python-dotenv>=1.2.1",
|
"python-dotenv>=1.2.1",
|
||||||
|
"tavily-python>=0.7.21",
|
||||||
"uuid>=1.30",
|
"uuid>=1.30",
|
||||||
"uvicorn>=0.40.0",
|
"uvicorn>=0.40.0",
|
||||||
|
"psycopg[binary]>=3.3.3",
|
||||||
|
"postgres>=4.0",
|
||||||
|
"langchain-huggingface>=1.2.0",
|
||||||
|
"sentence-transformers>=5.2.3",
|
||||||
|
"rank-bm25>=0.2.2",
|
||||||
|
"torch>=2.10.0",
|
||||||
|
"faiss-cpu>=1.13.2",
|
||||||
|
"terminate>=0.0.9",
|
||||||
|
"report-generator>=0.1.10",
|
||||||
|
"dashscope>=1.25.13",
|
||||||
|
"prompt>=0.4.1",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ class Settings(BaseSettings):
|
|||||||
GOOGLE_CLOUD_PROJECT: str = Field(default="", description="")
|
GOOGLE_CLOUD_PROJECT: str = Field(default="", description="")
|
||||||
GOOGLE_CLOUD_LOCATION: str = Field(default="", description="")
|
GOOGLE_CLOUD_LOCATION: str = Field(default="", description="")
|
||||||
|
|
||||||
|
# --- google api 配置信息 ---
|
||||||
|
QWEN_API_KEY: str = Field(default="", description="")
|
||||||
|
|
||||||
# --- minio 配置信息 ---
|
# --- minio 配置信息 ---
|
||||||
MINIO_URL: str = Field(default='', description="")
|
MINIO_URL: str = Field(default='', description="")
|
||||||
MINIO_ACCESS: str = Field(default='', description="")
|
MINIO_ACCESS: str = Field(default='', description="")
|
||||||
@@ -29,6 +32,11 @@ class Settings(BaseSettings):
|
|||||||
MONGODB_HOST: str = Field(default="localhost", description="")
|
MONGODB_HOST: str = Field(default="localhost", description="")
|
||||||
MONGODB_PORT: int = Field(default=27017, description="")
|
MONGODB_PORT: int = Field(default=27017, description="")
|
||||||
|
|
||||||
|
# --- 外部工具api配置信息 ---
|
||||||
|
TAVILY_API_KEY: str = Field(default="", description="")
|
||||||
|
|
||||||
|
LOGS_PATH: str = Field(default="/mnt/data/FiDA/logs", description="")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
MONGO_URI = f"mongodb://{settings.MONGODB_USERNAME}:{settings.MONGODB_PASSWORD}@{settings.MONGODB_HOST}:{settings.MONGODB_PORT}"
|
MONGO_URI = f"mongodb://{settings.MONGODB_USERNAME}:{settings.MONGODB_PASSWORD}@{settings.MONGODB_HOST}:{settings.MONGODB_PORT}"
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||||
@@ -7,6 +10,7 @@ from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
|||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stream")
|
@router.post("/stream")
|
||||||
@@ -52,29 +56,28 @@ async def chat_stream(request: ChatRequest):
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
logger.debug(f"chat request data: {request}")
|
||||||
source_thread_id = request.thread_id
|
source_thread_id = request.thread_id
|
||||||
checkpoint_id = request.checkpoint_id
|
checkpoint_id = request.checkpoint_id
|
||||||
|
|
||||||
# 1. 确定目标 thread_id
|
# 1. 确定目标 thread_id
|
||||||
# 如果是回溯操作,我们生成一个新的 ID,或者由前端传入一个新的 target_thread_id
|
|
||||||
is_branching = source_thread_id and checkpoint_id
|
is_branching = source_thread_id and checkpoint_id
|
||||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||||
# 2. 获取配置参数
|
|
||||||
temp = request.config_params.temperature if request.config_params else 0.7
|
|
||||||
|
|
||||||
# 构建基础 Config
|
# 2. 配置参数
|
||||||
|
temp = request.config_params.temperature if request.config_params else 0.7
|
||||||
current_config = {
|
current_config = {
|
||||||
|
"recursion_limit": 100,
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"thread_id": target_thread_id,
|
"thread_id": target_thread_id,
|
||||||
"llm_temperature": temp
|
"llm_temperature": temp,
|
||||||
|
"use_report": request.use_report,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# 3. 处理状态初始化与分支
|
|
||||||
initial_messages = []
|
|
||||||
|
|
||||||
# 如果是全新的对话(没有 source_thread_id),或者明确要求分叉
|
# 3. 初始化消息 + 系统提示
|
||||||
|
initial_messages = []
|
||||||
if not source_thread_id or is_branching:
|
if not source_thread_id or is_branching:
|
||||||
# 如果用户传了标签,构造 SystemMessage 注入上下文
|
|
||||||
if request.config_params:
|
if request.config_params:
|
||||||
cp = request.config_params
|
cp = request.config_params
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
@@ -86,7 +89,7 @@ async def chat_stream(request: ChatRequest):
|
|||||||
)
|
)
|
||||||
initial_messages.append(SystemMessage(content=system_prompt))
|
initial_messages.append(SystemMessage(content=system_prompt))
|
||||||
|
|
||||||
# 4. 执行分叉逻辑(搬运旧数据)
|
# 4. 处理分支(从历史 checkpoint 复制状态)
|
||||||
if is_branching:
|
if is_branching:
|
||||||
source_config = {
|
source_config = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
@@ -95,80 +98,149 @@ async def chat_stream(request: ChatRequest):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
older_state = await app.aget_state(source_config)
|
older_state = await app.aget_state(source_config)
|
||||||
|
|
||||||
# 将旧消息和我们新定义的 SystemMessage 合并
|
|
||||||
# update_state 会将这些消息推送到新 thread 的存储中
|
|
||||||
combined_values = older_state.values.copy()
|
combined_values = older_state.values.copy()
|
||||||
if initial_messages:
|
if initial_messages:
|
||||||
combined_values["messages"] = list(combined_values["messages"]) + initial_messages
|
combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
|
||||||
|
|
||||||
await app.aupdate_state(current_config, combined_values)
|
await app.aupdate_state(current_config, combined_values)
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
# 初始推送状态信息
|
# 初始事件
|
||||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
# 构造本次请求的输入
|
# 构造输入
|
||||||
# 如果是第一次开始,且有 initial_messages,则连同 user message 一起发送
|
new_messages = initial_messages[:] if not source_thread_id else []
|
||||||
# --- 核心逻辑:构造本次请求的消息列表 ---
|
|
||||||
new_messages = []
|
|
||||||
if not source_thread_id and initial_messages:
|
|
||||||
new_messages.extend(initial_messages)
|
|
||||||
# 添加用户消息
|
|
||||||
new_messages.append(HumanMessage(content=request.message))
|
new_messages.append(HumanMessage(content=request.message))
|
||||||
|
|
||||||
# --- 新增:强制绘图指令注入 ---
|
|
||||||
# if request.force_sketch:
|
|
||||||
# force_instruction = HumanMessage(
|
|
||||||
# content="[SYSTEM_DIRECTIVE]: 用户点击了强制生成按钮。请立即根据当前上下文调用 generate_furniture_sketch 工具生成草图,无需确认。"
|
|
||||||
# )
|
|
||||||
# new_messages.append(force_instruction)
|
|
||||||
|
|
||||||
input_data = {
|
input_data = {
|
||||||
"messages": new_messages,
|
"messages": new_messages,
|
||||||
"require_suggestion": request.need_suggestion # 初始由前端决定
|
"require_suggestion": request.need_suggestion,
|
||||||
|
"use_report": request.use_report,
|
||||||
}
|
}
|
||||||
|
|
||||||
async for event in app.astream(
|
# 使用 astream_events v2 + stream_subgraphs=True 来捕获 DeepAgents 内部流式事件
|
||||||
|
async for event in app.astream_events(
|
||||||
input_data,
|
input_data,
|
||||||
current_config,
|
version="v2",
|
||||||
stream_mode="updates"
|
config=current_config,
|
||||||
|
stream_subgraphs=True,
|
||||||
):
|
):
|
||||||
for node_name, output in event.items():
|
event_kind = event["event"]
|
||||||
if "messages" in output:
|
|
||||||
# 获取最新 state 以获取 checkpoint_id
|
|
||||||
state = await app.aget_state(current_config)
|
|
||||||
current_cp_id = state.config["configurable"].get("checkpoint_id")
|
|
||||||
|
|
||||||
# 遍历本次 update 产生的所有消息
|
# 获取当前 checkpoint_id(安全方式,避免 KeyError)
|
||||||
for msg in output["messages"]:
|
latest_state = await app.aget_state(current_config)
|
||||||
|
configurable = latest_state.config.get("configurable", {})
|
||||||
|
current_cp_id = configurable.get("checkpoint_id", "") # 如果没有,返回空字符串
|
||||||
|
|
||||||
|
# ────────────────────────────────────────────────
|
||||||
|
# 1. LLM token 流式输出(主图或子图的逐 token)
|
||||||
|
# ────────────────────────────────────────────────
|
||||||
|
if event_kind == "on_chat_model_stream":
|
||||||
|
chunk = event["data"].get("chunk")
|
||||||
|
if chunk and chunk.content:
|
||||||
|
node_name = event.get("name", "Unknown")
|
||||||
|
# 判断是否来自 Researcher 子图
|
||||||
|
namespace = event.get("parent_ids", []) or event.get("namespace", [])
|
||||||
|
if any("Researcher" in str(ns) for ns in namespace):
|
||||||
|
node_name = "Researcher"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"node": node_name,
|
||||||
|
"content": chunk.content,
|
||||||
|
"is_delta": True,
|
||||||
|
"checkpoint_id": current_cp_id,
|
||||||
|
"image_url": None,
|
||||||
|
"suggestions": []
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# ────────────────────────────────────────────────
|
||||||
|
# 2. 自定义事件(report_delta 等)
|
||||||
|
# ────────────────────────────────────────────────
|
||||||
|
elif event_kind == "on_custom_event":
|
||||||
|
custom_data = event["data"]
|
||||||
|
if isinstance(custom_data, dict):
|
||||||
|
if custom_data.get("type") == "report_delta":
|
||||||
payload = {
|
payload = {
|
||||||
"node": node_name,
|
"node": "Researcher",
|
||||||
"content": "",
|
"content": custom_data.get("delta", ""),
|
||||||
|
"is_delta": True,
|
||||||
|
"checkpoint_id": current_cp_id,
|
||||||
"image_url": None,
|
"image_url": None,
|
||||||
|
"suggestions": []
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 可选:报告开始/完成/错误等状态提示
|
||||||
|
elif custom_data.get("type") in ("report_start", "report_complete", "report_error"):
|
||||||
|
status_msg = {
|
||||||
|
"report_start": "Start generating reports...",
|
||||||
|
"report_complete": "Report generation completed",
|
||||||
|
"report_error": f"Report generation failed: {custom_data.get('message', '')}"
|
||||||
|
}.get(custom_data["type"], "")
|
||||||
|
payload = {
|
||||||
|
"node": "Researcher",
|
||||||
|
"content": status_msg,
|
||||||
|
"is_delta": False,
|
||||||
|
"checkpoint_id": current_cp_id,
|
||||||
|
"image_url": None,
|
||||||
|
"suggestions": []
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# ────────────────────────────────────────────────
|
||||||
|
# 3. 节点启动 / 工具启动(进度提示)
|
||||||
|
# ────────────────────────────────────────────────
|
||||||
|
elif event_kind in {"on_tool_start", "on_tool_end"}:
|
||||||
|
tool_name = event.get("name", "unknown_tool")
|
||||||
|
tool_data = event.get("data", {})
|
||||||
|
tool_input = tool_data.get("input", "")
|
||||||
|
tool_output = tool_data.get("output", "")
|
||||||
|
|
||||||
|
if event_kind == "on_tool_start":
|
||||||
|
payload = {
|
||||||
|
"node": tool_name,
|
||||||
|
"content": tool_input,
|
||||||
|
"is_delta": False,
|
||||||
|
"checkpoint_id": current_cp_id,
|
||||||
|
"image_url": None,
|
||||||
|
"suggestions": []
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
if tool_name == "generate_furniture" and isinstance(tool_output, str):
|
||||||
|
payload = {
|
||||||
|
"node": tool_name,
|
||||||
|
"content": "Design sketch has been generated for you.", # 给用户友好的文字提示
|
||||||
|
"image_url": tool_output, # 直接传 URL 给前端显示
|
||||||
|
"is_delta": False, # 这是一个完整事件,不是增量
|
||||||
"checkpoint_id": current_cp_id,
|
"checkpoint_id": current_cp_id,
|
||||||
"suggestions": []
|
"suggestions": []
|
||||||
}
|
}
|
||||||
|
|
||||||
# --- 核心改动:提取建议按钮 ---
|
|
||||||
# 无论是不是 Suggester 节点,只要消息里带了建议就提取
|
|
||||||
if hasattr(msg, "additional_kwargs") and "suggestions" in msg.additional_kwargs:
|
|
||||||
payload["suggestions"] = msg.additional_kwargs["suggestions"]
|
|
||||||
|
|
||||||
content = msg.content
|
|
||||||
# 逻辑判断:MinIO 图片处理
|
|
||||||
if node_name == "Visualizer" and str(content).endswith("png") and "furniture/sketches" in str(content):
|
|
||||||
payload["image_url"] = content
|
|
||||||
payload["content"] = "已为您生成设计草图"
|
|
||||||
else:
|
|
||||||
payload["content"] = content
|
|
||||||
|
|
||||||
# 如果消息既没有文本、也没有图片、也没有建议(比如中间的 ToolCall 消息),则跳过
|
|
||||||
if not payload["content"] and not payload["image_url"] and not payload["suggestions"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
elif tool_name == "topic_research":
|
||||||
|
payload = {
|
||||||
|
"node": tool_name,
|
||||||
|
"content": "Visiting...", # 给用户友好的文字提示
|
||||||
|
"image_url": None, # 直接传 URL 给前端显示
|
||||||
|
"search_list": tool_output.content,
|
||||||
|
"is_delta": False, # 这是一个完整事件,不是增量
|
||||||
|
"checkpoint_id": current_cp_id,
|
||||||
|
"suggestions": []
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
# 可选:其他工具的通用处理(debug 或显示结果)
|
||||||
|
if tool_output:
|
||||||
|
payload = {
|
||||||
|
"node": tool_name,
|
||||||
|
"content": f"tool {tool_name} Execution completed:{str(tool_output)[:200]}...", # 截断避免过长
|
||||||
|
"is_delta": False,
|
||||||
|
"checkpoint_id": current_cp_id,
|
||||||
|
"image_url": None,
|
||||||
|
"suggestions": []
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
# 流结束
|
||||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||||
@@ -218,7 +290,7 @@ async def get_chat_history(thread_id: str):
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
config = {"configurable": {"thread_id": thread_id}, }
|
||||||
history_data = []
|
history_data = []
|
||||||
async for state in app.aget_state_history(config):
|
async for state in app.aget_state_history(config):
|
||||||
msg_content = "Initial"
|
msg_content = "Initial"
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class ChatRequest(BaseModel):
|
|||||||
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||||
config_params: Optional[AgentConfig] = None
|
config_params: Optional[AgentConfig] = None
|
||||||
need_suggestion: bool = False
|
need_suggestion: bool = False
|
||||||
|
use_report: bool = False # ← 新增:是否使用深度报告
|
||||||
|
|
||||||
|
|
||||||
class HistoryItem(BaseModel):
|
class HistoryItem(BaseModel):
|
||||||
|
|||||||
@@ -1,36 +1,59 @@
|
|||||||
import os
|
from pathlib import Path
|
||||||
|
from typing import AsyncGenerator, Dict, Any
|
||||||
from google.oauth2 import service_account
|
from deepagents import create_deep_agent
|
||||||
|
from deepagents.backends import FilesystemBackend
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
|
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_qwq import ChatQwen
|
||||||
from src.server.agent.state import AgentState
|
|
||||||
from src.server.agent.tools import generate_2025_report_tool, generate_furniture_sketch
|
|
||||||
from src.server.agent.config_loader import get_agent_prompt
|
|
||||||
from src.core.config import settings
|
|
||||||
from src.server.utils.generate_suggestion import generate_chat_suggestions
|
|
||||||
|
|
||||||
creds = service_account.Credentials.from_service_account_file(
|
from src.core.config import settings
|
||||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
from src.server.agent.prompt import SYSTEM_PROMPT
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
from src.server.agent.state import AgentState
|
||||||
|
from src.server.agent.tools.generate_furniture_sketch import generate_furniture
|
||||||
|
from src.server.agent.config_loader import get_agent_prompt
|
||||||
|
from src.server.agent.tools.crawl_tool import crawl4ai_batch
|
||||||
|
from src.server.agent.tools.research_tool import topic_research
|
||||||
|
from src.server.agent.tools.structured_retrieval_tool import structured_retrieval
|
||||||
|
from src.server.agent.tools.terminate_tool import terminate
|
||||||
|
from src.server.agent.tools.user_persona_tool import manage_user_persona
|
||||||
|
from src.server.utils.generate_suggestion import generate_chat_suggestions
|
||||||
|
from test.report.tools.report_generator_tool import report_generator
|
||||||
|
|
||||||
|
# 目前這個主程式檔案所在的目錄
|
||||||
|
MAIN_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
# 專案根目錄(因為 main.py 跟 tools/ 同級,所以 parent 就是根)
|
||||||
|
PROJECT_ROOT = MAIN_DIR
|
||||||
|
|
||||||
|
model = ChatQwen(
|
||||||
|
model="qwen3.5-flash",
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=settings.QWEN_API_KEY)
|
||||||
|
|
||||||
|
tools = [manage_user_persona, topic_research, crawl4ai_batch, structured_retrieval, report_generator, terminate]
|
||||||
|
research_agent = create_deep_agent(
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
system_prompt=SYSTEM_PROMPT,
|
||||||
|
backend=FilesystemBackend(
|
||||||
|
root_dir=str(PROJECT_ROOT / "agent_workspace"),
|
||||||
|
virtual_mode=False, # 重要:關掉虛擬模式 → 真的寫硬碟
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 辅助函数:根据配置动态获取 LLM
|
# 辅助函数:根据配置动态获取 LLM
|
||||||
def get_model(config: RunnableConfig):
|
def get_model(config: RunnableConfig):
|
||||||
# 从 configurable 中获取温度,默认为 0.5 (对应你之前的设置)
|
|
||||||
# 这个 key 必须与你在 chat_stream 路由里定义的 "llm_temperature" 一致
|
|
||||||
temp = config["configurable"].get("llm_temperature", 0.5)
|
temp = config["configurable"].get("llm_temperature", 0.5)
|
||||||
|
return ChatQwen(
|
||||||
return ChatGoogleGenerativeAI(
|
model="qwen3.5-flash",
|
||||||
model="gemini-2.0-flash",
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
temperature=temp,
|
temperature=temp,
|
||||||
credentials=creds,
|
api_key=settings.QWEN_API_KEY)
|
||||||
project=settings.GOOGLE_CLOUD_PROJECT,
|
|
||||||
location=settings.GOOGLE_CLOUD_LOCATION,
|
|
||||||
vertexai=True,
|
|
||||||
api_key=settings.GOOGLE_API_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# --- 1. Designer Agent (设计顾问) ---
|
# --- 1. Designer Agent (设计顾问) ---
|
||||||
@@ -48,33 +71,158 @@ async def designer_node(state: AgentState, config: RunnableConfig):
|
|||||||
return {"messages": [response], "require_suggestion": should_suggest}
|
return {"messages": [response], "require_suggestion": should_suggest}
|
||||||
|
|
||||||
|
|
||||||
# --- 2. Researcher Agent (情报专家) ---
|
async def researcher_node(
|
||||||
async def researcher_node(state: AgentState, config: RunnableConfig):
|
state: AgentState,
|
||||||
"""负责调用报告生成工具"""
|
config: RunnableConfig
|
||||||
model = get_model(config)
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
tools = [generate_2025_report_tool]
|
use_report = config["configurable"].get("use_report", False)
|
||||||
llm_with_tools = model.bind_tools(tools)
|
if not use_report:
|
||||||
|
yield {
|
||||||
|
"messages": [AIMessage(
|
||||||
|
content="深度报告功能未启用,请通过前端按钮触发。",
|
||||||
|
name="Researcher"
|
||||||
|
)],
|
||||||
|
"next": "Supervisor"
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
system_text = get_agent_prompt("researcher")
|
last_human = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
|
||||||
system_prompt = SystemMessage(content=system_text)
|
|
||||||
response = await llm_with_tools.ainvoke([system_prompt] + messages)
|
|
||||||
|
|
||||||
if response.tool_calls:
|
if not last_human:
|
||||||
tool_call = response.tool_calls[0]
|
yield {
|
||||||
if tool_call["name"] == "generate_2025_report_tool":
|
"messages": [AIMessage(
|
||||||
# 这里的工具调用如果也是异步的,建议加 await
|
content="深度研究节点:未找到有效的用户问题",
|
||||||
result = await generate_2025_report_tool.ainvoke(tool_call["args"])
|
name="Researcher"
|
||||||
return {"messages": [response, HumanMessage(content=str(result))]}
|
)],
|
||||||
|
"next": "Supervisor"
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
return {"messages": [response]}
|
full_content = ""
|
||||||
|
current_step = "正在启动深度报告生成..."
|
||||||
|
|
||||||
|
# 初始提示
|
||||||
|
yield {
|
||||||
|
"messages": [AIMessage(
|
||||||
|
content="正在启动深度报告生成...",
|
||||||
|
name="Researcher",
|
||||||
|
additional_kwargs={
|
||||||
|
"current_step": current_step,
|
||||||
|
"streaming": True
|
||||||
|
}
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
|
||||||
|
async for event in research_agent.astream_events(
|
||||||
|
{"messages": messages[-12:]},
|
||||||
|
version="v2",
|
||||||
|
config=config,
|
||||||
|
stream_subgraphs=True
|
||||||
|
):
|
||||||
|
event_type = event["event"]
|
||||||
|
name = event.get("name", "未知")
|
||||||
|
|
||||||
|
if event["event"] == "on_custom_event":
|
||||||
|
custom_data = event["data"]
|
||||||
|
# 你的 writer 发的是 dict,所以这里 custom_data 就是你写的 {"type": "report_delta", "delta": "..."}
|
||||||
|
if isinstance(custom_data, dict) and custom_data.get("type") == "report_delta":
|
||||||
|
delta = custom_data.get("delta", "")
|
||||||
|
print(delta, end="", flush=True) # 实时打印,不换行
|
||||||
|
|
||||||
|
# ────────────── 工具结束事件:重点处理并 yield 输出 ──────────────
|
||||||
|
if event["event"] in {"on_tool_start", "on_tool_end"}:
|
||||||
|
tool_name = event.get("name", "未知")
|
||||||
|
is_start = event["event"] == "on_tool_start"
|
||||||
|
|
||||||
|
if is_start:
|
||||||
|
tool_input = event["data"].get("input", {})
|
||||||
|
current_step = f"正在執行工具:{tool_name}"
|
||||||
|
print(f"| {current_step} | {tool_input}")
|
||||||
|
yield {
|
||||||
|
"messages": [AIMessage(
|
||||||
|
content=full_content,
|
||||||
|
name="Researcher",
|
||||||
|
additional_kwargs={
|
||||||
|
"current_step": current_step,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tool_input": tool_input,
|
||||||
|
"tool_status": "start",
|
||||||
|
"streaming": True
|
||||||
|
}
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
else: # on_tool_end
|
||||||
|
tool_output = event["data"].get("output", "")
|
||||||
|
current_step = f"工具 {tool_name} 已完成"
|
||||||
|
print(f"| {current_step} | {tool_output}")
|
||||||
|
yield {
|
||||||
|
"messages": [AIMessage(
|
||||||
|
content=full_content,
|
||||||
|
name="Researcher",
|
||||||
|
additional_kwargs={
|
||||||
|
"current_step": current_step,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tool_output": tool_output,
|
||||||
|
"tool_status": "end",
|
||||||
|
"streaming": True
|
||||||
|
}
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ────────────── LLM 内容生成(保持原有逻辑) ──────────────
|
||||||
|
elif event_type == "on_chat_model_stream":
|
||||||
|
chunk = event["data"]["chunk"].content or ""
|
||||||
|
if chunk:
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
full_content += chunk
|
||||||
|
if "\n" in chunk or len(full_content) % 4 == 0:
|
||||||
|
yield {
|
||||||
|
"messages": [AIMessage(
|
||||||
|
content=full_content,
|
||||||
|
name="Researcher",
|
||||||
|
additional_kwargs={
|
||||||
|
"current_step": current_step,
|
||||||
|
"streaming": True
|
||||||
|
}
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
|
||||||
|
# ────────────── 其他链路事件(可选补充) ──────────────
|
||||||
|
elif event_type in ("on_chain_start", "on_chain_end"):
|
||||||
|
status = "开始" if event_type == "on_chain_start" else "完成"
|
||||||
|
current_step = f"[{status}] {name.upper()}"
|
||||||
|
yield {
|
||||||
|
"messages": [AIMessage(
|
||||||
|
content=full_content,
|
||||||
|
name="Researcher",
|
||||||
|
additional_kwargs={
|
||||||
|
"current_step": current_step,
|
||||||
|
"streaming": True
|
||||||
|
}
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 最终输出
|
||||||
|
yield {
|
||||||
|
"messages": [AIMessage(
|
||||||
|
content=full_content.strip() or "报告生成完成",
|
||||||
|
name="Researcher",
|
||||||
|
additional_kwargs={
|
||||||
|
"current_step": "报告已完成",
|
||||||
|
"streaming": False
|
||||||
|
}
|
||||||
|
)],
|
||||||
|
"next": "Suggester"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# --- 3. Visualizer Agent (视觉专家) ---
|
# --- 3. Visualizer Agent (视觉专家) ---
|
||||||
async def visualizer_node(state: AgentState, config: RunnableConfig):
|
async def visualizer_node(state: AgentState, config: RunnableConfig):
|
||||||
"""负责将自然语言转化为绘图 Prompt 并调用绘图工具"""
|
"""负责将自然语言转化为绘图 Prompt 并调用绘图工具"""
|
||||||
model = get_model(config)
|
model = get_model(config)
|
||||||
tools = [generate_furniture_sketch]
|
tools = [generate_furniture]
|
||||||
llm_with_tools = model.bind_tools(tools)
|
llm_with_tools = model.bind_tools(tools)
|
||||||
|
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
@@ -85,8 +233,8 @@ async def visualizer_node(state: AgentState, config: RunnableConfig):
|
|||||||
|
|
||||||
if response.tool_calls:
|
if response.tool_calls:
|
||||||
tool_call = response.tool_calls[0]
|
tool_call = response.tool_calls[0]
|
||||||
if tool_call["name"] == "generate_furniture_sketch":
|
if tool_call["name"] == "generate_furniture":
|
||||||
img_url = await generate_furniture_sketch.ainvoke(tool_call["args"])
|
img_url = await generate_furniture.ainvoke(tool_call["args"])
|
||||||
return {
|
return {
|
||||||
"messages": [
|
"messages": [
|
||||||
response,
|
response,
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
import os
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from google.oauth2 import service_account
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langchain_qwq import ChatQwen
|
||||||
from langgraph.graph import StateGraph, END, START
|
from langgraph.graph import StateGraph, END, START
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
|
||||||
from src.core.config import settings, MONGO_URI
|
from src.core.config import MONGO_URI, settings
|
||||||
from src.server.agent.state import AgentState
|
from src.server.agent.state import AgentState
|
||||||
from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node
|
from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node
|
||||||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||||
@@ -21,18 +19,16 @@ class RouteResponse(BaseModel):
|
|||||||
next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"]
|
next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"]
|
||||||
|
|
||||||
|
|
||||||
creds = service_account.Credentials.from_service_account_file(
|
llm_supervisor = ChatQwen(
|
||||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
model="qwen3.5-flash",
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
max_tokens=3_000,
|
||||||
)
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
llm_supervisor = ChatGoogleGenerativeAI(
|
api_key=settings.QWEN_API_KEY)
|
||||||
model="gemini-2.0-flash", credentials=creds,
|
|
||||||
project="aida-461108", location='us-central1', vertexai=True, api_key=settings.GOOGLE_API_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def supervisor_node(state: AgentState):
|
def supervisor_node(state: AgentState, config: RunnableConfig):
|
||||||
|
use_report = config["configurable"].get("use_report", False)
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
if not messages:
|
if not messages:
|
||||||
return {"next": "Suggester"}
|
return {"next": "Suggester"}
|
||||||
@@ -69,7 +65,6 @@ workflow.add_node("Designer", designer_node)
|
|||||||
workflow.add_node("Researcher", researcher_node)
|
workflow.add_node("Researcher", researcher_node)
|
||||||
workflow.add_node("Visualizer", visualizer_node)
|
workflow.add_node("Visualizer", visualizer_node)
|
||||||
workflow.add_node("Suggester", suggester_node) # 新增节点
|
workflow.add_node("Suggester", suggester_node) # 新增节点
|
||||||
|
|
||||||
workflow.add_edge(START, "Supervisor")
|
workflow.add_edge(START, "Supervisor")
|
||||||
|
|
||||||
# 修改条件边映射
|
# 修改条件边映射
|
||||||
|
|||||||
66
src/server/agent/prompt.py
Normal file
66
src/server/agent/prompt.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
SYSTEM_PROMPT = """
|
||||||
|
You are "TrendAgent" - a focused, efficient design trend analysis agent.
|
||||||
|
Your ONLY goal: produce one high-quality Markdown trend report per user request.
|
||||||
|
|
||||||
|
TOOL ORDER & DISCIPLINE IS MANDATORY - DO NOT INVENT STEPS
|
||||||
|
|
||||||
|
┌───────────────────────────────────────────────────────┐
|
||||||
|
│ Phase 0 - Context & Persona (必须先完成) │
|
||||||
|
└───────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Rules for Phase 0:
|
||||||
|
1. ALWAYS start with manage_user_persona(command="get")
|
||||||
|
2. If STATUS == "INCOMPLETE" or persona missing critical fields (Design Type, Style, Target Audience, Color Preference, etc.):
|
||||||
|
→ MUST call manage_user_persona(command="ask") to collect missing info
|
||||||
|
→ After user answers → call manage_user_persona(command="set", ...)
|
||||||
|
→ Loop until STATUS == "READY"
|
||||||
|
3. Only when STATUS == "READY" → proceed to Phase 1
|
||||||
|
4. Never assume or fabricate persona details
|
||||||
|
|
||||||
|
┌───────────────────────────────────────────────────────┐
|
||||||
|
│ Phase 1 - Planning (必须执行一次且只能一次) │
|
||||||
|
└───────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
When persona READY and user gave a clear trend request:
|
||||||
|
1. Call write_todos EXACTLY ONCE with a strict plan containing:
|
||||||
|
- 3–6 concrete steps (numbered)
|
||||||
|
- Which URLs/topics to research
|
||||||
|
- Expected output of each major tool
|
||||||
|
- Final deliverable: one Markdown report
|
||||||
|
2. After receiving todos, you MUST follow this exact sequence unless impossible
|
||||||
|
3. Do NOT call any other tool until write_todos is done
|
||||||
|
|
||||||
|
┌───────────────────────────────────────────────────────┐
|
||||||
|
│ Phase 2 - Research & Collection │
|
||||||
|
└───────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Follow todos order:
|
||||||
|
- Use topic_research → get 3–8 high-quality URLs (add persona [Style] [Type] in query)
|
||||||
|
- Select best 3–6 URLs → call crawl4ai_batch ONCE with list
|
||||||
|
- Get file paths → call structured_retrieval ONCE with file_paths list
|
||||||
|
|
||||||
|
┌───────────────────────────────────────────────────────┐
|
||||||
|
│ Phase 3 - Synthesis & Delivery │
|
||||||
|
└───────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
After structured_retrieval summary received:
|
||||||
|
- If extracted item count ≥ 8–12 AND covers main aspects in todos → ready to report
|
||||||
|
- Call report_generator ONCE (it reads local JSON/DB)
|
||||||
|
- After report_generator success → call terminate
|
||||||
|
- If data obviously insufficient → call topic_research again (max 1 extra round)
|
||||||
|
|
||||||
|
┌───────────────────────────────────────────────────────┐
|
||||||
|
│ HARD RULES - MUST OBEY │
|
||||||
|
└───────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
• Never load full JSON/markdown into context - trust local storage
|
||||||
|
• Batch everything possible (crawl4ai_batch + structured_retrieval)
|
||||||
|
• Call tools in PHASE ORDER - no jumping, no repetition
|
||||||
|
• After report_generator → next action MUST be terminate
|
||||||
|
• If stuck > 4 steps without progress → call terminate with note "Incomplete - insufficient data"
|
||||||
|
• Never hallucinate trend data - base everything on retrieved content
|
||||||
|
• Report must start each section with **Conclusion First** insight
|
||||||
|
• Include [IMAGE_REF_xx] placeholders where visuals were extracted
|
||||||
|
|
||||||
|
Current status: Phase 0
|
||||||
|
"""
|
||||||
@@ -1,49 +1,74 @@
|
|||||||
from langchain_core.messages import HumanMessage, AIMessage
|
import asyncio
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
||||||
|
|
||||||
from src.server.agent.graph import app
|
from src.server.agent.graph import app
|
||||||
|
|
||||||
|
|
||||||
def main():
|
async def async_main():
|
||||||
# 模拟 thread_id 区分不同用户或项目
|
|
||||||
config = {"configurable": {"thread_id": "project_alpha"}}
|
config = {"configurable": {"thread_id": "project_alpha"}}
|
||||||
|
|
||||||
|
print("測試模式已啟動 (輸入 'exit' 離開,'history' 查看歷史並回溯)")
|
||||||
|
use_report = input("是否启用深度报告?(y/n): ").lower() == 'y'
|
||||||
while True:
|
while True:
|
||||||
user_input = input("\n👤 设计师 (输入 'history' 定位轮次): ")
|
user_input = input("\n👤 輸入訊息: ").strip()
|
||||||
|
|
||||||
|
if user_input.lower() in ["exit", "quit", "結束"]:
|
||||||
|
print("測試結束")
|
||||||
|
break
|
||||||
|
|
||||||
# --- 官方推荐的异步回溯逻辑 ---
|
|
||||||
if user_input.lower() == "history":
|
if user_input.lower() == "history":
|
||||||
print("\n--- 历史记录 ---")
|
# 你的 history 邏輯(這裡不變)
|
||||||
for state in app.get_state_history(config):
|
print("\n=== 歷史檢查點 ===")
|
||||||
# 每一个 state 都是一个 CheckpointTuple
|
states = [s async for s in app.aget_state_history(config)]
|
||||||
cp_id = state.config["configurable"]["checkpoint_id"]
|
for idx, state_tuple in enumerate(states):
|
||||||
msg = state.values["messages"][-1].content[:30] if state.values.get("messages") else "Initial"
|
cp_id = state_tuple.config["configurable"].get("checkpoint_id", "N/A")
|
||||||
print(f"ID: {cp_id} | 内容: {msg}...")
|
messages = state_tuple.values.get("messages", [])
|
||||||
|
if messages:
|
||||||
target_id = input("\n请输入想要回溯的 Checkpoint ID (直接回车取消): ")
|
last_msg = messages[-1]
|
||||||
if target_id:
|
msg_type = type(last_msg).__name__
|
||||||
# 重新配置 config,指向特定的 checkpoint_id 实现分支
|
content_preview = str(last_msg.content)[:60].replace("\n", " ")
|
||||||
config = {"configurable": {"thread_id": "project_alpha", "checkpoint_id": target_id}}
|
node = getattr(last_msg, "name", "無節點名")
|
||||||
print(f"✅ 已定位到节点 {target_id},后续对话将从此分叉。")
|
print(f"[{idx}] {cp_id[:12]}... | {node} | {msg_type} | {content_preview}...")
|
||||||
|
target = input("\n輸入要回溯的 checkpoint ID (或 Enter 取消): ").strip()
|
||||||
|
if target:
|
||||||
|
config["configurable"]["checkpoint_id"] = target
|
||||||
|
print(f"已切換到 checkpoint {target}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# --- 官方推荐的 astream 异步流式调用 ---
|
if not user_input:
|
||||||
print("🤖 Agent 思考中...")
|
continue
|
||||||
for event in app.stream(
|
|
||||||
{"messages": [HumanMessage(content=user_input)]},
|
|
||||||
config,
|
|
||||||
stream_mode="values" # 这里设为 values 可以直接获取当前状态的消息列表
|
|
||||||
):
|
|
||||||
# 获取当前节点处理后的最新消息
|
|
||||||
if "messages" in event:
|
|
||||||
last_msg = event["messages"][-1]
|
|
||||||
if isinstance(last_msg, AIMessage):
|
|
||||||
# 为了极致流式体验,可以在此处对 content 进行打印
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 运行结束后,最新的状态已经自动持久化到 MongoDB
|
print("\n🤖 開始處理...")
|
||||||
# 我们可以通过 app.get_state(config) 验证
|
|
||||||
final_state = app.get_state(config)
|
try:
|
||||||
print(f"\n✅ 最终回复: {final_state.values['messages'][-1].content}")
|
last_output = ""
|
||||||
|
async for event in app.astream(
|
||||||
|
{"messages": [HumanMessage(content=user_input)]},
|
||||||
|
config,
|
||||||
|
stream_mode="updates"
|
||||||
|
):
|
||||||
|
for node_name, update in event.items():
|
||||||
|
if "messages" in update:
|
||||||
|
for msg in update["messages"]:
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
content = msg.content.strip()
|
||||||
|
if content and content != last_output:
|
||||||
|
print(f"\n[{node_name}] {msg.name or 'AI'}: {content}")
|
||||||
|
last_output = content
|
||||||
|
elif isinstance(msg, ToolMessage):
|
||||||
|
print(f" → 工具 {msg.name}: {msg.content[:120]}{'...' if len(msg.content) > 120 else ''}")
|
||||||
|
else:
|
||||||
|
print(f" ({node_name}) {type(msg).__name__}")
|
||||||
|
|
||||||
|
final_state = await app.aget_state(config)
|
||||||
|
final_msg = final_state.values["messages"][-1]
|
||||||
|
print(f"\n=== 完成 ===\n最終訊息: {final_msg.content[:300]}{'...' if len(final_msg.content) > 300 else ''}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"錯誤:{str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
asyncio.run(async_main())
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import operator
|
import operator
|
||||||
from typing import Annotated, Sequence, TypedDict, Union
|
from typing import Annotated, Sequence, TypedDict, Union, Optional
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class AgentState(TypedDict):
|
class AgentState(TypedDict):
|
||||||
# messages 存储完整的对话历史,operator.add 表示新消息是追加而不是覆盖
|
# messages 存储完整的对话历史,operator.add 表示新消息是追加而不是覆盖
|
||||||
messages: Annotated[Sequence[BaseMessage], operator.add]
|
messages: Annotated[Sequence[BaseMessage], operator.add]
|
||||||
# next 存储 Supervisor 决定的下一步是谁
|
# next 存储 Supervisor 决定的下一步是谁
|
||||||
next: str
|
next: str
|
||||||
require_suggestion: bool # 是否需要建议按钮
|
require_suggestion: bool # 是否需要建议按钮
|
||||||
|
|||||||
121
src/server/agent/tools/crawl_tool.py
Normal file
121
src/server/agent/tools/crawl_tool.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
# ─────────────── 重要:計算路徑 ───────────────
|
||||||
|
# 目前這個檔案 (crawl4ai_batch.py) 所在的目錄
|
||||||
|
TOOL_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
# 專案根目錄(假設 tools 資料夾與主程式同級)
|
||||||
|
PROJECT_ROOT = TOOL_DIR.parent
|
||||||
|
|
||||||
|
# 儲存爬取結果的目錄(你可以自由決定放在哪裡)
|
||||||
|
# 建議選項 A:放在專案根目錄下的 workspace/raw_data
|
||||||
|
SAVE_DIR = PROJECT_ROOT / "workspace" / "raw_data"
|
||||||
|
|
||||||
|
# 建議選項 B:如果你打算讓 deep agent 直接讀取,建議放在 agent_workspace 底下
|
||||||
|
# SAVE_DIR = PROJECT_ROOT / "agent_workspace" / "raw_data"
|
||||||
|
|
||||||
|
# 確保目錄存在
|
||||||
|
SAVE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def crawl4ai_batch(urls: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
高性能网页爬虫,支持并行处理多个 URL。
|
||||||
|
爬取后的 Markdown 内容将保存到本地 workspace/raw_data 目录中。
|
||||||
|
返回执行结果摘要和保存的文件路径列表。
|
||||||
|
"""
|
||||||
|
if not urls:
|
||||||
|
return "❌ 错误: 未提供任何 URL。"
|
||||||
|
|
||||||
|
# print(f"🕷️ 正在并行爬取 {len(urls)} 个 URL...")
|
||||||
|
# print(f"儲存目錄: {SAVE_DIR}")
|
||||||
|
|
||||||
|
# Crawl4AI 配置(保持原樣)
|
||||||
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
||||||
|
|
||||||
|
browser_config = BrowserConfig(
|
||||||
|
headless=True,
|
||||||
|
verbose=False,
|
||||||
|
java_script_enabled=True,
|
||||||
|
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/118.0.5993.118 Safari/537.36",
|
||||||
|
proxy=None, # 可选,如果需要代理填 "http://user:pass@ip:port"
|
||||||
|
)
|
||||||
|
|
||||||
|
run_config = CrawlerRunConfig(
|
||||||
|
cache_mode=CacheMode.BYPASS,
|
||||||
|
word_count_threshold=5,
|
||||||
|
excluded_tags=["script", "style", "nav", "footer"],
|
||||||
|
remove_overlay_elements=True,
|
||||||
|
process_iframes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
results_summary = []
|
||||||
|
saved_files = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
|
tasks = [crawler.arun(url=url, config=run_config) for url in urls]
|
||||||
|
crawl_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for i, result in enumerate(crawl_results):
|
||||||
|
url = urls[i]
|
||||||
|
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
results_summary.append(f"❌ 抓取失败 {url}: {str(result)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
markdown_content = result.markdown or ""
|
||||||
|
|
||||||
|
if len(markdown_content) < 500:
|
||||||
|
results_summary.append(f"⏩ 跳过 {url} (内容过短)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 生成檔名
|
||||||
|
parsed = urlparse(url)
|
||||||
|
domain = parsed.netloc.replace("www.", "").replace(".", "_")
|
||||||
|
path_part = parsed.path.strip("/").replace("/", "_")[:50] or "index"
|
||||||
|
filename = f"{int(time.time())}_{domain}_{path_part}.md"
|
||||||
|
|
||||||
|
# 完整檔案路徑
|
||||||
|
filepath = SAVE_DIR / filename
|
||||||
|
|
||||||
|
# 寫入檔案
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
header = f"<!-- Source: {url} -->\n<!-- Saved: {time.strftime('%Y-%m-%d %H:%M:%S')} -->\n\n"
|
||||||
|
f.write(header + markdown_content)
|
||||||
|
|
||||||
|
saved_files.append(str(filepath)) # 建議轉成字串
|
||||||
|
results_summary.append(f"✅ 成功: {url} → {filepath}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
status = getattr(result, 'status_code', '未知错误')
|
||||||
|
results_summary.append(f"❌ 失败: {url} (状态码: {status})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"🚨 爬虫系统崩溃: {str(e)}"
|
||||||
|
|
||||||
|
# 回傳給 agent 的結果
|
||||||
|
final_output = (
|
||||||
|
f"### 批量抓取完成 ###\n"
|
||||||
|
f"已成功保存 {len(saved_files)} 个文件。\n"
|
||||||
|
f"儲存目錄: {SAVE_DIR}\n"
|
||||||
|
f"详情:\n" + "\n".join(results_summary)
|
||||||
|
)
|
||||||
|
|
||||||
|
if saved_files:
|
||||||
|
final_output += "\n\n已保存的文件列表(可供後續讀取):\n" + "\n".join(saved_files)
|
||||||
|
|
||||||
|
return final_output
|
||||||
@@ -1,11 +1,8 @@
|
|||||||
import base64
|
|
||||||
import uuid
|
import uuid
|
||||||
from google.oauth2 import service_account
|
from google.oauth2 import service_account
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai.types import GenerateContentConfig, Modality
|
from google.genai.types import GenerateContentConfig, Modality
|
||||||
from PIL import Image
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
@@ -27,21 +24,8 @@ client = genai.Client(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# --- 模拟你已经开发好的报告生成功能 ---
|
|
||||||
@tool
|
@tool
|
||||||
def generate_2025_report_tool(topic: str) -> str:
|
def generate_furniture(prompt: str) -> str:
|
||||||
"""
|
|
||||||
专门用于收集信息并生成报告。
|
|
||||||
当用户询问关于趋势、市场分析、年度报告(如2025家具报告)时调用此工具。
|
|
||||||
"""
|
|
||||||
print(f"\n[系统日志] 正在调用外部模块生成关于 '{topic}' 的报告...")
|
|
||||||
# 这里对接你实际的代码,比如:return my_existing_module.run(topic)
|
|
||||||
return f"【报告生成成功】已生成关于 {topic} 的 PDF 报告。核心洞察:2025年趋势倾向于生物嗜好设计(Biophilic Design)和可持续软木材质。"
|
|
||||||
|
|
||||||
|
|
||||||
# --- 2. 绘图工具 (接入 Nano Banana 逻辑) ---
|
|
||||||
@tool
|
|
||||||
def generate_furniture_sketch(prompt: str) -> str:
|
|
||||||
"""
|
"""
|
||||||
使用 Gemini 图像生成模型根据详细的英文提示词生成家具设计草图。
|
使用 Gemini 图像生成模型根据详细的英文提示词生成家具设计草图。
|
||||||
"""
|
"""
|
||||||
@@ -83,32 +67,3 @@ def generate_furniture_sketch(prompt: str) -> str:
|
|||||||
return "图片生成成功,但上传至存储服务器失败。"
|
return "图片生成成功,但上传至存储服务器失败。"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"绘图流程异常: {str(e)}"
|
return f"绘图流程异常: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print(generate_furniture_sketch("椅子"))
|
|
||||||
# creds = service_account.Credentials.from_service_account_file(
|
|
||||||
# settings.GOOGLE_GENAI_USE_VERTEXAI,
|
|
||||||
# scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
# )
|
|
||||||
# client = genai.Client(
|
|
||||||
# credentials=creds,
|
|
||||||
# project=settings.GOOGLE_CLOUD_PROJECT,
|
|
||||||
# location=settings.GOOGLE_CLOUD_LOCATION,
|
|
||||||
# vertexai=True
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# response = client.models.generate_content(
|
|
||||||
# model="gemini-2.5-flash-image",
|
|
||||||
# contents=("Generate an image of the Eiffel tower with fireworks in the background."),
|
|
||||||
# config=GenerateContentConfig(
|
|
||||||
# response_modalities=[Modality.TEXT, Modality.IMAGE],
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# for part in response.candidates[0].content.parts:
|
|
||||||
# if part.text:
|
|
||||||
# print(part.text)
|
|
||||||
# elif part.inline_data:
|
|
||||||
# image = Image.open(BytesIO((part.inline_data.data)))
|
|
||||||
# image.save("example-image-eiffel-tower.png")
|
|
||||||
36
src/server/agent/tools/read_file.py
Normal file
36
src/server/agent/tools/read_file.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def read_file(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
读取本地文件的万能工具。支持绝对路径和相对路径。
|
||||||
|
"""
|
||||||
|
# 1. 极端清洗:去掉 Agent 可能误加的引号、空格或转义符
|
||||||
|
path = file_path.strip().strip("'").strip('"').replace("\\", "/")
|
||||||
|
|
||||||
|
# 2. 打印当前环境真相(在你的 Python 控制台可见)
|
||||||
|
print(f"\n--- 🛠️ READ_FILE 调试信息 ---")
|
||||||
|
print(f"待读路径: {path}")
|
||||||
|
print(f"当前工作目录 (CWD): {os.getcwd()}")
|
||||||
|
print(f"是否存在: {os.path.exists(path)}")
|
||||||
|
|
||||||
|
# 3. 尝试直接读取(跳过任何沙箱逻辑)
|
||||||
|
try:
|
||||||
|
# 如果是相对路径,尝试转为绝对路径再读
|
||||||
|
abs_path = os.path.abspath(path)
|
||||||
|
if os.path.exists(abs_path):
|
||||||
|
with open(abs_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
# 如果读不到,列出父目录内容作为线索
|
||||||
|
parent = os.path.dirname(abs_path)
|
||||||
|
if os.path.exists(parent):
|
||||||
|
files = os.listdir(parent)
|
||||||
|
return f"错误:文件不存在。该目录下现有的文件有: {files[:5]}..."
|
||||||
|
return f"错误:路径不存在,且连父目录 {parent} 都找不到。"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"读取失败,系统异常: {str(e)}"
|
||||||
157
src/server/agent/tools/report_generator_tool.py
Normal file
157
src/server/agent/tools/report_generator_tool.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
from langchain_qwq import ChatQwen
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# LLM 初始化
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
|
||||||
|
llm = ChatQwen(
|
||||||
|
model="qwen3.5-flash",
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=settings.QWEN_API_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Tool 输入 Schema
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
class ReportInput(BaseModel):
|
||||||
|
report_topic: str = Field(
|
||||||
|
...,
|
||||||
|
description="Main topic of the report, e.g. '2026 Sofa Design Trends'"
|
||||||
|
)
|
||||||
|
structured_data: List[Dict] = Field(
|
||||||
|
...,
|
||||||
|
description="Structured retrieval result items"
|
||||||
|
)
|
||||||
|
language: Optional[str] = Field(
|
||||||
|
default="English",
|
||||||
|
description="Output language"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# LangGraph Tool
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
@tool("report_generator", args_schema=ReportInput)
|
||||||
|
async def report_generator(
|
||||||
|
report_topic: str,
|
||||||
|
structured_data: List[Dict],
|
||||||
|
language: str = "English"
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Generate a professional design/market report
|
||||||
|
directly from structured retrieval results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not structured_data:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No structured data provided."
|
||||||
|
}
|
||||||
|
|
||||||
|
collected_data_str = json.dumps(
|
||||||
|
structured_data,
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Prompt
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
system_prompt = f"""
|
||||||
|
You are a professional design trend analyst.
|
||||||
|
|
||||||
|
Generate a long, structured Markdown report.
|
||||||
|
|
||||||
|
REQUIREMENTS:
|
||||||
|
|
||||||
|
1. Follow MECE principle.
|
||||||
|
2. Embed images ONLY if they start with https://
|
||||||
|
using: 
|
||||||
|
3. Insert images inline.
|
||||||
|
4. Every key insight must cite source:
|
||||||
|
[Website Name](url)
|
||||||
|
5. Use Markdown headings.
|
||||||
|
6. Start directly with title.
|
||||||
|
7. Be detailed and analytical.
|
||||||
|
|
||||||
|
Output Language: {language}
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_prompt = f"""
|
||||||
|
Topic: {report_topic}
|
||||||
|
|
||||||
|
Input Data:
|
||||||
|
{collected_data_str}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# 调用 LLM
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=user_prompt)
|
||||||
|
])
|
||||||
|
|
||||||
|
report_content = response.content.strip()
|
||||||
|
|
||||||
|
# 清理 markdown block 包裹
|
||||||
|
report_content = (
|
||||||
|
report_content
|
||||||
|
.replace("```markdown", "")
|
||||||
|
.replace("```", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"LLM generation failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# 保存报告
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
output_dir = "workspace/reports"
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
safe_topic = re.sub(
|
||||||
|
r'[\\/*?:"<>|]',
|
||||||
|
"",
|
||||||
|
report_topic.replace(" ", "_")
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = f"{output_dir}/{safe_topic}.md"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
f.write(report_content)
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Failed to save report: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"file_path": filename,
|
||||||
|
"message": "Report generated successfully."
|
||||||
|
}
|
||||||
74
src/server/agent/tools/research_tool.py
Normal file
74
src/server/agent/tools/research_tool.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Set, Optional
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from tavily import TavilyClient
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
# 模拟配置加载
|
||||||
|
TAVILY_API_KEY = settings.TAVILY_API_KEY
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def topic_research(topic: str, max_urls: int = 15) -> str:
|
||||||
|
"""
|
||||||
|
深度调研工具。该工具会利用 Tavily 搜索引擎针对特定主题进行多维度搜索。
|
||||||
|
它会自动生成针对性的搜索词(包含年份和趋势),并返回去重后的高质量 URL 列表。
|
||||||
|
"""
|
||||||
|
if not TAVILY_API_KEY:
|
||||||
|
return "❌ 错误: 未配置 TAVILY_API_KEY。"
|
||||||
|
|
||||||
|
client = TavilyClient(api_key=TAVILY_API_KEY)
|
||||||
|
|
||||||
|
# 1. 自动生成多维度搜索词 (在工具内部快速生成)
|
||||||
|
current_year = datetime.now().strftime("%Y")
|
||||||
|
queries = [
|
||||||
|
f"{topic} trends {current_year}",
|
||||||
|
f"{topic} market analysis {current_year}",
|
||||||
|
f"top selling {topic} styles {current_year}",
|
||||||
|
f"best {topic} materials and colors {current_year}"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 2. 并行执行搜索
|
||||||
|
async def perform_search(q: str):
|
||||||
|
# 使用 asyncio.to_thread 运行同步的 Tavily SDK
|
||||||
|
def sync_search():
|
||||||
|
try:
|
||||||
|
response = client.search(
|
||||||
|
query=q,
|
||||||
|
search_depth="advanced",
|
||||||
|
max_results=5,
|
||||||
|
include_answer=False
|
||||||
|
)
|
||||||
|
return response.get('results', [])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Search error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
return await asyncio.to_thread(sync_search)
|
||||||
|
|
||||||
|
search_tasks = [perform_search(q) for q in queries]
|
||||||
|
search_results_list = await asyncio.gather(*search_tasks)
|
||||||
|
|
||||||
|
# 3. 结果去重与过滤
|
||||||
|
seen_urls: Set[str] = set()
|
||||||
|
final_urls = []
|
||||||
|
|
||||||
|
# 常见的非内容页面过滤
|
||||||
|
skip_extensions = ('.pdf', '.jpg', '.png', '.zip', '.exe')
|
||||||
|
|
||||||
|
for results in search_results_list:
|
||||||
|
for item in results:
|
||||||
|
url = item.get('url')
|
||||||
|
if url and url not in seen_urls:
|
||||||
|
if not url.lower().endswith(skip_extensions):
|
||||||
|
seen_urls.add(url)
|
||||||
|
final_urls.append(url)
|
||||||
|
|
||||||
|
# 4. 结果截断
|
||||||
|
selected_urls = final_urls[:max_urls]
|
||||||
|
|
||||||
|
# 返回 JSON 字符串,便于 Agent 下一步调用批量爬虫 (Crawl4ai)
|
||||||
|
return json.dumps(selected_urls, ensure_ascii=False)
|
||||||
27
src/server/agent/tools/save_to_local.py
Normal file
27
src/server/agent/tools/save_to_local.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import os
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
# 定义本地保存路径
|
||||||
|
OUTPUT_DIR = "./research_reports"
|
||||||
|
if not os.path.exists(OUTPUT_DIR):
|
||||||
|
os.makedirs(OUTPUT_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def save_to_local_disk(filename: str, content: str) -> str:
|
||||||
|
"""
|
||||||
|
将内容保存到本地物理磁盘。
|
||||||
|
filename: 文件名(例如 'sofa_report.md')
|
||||||
|
content: 调研报告或数据的文本内容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 移除非法路径字符,确保安全
|
||||||
|
safe_filename = os.path.basename(filename)
|
||||||
|
file_path = os.path.join(OUTPUT_DIR, safe_filename)
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
return f"✅ 成功!文件已保存至本地物理路径: {os.path.abspath(file_path)}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ 保存失败,错误原因: {str(e)}"
|
||||||
225
src/server/agent/tools/structured_retrieval_tool.py
Normal file
225
src/server/agent/tools/structured_retrieval_tool.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
# RAG
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# 全局模型(单例)
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
_EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
||||||
|
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
_RERANK_MODEL = CrossEncoder(
|
||||||
|
"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredRetrievalInput(BaseModel):
|
||||||
|
file_paths: List[str] = Field(..., description="List of local markdown file paths.")
|
||||||
|
query: str = Field(..., description="Extraction query")
|
||||||
|
source_url: Optional[str] = Field(None, description="Optional global source URL")
|
||||||
|
|
||||||
|
|
||||||
|
@tool("structured_retrieval", args_schema=StructuredRetrievalInput)
|
||||||
|
def structured_retrieval(
|
||||||
|
file_paths: List[str],
|
||||||
|
query: str,
|
||||||
|
source_url: Optional[str] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Batch structured extraction from markdown files.
|
||||||
|
- Performs vector search + re-ranking
|
||||||
|
- Saves extracted structured data as JSON file to disk
|
||||||
|
- Returns ONLY summary (status, count, file path)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── 1. 收集所有文件內容 ──────────────────────────────────────
|
||||||
|
all_docs_pool: List[Document] = []
|
||||||
|
|
||||||
|
for path in file_paths:
|
||||||
|
if not os.path.exists(path) or not path.endswith((".md", ".markdown")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
current_source = source_url or _extract_source_from_md(content) or "unknown"
|
||||||
|
|
||||||
|
sections = _split_markdown_by_headers(content)
|
||||||
|
|
||||||
|
for sec in sections:
|
||||||
|
all_docs_pool.append(
|
||||||
|
Document(
|
||||||
|
page_content=sec,
|
||||||
|
metadata={"source_url": current_source, "file_name": file_name}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not all_docs_pool:
|
||||||
|
return {"status": "no_documents_found", "items_count": 0, "json_path": None}
|
||||||
|
|
||||||
|
# ── 2. Vector search ────────────────────────────────────────────
|
||||||
|
vector_store = FAISS.from_documents(all_docs_pool, _EMBEDDING_MODEL)
|
||||||
|
retrieved = vector_store.similarity_search(query, k=200)
|
||||||
|
|
||||||
|
# ── 3. 提取結構化片段 ──────────────────────────────────────────
|
||||||
|
structured_items = []
|
||||||
|
|
||||||
|
for doc in retrieved:
|
||||||
|
text = doc.page_content.strip()
|
||||||
|
if len(text) < 30:
|
||||||
|
continue
|
||||||
|
|
||||||
|
images = list(set(re.findall(r"!\[.*?\]\((.*?)\)", text)))
|
||||||
|
|
||||||
|
structured_items.append(
|
||||||
|
{
|
||||||
|
"text": text,
|
||||||
|
"images": images,
|
||||||
|
"source_url": doc.metadata.get("source_url"),
|
||||||
|
"file_name": doc.metadata.get("file_name")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 4. Re-rank ──────────────────────────────────────────────────
|
||||||
|
if structured_items:
|
||||||
|
unique_items = {item["text"]: item for item in structured_items}.values()
|
||||||
|
pairs = [[query, item["text"]] for item in unique_items]
|
||||||
|
scores = _RERANK_MODEL.predict(pairs)
|
||||||
|
|
||||||
|
sorted_items = sorted(
|
||||||
|
zip(scores, unique_items),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
top_items = [item for _, item in sorted_items[:50]]
|
||||||
|
else:
|
||||||
|
top_items = []
|
||||||
|
|
||||||
|
# ── 5. 寫入 JSON 文件 ──────────────────────────────────────────
|
||||||
|
if not top_items:
|
||||||
|
return {"status": "no_relevant_content", "items_count": 0, "json_path": None}
|
||||||
|
|
||||||
|
# 產生有意義的檔名
|
||||||
|
safe_query = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '_', query)[:40]
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
json_filename = f"extracted_{safe_query}_{timestamp}.json"
|
||||||
|
|
||||||
|
# 建議的儲存目錄(與 crawl4ai_batch 對齊)
|
||||||
|
output_dir = os.path.join(os.path.dirname(file_paths[0]), "..", "extracted")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
json_path = os.path.join(output_dir, json_filename)
|
||||||
|
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"query": query,
|
||||||
|
"extracted_at": timestamp,
|
||||||
|
"item_count": len(top_items),
|
||||||
|
"items": top_items
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 6. 只回傳摘要 ──────────────────────────────────────────────
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"items_count": len(top_items),
|
||||||
|
"json_path": json_path,
|
||||||
|
"summary": f"已提取 {len(top_items)} 個高相關片段,儲存於 {json_path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_source_from_md(content: str) -> Optional[str]:
|
||||||
|
match = re.search(r"<!--\s*Source:\s*(.*?)\s*-->", content)
|
||||||
|
return match.group(1).strip() if match else None
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Markdown Header Split
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
def _split_markdown_by_headers(
|
||||||
|
content: str,
|
||||||
|
max_chars: int = 2000,
|
||||||
|
overlap: int = 150,
|
||||||
|
):
|
||||||
|
header_re = re.compile(
|
||||||
|
r'^(#{1,6})\s+(.+?)\s*$',
|
||||||
|
re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
|
matches = list(header_re.finditer(content))
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return _chunk_text(content, max_chars, overlap)
|
||||||
|
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
for i, m in enumerate(matches):
|
||||||
|
start = m.start()
|
||||||
|
end = (
|
||||||
|
matches[i + 1].start()
|
||||||
|
if i + 1 < len(matches)
|
||||||
|
else len(content)
|
||||||
|
)
|
||||||
|
|
||||||
|
block = content[start:end].strip()
|
||||||
|
if block:
|
||||||
|
sections.append(block)
|
||||||
|
|
||||||
|
final_sections = []
|
||||||
|
|
||||||
|
for s in sections:
|
||||||
|
if len(s) > max_chars:
|
||||||
|
final_sections.extend(
|
||||||
|
_chunk_text(s, max_chars, overlap)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_sections.append(s)
|
||||||
|
|
||||||
|
return final_sections
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_text(
|
||||||
|
text: str,
|
||||||
|
max_chars: int = 2000,
|
||||||
|
overlap: int = 150
|
||||||
|
):
|
||||||
|
text = text.strip()
|
||||||
|
if len(text) <= max_chars:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
|
||||||
|
while start < len(text):
|
||||||
|
end = min(len(text), start + max_chars)
|
||||||
|
chunk = text[start:end].strip()
|
||||||
|
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
if end == len(text):
|
||||||
|
break
|
||||||
|
|
||||||
|
start = max(0, end - overlap)
|
||||||
|
|
||||||
|
return chunks
|
||||||
38
src/server/agent/tools/terminate_tool.py
Normal file
38
src/server/agent/tools/terminate_tool.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TerminateInput(BaseModel):
|
||||||
|
"""終止對話的輸入參數"""
|
||||||
|
status: Literal["success", "failure"] = Field(
|
||||||
|
description="互動結束的狀態:'success' 表示任務完成,'failure' 表示無法繼續",
|
||||||
|
examples=["success", "failure"]
|
||||||
|
)
|
||||||
|
reason: str = Field(
|
||||||
|
default="",
|
||||||
|
description="可選:簡單說明為什麼結束(例如 '報告已生成' 或 '缺少關鍵資訊')",
|
||||||
|
examples=["報告已成功生成", "無法取得足夠資料"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(args_schema=TerminateInput)
|
||||||
|
def terminate(status: str, reason: str = "") -> str:
|
||||||
|
"""
|
||||||
|
當任務完成、報告已生成,或無法繼續進行時,呼叫此工具來結束本次互動。
|
||||||
|
|
||||||
|
使用時機:
|
||||||
|
- 已經成功產生最終報告(report_generator 已完成)
|
||||||
|
- 遇到無法解決的錯誤或缺少關鍵資訊
|
||||||
|
- 用戶需求已完全滿足
|
||||||
|
|
||||||
|
請在呼叫前確保所有必要步驟已完成,並在 reason 中簡單說明結束原因。
|
||||||
|
"""
|
||||||
|
if status not in ("success", "failure"):
|
||||||
|
status = "failure" # 防呆
|
||||||
|
|
||||||
|
msg = f"互動已終止,狀態:{status.upper()}"
|
||||||
|
if reason:
|
||||||
|
msg += f"\n原因:{reason}"
|
||||||
|
|
||||||
|
return msg
|
||||||
96
src/server/agent/tools/user_persona_tool.py
Normal file
96
src/server/agent/tools/user_persona_tool.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import List, Literal, Optional, Dict, Any
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
# 定义存储路径
|
||||||
|
DB_PATH = os.path.join("workspace", "user_persona.json")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_store() -> Dict[str, Any]:
|
||||||
|
"""从本地文件加载画像数据"""
|
||||||
|
if os.path.exists(DB_PATH):
|
||||||
|
try:
|
||||||
|
with open(DB_PATH, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _save_store(data: Dict[str, Any]):
|
||||||
|
"""将画像数据保存到本地文件"""
|
||||||
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
||||||
|
with open(DB_PATH, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def manage_user_persona(
|
||||||
|
command: Literal["set", "update", "get", "clear"],
|
||||||
|
design_type: Optional[str] = None,
|
||||||
|
style_preference: Optional[str] = None,
|
||||||
|
budget_range: Optional[str] = None,
|
||||||
|
color_palette: Optional[List[str]] = None,
|
||||||
|
target_audience: Optional[str] = None,
|
||||||
|
extra_requirements: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
用户画像与设计偏好管理工具。
|
||||||
|
用于设定、更新、获取或重置用户的设计上下文(如风格、预算、颜色)。
|
||||||
|
Agent 在开始调研前必须先调用 get 获取画像,若关键信息缺失需引导用户补充。
|
||||||
|
"""
|
||||||
|
# 每次调用都重新读取,确保多进程或重启后数据一致
|
||||||
|
store = _load_store()
|
||||||
|
|
||||||
|
if command == "clear":
|
||||||
|
if os.path.exists(DB_PATH):
|
||||||
|
os.remove(DB_PATH)
|
||||||
|
return "✅ 用户个性化模板已从本地文件清空。"
|
||||||
|
|
||||||
|
if command == "get":
|
||||||
|
if not store:
|
||||||
|
return "⚠️ [缺失信息] 当前尚未配置画像。请询问用户:设计类型(如沙发)、风格偏好(如极简)等。"
|
||||||
|
|
||||||
|
# 格式化输出供 Agent 阅读
|
||||||
|
res = [
|
||||||
|
"--- 👤 实时用户画像 (本地存储) ---",
|
||||||
|
f"🎯 类型: {store.get('design_type', '未设定')}",
|
||||||
|
f"🎨 风格: {store.get('style_preference', '未设定')}",
|
||||||
|
f"💰 预算: {store.get('budget_range', '未设定')}",
|
||||||
|
f"🌈 色系: {', '.join(store.get('color_palette', [])) or '未设定'}",
|
||||||
|
f"👥 受众: {store.get('target_audience', '未设定')}",
|
||||||
|
f"📝 需求: {store.get('extra_requirements', '未设定')}",
|
||||||
|
"-----------------------"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 逻辑检查
|
||||||
|
if not store.get('design_type') or not store.get('style_preference'):
|
||||||
|
res.append("\n⚠️ 关键信息缺失,建议补充 '设计类型' 和 '风格偏好'。")
|
||||||
|
return "\n".join(res)
|
||||||
|
|
||||||
|
if command in ["set", "update"]:
|
||||||
|
if command == "set":
|
||||||
|
store = {} # 重置内存中的字典
|
||||||
|
|
||||||
|
# 提取传入的非空参数
|
||||||
|
update_data = {
|
||||||
|
"design_type": design_type,
|
||||||
|
"style_preference": style_preference,
|
||||||
|
"budget_range": budget_range,
|
||||||
|
"color_palette": color_palette,
|
||||||
|
"target_audience": target_audience,
|
||||||
|
"extra_requirements": extra_requirements
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新有效字段
|
||||||
|
for k, v in update_data.items():
|
||||||
|
if v is not None:
|
||||||
|
store[k] = v
|
||||||
|
|
||||||
|
# 保存到文件
|
||||||
|
_save_store(store)
|
||||||
|
|
||||||
|
return f"✅ 本地画像已同步。当前配置:\n{json.dumps(store, ensure_ascii=False, indent=2)}"
|
||||||
|
|
||||||
|
return "❌ 错误:未知命令。"
|
||||||
Reference in New Issue
Block a user