Compare commits
99 Commits
1ecb02d706
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| f55522b135 | |||
| 3a8401c5f4 | |||
| e75b92e112 | |||
| 44ce5d0786 | |||
| 316d027aab | |||
| 4ecc409ab5 | |||
| a8f52dbdaa | |||
| fa47573a94 | |||
| b4c1b5169b | |||
| 536be96d10 | |||
| cbee81ee44 | |||
| dcf29a3b84 | |||
| 2c075a3871 | |||
| f3db0290af | |||
| 60c669a10e | |||
| 8fc93077fc | |||
| 22d9ef0e1f | |||
| 88e2c8cfb8 | |||
| 03cf977087 | |||
| 70ddf97484 | |||
| 417c6f01b5 | |||
| d328344edb | |||
| 4dd4e8ac4f | |||
| 3840a325a3 | |||
| 1afc431ee2 | |||
| cb836f5108 | |||
| c03488049e | |||
| 6f4a0cc80c | |||
| 85e75cd43f | |||
| e1419676fb | |||
| a9ffb5a446 | |||
| db8014f024 | |||
| 3dff7876e5 | |||
| 764b7ba063 | |||
| e969c407f7 | |||
| 8f3fb0f584 | |||
| 21c5b4872a | |||
| 63b4b932c7 | |||
| d27fbc969d | |||
| 8807fb6100 | |||
| e40048bf5d | |||
| 702f48626f | |||
| 79258a6a43 | |||
| 34cf3456cd | |||
| 108fa0fb8c | |||
| 3f34bb005c | |||
| 0890241cb1 | |||
| f957ded215 | |||
| 1c8283334e | |||
| f35753954e | |||
| ac87ca8126 | |||
| 35ad8f69e8 | |||
| 5bc27d4d52 | |||
| d6836fefc2 | |||
| ed9406732d | |||
| e3cf22edae | |||
| 1579c8d0f5 | |||
| 75b888eb37 | |||
| d9acdf593d | |||
| 1c672afd2d | |||
| 48ef18295f | |||
| bac64f0ef1 | |||
| 7a4426bc5d | |||
| 0ec38e2623 | |||
| 7c23d16ea6 | |||
| d5ef985e52 | |||
| 42f322ec34 | |||
| 3c2b2d9f4a | |||
| 56a410413e | |||
| adc7e70c1f | |||
| 8e65682dba | |||
| ac8a5e5a30 | |||
| b6ca7ae6ed | |||
| d05ef7b3c2 | |||
| dbf98e526c | |||
| affd4db6f0 | |||
| 4728a44ca5 | |||
| 91688b1686 | |||
| 6b4b6fd37c | |||
| 25abdfa38a | |||
| 510a5117ee | |||
| a6393df0e3 | |||
| 7042d428fa | |||
| c862121b48 | |||
| d66a870207 | |||
| 5106875618 | |||
| 5951205ac9 | |||
| a5ef8cfbd9 | |||
| 1759744c8b | |||
| 048a9979ed | |||
| 9d1d5f5078 | |||
| a5929431dd | |||
| 9b1b649153 | |||
| ab012f63a7 | |||
| 972c3803a7 | |||
| 62b4505261 | |||
| 129510363a | |||
| 6648541556 | |||
| 1ade907828 |
4
.gitea/workflows/prod_build_manual.yaml
Normal file → Executable file
4
.gitea/workflows/prod_build_manual.yaml
Normal file → Executable file
@@ -7,7 +7,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
env:
|
env:
|
||||||
REMOTE_DEPLOY_PATH: /workspace/FiDA_Workspace/Python_Server_Workspace/Prod
|
REMOTE_DEPLOY_PATH: /workspace/FiDA_Workspace/Python_Server_Workspace/FiDA_Prod
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: 1.检出代码
|
- name: 1.检出代码
|
||||||
@@ -36,5 +36,3 @@ jobs:
|
|||||||
|
|
||||||
docker-compose down 2>&1
|
docker-compose down 2>&1
|
||||||
docker-compose up -d 2>&1
|
docker-compose up -d 2>&1
|
||||||
|
|
||||||
docker image prune -f 2>&1
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
name: 手动 FiDA python prod 分支构建部署
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
scheduled_deploy:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
env:
|
|
||||||
REMOTE_DEPLOY_PATH: /mnt/process/A6000_Server/FiDA_Workspace/Python_Server_Workspace/Prod
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: 1.检出代码
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
ref: 'main'
|
|
||||||
|
|
||||||
- name: 2.复制文件到服务器
|
|
||||||
uses: appleboy/scp-action@v0.1.7
|
|
||||||
with:
|
|
||||||
host: ${{ secrets.SERVER_HOST_H200 }}
|
|
||||||
username: ${{ secrets.SERVER_USER_H200 }}
|
|
||||||
password: ${{ secrets.SERVER_PASSWORD_H200 }}
|
|
||||||
source: "."
|
|
||||||
target: ${{ env.REMOTE_DEPLOY_PATH }}
|
|
||||||
|
|
||||||
- name: Restart Docker containers
|
|
||||||
uses: appleboy/ssh-action@v0.1.10
|
|
||||||
with:
|
|
||||||
host: ${{ secrets.SERVER_HOST_H200 }}
|
|
||||||
username: ${{ secrets.SERVER_USER_H200 }}
|
|
||||||
password: ${{ secrets.SERVER_PASSWORD_H200 }}
|
|
||||||
script: |
|
|
||||||
# 进入项目目录
|
|
||||||
cd ${{ env.REMOTE_DEPLOY_PATH }}
|
|
||||||
|
|
||||||
docker compose down 2>&1
|
|
||||||
docker compose up -d 2>&1
|
|
||||||
docker compose ps 2>&1
|
|
||||||
3
.gitignore
vendored
Normal file → Executable file
3
.gitignore
vendored
Normal file → Executable file
@@ -147,3 +147,6 @@ app/logs/*
|
|||||||
*.json
|
*.json
|
||||||
*.env*
|
*.env*
|
||||||
config.backup.py
|
config.backup.py
|
||||||
|
*.md
|
||||||
|
.langgraph_api
|
||||||
|
.jbeval
|
||||||
8
Dockerfile
Normal file → Executable file
8
Dockerfile
Normal file → Executable file
@@ -17,13 +17,13 @@ ENV PATH="/env_build/.venv/bin:$PATH"
|
|||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# 更新索引并安装替代包
|
# 更新索引并安装替代包
|
||||||
RUN apt-get update && apt-get install -y vim
|
RUN apt-get update && apt-get install -y vim libxcb-shm0 libx11-xcb1 libx11-6 libxcb1 libxext6 libxrandr2 libxcomposite1 libxcursor1 libxdamage1 libxfixes3 libxi6 libgtk-3-0t64 libpangocairo-1.0-0 libpango-1.0-0 libatk1.0-0t64 libcairo-gobject2 libcairo2 libgdk-pixbuf-2.0-0 libglib2.0-0t64 libxrender1 libasound2t64 libfreetype6 libfontconfig1 libdbus-1-3 libnss3 libgbm1 libatspi2.0-0 libcups2 libdrm2 libxkbcommon0
|
||||||
|
RUN playwright install
|
||||||
|
|
||||||
|
|
||||||
#CMD ["tail","-f","/dev/null"]
|
#CMD ["tail","-f","/dev/null"]
|
||||||
# Run the application.
|
# Run the application.
|
||||||
CMD ["gunicorn", "main:app_server", \
|
CMD ["gunicorn", "main:app_server", \
|
||||||
"-w", "4", \
|
"-c", "/app/gunicorn.conf.py", \
|
||||||
"-k", "uvicorn.workers.UvicornWorker", \
|
|
||||||
"--bind", "0.0.0.0:80", \
|
|
||||||
"--access-logfile", "-", \
|
"--access-logfile", "-", \
|
||||||
"--error-logfile", "-"]
|
"--error-logfile", "-"]
|
||||||
35
config.yaml
35
config.yaml
@@ -1,35 +0,0 @@
|
|||||||
# 配置示例:模型与 Agent 提示词模板
|
|
||||||
model:
|
|
||||||
name: "gemini-mini"
|
|
||||||
temperature: 0.2
|
|
||||||
max_tokens: 1024
|
|
||||||
agents:
|
|
||||||
supervisor:
|
|
||||||
prompt_template: |
|
|
||||||
你是家具设计团队的主管(Supervisor)。
|
|
||||||
请根据用户的意图,选择最合适的专家:
|
|
||||||
- Designer: 设计建议、参数细化、闲聊、问候。
|
|
||||||
- Visualizer: 绘图、看草图。
|
|
||||||
- Researcher: 市场报告、趋势。
|
|
||||||
|
|
||||||
只需输出专家名称。
|
|
||||||
|
|
||||||
designer:
|
|
||||||
prompt_template: |
|
|
||||||
你是一位资深的家具设计师。你的职责是:
|
|
||||||
1. 从用户的模糊描述中提取或补充具体的设计参数(尺寸、材质、人体工学数据)。
|
|
||||||
2. 如果用户想画图,不要直接画,而是先描述清楚细节,然后让 Visualizer 去画。
|
|
||||||
请以专业的口吻回复。
|
|
||||||
|
|
||||||
visualizer:
|
|
||||||
prompt_template: |
|
|
||||||
你是视觉专家。你的目标是生成高质量的家具草图。
|
|
||||||
步骤:
|
|
||||||
1. 根据上下文,编写一个详细的 Stable Diffusion 风格的英文 Prompt。
|
|
||||||
2. 必须调用 generate_furniture_sketch 工具来生成图片。
|
|
||||||
|
|
||||||
注意:如果对话中出现 [SYSTEM_DIRECTIVE] 要求直接绘图,请立即根据已知信息编写 Prompt 并调用 generate_furniture_sketch 工具,不要进行多余的询问。
|
|
||||||
|
|
||||||
researcher:
|
|
||||||
prompt_template: |
|
|
||||||
你是情报专家,负责检索与整理参考资料并生成报告。
|
|
||||||
52
docker-compose.yml
Normal file → Executable file
52
docker-compose.yml
Normal file → Executable file
@@ -1,4 +1,5 @@
|
|||||||
name: fida-python-prod
|
name: fida-python-prod
|
||||||
|
|
||||||
services:
|
services:
|
||||||
server:
|
server:
|
||||||
container_name: "FiDA_${SERVE_ENV}_Server"
|
container_name: "FiDA_${SERVE_ENV}_Server"
|
||||||
@@ -12,3 +13,54 @@ services:
|
|||||||
- /etc/localtime:/etc/localtime:ro
|
- /etc/localtime:/etc/localtime:ro
|
||||||
ports:
|
ports:
|
||||||
- "${SERVE_PORT}:80"
|
- "${SERVE_PORT}:80"
|
||||||
|
environment:
|
||||||
|
- SERVE_ENV=${SERVE_ENV}
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
- fida_app_net
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Celery Worker(单个 Worker 同时处理两个任务) ====================
|
||||||
|
celery_worker:
|
||||||
|
container_name: "FiDA_${SERVE_ENV}_CeleryWorker"
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
working_dir: /app
|
||||||
|
command: >
|
||||||
|
celery -A src.server.canvas_generate_3D.celery_app worker -n celery_worker@%h -Q img_to_3d_queue,three_d_to_3views_queue --concurrency=1 --prefetch-multiplier=1 --max-tasks-per-child=1 --loglevel=INFO
|
||||||
|
volumes:
|
||||||
|
- ./:/app
|
||||||
|
- ./.env:/app/.env
|
||||||
|
- /etc/localtime:/etc/localtime:ro
|
||||||
|
environment:
|
||||||
|
- SERVE_ENV=${SERVE_ENV}
|
||||||
|
depends_on:
|
||||||
|
- server
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
- fida_app_net
|
||||||
|
|
||||||
|
|
||||||
|
networks:
|
||||||
|
fida_app_net: # 这个名称就是你在 services 中引用的网络
|
||||||
|
external: true
|
||||||
|
name: fida_app_net # 实际创建的网络名称(不带项目名前缀)
|
||||||
|
# ==================== 可选:RabbitMQ(如果你想把 RabbitMQ 也纳入 docker-compose 管理) ====================
|
||||||
|
# rabbitmq:
|
||||||
|
# image: rabbitmq:3.13-management
|
||||||
|
# container_name: "FiDA_${SERVE_ENV}_RabbitMQ"
|
||||||
|
# ports:
|
||||||
|
# - "5672:5672"
|
||||||
|
# - "15672:15672"
|
||||||
|
# environment:
|
||||||
|
# RABBITMQ_DEFAULT_USER: guest
|
||||||
|
# RABBITMQ_DEFAULT_PASS: guest
|
||||||
|
# volumes:
|
||||||
|
# - rabbitmq_data:/var/lib/rabbitmq
|
||||||
|
# restart: unless-stopped
|
||||||
|
|
||||||
|
# volumes:
|
||||||
|
# rabbitmq_data:
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
23
gunicorn.conf.py
Executable file
23
gunicorn.conf.py
Executable file
@@ -0,0 +1,23 @@
|
|||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
bind = "0.0.0.0:80"
|
||||||
|
worker_class = "uvicorn.workers.UvicornWorker"
|
||||||
|
loglevel = "info"
|
||||||
|
accesslog = "-"
|
||||||
|
errorlog = "-"
|
||||||
|
|
||||||
|
# 关键生产参数
|
||||||
|
workers = 2 # 先用 2 个(ML 场景推荐 1~4,根据 CPU 核数和内存调整)
|
||||||
|
timeout = 300 # 5 分钟,足够模型加载和慢推理
|
||||||
|
graceful_timeout = 300
|
||||||
|
preload_app = True # ★★★ 必须加!模型只加载一次,内存大幅节省
|
||||||
|
|
||||||
|
# 防止内存泄漏(ML 服务常见问题)
|
||||||
|
max_requests = 1000
|
||||||
|
max_requests_jitter = 100
|
||||||
|
|
||||||
|
# 其他优化
|
||||||
|
keepalive = 5
|
||||||
|
worker_connections = 1000
|
||||||
56
logging_env.py
Executable file
56
logging_env.py
Executable file
@@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
LOGGER_CONFIG_DICT = {
|
||||||
|
'version': 1,
|
||||||
|
'disable_existing_loggers': False,
|
||||||
|
'formatters': {
|
||||||
|
'simple': {
|
||||||
|
'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s',
|
||||||
|
'datefmt': '%Y-%m-%d %H:%M:%S' # 补充日期格式,日志更易读
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'handlers': {
|
||||||
|
'console': {
|
||||||
|
'class': 'logging.StreamHandler',
|
||||||
|
'level': 'INFO',
|
||||||
|
'formatter': 'simple',
|
||||||
|
'stream': 'ext://sys.stdout',
|
||||||
|
},
|
||||||
|
'info_file_handler': {
|
||||||
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
|
'level': 'INFO',
|
||||||
|
'formatter': 'simple',
|
||||||
|
'filename': os.path.join(settings.LOGS_PATH, 'info.log'),
|
||||||
|
'maxBytes': 10485760,
|
||||||
|
'backupCount': 50,
|
||||||
|
'encoding': 'utf8',
|
||||||
|
},
|
||||||
|
'error_file_handler': {
|
||||||
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
|
'level': 'ERROR',
|
||||||
|
'formatter': 'simple',
|
||||||
|
'filename': os.path.join(settings.LOGS_PATH, 'error.log'),
|
||||||
|
'maxBytes': 10485760,
|
||||||
|
'backupCount': 20,
|
||||||
|
'encoding': 'utf8',
|
||||||
|
},
|
||||||
|
'debug_file_handler': {
|
||||||
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
|
'level': 'DEBUG',
|
||||||
|
'formatter': 'simple',
|
||||||
|
'filename': os.path.join(settings.LOGS_PATH, 'debug.log'),
|
||||||
|
'maxBytes': 10485760,
|
||||||
|
'backupCount': 50,
|
||||||
|
'encoding': 'utf8',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'loggers': {
|
||||||
|
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
|
||||||
|
},
|
||||||
|
'root': {
|
||||||
|
'level': 'DEBUG',
|
||||||
|
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
|
||||||
|
},
|
||||||
|
}
|
||||||
20
main.py
Normal file → Executable file
20
main.py
Normal file → Executable file
@@ -1,7 +1,17 @@
|
|||||||
|
import logging.config
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from src.routers import chat
|
from logging_env import LOGGER_CONFIG_DICT
|
||||||
|
from src.routers import deep_agent_chat
|
||||||
|
from src.routers import generate_3D
|
||||||
|
from src.routers import flux2_gen_img
|
||||||
|
from src.routers import seg_furniture
|
||||||
|
from src.routers import canvas_assistant
|
||||||
|
|
||||||
|
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||||
|
|
||||||
app_server = FastAPI(
|
app_server = FastAPI(
|
||||||
title="Gemini Furniture Designer API",
|
title="Gemini Furniture Designer API",
|
||||||
@@ -18,7 +28,11 @@ app_server.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 包含路由
|
# 包含路由
|
||||||
app_server.include_router(chat.router)
|
app_server.include_router(deep_agent_chat.router)
|
||||||
|
app_server.include_router(generate_3D.router)
|
||||||
|
app_server.include_router(flux2_gen_img.router)
|
||||||
|
app_server.include_router(seg_furniture.router)
|
||||||
|
app_server.include_router(canvas_assistant.router)
|
||||||
|
|
||||||
|
|
||||||
@app_server.get("/")
|
@app_server.get("/")
|
||||||
@@ -27,4 +41,4 @@ async def root():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run("main:app_server", host="0.0.0.0", port=7777, reload=True)
|
uvicorn.run("main:app_server", host="0.0.0.0", port=7777, reload=False)
|
||||||
|
|||||||
51
pyproject.toml
Normal file → Executable file
51
pyproject.toml
Normal file → Executable file
@@ -1,23 +1,70 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "FiDA"
|
name = "fida"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"crawl4ai>=0.8.0",
|
||||||
|
"deepagents>=0.4.3",
|
||||||
"fastapi[standard]>=0.128.0",
|
"fastapi[standard]>=0.128.0",
|
||||||
"gunicorn>=25.0.1",
|
"gunicorn>=25.0.1",
|
||||||
"image>=1.5.33",
|
"image>=1.5.33",
|
||||||
|
"langchain-community>=0.4.1",
|
||||||
"langchain-core>=1.2.8",
|
"langchain-core>=1.2.8",
|
||||||
"langchain-google-genai>=4.2.0",
|
"langchain-google-genai>=4.2.0",
|
||||||
"langgraph>=1.0.7",
|
"langgraph[all,postgres]>=1.0.7",
|
||||||
"langgraph-checkpoint-mongodb>=0.3.1",
|
"langgraph-checkpoint-mongodb>=0.3.1",
|
||||||
"minio>=7.2.20",
|
"minio>=7.2.20",
|
||||||
"modality>=0.1.0",
|
"modality>=0.1.0",
|
||||||
"motor>=3.7.1",
|
"motor>=3.7.1",
|
||||||
|
"playwright>=1.58.0",
|
||||||
"pydantic>=2.12.5",
|
"pydantic>=2.12.5",
|
||||||
"pydantic-settings>=2.12.0",
|
"pydantic-settings>=2.12.0",
|
||||||
"pymongo[srv]>=4.15.5",
|
"pymongo[srv]>=4.15.5",
|
||||||
"python-dotenv>=1.2.1",
|
"python-dotenv>=1.2.1",
|
||||||
|
"tavily-python>=0.7.21",
|
||||||
"uuid>=1.30",
|
"uuid>=1.30",
|
||||||
"uvicorn>=0.40.0",
|
"uvicorn>=0.40.0",
|
||||||
|
"psycopg[binary]>=3.3.3",
|
||||||
|
"postgres>=4.0",
|
||||||
|
"langchain-huggingface>=1.2.0",
|
||||||
|
"rank-bm25>=0.2.2",
|
||||||
|
"faiss-cpu>=1.13.2",
|
||||||
|
"terminate>=0.0.9",
|
||||||
|
"report-generator>=0.1.10",
|
||||||
|
"dashscope>=1.25.13",
|
||||||
|
"prompt>=0.4.1",
|
||||||
|
"langchain-qwq>=0.3.4",
|
||||||
|
"asyncio>=4.0.0",
|
||||||
|
"requests>=2.32.5",
|
||||||
|
"chardet<6",
|
||||||
|
"datetime>=6.0",
|
||||||
|
"agentstate>=1.0.2",
|
||||||
|
"langchain-classic>=1.0.1",
|
||||||
|
"langsmith>=0.7.13",
|
||||||
|
"path>=17.1.1",
|
||||||
|
"langgraph-checkpoint-postgres>=3.0.4",
|
||||||
|
"langgraph-store-mongodb>=0.2.0",
|
||||||
|
"tool>=0.8.0",
|
||||||
|
"langchain-daytona>=0.0.3",
|
||||||
|
"langgraph-cli[inmem]>=0.4.19",
|
||||||
|
"start>=0.2",
|
||||||
|
"end>=1.3.1",
|
||||||
|
"annotated>=0.0.2",
|
||||||
|
"field>=0.2.0",
|
||||||
|
"aio-pika>=9.6.2",
|
||||||
|
"celery[redis]>=5.6.3",
|
||||||
|
"python3-pika>=0.9.14",
|
||||||
|
"tasks>=2.8.0",
|
||||||
|
"kombu>=5.4.0",
|
||||||
|
"sentence-transformers[onnx]>=5.3.0",
|
||||||
|
"celery-types>=0.26.0",
|
||||||
|
"langgraph-api>=0.7.94",
|
||||||
|
"debugpy>=1.8.20",
|
||||||
|
"pydevd-pycharm~=253.29346.308",
|
||||||
|
"python-magic>=0.4.27",
|
||||||
|
"ddgs>=9.14.1",
|
||||||
|
"aiofiles>=24.1.0",
|
||||||
|
"fast-langdetect>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|||||||
0
src/__init__.py
Normal file → Executable file
0
src/__init__.py
Normal file → Executable file
0
src/core/__init__.py
Normal file → Executable file
0
src/core/__init__.py
Normal file → Executable file
28
src/core/config.py
Normal file → Executable file
28
src/core/config.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@@ -17,18 +19,44 @@ class Settings(BaseSettings):
|
|||||||
GOOGLE_CLOUD_PROJECT: str = Field(default="", description="")
|
GOOGLE_CLOUD_PROJECT: str = Field(default="", description="")
|
||||||
GOOGLE_CLOUD_LOCATION: str = Field(default="", description="")
|
GOOGLE_CLOUD_LOCATION: str = Field(default="", description="")
|
||||||
|
|
||||||
|
# --- google api 配置信息 ---
|
||||||
|
QWEN_API_KEY: str = Field(default="", description="")
|
||||||
|
|
||||||
# --- minio 配置信息 ---
|
# --- minio 配置信息 ---
|
||||||
MINIO_URL: str = Field(default='', description="")
|
MINIO_URL: str = Field(default='', description="")
|
||||||
MINIO_ACCESS: str = Field(default='', description="")
|
MINIO_ACCESS: str = Field(default='', description="")
|
||||||
MINIO_SECRET: str = Field(default='', description="")
|
MINIO_SECRET: str = Field(default='', description="")
|
||||||
MINIO_SECURE: bool = Field(default=True, description="")
|
MINIO_SECURE: bool = Field(default=True, description="")
|
||||||
|
|
||||||
|
# --- redis 配置信息 ---
|
||||||
|
REDIS_HOST: str = Field(default='', description="")
|
||||||
|
REDIS_PORT: str = Field(default='', description="")
|
||||||
|
REDIS_DB: int = Field(default=0, description="")
|
||||||
|
|
||||||
# --- mongodb配置信息 ---
|
# --- mongodb配置信息 ---
|
||||||
MONGODB_USERNAME: str = Field(default="", description="")
|
MONGODB_USERNAME: str = Field(default="", description="")
|
||||||
MONGODB_PASSWORD: str = Field(default="", description="")
|
MONGODB_PASSWORD: str = Field(default="", description="")
|
||||||
MONGODB_HOST: str = Field(default="localhost", description="")
|
MONGODB_HOST: str = Field(default="localhost", description="")
|
||||||
MONGODB_PORT: int = Field(default=27017, description="")
|
MONGODB_PORT: int = Field(default=27017, description="")
|
||||||
|
|
||||||
|
# --- 本地服务器配置信息 ---
|
||||||
|
IMAGE_TO_3D_MODEL_URL: str = Field(default='', description="")
|
||||||
|
FLUX2_GEN_IMG_MODEL_URL: str = Field(default='', description="")
|
||||||
|
SEG_ANYTHING: str = Field(default='', description="")
|
||||||
|
RABBITMQ_URL: str = Field(default='', description="")
|
||||||
|
|
||||||
|
# --- 外部工具api配置信息 ---
|
||||||
|
TAVILY_API_KEY: str = Field(default="", description="")
|
||||||
|
TRIPO_API_KEY: str = Field(default="", description="")
|
||||||
|
|
||||||
|
LOGS_PATH: str = Field(default="/mnt/data/FiDA/logs", description="")
|
||||||
|
|
||||||
|
SERVE_ENV: str = Field(default="dev", description="")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
MONGO_URI = f"mongodb://{settings.MONGODB_USERNAME}:{settings.MONGODB_PASSWORD}@{settings.MONGODB_HOST}:{settings.MONGODB_PORT}"
|
MONGO_URI = f"mongodb://{settings.MONGODB_USERNAME}:{settings.MONGODB_PASSWORD}@{settings.MONGODB_HOST}:{settings.MONGODB_PORT}"
|
||||||
|
|
||||||
|
TOOL_DIR = Path(__file__).resolve().parent
|
||||||
|
PROJECT_ROOT = TOOL_DIR.parent
|
||||||
|
print(f"PROJECT_ROOT : {PROJECT_ROOT}")
|
||||||
|
|||||||
0
src/server/agent/__init__.py → src/db/__init__.py
Normal file → Executable file
0
src/server/agent/__init__.py → src/db/__init__.py
Normal file → Executable file
49
src/db/init_mongodb.py
Executable file
49
src/db/init_mongodb.py
Executable file
@@ -0,0 +1,49 @@
|
|||||||
|
import asyncio
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from src.core.config import MONGO_URI
|
||||||
|
|
||||||
|
DB_NAME = "fida_mongo"
|
||||||
|
COLLECTION_NAME = "user_persona"
|
||||||
|
|
||||||
|
|
||||||
|
async def init_mongo():
|
||||||
|
client = AsyncIOMotorClient(
|
||||||
|
MONGO_URI,
|
||||||
|
maxPoolSize=50,
|
||||||
|
minPoolSize=5,
|
||||||
|
serverSelectionTimeoutMS=5000
|
||||||
|
)
|
||||||
|
|
||||||
|
db = client[DB_NAME]
|
||||||
|
|
||||||
|
# 查看已有集合
|
||||||
|
collections = await db.list_collection_names()
|
||||||
|
|
||||||
|
if COLLECTION_NAME not in collections:
|
||||||
|
print(f"Creating collection: {COLLECTION_NAME}")
|
||||||
|
await db.create_collection(COLLECTION_NAME)
|
||||||
|
|
||||||
|
collection = db[COLLECTION_NAME]
|
||||||
|
|
||||||
|
# 创建 thread_id 唯一索引
|
||||||
|
print("Creating index: thread_id_unique")
|
||||||
|
await collection.create_index(
|
||||||
|
"thread_id",
|
||||||
|
unique=True,
|
||||||
|
name="thread_id_unique"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 TTL 索引(30天自动删除)
|
||||||
|
print("Creating TTL index: updated_at_ttl")
|
||||||
|
await collection.create_index(
|
||||||
|
"updated_at",
|
||||||
|
expireAfterSeconds=2592000, # 30天
|
||||||
|
name="updated_at_ttl"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("MongoDB initialization completed.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(init_mongo())
|
||||||
17
src/db/mongo.py
Executable file
17
src/db/mongo.py
Executable file
@@ -0,0 +1,17 @@
|
|||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from src.core.config import MONGO_URI
|
||||||
|
|
||||||
|
client = AsyncIOMotorClient(
|
||||||
|
MONGO_URI,
|
||||||
|
maxPoolSize=50,
|
||||||
|
minPoolSize=5,
|
||||||
|
serverSelectionTimeoutMS=5000
|
||||||
|
)
|
||||||
|
|
||||||
|
db = client["fida_mongo"]
|
||||||
|
|
||||||
|
user_persona_collection = db["user_persona"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
0
src/routers/__init__.py
Normal file → Executable file
0
src/routers/__init__.py
Normal file → Executable file
121
src/routers/canvas_assistant.py
Executable file
121
src/routers/canvas_assistant.py
Executable file
@@ -0,0 +1,121 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
from src.schemas.canvas_assistant import TriggerRequest
|
||||||
|
from src.server.canvas_assistant.graph import graph
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/canvas", tags=["Furniture Canvas assistant"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_fiphant(req: TriggerRequest) -> AsyncGenerator[str, None]:
|
||||||
|
thread_id = f"canvas_{str(uuid.uuid4())}"
|
||||||
|
config = {"configurable": {"thread_id": thread_id}}
|
||||||
|
|
||||||
|
input_state = {
|
||||||
|
"messages": [],
|
||||||
|
"trigger": req.tool_name if req.action == "tool_trigger" else None,
|
||||||
|
"language": req.language,
|
||||||
|
"is_first_enter": req.action == "enter_canvas"
|
||||||
|
}
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'thread_id': thread_id, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
async for event in graph.astream(input_state, config, stream_mode="messages", version="v2"):
|
||||||
|
if event["type"] == "messages":
|
||||||
|
msg, metadata = event["data"]
|
||||||
|
payload_out = {
|
||||||
|
"node": metadata.get("langgraph_node", "unknown"),
|
||||||
|
"is_delta": True,
|
||||||
|
"content": msg.content,
|
||||||
|
"type": "text"
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/assistant")
|
||||||
|
async def canvas_assistant(req: TriggerRequest):
|
||||||
|
"""
|
||||||
|
### 接口说明
|
||||||
|
触发 Fiphant 设计助手返回消息。
|
||||||
|
|
||||||
|
支持两种场景:
|
||||||
|
- 用户进入画布时,自动返回欢迎引导语
|
||||||
|
- 用户点击画布中的工具按钮时,返回对应工具的使用说明
|
||||||
|
|
||||||
|
### 参数说明:
|
||||||
|
- **action**: 操作类型(必填)
|
||||||
|
- `enter_canvas`: 用户进入画布时调用(返回欢迎语)
|
||||||
|
- `tool_trigger`: 用户点击工具按钮时调用(返回工具说明)
|
||||||
|
|
||||||
|
- **tool_name**: 工具名称(当 action 为 tool_trigger 时必填)
|
||||||
|
支持以下值:
|
||||||
|
- `to_real_style`
|
||||||
|
- `surface_edit_canvas`
|
||||||
|
- `surface_edit_ai`
|
||||||
|
- `color_palette`
|
||||||
|
- `scene_composition`
|
||||||
|
- `3d_model`
|
||||||
|
- `to_3d_view`
|
||||||
|
|
||||||
|
- **language**: 返回语言(非必填,默认 zh)
|
||||||
|
- `zh`: 中文
|
||||||
|
- `en`: 英文
|
||||||
|
|
||||||
|
### 请求体示例:
|
||||||
|
|
||||||
|
**1. 进入画布时调用**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"action": "enter_canvas",
|
||||||
|
"language": "zh"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
2. 点击工具时调用(推荐)
|
||||||
|
```JSON
|
||||||
|
{
|
||||||
|
"action": "tool_trigger",
|
||||||
|
"tool_name": "3d_model",
|
||||||
|
"language": "zh"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
3. 点击工具 - 英文版
|
||||||
|
```JSON
|
||||||
|
{
|
||||||
|
"action": "tool_trigger",
|
||||||
|
"tool_name": "scene_composition",
|
||||||
|
"language": "en"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
输出说明:
|
||||||
|
返回 Server-Sent Events (SSE) 流式响应,文字会逐句出现,提升用户体验。
|
||||||
|
流式输出示例(实际返回内容):
|
||||||
|
|
||||||
|
消息开始结束:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"thread_id": "canvas_be76cb75-18ef-4e84-8e30-5d36aef5b83a",
|
||||||
|
"status": "start"
|
||||||
|
}
|
||||||
|
{
|
||||||
|
"status": "end"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
正文消息:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"node": "assistant",
|
||||||
|
"is_delta": true,
|
||||||
|
"content": "Hi,我是你的设计助手 Fiphant 👋 我来帮你快速上手这个画布。我给你准备了两个起点——你可以用 To Real Style 直接把草图变成效果图,也可以先用 Surface Edit 换个材质或贴上印花。有了产品图之后,我们再一起配色、配场景、看 3D 效果,最后导出三视图就完成了。我建议先从 To Real Style 开始,看看整体感觉 ✨",
|
||||||
|
"type": "text"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return StreamingResponse(stream_fiphant(req), media_type="text/event-stream")
|
||||||
@@ -1,244 +0,0 @@
|
|||||||
import uuid
|
|
||||||
import json
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
|
||||||
from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stream")
|
|
||||||
async def chat_stream(request: ChatRequest):
|
|
||||||
"""
|
|
||||||
### 家具设计流式对话接口 (SSE)
|
|
||||||
|
|
||||||
通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。
|
|
||||||
|
|
||||||
#### 1. 核心功能
|
|
||||||
* **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。
|
|
||||||
* **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。
|
|
||||||
* **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。
|
|
||||||
|
|
||||||
#### 2. 请求参数
|
|
||||||
* `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。
|
|
||||||
* `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。
|
|
||||||
* `checkpoint_id`: (可选) 历史快照 ID。
|
|
||||||
* `config_params`: (可选) 对话配置参数
|
|
||||||
* `require_suggestion`: (可选) 是否需要建议按钮
|
|
||||||
|
|
||||||
#### 3. 响应流说明 (Data Format)
|
|
||||||
响应以 `data: ` 开头的 JSON 字符串流形式发送:
|
|
||||||
- **Session Start**: `{"thread_id": "...", "status": "start"}`
|
|
||||||
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
|
|
||||||
- **Session End**: `{"status": "end"}`
|
|
||||||
|
|
||||||
#### 4. 请求示例
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"message": "设计一款北欧风格的躺椅."
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
"message": "就以上信息直接生成sketch.",
|
|
||||||
"thread_id": "187e58af"
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
"message": "不要躺椅,要桌子",
|
|
||||||
"thread_id": "187e58af",
|
|
||||||
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
source_thread_id = request.thread_id
|
|
||||||
checkpoint_id = request.checkpoint_id
|
|
||||||
|
|
||||||
# 1. 确定目标 thread_id
|
|
||||||
# 如果是回溯操作,我们生成一个新的 ID,或者由前端传入一个新的 target_thread_id
|
|
||||||
is_branching = source_thread_id and checkpoint_id
|
|
||||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
|
||||||
# 2. 获取配置参数
|
|
||||||
temp = request.config_params.temperature if request.config_params else 0.7
|
|
||||||
|
|
||||||
# 构建基础 Config
|
|
||||||
current_config = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": target_thread_id,
|
|
||||||
"llm_temperature": temp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
# 3. 处理状态初始化与分支
|
|
||||||
initial_messages = []
|
|
||||||
|
|
||||||
# 如果是全新的对话(没有 source_thread_id),或者明确要求分叉
|
|
||||||
if not source_thread_id or is_branching:
|
|
||||||
# 如果用户传了标签,构造 SystemMessage 注入上下文
|
|
||||||
if request.config_params:
|
|
||||||
cp = request.config_params
|
|
||||||
system_prompt = (
|
|
||||||
f"Current furniture design background settings:\n"
|
|
||||||
f"- type: {cp.type}\n"
|
|
||||||
f"- space/region: {cp.region}\n"
|
|
||||||
f"- style tendency: {cp.style}\n"
|
|
||||||
f"Please strictly follow the above settings in subsequent conversations。"
|
|
||||||
)
|
|
||||||
initial_messages.append(SystemMessage(content=system_prompt))
|
|
||||||
|
|
||||||
# 4. 执行分叉逻辑(搬运旧数据)
|
|
||||||
if is_branching:
|
|
||||||
source_config = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": source_thread_id,
|
|
||||||
"checkpoint_id": checkpoint_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
older_state = await app.aget_state(source_config)
|
|
||||||
|
|
||||||
# 将旧消息和我们新定义的 SystemMessage 合并
|
|
||||||
# update_state 会将这些消息推送到新 thread 的存储中
|
|
||||||
combined_values = older_state.values.copy()
|
|
||||||
if initial_messages:
|
|
||||||
combined_values["messages"] = list(combined_values["messages"]) + initial_messages
|
|
||||||
|
|
||||||
await app.aupdate_state(current_config, combined_values)
|
|
||||||
|
|
||||||
async def event_generator():
|
|
||||||
# 初始推送状态信息
|
|
||||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
# 构造本次请求的输入
|
|
||||||
# 如果是第一次开始,且有 initial_messages,则连同 user message 一起发送
|
|
||||||
# --- 核心逻辑:构造本次请求的消息列表 ---
|
|
||||||
new_messages = []
|
|
||||||
if not source_thread_id and initial_messages:
|
|
||||||
new_messages.extend(initial_messages)
|
|
||||||
# 添加用户消息
|
|
||||||
new_messages.append(HumanMessage(content=request.message))
|
|
||||||
|
|
||||||
# --- 新增:强制绘图指令注入 ---
|
|
||||||
# if request.force_sketch:
|
|
||||||
# force_instruction = HumanMessage(
|
|
||||||
# content="[SYSTEM_DIRECTIVE]: 用户点击了强制生成按钮。请立即根据当前上下文调用 generate_furniture_sketch 工具生成草图,无需确认。"
|
|
||||||
# )
|
|
||||||
# new_messages.append(force_instruction)
|
|
||||||
|
|
||||||
input_data = {
|
|
||||||
"messages": new_messages,
|
|
||||||
"require_suggestion": request.need_suggestion # 初始由前端决定
|
|
||||||
}
|
|
||||||
|
|
||||||
async for event in app.astream(
|
|
||||||
input_data,
|
|
||||||
current_config,
|
|
||||||
stream_mode="updates"
|
|
||||||
):
|
|
||||||
for node_name, output in event.items():
|
|
||||||
if "messages" in output:
|
|
||||||
# 获取最新 state 以获取 checkpoint_id
|
|
||||||
state = await app.aget_state(current_config)
|
|
||||||
current_cp_id = state.config["configurable"].get("checkpoint_id")
|
|
||||||
|
|
||||||
# 遍历本次 update 产生的所有消息
|
|
||||||
for msg in output["messages"]:
|
|
||||||
payload = {
|
|
||||||
"node": node_name,
|
|
||||||
"content": "",
|
|
||||||
"image_url": None,
|
|
||||||
"checkpoint_id": current_cp_id,
|
|
||||||
"suggestions": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- 核心改动:提取建议按钮 ---
|
|
||||||
# 无论是不是 Suggester 节点,只要消息里带了建议就提取
|
|
||||||
if hasattr(msg, "additional_kwargs") and "suggestions" in msg.additional_kwargs:
|
|
||||||
payload["suggestions"] = msg.additional_kwargs["suggestions"]
|
|
||||||
|
|
||||||
content = msg.content
|
|
||||||
# 逻辑判断:MinIO 图片处理
|
|
||||||
if node_name == "Visualizer" and str(content).endswith("png") and "furniture/sketches" in str(content):
|
|
||||||
payload["image_url"] = content
|
|
||||||
payload["content"] = "已为您生成设计草图"
|
|
||||||
else:
|
|
||||||
payload["content"] = content
|
|
||||||
|
|
||||||
# 如果消息既没有文本、也没有图片、也没有建议(比如中间的 ToolCall 消息),则跳过
|
|
||||||
if not payload["content"] and not payload["image_url"] and not payload["suggestions"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/history/{thread_id}", response_model=HistoryResponse)
|
|
||||||
async def get_chat_history(thread_id: str):
|
|
||||||
"""
|
|
||||||
### 获取项目设计历史记录
|
|
||||||
|
|
||||||
此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。
|
|
||||||
|
|
||||||
#### 1. 功能说明
|
|
||||||
* **快照列表**: 返回该项目从启动至今的所有关键节点(Checkpoints)。
|
|
||||||
* **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。
|
|
||||||
* **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。
|
|
||||||
|
|
||||||
#### 2. 路径参数
|
|
||||||
* `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。
|
|
||||||
|
|
||||||
#### 3. 返回字段定义
|
|
||||||
* `thread_id`: 当前查询的项目ID。
|
|
||||||
* `history`: 历史记录数组,包含:
|
|
||||||
- `checkpoint_id`: 必填,回溯时使用的关键凭证。
|
|
||||||
- `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。
|
|
||||||
- `node`: 产生该快照的节点名称(如 Designer, Visualizer)。
|
|
||||||
- `timestamp`: 逻辑步骤序号。
|
|
||||||
|
|
||||||
#### 4. 响应示例
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"thread_id": "proj_001",
|
|
||||||
"history": [
|
|
||||||
{
|
|
||||||
"checkpoint_id": "d82f3a12",
|
|
||||||
"last_message": "我想设计一款北欧风书架",
|
|
||||||
"node": "Supervisor",
|
|
||||||
"timestamp": 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"checkpoint_id": "f4k92m1a",
|
|
||||||
"last_message": "建议使用浅色橡木材质,增加简约感...",
|
|
||||||
"node": "Designer",
|
|
||||||
"timestamp": 2
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
|
||||||
history_data = []
|
|
||||||
async for state in app.aget_state_history(config):
|
|
||||||
msg_content = "Initial"
|
|
||||||
if state.values and "messages" in state.values:
|
|
||||||
msgs = state.values["messages"]
|
|
||||||
if msgs and len(msgs) > 0:
|
|
||||||
last_msg = msgs[-1]
|
|
||||||
# 获取内容并做摘要截断
|
|
||||||
content = getattr(last_msg, "content", str(last_msg))
|
|
||||||
msg_content = content
|
|
||||||
|
|
||||||
history_data.append(HistoryItem(
|
|
||||||
checkpoint_id=state.config["configurable"]["checkpoint_id"],
|
|
||||||
last_message=msg_content,
|
|
||||||
node=state.metadata.get("source"),
|
|
||||||
timestamp=state.metadata.get("step")
|
|
||||||
))
|
|
||||||
|
|
||||||
return HistoryResponse(thread_id=thread_id, history=history_data)
|
|
||||||
# try:
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# raise HTTPException(status_code=404, detail=f"History not found: {str(e)}")
|
|
||||||
481
src/routers/deep_agent_chat.py
Executable file
481
src/routers/deep_agent_chat.py
Executable file
@@ -0,0 +1,481 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from minio import Minio
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from langchain_core.messages import SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
||||||
|
from minio.commonconfig import CopySource
|
||||||
|
|
||||||
|
from src.core.config import PROJECT_ROOT, settings
|
||||||
|
from src.server.deep_agent.agents.main_agent import build_main_agent, Context
|
||||||
|
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
|
||||||
|
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
|
||||||
|
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
|
||||||
|
from src.server.utils.new_oss_client import get_presigned_url
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/deep_agent_stream")
|
||||||
|
async def chat_stream(request: DeepAgentChatRequest):
|
||||||
|
"""
|
||||||
|
### 家具设计流式对话接口 (SSE)
|
||||||
|
|
||||||
|
通过此接口与 AI 家具设计专家团队进行实时沟通。支持 **记忆持久化** 和 **历史回溯分叉**。
|
||||||
|
|
||||||
|
#### 1. 核心功能
|
||||||
|
* **实时反馈**: 采用 Server-Sent Events (SSE) 技术,实时推送主管、设计师、视觉专家等节点的思考过程。
|
||||||
|
* **上下文记忆**: 传入 `thread_id` 即可恢复之前的对话进度。
|
||||||
|
* **版本分溯**: 传入 `checkpoint_id` 可准确定位到历史中的某一轮,并从该点开启新的设计分支。
|
||||||
|
|
||||||
|
#### 2. 请求参数
|
||||||
|
* `message`: 用户的设计意图(如:'我想设计一个极简风格的橡木办公桌')。
|
||||||
|
* `enable_thinking`: 是否开启思考模式。
|
||||||
|
* `quote_image_path`: 用户引用图片地址 如:"fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"。
|
||||||
|
* `input_image_paths`: 用户上传图片地址集合如:["fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"]。
|
||||||
|
* `thread_id`: (可选) 现有项目的唯一标识。若不传,系统将自动分配并返回。
|
||||||
|
* `checkpoint_id`: (可选) 历史快照 ID。
|
||||||
|
* `config_params`: (可选) 对话配置参数
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"message": "你好",
|
||||||
|
"config_params": {
|
||||||
|
"type": "test",
|
||||||
|
"region": "test",
|
||||||
|
"style": "test"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
* `need_suggestion`: (可选) 是否需要建议按钮,需要建议的频率,0-1的浮点数
|
||||||
|
* `use_report`: (可选) 是否需要使用report功能 true/false
|
||||||
|
|
||||||
|
|
||||||
|
#### 3. 响应流说明 (Data Format)
|
||||||
|
响应以 `data: ` 开头的 JSON 字符串流形式发送:
|
||||||
|
- **Session Start**: `{"thread_id": "...", "status": "start"}`
|
||||||
|
- **Node Message**: `{"node": "Designer", "content": "...", "checkpoint_id": "..."}`
|
||||||
|
- **Session End**: `{"status": "end"}`
|
||||||
|
|
||||||
|
- **is_delta**: False/True,表示这个消息不是完整内容,只是 AI 正在生成的一小段内容(一个字、一个词、一句话),需要前端把这些片段拼接起来才能得到完整的回答。
|
||||||
|
|
||||||
|
#### 4. 请求示例
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"message": "设计一款北欧风格的躺椅."
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
"message": "就以上信息直接生成sketch.",
|
||||||
|
"thread_id": "187e58af"
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
"message": "不要躺椅,要桌子",
|
||||||
|
"thread_id": "187e58af",
|
||||||
|
"checkpoint_id": "1f101aa2-8f24-6e2a-8001-2952c3a7447a"
|
||||||
|
}
|
||||||
|
用户上传:
|
||||||
|
{
|
||||||
|
"message": "合并两张图一边一半,左右拼",
|
||||||
|
"input_image_paths": [
|
||||||
|
"fida-test/furniture/sketches/218adbd2-c312-4298-9a82-5a92601ac9e2.png",
|
||||||
|
"fida-test/furniture/sketches/8a1804d1-5ac9-4d02-bf17-e65fa7272f65.png"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
用户引用:
|
||||||
|
{
|
||||||
|
"message": "描述这张图",
|
||||||
|
"quote_image_path":"fida-test/furniture/sketches/218adbd2-c312-4298-9a82-5a92601ac9e2.png"
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 响应流说明
|
||||||
|
所有响应均以 data: 开头,JSON 字符串格式,末尾以 \n\n 结束
|
||||||
|
响应流包含三种类型的事件:会话开始、节点消息、会话结束
|
||||||
|
|
||||||
|
"""
|
||||||
|
# ===================== 简洁优化版 =====================
|
||||||
|
# 1. 线程与标题标记
|
||||||
|
need_title = not request.thread_id
|
||||||
|
source_thread_id = request.thread_id
|
||||||
|
checkpoint_id = request.checkpoint_id
|
||||||
|
|
||||||
|
# 2. 目标线程 ID
|
||||||
|
is_branching = all([source_thread_id, checkpoint_id])
|
||||||
|
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||||
|
|
||||||
|
# 3. Agent 初始化
|
||||||
|
workspace_dir = os.path.join(PROJECT_ROOT, "agent_workspace", target_thread_id)
|
||||||
|
logger.info(f"chat request data: {request} | target_thread_id: {target_thread_id}, workspace_dir: {workspace_dir}")
|
||||||
|
main_agent = build_main_agent(workspace_dir, request.enable_thinking)
|
||||||
|
|
||||||
|
# 4. 配置
|
||||||
|
temp = request.config_params.temperature if request.config_params else 0.7
|
||||||
|
current_config = {
|
||||||
|
"recursion_limit": 120,
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": target_thread_id,
|
||||||
|
"llm_temperature": temp,
|
||||||
|
"use_report": request.use_report,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. 初始化系统消息
|
||||||
|
initial_messages = []
|
||||||
|
design_backend = ""
|
||||||
|
if not source_thread_id or is_branching:
|
||||||
|
cp = request.config_params
|
||||||
|
if cp:
|
||||||
|
config_items = [
|
||||||
|
("type", cp.type),
|
||||||
|
("space/region", cp.region),
|
||||||
|
("style tendency", cp.style)
|
||||||
|
]
|
||||||
|
valid_lines = [f"- {k}: {v}" for k, v in config_items if v]
|
||||||
|
if valid_lines:
|
||||||
|
system_prompt = (
|
||||||
|
"Current furniture design background settings:\n"
|
||||||
|
+ "\n".join(valid_lines) + "\n"
|
||||||
|
"Please strictly follow the above settings in subsequent conversations。"
|
||||||
|
)
|
||||||
|
initial_messages.append(SystemMessage(content=system_prompt))
|
||||||
|
design_backend = f"""
|
||||||
|
<design_constraints>
|
||||||
|
Category: {cp.type or 'unspecified'}
|
||||||
|
region: {cp.region or 'unspecified'}
|
||||||
|
style: {cp.style or 'unspecified'}
|
||||||
|
</design_constraints>
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 6. 分支处理
|
||||||
|
if is_branching:
|
||||||
|
source_config = {"configurable": {"thread_id": source_thread_id, "checkpoint_id": checkpoint_id}}
|
||||||
|
last_checkpoint_id = await get_branch_checkpoint_id(main_agent, source_config)
|
||||||
|
older_state = await main_agent.aget_state(source_config)
|
||||||
|
combined_values = older_state.values.copy()
|
||||||
|
if initial_messages:
|
||||||
|
combined_values["messages"] = combined_values.get("messages", []) + initial_messages
|
||||||
|
await main_agent.aupdate_state(current_config, combined_values)
|
||||||
|
else:
|
||||||
|
last_checkpoint_id = await get_checkpoint_id(main_agent, current_config)
|
||||||
|
|
||||||
|
# 7. 事件流生成
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
is_first = True
|
||||||
|
content = [{"type": "text", "text": request.message}]
|
||||||
|
files = {
|
||||||
|
"input_image": [],
|
||||||
|
"quote_image": "",
|
||||||
|
"current_image": ""
|
||||||
|
}
|
||||||
|
input_image_content = ""
|
||||||
|
|
||||||
|
# 处理上传图片
|
||||||
|
if request.input_image_paths:
|
||||||
|
input_image_content += "\n【附件上传图片路径】\n"
|
||||||
|
for i, path in enumerate(request.input_image_paths):
|
||||||
|
input_image_content += f"- 上传图片{i}: {path}\n"
|
||||||
|
bucket, obj = path.split("/", 1)
|
||||||
|
minio_client.copy_object("fida-public-bucket", path, CopySource(bucket, obj))
|
||||||
|
image_url = f"https://www.minio-api.aida.com.hk/fida-public-bucket/{path}"
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||||
|
files["input_image"].append(path)
|
||||||
|
|
||||||
|
# 处理引用图片
|
||||||
|
if request.quote_image_path:
|
||||||
|
input_image_content += "\n【附件引用图片路径】\n"
|
||||||
|
input_image_content += f"- 引用图片: {request.quote_image_path}\n"
|
||||||
|
bucket, obj = request.quote_image_path.split("/", 1)
|
||||||
|
minio_client.copy_object("fida-public-bucket", request.quote_image_path, CopySource(bucket, obj))
|
||||||
|
image_url = f"https://www.minio-api.aida.com.hk/fida-public-bucket/{request.quote_image_path}"
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||||
|
files["quote_image"] = request.quote_image_path
|
||||||
|
|
||||||
|
# 追加文本内容
|
||||||
|
if input_image_content:
|
||||||
|
content[0]["text"] += input_image_content
|
||||||
|
|
||||||
|
if design_backend:
|
||||||
|
content[0]["text"] += design_backend
|
||||||
|
|
||||||
|
message_list = [{"role": "user", "content": content}]
|
||||||
|
final_messages = {"messages": message_list, "files": files}
|
||||||
|
logger.info(final_messages)
|
||||||
|
|
||||||
|
config_content_type = f"- type: {request.config_params.type}\n" if request.config_params.type else ""
|
||||||
|
config_content_region = f"- region: {request.config_params.region}\n" if request.config_params.region else ""
|
||||||
|
config_content_style = f"- style: {request.config_params.style}\n" if request.config_params.style else ""
|
||||||
|
|
||||||
|
async for stream in main_agent.astream(
|
||||||
|
final_messages,
|
||||||
|
config=current_config,
|
||||||
|
stream_mode=["updates", "messages", "custom"],
|
||||||
|
subgraphs=True,
|
||||||
|
context=Context(use_report=request.use_report,
|
||||||
|
language=request.language,
|
||||||
|
type=request.config_params.type,
|
||||||
|
region=request.config_params.region,
|
||||||
|
style=request.config_params.style,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
_, mode, chunks = stream
|
||||||
|
if is_first:
|
||||||
|
checkpoint_id = get_latest_checkpoint_id(main_agent, current_config)
|
||||||
|
if not checkpoint_id:
|
||||||
|
print("123")
|
||||||
|
main_agent.store.put(
|
||||||
|
("image_history",),
|
||||||
|
"checkpoint_id",
|
||||||
|
{
|
||||||
|
"current_checkpoint_id": checkpoint_id,
|
||||||
|
"last_checkpoint_id": last_checkpoint_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(f"*******************{checkpoint_id}**********************************")
|
||||||
|
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start', "checkpoint_id": checkpoint_id}, ensure_ascii=False)}\n\n"
|
||||||
|
is_first = False
|
||||||
|
if mode == "updates": # 只做记录 不做事件返回
|
||||||
|
logger.info(f"[updates] -- {chunks}")
|
||||||
|
|
||||||
|
update_model_messages = chunks.get("model", None)
|
||||||
|
update_tools_messages = chunks.get("tools", None)
|
||||||
|
payload_out = {
|
||||||
|
"node": "",
|
||||||
|
"is_delta": False,
|
||||||
|
"content": "",
|
||||||
|
"type": "updates"
|
||||||
|
}
|
||||||
|
|
||||||
|
if update_model_messages:
|
||||||
|
model_messages = update_model_messages.get("messages", [])
|
||||||
|
for model_token in model_messages:
|
||||||
|
if isinstance(model_token, AIMessage):
|
||||||
|
model_name = model_token.name
|
||||||
|
payload_out.update({
|
||||||
|
"node": model_name if model_name else "main",
|
||||||
|
"tool_calls": model_token.tool_calls
|
||||||
|
})
|
||||||
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
|
elif update_tools_messages:
|
||||||
|
tools_messages = update_tools_messages.get("messages", [])
|
||||||
|
for tools_token in tools_messages:
|
||||||
|
if isinstance(tools_token, ToolMessage):
|
||||||
|
tool_content_blocks = tools_token.content_blocks[0] if tools_token.content_blocks else None
|
||||||
|
tool_name = tools_token.name
|
||||||
|
logger.info(f"[updates] {tool_name} -- {tool_content_blocks}")
|
||||||
|
|
||||||
|
elif mode == "messages":
|
||||||
|
# logger.info(f"[messages] -- {chunks}")
|
||||||
|
|
||||||
|
token, metadata = chunks
|
||||||
|
subagent_name = metadata.get('lc_agent_name', "main")
|
||||||
|
payload_out = {
|
||||||
|
"node": subagent_name,
|
||||||
|
"is_delta": False,
|
||||||
|
"content": "",
|
||||||
|
"type": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(token, AIMessageChunk): # 默认回复 思考内容
|
||||||
|
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
|
||||||
|
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||||
|
if reasoning:
|
||||||
|
if len(reasoning) == 1:
|
||||||
|
payload_out.update({
|
||||||
|
"type": "reasoning",
|
||||||
|
"is_delta": True,
|
||||||
|
"content": reasoning[0].get("reasoning", ""),
|
||||||
|
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
logger.info(f"[reasoning] {reasoning}*************************************************************************************")
|
||||||
|
elif text:
|
||||||
|
if len(text) == 1:
|
||||||
|
payload_out.update({
|
||||||
|
"type": "text",
|
||||||
|
"is_delta": True,
|
||||||
|
"content": text[0].get("text", ""),
|
||||||
|
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
logger.info(f"[text] {text}*************************************************************************************")
|
||||||
|
else:
|
||||||
|
payload_out.update({
|
||||||
|
"type": "tool_call",
|
||||||
|
"is_delta": True,
|
||||||
|
})
|
||||||
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
|
elif isinstance(token, ToolMessageChunk): # 工具返回
|
||||||
|
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||||
|
payload_out.update({
|
||||||
|
"type": "tool_result",
|
||||||
|
"is_delta": False,
|
||||||
|
"content": text,
|
||||||
|
"tool_name": token.name,
|
||||||
|
})
|
||||||
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
|
elif isinstance(token, ToolMessage): # 工具返回
|
||||||
|
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||||
|
payload_out.update({
|
||||||
|
"type": "tool_result",
|
||||||
|
"is_delta": False,
|
||||||
|
"content": text,
|
||||||
|
"tool_name": token.name,
|
||||||
|
})
|
||||||
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif mode == "custom":
|
||||||
|
logger.info(f"[custom] -- {chunks}")
|
||||||
|
|
||||||
|
payload_out = {
|
||||||
|
"node": "research-agent",
|
||||||
|
"is_delta": False,
|
||||||
|
"content": "",
|
||||||
|
"type": ""
|
||||||
|
}
|
||||||
|
delta = chunks.get("delta", "")
|
||||||
|
payload_out.update({
|
||||||
|
"type": chunks.get("type", ""),
|
||||||
|
"is_delta": True,
|
||||||
|
"content": delta,
|
||||||
|
})
|
||||||
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 获取建议消息
|
||||||
|
if request.need_suggestion > 0 and random.random() < request.need_suggestion:
|
||||||
|
suggested_questions = await generate_suggested_questions(main_agent, target_thread_id)
|
||||||
|
yield f"data: {json.dumps({'suggested_questions': suggested_questions}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 获取标题
|
||||||
|
if need_title:
|
||||||
|
title = await conversation_title(agent=main_agent, config=current_config)
|
||||||
|
logger.info(f"[title] {title}")
|
||||||
|
yield f"data: {json.dumps({'title': title}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history/{thread_id}", response_model=HistoryResponse)
|
||||||
|
async def get_chat_history(thread_id: str):
|
||||||
|
"""
|
||||||
|
### 获取项目设计历史记录
|
||||||
|
|
||||||
|
此接口用于拉取指定 `thread_id` 下的所有历史状态快照。它是实现 **“版本回溯”** 和 **“方案对比”** 的核心数据来源。
|
||||||
|
|
||||||
|
#### 1. 功能说明
|
||||||
|
* **快照列表**: 返回该项目从启动至今的所有关键节点(Checkpoints)。
|
||||||
|
* **版本定位**: 每个历史点都包含一个唯一的 `checkpoint_id`。
|
||||||
|
* **数据回溯**: 客户端获取此列表后,可以引导用户选择任意一个版本,并将其 `checkpoint_id` 传回 `/chat/stream` 接口以开启新的设计分支。
|
||||||
|
|
||||||
|
#### 2. 路径参数
|
||||||
|
* `thread_id`: 设计项目的唯一标识符(由 `/chat/stream` 首次调用时生成或指定)。
|
||||||
|
|
||||||
|
#### 3. 返回字段定义
|
||||||
|
* `thread_id`: 当前查询的项目ID。
|
||||||
|
* `history`: 历史记录数组,包含:
|
||||||
|
- `checkpoint_id`: 必填,回溯时使用的关键凭证。
|
||||||
|
- `last_message`: 该阶段的最后一条消息摘要(方便前端预览)。
|
||||||
|
- `node`: 产生该快照的节点名称(如 Designer, Visualizer)。
|
||||||
|
- `timestamp`: 逻辑步骤序号。
|
||||||
|
|
||||||
|
#### 4. 响应示例
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"thread_id": "proj_001",
|
||||||
|
"history": [
|
||||||
|
{
|
||||||
|
"checkpoint_id": "d82f3a12",
|
||||||
|
"last_message": "我想设计一款北欧风书架",
|
||||||
|
"node": "Supervisor",
|
||||||
|
"timestamp": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"checkpoint_id": "f4k92m1a",
|
||||||
|
"last_message": "建议使用浅色橡木材质,增加简约感...",
|
||||||
|
"node": "Designer",
|
||||||
|
"timestamp": 2
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
config = {"configurable": {"thread_id": thread_id}, }
|
||||||
|
history_data = []
|
||||||
|
|
||||||
|
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{thread_id}")
|
||||||
|
main_agent = build_main_agent(workspace_dir, enable_thinking=False)
|
||||||
|
async for state in main_agent.aget_state_history(config):
|
||||||
|
msg_content = "Initial"
|
||||||
|
if state.values and "messages" in state.values:
|
||||||
|
msgs = state.values["messages"]
|
||||||
|
if msgs and len(msgs) > 0:
|
||||||
|
last_msg = msgs[-1]
|
||||||
|
# 获取内容并做摘要截断
|
||||||
|
content = getattr(last_msg, "content", str(last_msg))
|
||||||
|
msg_content = content
|
||||||
|
|
||||||
|
history_data.append(HistoryItem(
|
||||||
|
checkpoint_id=state.config["configurable"]["checkpoint_id"],
|
||||||
|
last_message=msg_content,
|
||||||
|
node=state.metadata.get("source"),
|
||||||
|
timestamp=state.metadata.get("step")
|
||||||
|
))
|
||||||
|
|
||||||
|
return HistoryResponse(thread_id=thread_id, history=history_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_checkpoint_id(main_agent, current_config):
|
||||||
|
# 🔥 最优:边遍历边找,找到第一个就返回,不浪费内存
|
||||||
|
async for item in main_agent.aget_state_history(config=current_config):
|
||||||
|
if item.next == ("__start__",):
|
||||||
|
# 找到直接处理并返回
|
||||||
|
# if item.parent_config:
|
||||||
|
# return item.parent_config.get('configurable', {}).get('checkpoint_id')
|
||||||
|
return item.config.get('configurable', {}).get('checkpoint_id')
|
||||||
|
# 没找到
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_branch_checkpoint_id(main_agent, current_config):
|
||||||
|
# 🔥 最优:边遍历边找,找到第一个就返回,不浪费内存
|
||||||
|
async for item in main_agent.aget_state_history(config=current_config):
|
||||||
|
current_id = current_config.get('configurable', {}).get('checkpoint_id')
|
||||||
|
if item.next == ("__start__",) and item.config.get('configurable', {}).get('checkpoint_id') != current_id:
|
||||||
|
if item.parent_config:
|
||||||
|
if item.parent_config.get('configurable', {}).get('checkpoint_id') != current_id:
|
||||||
|
return item.config.get('configurable', {}).get('checkpoint_id')
|
||||||
|
else:
|
||||||
|
return item.config.get('configurable', {}).get('checkpoint_id')
|
||||||
|
# 没找到
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest_checkpoint_id(agent, config):
|
||||||
|
# 先尝试直接 get_state
|
||||||
|
state = agent.get_state(config)
|
||||||
|
checkpoint_id = state.config.get("configurable", {}).get("checkpoint_id")
|
||||||
|
|
||||||
|
if checkpoint_id:
|
||||||
|
return checkpoint_id
|
||||||
|
|
||||||
|
# 如果是 None 或空,使用 history 取最新一条(history[0] 永远是最新的)
|
||||||
|
print("checkpoint_id 为 None,使用 get_state_history 兜底...")
|
||||||
|
history = list(agent.get_state_history(config))
|
||||||
|
if history:
|
||||||
|
checkpoint_id = history[0].config["configurable"]["checkpoint_id"]
|
||||||
|
print(f"从 history 获取到最新 checkpoint_id: {checkpoint_id}")
|
||||||
|
return checkpoint_id
|
||||||
|
|
||||||
|
return None
|
||||||
75
src/routers/flux2_gen_img.py
Executable file
75
src/routers/flux2_gen_img.py
Executable file
@@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
from src.schemas.flux2_gen_img import Flux2_Gen_Img_Model
|
||||||
|
from src.schemas.response_template import ResponseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(prefix="/canvas", tags=["Furniture Canvas"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/flux2_gen_img")
|
||||||
|
async def flux2_gen_img(request_data: Flux2_Gen_Img_Model):
|
||||||
|
"""
|
||||||
|
### 参数说明:
|
||||||
|
|
||||||
|
- **bucket_name**: OSS桶名 (必填)
|
||||||
|
- **object_name**: OSS对象名(文件路径)(必填)
|
||||||
|
|
||||||
|
- **input_image_paths**: 输入图片路径列表 (非必填,默认[])
|
||||||
|
- **width**: 图片宽度,默认512像素 (非必填,默认512)
|
||||||
|
- **height**: 图片高度,默认512像素 (非必填,默认512)
|
||||||
|
- **prompt**: 文本提示词,用于模型推理等场景 (非必填,默认"")
|
||||||
|
- **steps**: 推理步数,控制模型生成过程的迭代次数 (非必填,默认4)
|
||||||
|
- **guidance**: 引导系数,调节提示词对生成结果的影响程度 (非必填,默认 4.0 )
|
||||||
|
|
||||||
|
### 请求体示例:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"input_image_paths": ["test/typical_building_space_station.png","test/typical_creature_dragon.png"],
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"bucket_name": "test",
|
||||||
|
"object_name": "generated_images/result.jpg",
|
||||||
|
"prompt": "a beautiful landscape with mountains and rivers",
|
||||||
|
"steps": 4,
|
||||||
|
"guidance": 4.0
|
||||||
|
}
|
||||||
|
````
|
||||||
|
|
||||||
|
### 输出示例:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "OK!",
|
||||||
|
"data": {
|
||||||
|
"output_path": "test/generated_images/result.jpg"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"flux2_gen_img request: {json.dumps(request_data.model_dump(), indent=4)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
|
||||||
|
json=request_data.model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 200:
|
||||||
|
result = resp.json()
|
||||||
|
logger.info(f"flux2_gen_img response: {json.dumps(result, indent=4)}")
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
else:
|
||||||
|
error = resp.json()
|
||||||
|
logger.info(f"flux2_gen_img response: {json.dumps(error, indent=4)}")
|
||||||
|
return ResponseModel(data=error, msg="ERROR!", code=500)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"img_to_3D Run Exception: {e}")
|
||||||
|
return ResponseModel(data=e, msg="ERROR!", code=500)
|
||||||
413
src/routers/generate_3D.py
Executable file
413
src/routers/generate_3D.py
Executable file
@@ -0,0 +1,413 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, BackgroundTasks
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
from src.schemas.generate_3D import ImageTo3DRequest, ToSVGRequest, Tripo3dApiModel
|
||||||
|
from src.schemas.response_template import ResponseModel
|
||||||
|
from src.server.canvas_generate_3D.server import submit_img_to_3d_task, submit_three_d_to_3views_task
|
||||||
|
from src.server.canvas_generate_3D.triop3d_api_server import create_single_task, create_multi_task, get_task_result_async, single_img_to_model_async
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/canvas", tags=["Furniture Canvas"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
img_to_3d_semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/img_to_3D_v2")
|
||||||
|
async def img_to_3D_v2(request_data: ImageTo3DRequest):
|
||||||
|
"""
|
||||||
|
### 接口说明:
|
||||||
|
将图片转换为3D模型(异步处理)。接口接收请求后立即返回任务ID,后台通过 Celery 处理,处理完成后结果会通过 RabbitMQ 发送。
|
||||||
|
|
||||||
|
### 参数说明:
|
||||||
|
- **input_images**: 输入图片路径列表(支持单张或多张)
|
||||||
|
- **model**: 推理模式,`single` 表示单张图片,`multi` 表示多张图片融合
|
||||||
|
|
||||||
|
### 请求体示例:
|
||||||
|
**单张图片模式:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input_images": [
|
||||||
|
"test/img_to_3d_data/example_multi_image/character_1.png"
|
||||||
|
],
|
||||||
|
"bucket_name": "test",
|
||||||
|
"user_id": "123",
|
||||||
|
"model": "single",
|
||||||
|
"task_id": "123",
|
||||||
|
"callback_url": "https://example.com/"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
**多张图片模式:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input_images": [
|
||||||
|
"test/img_to_3d_data/example_multi_image/character_1.png",
|
||||||
|
"test/img_to_3d_data/example_multi_image/character_2.png",
|
||||||
|
"test/img_to_3d_data/example_multi_image/character_3.png"
|
||||||
|
],
|
||||||
|
"bucket_name": "test",
|
||||||
|
"user_id": "123",
|
||||||
|
"model": "multi",
|
||||||
|
"task_id": "123",
|
||||||
|
"callback_url": "https://example.com/"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 输出示例:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "OK!",
|
||||||
|
"data": {
|
||||||
|
"state": "success",
|
||||||
|
"task_id": "123",
|
||||||
|
"message": "任务已成功提交,正在后台处理..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 429,
|
||||||
|
"message": "ok",
|
||||||
|
"data": {
|
||||||
|
"status": "queue_full",
|
||||||
|
"task_id": "123",
|
||||||
|
"message": "当前 3D 生成请求较多,请稍后重试。",
|
||||||
|
"queue_length": 10,
|
||||||
|
"max_length": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 500,
|
||||||
|
"message": "ok",
|
||||||
|
"data": {
|
||||||
|
"status": "fail",
|
||||||
|
"task_id": "123",
|
||||||
|
"message": "提交失败,请稍后重试。",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
logger.info(f"img_to_3D_v2 request: {json.dumps(request_data.model_dump(), indent=4)}")
|
||||||
|
result = submit_img_to_3d_task(
|
||||||
|
input_images=request_data.input_images, model=request_data.model,
|
||||||
|
task_id=request_data.task_id, callback_url=request_data.callback_url,
|
||||||
|
bucket_name=request_data.bucket_name, user_id=request_data.user_id
|
||||||
|
)
|
||||||
|
if result.get("state") == "success":
|
||||||
|
state_code = 200
|
||||||
|
elif result.get("state") == "queue_full":
|
||||||
|
state_code = 429
|
||||||
|
else:
|
||||||
|
state_code = 500
|
||||||
|
|
||||||
|
return ResponseModel(data=result, code=state_code)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/3d_to_3views_v2")
|
||||||
|
async def model_to_3views_v2(request_data: ToSVGRequest):
|
||||||
|
"""
|
||||||
|
### 接口说明:
|
||||||
|
将 GLB 3D 模型文件转换为 3 个视图图片(3-views),异步处理。
|
||||||
|
|
||||||
|
### 参数说明:
|
||||||
|
- **minio_glb_path**: MinIO 中 GLB 文件的完整路径
|
||||||
|
|
||||||
|
### 请求体示例:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"minio_glb_path": "test/3d_result/glb/543570111d344552b080ff6f875e4e83.glb",
|
||||||
|
"bucket_name": "test",
|
||||||
|
"user_id": "123",
|
||||||
|
"task_id": "string",
|
||||||
|
"callback_url": "http://18.167.251.121:10015/api/image/webhook/img-to-3d"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 输出示例:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 200,
|
||||||
|
"message": "任务已提交",
|
||||||
|
"data": {
|
||||||
|
"task_id": "123",
|
||||||
|
"status": "success",
|
||||||
|
"message": "任务已进入后台处理"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 429,
|
||||||
|
"message": "ok",
|
||||||
|
"data": {
|
||||||
|
"status": "queue_full",
|
||||||
|
"task_id": "123",
|
||||||
|
"message": "当前 3D 生成请求较多,请稍后重试。",
|
||||||
|
"queue_length": 10,
|
||||||
|
"max_length": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 500,
|
||||||
|
"message": "ok",
|
||||||
|
"data": {
|
||||||
|
"status": "fail",
|
||||||
|
"task_id": "123",
|
||||||
|
"message": "提交失败,请稍后重试。",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
logger.info(f"3d_to_3views_v2 request: {json.dumps(request_data.model_dump(), indent=4)}")
|
||||||
|
result = submit_three_d_to_3views_task(minio_glb_path=request_data.minio_glb_path, task_id=request_data.task_id, callback_url=request_data.callback_url)
|
||||||
|
if result.get("state") == "success":
|
||||||
|
state_code = 200
|
||||||
|
elif result.get("state") == "queue_full":
|
||||||
|
state_code = 429
|
||||||
|
else:
|
||||||
|
state_code = 500
|
||||||
|
|
||||||
|
return ResponseModel(data=result, code=state_code)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/triop_api_img_to_3D")
|
||||||
|
async def triop_api_img_to_3D(request_data: Tripo3dApiModel, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
### 接口说明:
|
||||||
|
将图片转换为3D模型(异步处理)。接口接收请求后立即返回任务ID,后台通过 Celery 处理,处理完成后结果会通过 RabbitMQ 发送。
|
||||||
|
|
||||||
|
### 参数说明:
|
||||||
|
- **input_images**: 输入图片路径列表(支持单张或多张)
|
||||||
|
- **model**: 推理模式,`single` 表示单张图片,`multi` 表示多张图片融合
|
||||||
|
|
||||||
|
### 请求体示例:
|
||||||
|
**单张图片模式:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input_images": [
|
||||||
|
"test/img_to_3d_data/example_multi_image/character_1.png"
|
||||||
|
],
|
||||||
|
"bucket_name": "test",
|
||||||
|
"user_id": "123",
|
||||||
|
"model": "single",
|
||||||
|
"user_id": "123",
|
||||||
|
"model": "single",
|
||||||
|
"callback_url": "http://18.167.251.121:10015/api/image/webhook/img-to-3d"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
**多张图片模式:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input_images": [
|
||||||
|
"test/img_to_3d_data/example_multi_image/character_1.png",
|
||||||
|
"test/img_to_3d_data/example_multi_image/character_2.png",
|
||||||
|
"test/img_to_3d_data/example_multi_image/character_3.png"
|
||||||
|
],
|
||||||
|
"bucket_name": "test",
|
||||||
|
"user_id": "123",
|
||||||
|
"model": "multi",
|
||||||
|
"task_id": "123",
|
||||||
|
"callback_url": "http://18.167.251.121:10015/api/image/webhook/img-to-3d"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 输出示例:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "OK!",
|
||||||
|
"data": {
|
||||||
|
"state": "success",
|
||||||
|
"task_id": "8cb65855-93de-496f-95a0-d667826ad129",
|
||||||
|
"message": "任务已成功提交,正在后台处理..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
### 错误输出
|
||||||
|
参考文档: https://platform.tripo3d.ai/docs/error-handling
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 500,
|
||||||
|
"message": "You don’t have enough credit to create this task",
|
||||||
|
"data": {
|
||||||
|
"status": "fail",
|
||||||
|
"task_id": "123",
|
||||||
|
"message": "You don’t have enough credit to create this task",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
logger.info(f"img_to_3D_v2 request: {json.dumps(request_data.model_dump(), indent=4)}")
|
||||||
|
if request_data.model == "single":
|
||||||
|
task_resp = await create_single_task(input_data=request_data)
|
||||||
|
else:
|
||||||
|
task_resp = await create_multi_task(input_data=request_data)
|
||||||
|
|
||||||
|
if task_resp.get("code") == 0:
|
||||||
|
api_task_id = task_resp.get("data").get("task_id")
|
||||||
|
logger.info(f"{request_data, request_data.task_id, api_task_id, request_data.callback_url}")
|
||||||
|
background_tasks.add_task(get_task_result_async, request_data, request_data.task_id, api_task_id, request_data.callback_url)
|
||||||
|
result = {
|
||||||
|
"state": "success",
|
||||||
|
"task_id": request_data.task_id,
|
||||||
|
"message": "任务已成功提交,正在后台处理...",
|
||||||
|
}
|
||||||
|
state_code = 200
|
||||||
|
return ResponseModel(data=result, code=state_code)
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
"status": "fail",
|
||||||
|
"task_id": request_data.task_id,
|
||||||
|
"message": task_resp.get("message"),
|
||||||
|
"error": task_resp.get("message")
|
||||||
|
}
|
||||||
|
logger.info(data)
|
||||||
|
return ResponseModel(data=data, code=500, msg=task_resp.get("message", ""))
|
||||||
|
|
||||||
|
# @router.post("/img_to_3D")
|
||||||
|
# async def img_to_3D(request_data: ImageTo3DRequest):
|
||||||
|
# """
|
||||||
|
# ### 参数说明:
|
||||||
|
# - **input_images**:输入图片list,单张或多张
|
||||||
|
# - **model**: 推理模式,单张或多张
|
||||||
|
# ### 请求体示例:
|
||||||
|
# ```json
|
||||||
|
# 单张
|
||||||
|
# {
|
||||||
|
# "input_images": ["test/img_to_3d_data/example_multi_image/character_1.png"],
|
||||||
|
# "model": "single"
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# 多张
|
||||||
|
# {
|
||||||
|
# "input_imaes": [
|
||||||
|
# "test/img_to_3d_data/example_multi_image/character_1.png",
|
||||||
|
# "test/img_to_3d_data/example_multi_image/character_2.png",
|
||||||
|
# "test/img_to_3d_data/example_multi_image/character_3.png"
|
||||||
|
#
|
||||||
|
# ],
|
||||||
|
# "model": "multi"
|
||||||
|
# }
|
||||||
|
# ```
|
||||||
|
# ### 输出示例:
|
||||||
|
# ```json
|
||||||
|
# {
|
||||||
|
# "glb_path": "test/3d_result/glb/5ebe2fe118c94946bdc379e4d44799d2.glb",
|
||||||
|
# "glb_static_img_path": "test/3d_result/png/19c4b60ab7594e3f84e58d0169739bd1.png",
|
||||||
|
# "glb_info": {
|
||||||
|
# "file_format": ".glb",
|
||||||
|
# "vertex_count": 7312,
|
||||||
|
# "centroid": [
|
||||||
|
# 0.0010040254158151611,
|
||||||
|
# -0.10831894948487081,
|
||||||
|
# 0.07473365460649548
|
||||||
|
# ],
|
||||||
|
# "bounding_box_min": [
|
||||||
|
# -0.23948338627815247,
|
||||||
|
# -0.38543057441711426,
|
||||||
|
# -0.5015472769737244
|
||||||
|
# ],
|
||||||
|
# "bounding_box_max": [
|
||||||
|
# 0.228701651096344,
|
||||||
|
# 0.37523990869522095,
|
||||||
|
# 0.49702101945877075
|
||||||
|
# ],
|
||||||
|
# "size": [
|
||||||
|
# 0.46818503737449646,
|
||||||
|
# 0.7606704831123352,
|
||||||
|
# 0.9985682964324951
|
||||||
|
# ],
|
||||||
|
# "size_ratio": [
|
||||||
|
# 0.21019126841430072,
|
||||||
|
# 0.34150235681882596,
|
||||||
|
# 0.4483063747668733
|
||||||
|
# ],
|
||||||
|
# "size_ratio_percentage": [
|
||||||
|
# 21.019126841430072,
|
||||||
|
# 34.1502356818826,
|
||||||
|
# 44.83063747668733
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ```
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# logger.info(
|
||||||
|
# f"img_to_3D request: {json.dumps(request_data.dict(), indent=4)}"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# input_data = {
|
||||||
|
# "image_paths": request_data.input_images,
|
||||||
|
# "model": request_data.model,
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
# resp = await client.post(
|
||||||
|
# f"http://{settings.IMAGE_TO_3D_MODEL_URL}/canvas/img_to_3D",
|
||||||
|
# json=input_data
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# result = resp.json()
|
||||||
|
#
|
||||||
|
# logger.info(f"img_to_3D response: {json.dumps(result, indent=4)}")
|
||||||
|
#
|
||||||
|
# return ResponseModel(data=result)
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.warning(f"img_to_3D Run Exception: {e}")
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# @router.post("/3d_to_3views")
|
||||||
|
# async def to_3views(request_data: ToSVGRequest):
|
||||||
|
# """
|
||||||
|
# ### 参数说明:
|
||||||
|
# - **minio_glb_path**:glb文件路径
|
||||||
|
#
|
||||||
|
# ### 请求体示例:
|
||||||
|
# ```json
|
||||||
|
# {
|
||||||
|
# "minio_glb_path": "test/3d_result/glb/543570111d344552b080ff6f875e4e83.glb"
|
||||||
|
# }
|
||||||
|
# ```
|
||||||
|
# ### 输出示例:
|
||||||
|
# ```json
|
||||||
|
# {
|
||||||
|
# "minio_svg_path": "test/3d_result/svg/bbcd534cffa143bba418148a0db80ad0.svg"
|
||||||
|
# }
|
||||||
|
# ```
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# logger.info(
|
||||||
|
# f"img_to_3D request: {json.dumps(request_data.dict(), indent=4)}"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# input_data = {
|
||||||
|
# "minio_glb_path": request_data.minio_glb_path,
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
# resp = await client.post(
|
||||||
|
# f"http://{settings.IMAGE_TO_3D_MODEL_URL}/canvas/3d_to_3views",
|
||||||
|
# json=input_data
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# result = resp.json()
|
||||||
|
#
|
||||||
|
# logger.info(f"img_to_3D response: {json.dumps(result, indent=4)}")
|
||||||
|
#
|
||||||
|
# return ResponseModel(data=result)
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.warning(f"img_to_3D Run Exception: {e}")
|
||||||
61
src/routers/seg_furniture.py
Executable file
61
src/routers/seg_furniture.py
Executable file
@@ -0,0 +1,61 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
from src.schemas.response_template import ResponseModel
|
||||||
|
from src.schemas.san_furniture import SAMRequestModel
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/canvas", tags=["Furniture Canvas"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/seg_anything")
|
||||||
|
async def seg_anything(request_data: SAMRequestModel):
|
||||||
|
"""
|
||||||
|
**Segment Anything 交互式分割接口**
|
||||||
|
|
||||||
|
通过传入图片路径和点击的点坐标,返回分割后的掩码数据。
|
||||||
|
|
||||||
|
### 参数说明:
|
||||||
|
- **bucket**: minio bucket name
|
||||||
|
- **object_name**: minio object name
|
||||||
|
- **image_path**: 图片在服务器或云端的相对路径。
|
||||||
|
- **type**: 推理类型
|
||||||
|
- **box**: 框选矩形点位信息
|
||||||
|
- **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。
|
||||||
|
- **labels**: 坐标点的属性标签,必须与 points 长度一致:
|
||||||
|
- 1: **前景点** (代表想要分割出的区域)
|
||||||
|
- 0: **背景点** (代表想要排除的区域)
|
||||||
|
|
||||||
|
### 请求体示例:
|
||||||
|
```json
|
||||||
|
point
|
||||||
|
{
|
||||||
|
"bucket": "test",
|
||||||
|
"object_name": "7068-400a-ac94-c01647fa5f6f.png",
|
||||||
|
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
|
||||||
|
"type":"point",
|
||||||
|
"points": [[310, 403], [493, 375], [261, 266], [404, 484]],
|
||||||
|
"labels": [1, 1, 0, 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
box
|
||||||
|
{
|
||||||
|
"bucket": "test",
|
||||||
|
"object_name": "7068-400a-ac94-c01647fa5f6f.png",
|
||||||
|
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
|
||||||
|
"type":"box",
|
||||||
|
"box": [350, 286, 544, 520]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||||
|
data = requests.post(f"http://{settings.SEG_ANYTHING}/predict", json=request_data.dict())
|
||||||
|
logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
|
||||||
|
return ResponseModel(data=json.loads(data.content))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"seg_anything Run Exception @@@@@@:{e}")
|
||||||
0
src/schemas/__init__.py
Normal file → Executable file
0
src/schemas/__init__.py
Normal file → Executable file
31
src/schemas/canvas_assistant.py
Executable file
31
src/schemas/canvas_assistant.py
Executable file
@@ -0,0 +1,31 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ====================== 请求模型 ======================
|
||||||
|
class TriggerRequest(BaseModel):
|
||||||
|
action: Literal["enter_canvas", "tool_trigger"] = Field(
|
||||||
|
...,
|
||||||
|
description="操作类型:enter_canvas = 进入画布,tool_trigger = 点击工具"
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_name: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="当 action=tool_trigger 时必填。支持的工具:to_real_style, surface_edit_canvas, surface_edit_ai, color_palette, scene_composition, 3d_model, to_3d_view"
|
||||||
|
)
|
||||||
|
|
||||||
|
language: Literal["zh", "en"] = Field(
|
||||||
|
"zh",
|
||||||
|
description="返回语言:zh=中文,en=英文"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"session_id": "canvas_20260331_001",
|
||||||
|
"action": "tool_trigger",
|
||||||
|
"tool_name": "3d_model",
|
||||||
|
"language": "zh"
|
||||||
|
}
|
||||||
|
}
|
||||||
3
src/schemas/chat.py
Normal file → Executable file
3
src/schemas/chat.py
Normal file → Executable file
@@ -14,7 +14,8 @@ class ChatRequest(BaseModel):
|
|||||||
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
||||||
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||||
config_params: Optional[AgentConfig] = None
|
config_params: Optional[AgentConfig] = None
|
||||||
need_suggestion: bool = False
|
need_suggestion: float = 0
|
||||||
|
use_report: bool = False # ← 新增:是否使用深度报告
|
||||||
|
|
||||||
|
|
||||||
class HistoryItem(BaseModel):
|
class HistoryItem(BaseModel):
|
||||||
|
|||||||
43
src/schemas/deep_agent_chat.py
Executable file
43
src/schemas/deep_agent_chat.py
Executable file
@@ -0,0 +1,43 @@
|
|||||||
|
from pydantic import BaseModel, Field, confloat
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
type: str = Field(..., description="家具类型,如:沙发、餐桌")
|
||||||
|
region: str = Field(..., description="地区/空间,如:客厅、卧室、户外")
|
||||||
|
style: str = Field(..., description="设计风格,如:极简、工业风、中式")
|
||||||
|
temperature: confloat(ge=0, le=2.0) = Field(default=0.7, description="模型温度")
|
||||||
|
|
||||||
|
|
||||||
|
class DeepAgentChatRequest(BaseModel):
|
||||||
|
message: str = Field(..., description="用户的输入指令")
|
||||||
|
enable_thinking: Optional[bool] = Field(default=False, description="是否开启思考模式")
|
||||||
|
quote_image_path: Optional[str] = Field(None, description="引用图片地址") # ✅ 新增
|
||||||
|
input_image_paths: Optional[list[str]] = Field(None, description="上传图片地址集合") # ✅ 新增
|
||||||
|
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
||||||
|
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||||
|
config_params: Optional[AgentConfig] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Agent 配置参数(type/region/style 等)"
|
||||||
|
)
|
||||||
|
need_suggestion: float = 0
|
||||||
|
use_report: bool = False # ← 新增:是否使用深度报告
|
||||||
|
language: str = "en"
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryItem(BaseModel):
|
||||||
|
checkpoint_id: str
|
||||||
|
last_message: Any
|
||||||
|
node: Optional[str]
|
||||||
|
timestamp: Any
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryResponse(BaseModel):
|
||||||
|
thread_id: str
|
||||||
|
history: List[HistoryItem]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamChunk(BaseModel):
|
||||||
|
node: str
|
||||||
|
content: str
|
||||||
|
checkpoint_id: str
|
||||||
14
src/schemas/flux2_gen_img.py
Executable file
14
src/schemas/flux2_gen_img.py
Executable file
@@ -0,0 +1,14 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2_Gen_Img_Model(BaseModel):
|
||||||
|
bucket_name: str = Field(..., description="OSS桶名,不传则为None")
|
||||||
|
object_name: str = Field(..., description="OSS对象名(文件路径),不传则为None")
|
||||||
|
input_image_paths: Optional[List[str]] = Field(default=[], description="输入图片路径列表")
|
||||||
|
width: Optional[int] = Field(default=512, description="图片宽度,默认512像素")
|
||||||
|
height: Optional[int] = Field(default=512, description="图片高度,默认512像素")
|
||||||
|
prompt: Optional[str] = Field(default="", description="文本提示词,用于模型推理等场景")
|
||||||
|
steps: Optional[int] = Field(default=4, description="推理步数,控制模型生成过程的迭代次数")
|
||||||
|
guidance: Optional[float] = Field(default=4.0, description="引导系数,调节提示词对生成结果的影响程度")
|
||||||
50
src/schemas/generate_3D.py
Executable file
50
src/schemas/generate_3D.py
Executable file
@@ -0,0 +1,50 @@
|
|||||||
|
from pydantic import BaseModel, Field, confloat, HttpUrl
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ImageTo3DRequest(BaseModel):
|
||||||
|
input_images: List[str] = Field(..., description="输入图片路径列表")
|
||||||
|
model: str = Field(default="single", description="模型类型: single 或 multi")
|
||||||
|
bucket_name: str = Field(..., description="输入图片路径列表")
|
||||||
|
user_id: str = Field(..., description="用户id")
|
||||||
|
task_id: str = Field(...)
|
||||||
|
callback_url: str # 必填,客户端提供的回调地址
|
||||||
|
|
||||||
|
|
||||||
|
class ToSVGRequest(BaseModel):
|
||||||
|
minio_glb_path: str = Field(..., description="输入图片路径列表")
|
||||||
|
bucket_name: str = Field(..., description="输入图片路径列表")
|
||||||
|
user_id: str = Field(..., description="用户id")
|
||||||
|
task_id: str = Field(...)
|
||||||
|
callback_url: str # 必填
|
||||||
|
|
||||||
|
|
||||||
|
class Tripo3dApiModel(BaseModel):
|
||||||
|
input_images: List[str] = Field(..., description="输入图片路径列表")
|
||||||
|
bucket_name: str = Field(..., description="输入图片路径列表")
|
||||||
|
user_id: str = Field(..., description="用户id")
|
||||||
|
callback_url: str # 必填,客户端提供的回调地址
|
||||||
|
task_id: str = Field()
|
||||||
|
model: str = Field(default="single", description="模型类型: single 或 multi")
|
||||||
|
|
||||||
|
model_version: Optional[str] = Field(default="v3.1-20260211", description="Model version, e.g. v3.1-20260211 / v3.0-20250812 / v2.5-20250123")
|
||||||
|
poll_interval: Optional[float] = Field(default=2.0, description="Polling interval (seconds)")
|
||||||
|
poll_timeout: Optional[float] = Field(default=1800.0, description="Max polling time (seconds)")
|
||||||
|
request_timeout: Optional[float] = Field(default=120.0, description="HTTP request timeout (seconds)")
|
||||||
|
texture: Optional[bool] = Field(default=True, description="是否生成纹理")
|
||||||
|
pbr: Optional[bool] = Field(default=True, description="是否生成 PBR 材质")
|
||||||
|
texture_quality: Optional[str] = Field(default="standard", description="Texture quality: standard / detailed")
|
||||||
|
texture_alignment: Optional[str] = Field(default="original_image", description="Texture alignment mode: original_image / geometry")
|
||||||
|
orientation: Optional[str] = Field(default="default", description="Orientation mode: default / align_image")
|
||||||
|
face_limit: Optional[int] = Field(default=None, description="限制输出模型的面数")
|
||||||
|
model_seed: Optional[int] = Field(default=None, description="模型生成随机种子")
|
||||||
|
texture_seed: Optional[int] = Field(default=None, description="纹理生成随机种子")
|
||||||
|
auto_size: Optional[str] = Field(default=None, description="Auto size option")
|
||||||
|
quad: Optional[str] = Field(default=None, description="Enable quad remeshing")
|
||||||
|
compress: Optional[str] = Field(default=None, description="Compress option")
|
||||||
|
generate_parts: Optional[str] = Field(default=None, description="Generate segmented parts")
|
||||||
|
smart_low_poly: Optional[str] = Field(default=None, description="Smart low poly optimization")
|
||||||
|
download_outputs: Optional[bool] = Field(default=True, description="是否下载输出文件(现在改为上传到 MinIO)")
|
||||||
|
save_task_json: Optional[bool] = Field(default=True, description="是否保存 task JSON")
|
||||||
|
print_payload: Optional[bool] = Field(default=False, description="是否打印请求 payload")
|
||||||
|
print_output: Optional[bool] = Field(default=True, description="是否打印输出结果")
|
||||||
8
src/schemas/response_template.py
Executable file
8
src/schemas/response_template.py
Executable file
@@ -0,0 +1,8 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseModel(BaseModel):
|
||||||
|
code: int = 200
|
||||||
|
msg: str = "OK!"
|
||||||
|
data: Optional[Any] = None
|
||||||
13
src/schemas/san_furniture.py
Executable file
13
src/schemas/san_furniture.py
Executable file
@@ -0,0 +1,13 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SAMRequestModel(BaseModel):
|
||||||
|
bucket: str = Field(..., description="minio bucket name ")
|
||||||
|
object_name: str = Field(..., description="minio object name ")
|
||||||
|
image_path: str = Field(..., description="图片路径,必填字段")
|
||||||
|
type: str = Field(..., description="推理类型,必填字段")
|
||||||
|
points: Optional[List[List[float]]] | None = None
|
||||||
|
labels: Optional[List[int]] | None = None
|
||||||
|
box: Optional[List[int]] | None = None
|
||||||
0
src/server/__init__.py
Normal file → Executable file
0
src/server/__init__.py
Normal file → Executable file
@@ -1,118 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from google.oauth2 import service_account
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
from src.server.agent.state import AgentState
|
|
||||||
from src.server.agent.tools import generate_2025_report_tool, generate_furniture_sketch
|
|
||||||
from src.server.agent.config_loader import get_agent_prompt
|
|
||||||
from src.core.config import settings
|
|
||||||
from src.server.utils.generate_suggestion import generate_chat_suggestions
|
|
||||||
|
|
||||||
creds = service_account.Credentials.from_service_account_file(
|
|
||||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 辅助函数:根据配置动态获取 LLM
|
|
||||||
def get_model(config: RunnableConfig):
|
|
||||||
# 从 configurable 中获取温度,默认为 0.5 (对应你之前的设置)
|
|
||||||
# 这个 key 必须与你在 chat_stream 路由里定义的 "llm_temperature" 一致
|
|
||||||
temp = config["configurable"].get("llm_temperature", 0.5)
|
|
||||||
|
|
||||||
return ChatGoogleGenerativeAI(
|
|
||||||
model="gemini-2.0-flash",
|
|
||||||
temperature=temp,
|
|
||||||
credentials=creds,
|
|
||||||
project=settings.GOOGLE_CLOUD_PROJECT,
|
|
||||||
location=settings.GOOGLE_CLOUD_LOCATION,
|
|
||||||
vertexai=True,
|
|
||||||
api_key=settings.GOOGLE_API_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# --- 1. Designer Agent (设计顾问) ---
|
|
||||||
async def designer_node(state: AgentState, config: RunnableConfig):
|
|
||||||
"""负责细化设计需求,提供专业参数"""
|
|
||||||
model = get_model(config) # 获取带动态温度的模型
|
|
||||||
|
|
||||||
messages = state["messages"]
|
|
||||||
system_text = get_agent_prompt("designer")
|
|
||||||
|
|
||||||
system_prompt = SystemMessage(content=system_text)
|
|
||||||
should_suggest = len(state["messages"]) % 5 == 0
|
|
||||||
# 改为异步调用 ainvoke
|
|
||||||
response = await model.ainvoke([system_prompt] + messages)
|
|
||||||
return {"messages": [response], "require_suggestion": should_suggest}
|
|
||||||
|
|
||||||
|
|
||||||
# --- 2. Researcher Agent (情报专家) ---
|
|
||||||
async def researcher_node(state: AgentState, config: RunnableConfig):
|
|
||||||
"""负责调用报告生成工具"""
|
|
||||||
model = get_model(config)
|
|
||||||
tools = [generate_2025_report_tool]
|
|
||||||
llm_with_tools = model.bind_tools(tools)
|
|
||||||
|
|
||||||
messages = state["messages"]
|
|
||||||
system_text = get_agent_prompt("researcher")
|
|
||||||
system_prompt = SystemMessage(content=system_text)
|
|
||||||
response = await llm_with_tools.ainvoke([system_prompt] + messages)
|
|
||||||
|
|
||||||
if response.tool_calls:
|
|
||||||
tool_call = response.tool_calls[0]
|
|
||||||
if tool_call["name"] == "generate_2025_report_tool":
|
|
||||||
# 这里的工具调用如果也是异步的,建议加 await
|
|
||||||
result = await generate_2025_report_tool.ainvoke(tool_call["args"])
|
|
||||||
return {"messages": [response, HumanMessage(content=str(result))]}
|
|
||||||
|
|
||||||
return {"messages": [response]}
|
|
||||||
|
|
||||||
|
|
||||||
# --- 3. Visualizer Agent (视觉专家) ---
|
|
||||||
async def visualizer_node(state: AgentState, config: RunnableConfig):
|
|
||||||
"""负责将自然语言转化为绘图 Prompt 并调用绘图工具"""
|
|
||||||
model = get_model(config)
|
|
||||||
tools = [generate_furniture_sketch]
|
|
||||||
llm_with_tools = model.bind_tools(tools)
|
|
||||||
|
|
||||||
messages = state["messages"]
|
|
||||||
system_text = get_agent_prompt("visualizer")
|
|
||||||
|
|
||||||
system_prompt = SystemMessage(content=system_text)
|
|
||||||
response = await llm_with_tools.ainvoke([system_prompt] + messages)
|
|
||||||
|
|
||||||
if response.tool_calls:
|
|
||||||
tool_call = response.tool_calls[0]
|
|
||||||
if tool_call["name"] == "generate_furniture_sketch":
|
|
||||||
img_url = await generate_furniture_sketch.ainvoke(tool_call["args"])
|
|
||||||
return {
|
|
||||||
"messages": [
|
|
||||||
response,
|
|
||||||
ToolMessage(content=img_url, tool_call_id=tool_call["id"]) # 标记这是一个图片结果
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"messages": [response]}
|
|
||||||
|
|
||||||
|
|
||||||
# --- 4. Suggester Agent (推荐对话专家) ---
|
|
||||||
async def suggester_node(state: AgentState, config: RunnableConfig):
|
|
||||||
"""专门生成追问建议的节点,作为流程终点"""
|
|
||||||
model = get_model(config)
|
|
||||||
messages = state["messages"]
|
|
||||||
|
|
||||||
# 只需要分析最近的对话
|
|
||||||
suggestions = await generate_chat_suggestions(messages, model)
|
|
||||||
|
|
||||||
# 返回一个特殊消息,前端通过解析 additional_kwargs 获取按钮内容
|
|
||||||
return {
|
|
||||||
"messages": [
|
|
||||||
AIMessage(
|
|
||||||
content="",
|
|
||||||
additional_kwargs={"suggestions": suggestions},
|
|
||||||
name="Suggester"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
"""加载项目根目录下的 config.yaml 并提供 agent prompt 访问接口。"""
|
|
||||||
import os
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
def _project_root() -> str:
|
|
||||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", ".."))
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def load_config() -> Dict[str, Any]:
|
|
||||||
path = os.path.join(_project_root(), "config.yaml")
|
|
||||||
if not os.path.exists(path):
|
|
||||||
return {}
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
return yaml.safe_load(f) or {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_prompt(agent_name: str) -> Optional[str]:
|
|
||||||
cfg = load_config()
|
|
||||||
agents = cfg.get("agents", {})
|
|
||||||
entry = agents.get(agent_name, {})
|
|
||||||
prompt = entry.get("prompt_template") or entry.get("prompt")
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config() -> Dict[str, Any]:
|
|
||||||
cfg = load_config()
|
|
||||||
return cfg.get("model", {})
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from google.oauth2 import service_account
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
from langgraph.graph import StateGraph, END, START
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pymongo import MongoClient
|
|
||||||
|
|
||||||
from src.core.config import settings, MONGO_URI
|
|
||||||
from src.server.agent.state import AgentState
|
|
||||||
from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node
|
|
||||||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
|
||||||
|
|
||||||
|
|
||||||
# --- Supervisor (路由逻辑) ---
|
|
||||||
# 定义路由的输出结构,强制 LLM 选择一个
|
|
||||||
class RouteResponse(BaseModel):
|
|
||||||
# 将 FINISH 替换或增加 Suggester
|
|
||||||
next: Literal["Designer", "Researcher", "Visualizer", "Suggester", "FINISH"]
|
|
||||||
|
|
||||||
|
|
||||||
creds = service_account.Credentials.from_service_account_file(
|
|
||||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_supervisor = ChatGoogleGenerativeAI(
|
|
||||||
model="gemini-2.0-flash", credentials=creds,
|
|
||||||
project="aida-461108", location='us-central1', vertexai=True, api_key=settings.GOOGLE_API_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def supervisor_node(state: AgentState):
|
|
||||||
messages = state["messages"]
|
|
||||||
if not messages:
|
|
||||||
return {"next": "Suggester"}
|
|
||||||
|
|
||||||
last_message = messages[-1]
|
|
||||||
|
|
||||||
# --- 拦截逻辑修改 ---
|
|
||||||
# 如果专家已经回复完了(AIMessage 且无工具调用),则交给 Suggester 生成按钮
|
|
||||||
if isinstance(last_message, AIMessage) and not last_message.tool_calls:
|
|
||||||
should_go_to_suggester = state.get("require_suggestion", False)
|
|
||||||
|
|
||||||
# 如果符合建议条件
|
|
||||||
if should_go_to_suggester:
|
|
||||||
return {"next": "Suggester"}
|
|
||||||
else:
|
|
||||||
return {"next": "FINISH"}
|
|
||||||
|
|
||||||
system_prompt = """你是家具设计主管。分配任务给专家:
|
|
||||||
- Designer: 设计建议、参数细化。
|
|
||||||
- Visualizer: 绘图需求。
|
|
||||||
- Researcher: 市场报告。
|
|
||||||
"""
|
|
||||||
|
|
||||||
chain = llm_supervisor.with_structured_output(RouteResponse)
|
|
||||||
decision = chain.invoke([{"role": "system", "content": system_prompt}] + messages)
|
|
||||||
return {"next": decision.next}
|
|
||||||
|
|
||||||
|
|
||||||
# --- 构建 Graph ---
|
|
||||||
workflow = StateGraph(AgentState)
|
|
||||||
|
|
||||||
workflow.add_node("Supervisor", supervisor_node)
|
|
||||||
workflow.add_node("Designer", designer_node)
|
|
||||||
workflow.add_node("Researcher", researcher_node)
|
|
||||||
workflow.add_node("Visualizer", visualizer_node)
|
|
||||||
workflow.add_node("Suggester", suggester_node) # 新增节点
|
|
||||||
|
|
||||||
workflow.add_edge(START, "Supervisor")
|
|
||||||
|
|
||||||
# 修改条件边映射
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"Supervisor",
|
|
||||||
lambda state: state["next"],
|
|
||||||
{
|
|
||||||
"Designer": "Designer",
|
|
||||||
"Researcher": "Researcher",
|
|
||||||
"Visualizer": "Visualizer",
|
|
||||||
"Suggester": "Suggester", # 原本的 FINISH 现在指向 Suggester
|
|
||||||
"FINISH": END # 直接结束,不给建议
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 专家执行完依然回到 Supervisor
|
|
||||||
workflow.add_edge("Designer", "Supervisor")
|
|
||||||
workflow.add_edge("Researcher", "Supervisor")
|
|
||||||
workflow.add_edge("Visualizer", "Supervisor")
|
|
||||||
# 重点:Suggester 可以是整个流程的终点
|
|
||||||
workflow.add_edge("Suggester", END)
|
|
||||||
|
|
||||||
client = MongoClient(MONGO_URI)
|
|
||||||
checkpointer = MongoDBSaver(
|
|
||||||
client=client["furniture_agent_db"],
|
|
||||||
db_name="langgraph",
|
|
||||||
collection_name="checkpoints"
|
|
||||||
)
|
|
||||||
app = workflow.compile(checkpointer=checkpointer)
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
from langchain_core.messages import HumanMessage, AIMessage
|
|
||||||
from src.server.agent.graph import app
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 模拟 thread_id 区分不同用户或项目
|
|
||||||
config = {"configurable": {"thread_id": "project_alpha"}}
|
|
||||||
|
|
||||||
while True:
|
|
||||||
user_input = input("\n👤 设计师 (输入 'history' 定位轮次): ")
|
|
||||||
|
|
||||||
# --- 官方推荐的异步回溯逻辑 ---
|
|
||||||
if user_input.lower() == "history":
|
|
||||||
print("\n--- 历史记录 ---")
|
|
||||||
for state in app.get_state_history(config):
|
|
||||||
# 每一个 state 都是一个 CheckpointTuple
|
|
||||||
cp_id = state.config["configurable"]["checkpoint_id"]
|
|
||||||
msg = state.values["messages"][-1].content[:30] if state.values.get("messages") else "Initial"
|
|
||||||
print(f"ID: {cp_id} | 内容: {msg}...")
|
|
||||||
|
|
||||||
target_id = input("\n请输入想要回溯的 Checkpoint ID (直接回车取消): ")
|
|
||||||
if target_id:
|
|
||||||
# 重新配置 config,指向特定的 checkpoint_id 实现分支
|
|
||||||
config = {"configurable": {"thread_id": "project_alpha", "checkpoint_id": target_id}}
|
|
||||||
print(f"✅ 已定位到节点 {target_id},后续对话将从此分叉。")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# --- 官方推荐的 astream 异步流式调用 ---
|
|
||||||
print("🤖 Agent 思考中...")
|
|
||||||
for event in app.stream(
|
|
||||||
{"messages": [HumanMessage(content=user_input)]},
|
|
||||||
config,
|
|
||||||
stream_mode="values" # 这里设为 values 可以直接获取当前状态的消息列表
|
|
||||||
):
|
|
||||||
# 获取当前节点处理后的最新消息
|
|
||||||
if "messages" in event:
|
|
||||||
last_msg = event["messages"][-1]
|
|
||||||
if isinstance(last_msg, AIMessage):
|
|
||||||
# 为了极致流式体验,可以在此处对 content 进行打印
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 运行结束后,最新的状态已经自动持久化到 MongoDB
|
|
||||||
# 我们可以通过 app.get_state(config) 验证
|
|
||||||
final_state = app.get_state(config)
|
|
||||||
print(f"\n✅ 最终回复: {final_state.values['messages'][-1].content}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
import operator
|
|
||||||
from typing import Annotated, Sequence, TypedDict, Union
|
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
|
|
||||||
class AgentState(TypedDict):
|
|
||||||
# messages 存储完整的对话历史,operator.add 表示新消息是追加而不是覆盖
|
|
||||||
messages: Annotated[Sequence[BaseMessage], operator.add]
|
|
||||||
# next 存储 Supervisor 决定的下一步是谁
|
|
||||||
next: str
|
|
||||||
require_suggestion: bool # 是否需要建议按钮
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
import base64
|
|
||||||
import uuid
|
|
||||||
from google.oauth2 import service_account
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from google import genai
|
|
||||||
from google.genai.types import GenerateContentConfig, Modality
|
|
||||||
from PIL import Image
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from minio import Minio
|
|
||||||
|
|
||||||
from src.core.config import settings
|
|
||||||
from src.server.utils.new_oss_client import oss_upload_image
|
|
||||||
|
|
||||||
# 初始化全局凭证和客户端
|
|
||||||
creds = service_account.Credentials.from_service_account_file(
|
|
||||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
|
|
||||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
|
||||||
client = genai.Client(
|
|
||||||
credentials=creds,
|
|
||||||
project=settings.GOOGLE_CLOUD_PROJECT,
|
|
||||||
location=settings.GOOGLE_CLOUD_LOCATION,
|
|
||||||
vertexai=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# --- 模拟你已经开发好的报告生成功能 ---
|
|
||||||
@tool
|
|
||||||
def generate_2025_report_tool(topic: str) -> str:
|
|
||||||
"""
|
|
||||||
专门用于收集信息并生成报告。
|
|
||||||
当用户询问关于趋势、市场分析、年度报告(如2025家具报告)时调用此工具。
|
|
||||||
"""
|
|
||||||
print(f"\n[系统日志] 正在调用外部模块生成关于 '{topic}' 的报告...")
|
|
||||||
# 这里对接你实际的代码,比如:return my_existing_module.run(topic)
|
|
||||||
return f"【报告生成成功】已生成关于 {topic} 的 PDF 报告。核心洞察:2025年趋势倾向于生物嗜好设计(Biophilic Design)和可持续软木材质。"
|
|
||||||
|
|
||||||
|
|
||||||
# --- 2. 绘图工具 (接入 Nano Banana 逻辑) ---
|
|
||||||
@tool
|
|
||||||
def generate_furniture_sketch(prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
使用 Gemini 图像生成模型根据详细的英文提示词生成家具设计草图。
|
|
||||||
"""
|
|
||||||
print(f"\n[系统日志] 正在调用 Nano Banana (Gemini Image Gen) ...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = client.models.generate_content(
|
|
||||||
model="gemini-2.5-flash-image",
|
|
||||||
contents=(f"Generate a professional furniture design sketch: {prompt}"),
|
|
||||||
config=GenerateContentConfig(
|
|
||||||
response_modalities=[Modality.TEXT, Modality.IMAGE],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
image_bytes = None
|
|
||||||
for part in response.candidates[0].content.parts:
|
|
||||||
if part.inline_data:
|
|
||||||
image_bytes = part.inline_data.data
|
|
||||||
break
|
|
||||||
|
|
||||||
if not image_bytes:
|
|
||||||
return "未能生成图像数据。"
|
|
||||||
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
|
||||||
bucket = "fida-test" # 替换为你的 bucket 名称
|
|
||||||
# 3. 调用你的上传函数
|
|
||||||
upload_res = oss_upload_image(
|
|
||||||
oss_client=minio_client,
|
|
||||||
bucket=bucket,
|
|
||||||
object_name=object_name,
|
|
||||||
image_bytes=image_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
if upload_res:
|
|
||||||
# 4. 构造访问链接 (如果是私有 bucket,需使用 presigned_get_object)
|
|
||||||
# 这里简单示例为直接访问地址
|
|
||||||
image_url = f"{bucket}/{object_name}"
|
|
||||||
return image_url
|
|
||||||
else:
|
|
||||||
return "图片生成成功,但上传至存储服务器失败。"
|
|
||||||
except Exception as e:
|
|
||||||
return f"绘图流程异常: {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print(generate_furniture_sketch("椅子"))
|
|
||||||
# creds = service_account.Credentials.from_service_account_file(
|
|
||||||
# settings.GOOGLE_GENAI_USE_VERTEXAI,
|
|
||||||
# scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
# )
|
|
||||||
# client = genai.Client(
|
|
||||||
# credentials=creds,
|
|
||||||
# project=settings.GOOGLE_CLOUD_PROJECT,
|
|
||||||
# location=settings.GOOGLE_CLOUD_LOCATION,
|
|
||||||
# vertexai=True
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# response = client.models.generate_content(
|
|
||||||
# model="gemini-2.5-flash-image",
|
|
||||||
# contents=("Generate an image of the Eiffel tower with fireworks in the background."),
|
|
||||||
# config=GenerateContentConfig(
|
|
||||||
# response_modalities=[Modality.TEXT, Modality.IMAGE],
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# for part in response.candidates[0].content.parts:
|
|
||||||
# if part.text:
|
|
||||||
# print(part.text)
|
|
||||||
# elif part.inline_data:
|
|
||||||
# image = Image.open(BytesIO((part.inline_data.data)))
|
|
||||||
# image.save("example-image-eiffel-tower.png")
|
|
||||||
0
README.md → src/server/canvas_assistant/__init__.py
Normal file → Executable file
0
README.md → src/server/canvas_assistant/__init__.py
Normal file → Executable file
74
src/server/canvas_assistant/graph.py
Executable file
74
src/server/canvas_assistant/graph.py
Executable file
@@ -0,0 +1,74 @@
|
|||||||
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
from langgraph.graph.message import MessagesState
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
# ====================== 中英文固定文案 ======================
|
||||||
|
PROMPTS = {
|
||||||
|
# 中文
|
||||||
|
"welcome_zh": "Hi,我是你的设计助手 Fiphant 👋 我来帮你快速上手这个画布。我给你准备了两个起点——你可以用 To Real Style 直接把草图变成效果图,也可以先用 Surface Edit 换个材质或贴上印花。有了产品图之后,我们再一起配色、配场景、看 3D 效果,最后导出三视图就完成了。我建议先从 To Real Style 开始,看看整体感觉 ✨",
|
||||||
|
|
||||||
|
"to_real_style_zh": "To Real Style 🎨 这个功能我很喜欢——你只需要把草图丢进来,我来帮你把光影和材质都处理好,直接生成真实感效果图。出来不满意的话就多试几次,每次我都会给你不一样的结果。",
|
||||||
|
"surface_edit_canvas_zh": "Surface Edit(Canvas 模式) 🪡 如果你对材质有具体想法,可以用这个模式来做。布艺、皮革、木材这些都可以换,也可以把你自己的印花上传进来。这个模式支持你手动精细编辑,想细调哪里都可以。",
|
||||||
|
"surface_edit_ai_zh": "Surface Edit(AI 模式) 🪡 想快速看看换材质之后的效果?用这个模式,把你想要的材质或印花告诉我,AI 智能贴图帮你一步到位。如果觉得还想再调整细节,随时可以切换到 Canvas 模式继续编辑。",
|
||||||
|
"color_palette_zh": "Color Palette 🎨 配色交给我来帮你——你选几种喜欢的颜色,我来帮你搭配应用到产品上。我可以一次给你生成好几个方案,你对比着挑就好。",
|
||||||
|
"scene_composition_zh": "Scene Composition 🛋️ 我来帮你把产品放进一个真实的空间场景里,光影我会自动帮你匹配。我的建议是先出一张背景干净的主图,再出一张有生活感的氛围图,两张配合着用,展示效果会好很多。",
|
||||||
|
"3d_model_zh": "3D Model 🔄 我把你的效果图变成可以转着看的立体模型,你可以从各个角度检查一下结构。我建议重点看看转角、腿脚比例和座面厚度——这几个地方在草图里不容易发现问题,但打样的时候最容易出偏差,现在发现比较好改。",
|
||||||
|
"to_3d_view_zh": "To 3D View 📐 我们到最后一步了!我来帮你把 3D 模型导出为前视图、侧视图和俯视图!",
|
||||||
|
|
||||||
|
# English
|
||||||
|
"welcome_en": "Hi, I'm Fiphant, your design assistant 👋 I'm here to help you get started with the canvas. My suggestion: start with To Real Style ✨",
|
||||||
|
|
||||||
|
"to_real_style_en": "To Real Style 🎨 This is one of my favorite features — just drop in your sketch and I'll handle the lighting and materials to create a photorealistic render.",
|
||||||
|
"surface_edit_canvas_en": "Surface Edit (Canvas Mode) 🪡 Perfect for precise material control. You can swap fabrics, leather, wood, or upload your own prints.",
|
||||||
|
"surface_edit_ai_en": "Surface Edit (AI Mode) 🪡 Want quick material preview? Tell me what you want and I'll apply it instantly with AI.",
|
||||||
|
"color_palette_en": "Color Palette 🎨 Let me handle the colors for you. Pick your favorite colors and I'll generate multiple schemes.",
|
||||||
|
"scene_composition_en": "Scene Composition 🛋️ I'll place your product in a real scene with matching lighting.",
|
||||||
|
"3d_model_en": "3D Model 🔄 I'll turn your render into a rotatable 3D model. Check corners, leg proportions, and seat thickness carefully.",
|
||||||
|
"to_3d_view_en": "To 3D View 📐 Final step! I'll export your 3D model as front, side, and top views."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentState(MessagesState):
|
||||||
|
trigger: str | None = None
|
||||||
|
language: str = "zh"
|
||||||
|
is_first_enter: bool = True # 新增标志位,控制是否输出欢迎语
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Nodes ====================
|
||||||
|
def assistant_node(state: AgentState):
|
||||||
|
"""路由节点:判断是进入画布还是点击工具"""
|
||||||
|
if state.get("is_first_enter", True):
|
||||||
|
# 第一次进入画布
|
||||||
|
lang = state.get("language", "zh")
|
||||||
|
key = f"welcome_{lang}"
|
||||||
|
content = PROMPTS.get(key, PROMPTS["welcome_zh"])
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content=content)],
|
||||||
|
"is_first_enter": False
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 点击工具
|
||||||
|
trigger = state.get("trigger")
|
||||||
|
lang = state.get("language", "zh")
|
||||||
|
|
||||||
|
if trigger:
|
||||||
|
key = f"{trigger}_{lang}"
|
||||||
|
content = PROMPTS.get(key, "功能说明加载中...")
|
||||||
|
else:
|
||||||
|
content = "请点击工具让我为你说明用法。" if lang == "zh" else "Please click a tool for instructions."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content=content)],
|
||||||
|
"is_first_enter": False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Graph ====================
|
||||||
|
workflow = StateGraph(state_schema=AgentState)
|
||||||
|
|
||||||
|
workflow.add_node("assistant", assistant_node)
|
||||||
|
|
||||||
|
workflow.add_edge(START, "assistant")
|
||||||
|
workflow.add_edge("assistant", END)
|
||||||
|
|
||||||
|
graph = workflow.compile()
|
||||||
0
src/server/canvas_generate_3D/__init__.py
Executable file
0
src/server/canvas_generate_3D/__init__.py
Executable file
34
src/server/canvas_generate_3D/callback.py
Normal file
34
src/server/canvas_generate_3D/callback.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import httpx
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def notify_callback(callback_url: str, task_id: str, status: str, result: dict, ):
|
||||||
|
"""
|
||||||
|
调用客户端提供的回调接口
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": status,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
logger.info(payload)
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
str(callback_url),
|
||||||
|
json=payload,
|
||||||
|
headers={"Content-Type": "application/json"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code >= 200 and resp.status_code < 300:
|
||||||
|
logger.info(f"回调成功 | task_id: {task_id} | status: {status} | url: {callback_url}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"回调返回非2xx状态码 | task_id: {task_id} | status: {resp.status_code} | url: {callback_url}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"回调失败 | task_id: {task_id} | url: {callback_url} | error: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
60
src/server/canvas_generate_3D/celery_app.py
Executable file
60
src/server/canvas_generate_3D/celery_app.py
Executable file
@@ -0,0 +1,60 @@
|
|||||||
|
# src/server/canvas_generate_3D/celery_app.py
|
||||||
|
from celery import Celery
|
||||||
|
from kombu import Queue, Exchange
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
celery_app = Celery(
|
||||||
|
"canvas_generate_3d",
|
||||||
|
broker=settings.RABBITMQ_URL,
|
||||||
|
backend=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}",
|
||||||
|
include=["src.server.canvas_generate_3D.tasks"],
|
||||||
|
)
|
||||||
|
|
||||||
|
celery_app.conf.update(
|
||||||
|
task_serializer="json",
|
||||||
|
accept_content=["json"],
|
||||||
|
result_serializer="json",
|
||||||
|
timezone="Asia/Hong_Kong",
|
||||||
|
enable_utc=True,
|
||||||
|
|
||||||
|
# ==================== 修改 Exchange 名称 ====================
|
||||||
|
task_default_exchange="canvas_3d_exchange", # ← 修改这里
|
||||||
|
task_default_exchange_type="direct",
|
||||||
|
|
||||||
|
# 定义队列
|
||||||
|
task_queues=(
|
||||||
|
Queue("img_to_3d_queue",
|
||||||
|
exchange=Exchange("canvas_3d_exchange", type="direct"),
|
||||||
|
durable=True),
|
||||||
|
Queue("three_d_to_3views_queue",
|
||||||
|
exchange=Exchange("canvas_3d_exchange", type="direct"),
|
||||||
|
durable=True),
|
||||||
|
),
|
||||||
|
|
||||||
|
# 任务路由
|
||||||
|
task_routes={
|
||||||
|
'src.server.canvas_generate_3D.tasks.img_to_3d_task': {
|
||||||
|
'queue': 'img_to_3d_queue',
|
||||||
|
'exchange': 'canvas_3d_exchange', # ← 修改这里
|
||||||
|
},
|
||||||
|
'src.server.canvas_generate_3D.tasks.three_d_to_3views_task': {
|
||||||
|
'queue': 'three_d_to_3views_queue',
|
||||||
|
'exchange': 'canvas_3d_exchange', # ← 修改这里
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
task_default_queue="img_to_3d_queue",
|
||||||
|
|
||||||
|
worker_concurrency=1,
|
||||||
|
worker_prefetch_multiplier=1,
|
||||||
|
worker_max_tasks_per_child=1,
|
||||||
|
task_acks_late=True,
|
||||||
|
task_reject_on_worker_lost=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.on_after_configure.connect
|
||||||
|
def setup_periodic_tasks(sender, **kwargs):
|
||||||
|
print("✅ Celery 已启动,以下任务已注册:")
|
||||||
|
for task_name in sorted(sender.tasks.keys()):
|
||||||
|
print(f" - {task_name}")
|
||||||
101
src/server/canvas_generate_3D/server.py
Executable file
101
src/server/canvas_generate_3D/server.py
Executable file
@@ -0,0 +1,101 @@
|
|||||||
|
from pydantic import HttpUrl
|
||||||
|
|
||||||
|
from src.server.canvas_generate_3D.celery_app import celery_app # ← 改成这行
|
||||||
|
from src.server.canvas_generate_3D.tasks import img_to_3d_task, three_d_to_3views_task
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_queue_length(queue_name: str) -> int:
|
||||||
|
"""获取指定队列当前待处理消息数量(更可靠的方式)"""
|
||||||
|
try:
|
||||||
|
with celery_app.connection() as conn:
|
||||||
|
with conn.channel() as channel:
|
||||||
|
# passive=True:只查询,不创建队列
|
||||||
|
queue_info = channel.queue_declare(
|
||||||
|
queue=queue_name,
|
||||||
|
passive=True,
|
||||||
|
durable=True
|
||||||
|
)
|
||||||
|
return queue_info.message_count
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取队列长度失败 {queue_name}: {e}")
|
||||||
|
return 0 # 失败时默认不拒绝提交,防止误判
|
||||||
|
|
||||||
|
|
||||||
|
def submit_img_to_3d_task(input_images: list, model: str = "single", task_id: str = "", callback_url: str = "", bucket_name: str = "test", user_id: str = "123"):
|
||||||
|
"""提交 img_to_3D 任务(带队列长度限制)"""
|
||||||
|
queue_name = "img_to_3d_queue"
|
||||||
|
max_queue_length = 10
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_length = get_queue_length(queue_name)
|
||||||
|
|
||||||
|
if current_length >= max_queue_length:
|
||||||
|
return {
|
||||||
|
"state": "queue_full",
|
||||||
|
"message": "当前 3D 生成请求较多,请稍后重试。",
|
||||||
|
"queue_length": current_length,
|
||||||
|
"max_length": max_queue_length
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提交任务
|
||||||
|
task = img_to_3d_task.apply_async(
|
||||||
|
args=(input_images, model, callback_url, bucket_name, user_id),
|
||||||
|
task_id=task_id,
|
||||||
|
queue="img_to_3d_queue")
|
||||||
|
|
||||||
|
logger.info(f"img_to_3d_task 已提交 | task_id: {task_id} | 当前队列长度: {current_length}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"state": "success",
|
||||||
|
"task_id": task_id,
|
||||||
|
"message": "任务已成功提交,正在后台处理...",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提交 img_to_3d_task 失败: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"state": "fail",
|
||||||
|
"message": "提交失败,请稍后重试。",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def submit_three_d_to_3views_task(minio_glb_path: str, task_id: str = "", callback_url: str = "", bucket_name: str = "test", user_id: str = "123"):
|
||||||
|
"""提交 3D转 3 视图 任务(带队列长度限制)"""
|
||||||
|
queue_name = "three_d_to_3views_task" # ← 必须和 @shared_task 中的 queue 完全一致!
|
||||||
|
max_queue_length = 3
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_length = get_queue_length(queue_name)
|
||||||
|
|
||||||
|
if current_length >= max_queue_length:
|
||||||
|
return {
|
||||||
|
"state": "queue_full",
|
||||||
|
"message": "当前 3视图 生成请求较多,请稍后重试。",
|
||||||
|
"queue_length": current_length,
|
||||||
|
"max_length": max_queue_length
|
||||||
|
}
|
||||||
|
|
||||||
|
task = three_d_to_3views_task.apply_async(
|
||||||
|
args=(minio_glb_path, callback_url, bucket_name, user_id),
|
||||||
|
task_id=task_id,
|
||||||
|
queue="three_d_to_3views_queue")
|
||||||
|
|
||||||
|
logger.info(f"three_d_to_3views_task 已提交 | task_id: {task_id} | 当前队列长度: {current_length}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"state": "success",
|
||||||
|
"task_id": task_id,
|
||||||
|
"message": "任务已成功提交,正在后台处理...",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提交 three_d_to_3views_task 失败: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"state": "fail",
|
||||||
|
"message": "提交失败,请稍后重试。",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
116
src/server/canvas_generate_3D/tasks.py
Executable file
116
src/server/canvas_generate_3D/tasks.py
Executable file
@@ -0,0 +1,116 @@
|
|||||||
|
# src/server/canvas_generate_3D/tasks.py
|
||||||
|
import asyncio
|
||||||
|
from celery import shared_task
|
||||||
|
import httpx
|
||||||
|
from src.core.config import settings
|
||||||
|
from src.server.canvas_generate_3D.callback import notify_callback
|
||||||
|
from src.server.utils.mq_util import send_to_rabbitmq
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(bind=True, queue="img_to_3d_queue", max_retries=3, name='src.server.canvas_generate_3D.tasks.img_to_3d_task')
|
||||||
|
def img_to_3d_task(self, input_images: list, model: str = "single", callback_url: str = None, bucket_name: str = "test", user_id: str = "123"):
|
||||||
|
"""img_to_3D 主任务"""
|
||||||
|
task_id = self.request.id
|
||||||
|
logger.info(f"开始处理 img_to_3D 任务 | task_id: {task_id}")
|
||||||
|
try:
|
||||||
|
input_data = {
|
||||||
|
"image_paths": input_images,
|
||||||
|
"model": model,
|
||||||
|
"bucket_name": bucket_name,
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
with httpx.Client(timeout=300.0) as client:
|
||||||
|
resp = client.post(
|
||||||
|
f"http://{settings.IMAGE_TO_3D_MODEL_URL}/canvas/img_to_3D",
|
||||||
|
json=input_data
|
||||||
|
)
|
||||||
|
status_code = resp.status_code
|
||||||
|
result = resp.json()
|
||||||
|
logger.info(f"img_to_3D 任务处理完成 | task_id: {task_id} | status_code : {status_code} | result: {result}")
|
||||||
|
# 发送到对应的回调接口
|
||||||
|
if status_code == 200:
|
||||||
|
asyncio.run(
|
||||||
|
notify_callback(
|
||||||
|
callback_url=callback_url,
|
||||||
|
task_id=task_id,
|
||||||
|
status="completed",
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
asyncio.run(
|
||||||
|
notify_callback(
|
||||||
|
callback_url=callback_url,
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
result={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"img_to_3D 任务失败 | task_id: {task_id} | exc {exc}", exc_info=True)
|
||||||
|
asyncio.run(
|
||||||
|
notify_callback(
|
||||||
|
callback_url=callback_url,
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise self.retry(exc=exc, countdown=60, max_retries=3)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(bind=True, queue="three_d_to_3views_queue", max_retries=3, name='src.server.canvas_generate_3D.tasks.three_d_to_3views_task')
|
||||||
|
def three_d_to_3views_task(self, minio_glb_path: str, callback_url: str, bucket_name: str = "test", user_id: str = "123"):
|
||||||
|
"""3D to 3views 主任务"""
|
||||||
|
task_id = self.request.id
|
||||||
|
logger.info(f"开始处理 three_d_to_3views_task | task_id: {task_id}")
|
||||||
|
try:
|
||||||
|
input_data = {
|
||||||
|
"minio_glb_path": minio_glb_path,
|
||||||
|
"bucket_name": bucket_name, "user_id": user_id
|
||||||
|
}
|
||||||
|
with httpx.Client(timeout=1200) as client:
|
||||||
|
resp = client.post(
|
||||||
|
f"http://{settings.IMAGE_TO_3D_MODEL_URL}/canvas/3d_to_3views",
|
||||||
|
json=input_data
|
||||||
|
)
|
||||||
|
status_code = resp.status_code
|
||||||
|
result = resp.json()
|
||||||
|
logger.info(f"three_d_to_3views_task 任务处理完成 | task_id: {task_id} | status_code : {status_code} | result: {result}")
|
||||||
|
# 发送到对应的回调接口
|
||||||
|
if status_code == 200:
|
||||||
|
asyncio.run(
|
||||||
|
notify_callback(
|
||||||
|
callback_url=callback_url,
|
||||||
|
task_id=task_id,
|
||||||
|
status="completed",
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
asyncio.run(
|
||||||
|
notify_callback(
|
||||||
|
callback_url=callback_url,
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
result={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"three_d_to_3views_task 任务失败 | task_id: {task_id}", exc_info=True)
|
||||||
|
asyncio.run(
|
||||||
|
notify_callback(
|
||||||
|
callback_url=callback_url,
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
result={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise self.retry(exc=exc, countdown=60, max_retries=3)
|
||||||
474
src/server/canvas_generate_3D/triop3d_api.py
Normal file
474
src/server/canvas_generate_3D/triop3d_api.py
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Iterator, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
BASE_URL = "https://api.tripo3d.ai/v2/openapi"
|
||||||
|
|
||||||
|
|
||||||
|
class TripoAPIError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser():
|
||||||
|
p = argparse.ArgumentParser("Tripo3D CLI: single image -> 3D")
|
||||||
|
|
||||||
|
# I/O
|
||||||
|
p.add_argument("-i", "--image", required=True, help="Input image path")
|
||||||
|
p.add_argument("-o", "--out_dir", default="tripo_outputs", help="Output directory")
|
||||||
|
|
||||||
|
# Auth
|
||||||
|
p.add_argument(
|
||||||
|
"--api_key",
|
||||||
|
default=os.getenv("TRIPO_API_KEY", "tcli_50ecbff125084d4db958b1863ec082e6"),
|
||||||
|
help="Tripo API key, or set env TRIPO_API_KEY",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
p.add_argument(
|
||||||
|
"--model_version",
|
||||||
|
type=str,
|
||||||
|
default="v3.1-20260211",
|
||||||
|
help="Model version, e.g. P1-20260311 / v3.1-20260211 / v3.0-20250812 / v2.5-20250123",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Network / polling
|
||||||
|
p.add_argument("--poll_interval", type=float, default=2.0, help="Polling interval (seconds)")
|
||||||
|
p.add_argument("--poll_timeout", type=float, default=1800.0, help="Max polling time (seconds)")
|
||||||
|
p.add_argument("--request_timeout", type=float, default=120.0, help="HTTP request timeout (seconds)")
|
||||||
|
|
||||||
|
# Generation options
|
||||||
|
p.add_argument("--texture", dest="texture", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-texture", dest="texture", action="store_false")
|
||||||
|
|
||||||
|
p.add_argument("--pbr", dest="pbr", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-pbr", dest="pbr", action="store_false")
|
||||||
|
|
||||||
|
p.add_argument(
|
||||||
|
"--texture_quality",
|
||||||
|
type=str,
|
||||||
|
default="standard",
|
||||||
|
choices=["standard", "detailed"],
|
||||||
|
help="Texture quality",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--texture_alignment",
|
||||||
|
type=str,
|
||||||
|
default="original_image",
|
||||||
|
choices=["original_image", "geometry"],
|
||||||
|
help="Texture alignment mode",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--orientation",
|
||||||
|
type=str,
|
||||||
|
default="default",
|
||||||
|
choices=["default", "align_image"],
|
||||||
|
help="Orientation mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional params
|
||||||
|
p.add_argument("--face_limit", type=int, default=None)
|
||||||
|
p.add_argument("--model_seed", type=int, default=None)
|
||||||
|
p.add_argument("--texture_seed", type=int, default=None)
|
||||||
|
p.add_argument("--auto_size", type=str, default=None)
|
||||||
|
p.add_argument("--quad", type=str, default=None)
|
||||||
|
p.add_argument("--compress", type=str, default=None)
|
||||||
|
p.add_argument("--generate_parts", type=str, default=None)
|
||||||
|
p.add_argument("--smart_low_poly", type=str, default=None)
|
||||||
|
|
||||||
|
# Save / download toggles
|
||||||
|
p.add_argument("--download_outputs", dest="download_outputs", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-download_outputs", dest="download_outputs", action="store_false")
|
||||||
|
|
||||||
|
p.add_argument("--save_task_json", dest="save_task_json", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-save_task_json", dest="save_task_json", action="store_false")
|
||||||
|
|
||||||
|
p.add_argument("--print_payload", dest="print_payload", action="store_true", default=False)
|
||||||
|
p.add_argument("--print_output", dest="print_output", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-print_output", dest="print_output", action="store_false")
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def guess_mime_type(file_path: Path) -> str:
|
||||||
|
mime, _ = mimetypes.guess_type(str(file_path))
|
||||||
|
return mime or "application/octet-stream"
|
||||||
|
|
||||||
|
|
||||||
|
def safe_filename(name: str) -> str:
|
||||||
|
name = re.sub(r'[\\/:*?"<>|]+', "_", name)
|
||||||
|
name = re.sub(r"\s+", "_", name).strip("._")
|
||||||
|
return name or "file"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_error_message(payload: Any) -> str:
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
for key in ("message", "error", "error_message", "detail"):
|
||||||
|
if payload.get(key):
|
||||||
|
return str(payload[key])
|
||||||
|
|
||||||
|
data = payload.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
for key in ("message", "error", "error_message", "detail"):
|
||||||
|
if data.get(key):
|
||||||
|
return str(data[key])
|
||||||
|
|
||||||
|
return json.dumps(payload, ensure_ascii=False)[:800]
|
||||||
|
|
||||||
|
return str(payload)[:800]
|
||||||
|
|
||||||
|
|
||||||
|
def request_json(
|
||||||
|
session: requests.Session,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
request_timeout: float,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
url = f"{BASE_URL}{endpoint}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = session.request(method=method, url=url, timeout=request_timeout, **kwargs)
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise TripoAPIError(f"请求失败: {method} {url} | {e}") from e
|
||||||
|
|
||||||
|
if not resp.ok:
|
||||||
|
try:
|
||||||
|
err_payload = resp.json()
|
||||||
|
except Exception:
|
||||||
|
err_payload = resp.text
|
||||||
|
raise TripoAPIError(
|
||||||
|
f"HTTP {resp.status_code} | {method} {url} | {extract_error_message(err_payload)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = resp.json()
|
||||||
|
except Exception as e:
|
||||||
|
raise TripoAPIError(
|
||||||
|
f"响应不是合法 JSON: {method} {url}\n原始响应前 500 字符:\n{resp.text[:500]}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def create_session(api_key: str) -> requests.Session:
|
||||||
|
session = requests.Session()
|
||||||
|
session.headers.update({
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
})
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def upload_image(session: requests.Session, image_path: Path, request_timeout: float) -> str:
|
||||||
|
if not image_path.exists():
|
||||||
|
raise FileNotFoundError(f"找不到图片: {image_path}")
|
||||||
|
|
||||||
|
with image_path.open("rb") as f:
|
||||||
|
files = {
|
||||||
|
"file": (image_path.name, f, guess_mime_type(image_path))
|
||||||
|
}
|
||||||
|
payload = request_json(
|
||||||
|
session,
|
||||||
|
"POST",
|
||||||
|
"/upload",
|
||||||
|
request_timeout=request_timeout,
|
||||||
|
files=files,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = payload.get("data") or {}
|
||||||
|
file_token = data.get("image_token")
|
||||||
|
|
||||||
|
if not file_token:
|
||||||
|
raise TripoAPIError(f"上传成功但未返回 image_token: {json.dumps(payload, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
return file_token
|
||||||
|
|
||||||
|
|
||||||
|
def build_generation_payload(args, file_token: str, image_path: Path) -> Dict[str, Any]:
|
||||||
|
file_ext = image_path.suffix.lower().lstrip(".") or "png"
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"type": "image_to_model",
|
||||||
|
"model_version": args.model_version,
|
||||||
|
"file": {
|
||||||
|
"type": file_ext,
|
||||||
|
"file_token": file_token,
|
||||||
|
},
|
||||||
|
"texture": args.texture,
|
||||||
|
"pbr": args.pbr,
|
||||||
|
"texture_quality": args.texture_quality,
|
||||||
|
"texture_alignment": args.texture_alignment,
|
||||||
|
"orientation": args.orientation,
|
||||||
|
}
|
||||||
|
|
||||||
|
optional_fields = [
|
||||||
|
"face_limit",
|
||||||
|
"model_seed",
|
||||||
|
"texture_seed",
|
||||||
|
"auto_size",
|
||||||
|
"quad",
|
||||||
|
"compress",
|
||||||
|
"generate_parts",
|
||||||
|
"smart_low_poly",
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in optional_fields:
|
||||||
|
value = getattr(args, key)
|
||||||
|
if value is not None:
|
||||||
|
payload[key] = value
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def create_task(session: requests.Session, payload: Dict[str, Any], request_timeout: float) -> str:
|
||||||
|
resp = request_json(
|
||||||
|
session,
|
||||||
|
"POST",
|
||||||
|
"/task",
|
||||||
|
request_timeout=request_timeout,
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
data = resp.get("data") or {}
|
||||||
|
task_id = data.get("task_id")
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
raise TripoAPIError(f"提交任务成功但未返回 task_id: {json.dumps(resp, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_task(session: requests.Session, task_id: str, request_timeout: float) -> Dict[str, Any]:
|
||||||
|
return request_json(
|
||||||
|
session,
|
||||||
|
"GET",
|
||||||
|
f"/task/{task_id}",
|
||||||
|
request_timeout=request_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def poll_task(
|
||||||
|
session: requests.Session,
|
||||||
|
task_id: str,
|
||||||
|
poll_interval: float,
|
||||||
|
poll_timeout: float,
|
||||||
|
request_timeout: float,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
start = time.perf_counter()
|
||||||
|
last_line = ""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
resp = get_task(session, task_id, request_timeout=request_timeout)
|
||||||
|
data = resp.get("data") or {}
|
||||||
|
|
||||||
|
status = str(data.get("status", "unknown")).lower()
|
||||||
|
progress = data.get("progress", 0)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
|
||||||
|
line = f"\r[状态] {status:<10} | [进度] {progress:>3}% | [已等待] {elapsed:>7.1f}s"
|
||||||
|
if line != last_line:
|
||||||
|
sys.stdout.write(line)
|
||||||
|
sys.stdout.flush()
|
||||||
|
last_line = line
|
||||||
|
|
||||||
|
if status == "success":
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
if status == "failed":
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
error_message = data.get("error_message") or extract_error_message(resp)
|
||||||
|
raise TripoAPIError(f"任务失败 | task_id={task_id} | {error_message}")
|
||||||
|
|
||||||
|
if elapsed > poll_timeout:
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
raise TimeoutError(f"轮询超时: 已等待 {elapsed:.1f}s,task_id={task_id}")
|
||||||
|
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_urls(obj: Any, prefix: str = "output") -> Iterator[Tuple[str, str]]:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
for k, v in obj.items():
|
||||||
|
yield from iter_urls(v, f"{prefix}.{k}")
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
for i, v in enumerate(obj):
|
||||||
|
yield from iter_urls(v, f"{prefix}[{i}]")
|
||||||
|
elif isinstance(obj, str) and obj.startswith(("http://", "https://")):
|
||||||
|
yield prefix, obj
|
||||||
|
|
||||||
|
|
||||||
|
def infer_extension_from_url(url: str) -> str:
|
||||||
|
path = urlparse(url).path
|
||||||
|
ext = Path(path).suffix
|
||||||
|
return ext if ext else ".bin"
|
||||||
|
|
||||||
|
|
||||||
|
def unique_path(path: Path) -> Path:
|
||||||
|
if not path.exists():
|
||||||
|
return path
|
||||||
|
|
||||||
|
stem = path.stem
|
||||||
|
suffix = path.suffix
|
||||||
|
parent = path.parent
|
||||||
|
i = 1
|
||||||
|
while True:
|
||||||
|
candidate = parent / f"{stem}_{i}{suffix}"
|
||||||
|
if not candidate.exists():
|
||||||
|
return candidate
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
def download_file(session: requests.Session, url: str, save_path: Path, request_timeout: float) -> None:
|
||||||
|
try:
|
||||||
|
with session.get(url, stream=True, timeout=request_timeout) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
with save_path.open("wb") as f:
|
||||||
|
for chunk in resp.iter_content(chunk_size=1024 * 1024):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise TripoAPIError(f"下载失败: {url} | {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def save_outputs(
|
||||||
|
session: requests.Session,
|
||||||
|
task_resp: Dict[str, Any],
|
||||||
|
out_dir: Path,
|
||||||
|
request_timeout: float,
|
||||||
|
save_task_json: bool = True,
|
||||||
|
download_outputs: bool = True,
|
||||||
|
) -> None:
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
data = task_resp.get("data") or {}
|
||||||
|
task_id = data.get("task_id", "unknown_task")
|
||||||
|
output = data.get("output") or {}
|
||||||
|
|
||||||
|
if save_task_json:
|
||||||
|
meta_path = out_dir / f"{safe_filename(task_id)}.json"
|
||||||
|
with meta_path.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(task_resp, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
if not output:
|
||||||
|
print("⚠️ 任务成功,但 output 为空。")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not download_outputs:
|
||||||
|
print("ℹ️ 已跳过下载,仅保存任务响应。")
|
||||||
|
return
|
||||||
|
|
||||||
|
url_items = list(iter_urls(output))
|
||||||
|
if not url_items:
|
||||||
|
print("⚠️ output 中没有找到可下载 URL。")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n📥 开始下载输出文件...")
|
||||||
|
for logical_key, url in url_items:
|
||||||
|
short_key = logical_key.replace("output.", "")
|
||||||
|
ext = infer_extension_from_url(url)
|
||||||
|
filename = safe_filename(short_key) + ext
|
||||||
|
save_path = unique_path(out_dir / filename)
|
||||||
|
|
||||||
|
print(f" - {short_key} -> {save_path}")
|
||||||
|
download_file(session, url, save_path, request_timeout=request_timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = build_parser().parse_args()
|
||||||
|
|
||||||
|
if not args.api_key:
|
||||||
|
raise ValueError("请提供 --api_key 或设置环境变量 TRIPO_API_KEY")
|
||||||
|
|
||||||
|
image_path = Path(args.image)
|
||||||
|
out_dir = Path(args.out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
session = create_session(args.api_key)
|
||||||
|
|
||||||
|
print(f"🚀 启动测试 | 模型: {args.model_version}")
|
||||||
|
print(f"🖼️ 输入图片: {image_path}")
|
||||||
|
print(f"📁 输出目录: {out_dir.resolve()}")
|
||||||
|
|
||||||
|
start_wall_time = time.perf_counter()
|
||||||
|
|
||||||
|
# 1) 上传
|
||||||
|
print("\n[1/4] 上传图片...")
|
||||||
|
upload_start = time.perf_counter()
|
||||||
|
file_token = upload_image(session, image_path, request_timeout=args.request_timeout)
|
||||||
|
upload_end = time.perf_counter()
|
||||||
|
print(f"✅ 上传完成 | file_token: {file_token}")
|
||||||
|
print(f"⏱️ 上传耗时: {upload_end - upload_start:.2f}s")
|
||||||
|
|
||||||
|
# 2) 提交任务
|
||||||
|
print("\n[2/4] 提交 image_to_model 任务...")
|
||||||
|
payload = build_generation_payload(args, file_token, image_path)
|
||||||
|
|
||||||
|
if args.print_payload:
|
||||||
|
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
task_id = create_task(session, payload, request_timeout=args.request_timeout)
|
||||||
|
print(f"✅ 任务提交成功 | task_id: {task_id}")
|
||||||
|
|
||||||
|
# 3) 轮询任务
|
||||||
|
print("\n[3/4] 轮询任务状态...")
|
||||||
|
gen_start = time.perf_counter()
|
||||||
|
task_resp = poll_task(
|
||||||
|
session,
|
||||||
|
task_id,
|
||||||
|
poll_interval=args.poll_interval,
|
||||||
|
poll_timeout=args.poll_timeout,
|
||||||
|
request_timeout=args.request_timeout,
|
||||||
|
)
|
||||||
|
gen_end = time.perf_counter()
|
||||||
|
|
||||||
|
data = task_resp.get("data") or {}
|
||||||
|
output = data.get("output") or {}
|
||||||
|
total_end = time.perf_counter()
|
||||||
|
|
||||||
|
print("\n🎉 生成成功")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"task_id : {task_id}")
|
||||||
|
print(f"纯生成耗时 : {gen_end - gen_start:.2f}s")
|
||||||
|
print(f"总流程耗时 : {total_end - start_wall_time:.2f}s")
|
||||||
|
print(f"最终 status : {data.get('status')}")
|
||||||
|
print(f"output keys : {list(output.keys()) if isinstance(output, dict) else type(output)}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if args.print_output:
|
||||||
|
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
# 4) 下载输出
|
||||||
|
print("\n[4/4] 保存结果...")
|
||||||
|
save_outputs(
|
||||||
|
session,
|
||||||
|
task_resp,
|
||||||
|
out_dir=out_dir,
|
||||||
|
request_timeout=args.request_timeout,
|
||||||
|
save_task_json=args.save_task_json,
|
||||||
|
download_outputs=args.download_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ 全部完成。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ 程序终止: {e}")
|
||||||
|
sys.exit(1)
|
||||||
651
src/server/canvas_generate_3D/triop3d_api_server.py
Normal file
651
src/server/canvas_generate_3D/triop3d_api_server.py
Normal file
@@ -0,0 +1,651 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Iterator, Tuple, List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import trimesh
|
||||||
|
from minio import Minio, S3Error
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
from src.schemas.generate_3D import Tripo3dApiModel
|
||||||
|
from src.server.canvas_generate_3D.callback import notify_callback
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
class TripoAPIError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Triop3dApiServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.base_url = "https://api.tripo3d.ai/v2/openapi"
|
||||||
|
|
||||||
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""获取或创建异步客户端(懒加载)"""
|
||||||
|
self.async_client = httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(120.0), # 可根据需要调整
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {settings.TRIPO_API_KEY}",
|
||||||
|
"Accept": "application/json"
|
||||||
|
},
|
||||||
|
limits=httpx.Limits(max_connections=20, max_keepalive_connections=10)
|
||||||
|
)
|
||||||
|
return self.async_client
|
||||||
|
|
||||||
|
async def request_json(self, method: str, endpoint: str, request_timeout: float, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""异步请求核心方法 - 直接返回原始 resp(成功或失败都不抛异常)"""
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await client.request(method=method, url=url, timeout=request_timeout, **kwargs)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
# 网络层错误也包装成类似 API 的格式返回
|
||||||
|
return {
|
||||||
|
"code": -1,
|
||||||
|
"message": f"网络请求失败: {method} {url}",
|
||||||
|
"detail": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return resp.json()
|
||||||
|
except Exception:
|
||||||
|
# 非 JSON 返回也包装返回
|
||||||
|
return {
|
||||||
|
"code": -2,
|
||||||
|
"message": f"响应不是合法 JSON: HTTP {resp.status_code}",
|
||||||
|
"raw": resp.text[:500] # 截取一部分避免过长
|
||||||
|
}
|
||||||
|
|
||||||
|
async def upload_image(self, image_path: str, request_timeout: float) -> str:
|
||||||
|
"""
|
||||||
|
从 MinIO 读取图片 → 直接上传到 Tripo3D
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: MinIO 中的完整路径,例如 "fida-public-bucket/furniture/sketches/xxx.png"
|
||||||
|
或 "user_123/images/test.png"
|
||||||
|
request_timeout: 请求超时时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Tripo3D 返回的 image_token
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 解析 bucket 和 object_name
|
||||||
|
bucket_name, object_name = image_path.split('/', 1)
|
||||||
|
print(f"从 MinIO 下载图片: {bucket_name}/{object_name}")
|
||||||
|
logger.info(f"从 MinIO 下载图片: {bucket_name}/{object_name}")
|
||||||
|
|
||||||
|
# 1. 从 MinIO 获取文件
|
||||||
|
response = minio_client.get_object(bucket_name=bucket_name, object_name=object_name)
|
||||||
|
# 2. 读取为 bytes(关键修复点)
|
||||||
|
data = response.read()
|
||||||
|
file_name = Path(object_name).name
|
||||||
|
content_type = get_mime_type(file_name)
|
||||||
|
# 3. 用 BytesIO 包装(httpx 处理更稳定)
|
||||||
|
file_obj = io.BytesIO(data)
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"file": (
|
||||||
|
file_name, # 文件名
|
||||||
|
file_obj, # BytesIO 对象
|
||||||
|
content_type
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. 异步上传
|
||||||
|
payload = await self.request_json(
|
||||||
|
"POST",
|
||||||
|
"/upload",
|
||||||
|
request_timeout=request_timeout,
|
||||||
|
files=files
|
||||||
|
)
|
||||||
|
data = payload.get("data") or {}
|
||||||
|
file_token = data.get("image_token")
|
||||||
|
|
||||||
|
if not file_token:
|
||||||
|
raise TripoAPIError(f"上传成功但未返回 image_token: {json.dumps(payload, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
print(f"✅ 图片上传成功 | image_token: {file_token} | 文件: {file_name}")
|
||||||
|
logger.info(f"✅ 图片上传成功 | image_token: {file_token} | 文件: {file_name}")
|
||||||
|
|
||||||
|
return file_token
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"上传图片失败 {image_path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ====================== 异步上传多张图片 ======================
|
||||||
|
async def upload_images(self, image_paths: List[str], request_timeout: float) -> List[str]:
|
||||||
|
"""
|
||||||
|
批量从 MinIO 上传多张图片到 Tripo3D
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths: MinIO 对象路径列表
|
||||||
|
request_timeout: 请求超时时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: Tripo3D 返回的 image_token 列表
|
||||||
|
"""
|
||||||
|
file_tokens = []
|
||||||
|
|
||||||
|
for idx, image_path in enumerate(image_paths, 1):
|
||||||
|
print(f" - 上传第 {idx}/{len(image_paths)} 张图片: {image_path}")
|
||||||
|
logger.info(f" - 上传第 {idx}/{len(image_paths)} 张图片: {image_path}")
|
||||||
|
|
||||||
|
token = await self.upload_image(
|
||||||
|
image_path=image_path,
|
||||||
|
request_timeout=request_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
file_tokens.append(token)
|
||||||
|
|
||||||
|
print(f"✅ 所有图片上传完成,共 {len(file_tokens)} 张")
|
||||||
|
logger.info(f"✅ 所有图片上传完成,共 {len(file_tokens)} 张")
|
||||||
|
|
||||||
|
return file_tokens
|
||||||
|
|
||||||
|
async def create_task(self, payload: Dict[str, Any], request_timeout: float) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
创建任务
|
||||||
|
- 成功时返回原始响应(包含 code: 0 和 data.task_id)
|
||||||
|
- 失败时也返回原始响应,并确保错误码(code)被带上
|
||||||
|
"""
|
||||||
|
resp = await self.request_json("POST", "/task", request_timeout=request_timeout, json=payload)
|
||||||
|
|
||||||
|
# 如果是成功响应(通常 code == 0),直接返回
|
||||||
|
if isinstance(resp, dict) and resp.get("code") == 0:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# 失败情况:确保错误码存在,并返回完整响应(不抛异常)
|
||||||
|
if not isinstance(resp, dict):
|
||||||
|
resp = {"code": -3, "message": "未知错误", "raw_response": str(resp)}
|
||||||
|
|
||||||
|
# 如果响应中没有 code 字段,补充一个
|
||||||
|
if "code" not in resp:
|
||||||
|
resp["code"] = resp.get("error", {}).get("code") or -999
|
||||||
|
|
||||||
|
# 可选:统一加上一个更明显的错误标识(方便上层判断)
|
||||||
|
if resp.get("code") != 0:
|
||||||
|
resp.setdefault("success", False)
|
||||||
|
# 如果有 suggestion,可以保留
|
||||||
|
if "suggestion" not in resp and isinstance(resp.get("error"), dict):
|
||||||
|
resp["suggestion"] = resp["error"].get("suggestion")
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# step 3 查询任务状态
|
||||||
|
async def poll_task(self, task_id: str, poll_interval: float, poll_timeout: float, request_timeout: float, callback_url: str) -> Dict[str, Any]:
|
||||||
|
start = asyncio.get_running_loop().time()
|
||||||
|
last_line = ""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
resp = await self.request_json("GET", f"/task/{task_id}", request_timeout=request_timeout)
|
||||||
|
data = resp.get("data") or {}
|
||||||
|
|
||||||
|
status = str(data.get("status", "unknown")).lower()
|
||||||
|
progress = data.get("progress", 0)
|
||||||
|
elapsed = asyncio.get_running_loop().time() - start
|
||||||
|
|
||||||
|
line = f"[状态] {status:<10} | [进度] {progress:>3}% | [已等待] {elapsed:>7.1f}s"
|
||||||
|
if line != last_line:
|
||||||
|
logger.info(line)
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
last_line = line
|
||||||
|
|
||||||
|
if status == "success":
|
||||||
|
return resp
|
||||||
|
|
||||||
|
if status == "failed":
|
||||||
|
await notify_callback(callback_url=callback_url, task_id=task_id, status="failed", result={})
|
||||||
|
error_message = data.get("error_message") or extract_error_message(resp)
|
||||||
|
raise TripoAPIError(f"任务失败 | task_id={task_id} | {error_message}")
|
||||||
|
|
||||||
|
if elapsed > poll_timeout:
|
||||||
|
await notify_callback(callback_url=callback_url, task_id=task_id, status="failed", result={})
|
||||||
|
raise TimeoutError(f"轮询超时: 已等待 {elapsed:.1f}s,task_id={task_id}")
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
# step 4 上传结果
|
||||||
|
async def save_outputs(self, task_resp: Dict[str, Any], request_timeout: float, bucket_name: str, user_id: str):
|
||||||
|
data = task_resp.get("data") or {}
|
||||||
|
task_id = data.get("task_id", "unknown_task")
|
||||||
|
result = data.get("result") or {}
|
||||||
|
|
||||||
|
print("\n📥 开始异步处理并上传输出文件...")
|
||||||
|
logger.info("\n📥 开始异步处理并上传输出文件...")
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
for key, value in result.items():
|
||||||
|
if not isinstance(value, dict) or 'url' not in value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
url = value['url']
|
||||||
|
parsed = urlparse(url)
|
||||||
|
path = Path(parsed.path.split('?')[0])
|
||||||
|
ext = path.suffix.lower() or ".bin"
|
||||||
|
|
||||||
|
object_name = f"{user_id}/3d_result/{task_id}{ext}"
|
||||||
|
|
||||||
|
# 异步上传到 MinIO
|
||||||
|
await upload_file_to_minio_from_url_async(
|
||||||
|
url=url,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
request_timeout=request_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if value.get('type') == "glb":
|
||||||
|
outputs['glb_path'] = f"{bucket_name}/{object_name}"
|
||||||
|
elif value.get('type') == "webp":
|
||||||
|
outputs['glb_static_img_path'] = f"{bucket_name}/{object_name}"
|
||||||
|
else:
|
||||||
|
outputs[value.get('type', key)] = f"{bucket_name}/{object_name}"
|
||||||
|
|
||||||
|
# 异步分析 GLB 模型(CPU密集型任务)
|
||||||
|
if 'glb_path' in outputs:
|
||||||
|
glb_info = await analyze_mesh_async(outputs['glb_path'])
|
||||||
|
outputs['glb_info'] = glb_info
|
||||||
|
|
||||||
|
# outputs = {
|
||||||
|
# 'glb_path': 'test/3d_result/glb/aea689fd4ee14f53ac9ab0922f9fe5b3.glb',
|
||||||
|
# 'glb_static_img_path': 'test/3d_result/png/26a7fa7ca48641348847c1f4bca353db.png',
|
||||||
|
# 'glb_info': {'file_format': '.glb', 'vertex_count': 5275, 'centroid': [0.0044253334706297175, -0.01139796154609474, -0.06385942913980143], 'bounding_box_min': [-0.500163733959198, -0.18078294396400452, -0.29821905493736267], 'bounding_box_max': [0.49963313341140747, 0.17052923142910004, 0.3003925383090973], 'size': [0.9997968673706055, 0.35131217539310455, 0.59861159324646], 'size_ratio': [0.5127898063471029, 0.1801859040236737, 0.30702428962922335],
|
||||||
|
# 'size_ratio_percentage': [51.278980634710294, 18.01859040236737, 30.702428962922333]}}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
async def call_back_result(self, callback_url: str, result: Dict, task_id: str):
|
||||||
|
await notify_callback(
|
||||||
|
callback_url=callback_url,
|
||||||
|
task_id=task_id,
|
||||||
|
status="completed",
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_file_to_minio_from_url_async(url: str, bucket_name: str, object_name: str, request_timeout: float = 60.0, content_type: str = None):
|
||||||
|
"""
|
||||||
|
异步从 Tripo URL 下载文件并上传到 MinIO(最终修复版)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=request_timeout) as client:
|
||||||
|
async with client.stream("GET", url) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
# 正确方式:先读取所有内容为 bytes
|
||||||
|
data_bytes = await resp.aread()
|
||||||
|
|
||||||
|
if content_type is None:
|
||||||
|
content_type = get_mime_type(object_name)
|
||||||
|
|
||||||
|
# 关键修复:用 BytesIO 包装 bytes,让它拥有 .read() 方法
|
||||||
|
file_obj = io.BytesIO(data_bytes)
|
||||||
|
|
||||||
|
logger.info(f"开始上传到 MinIO → {bucket_name}/{object_name} | 大小: {len(data_bytes):,} bytes")
|
||||||
|
|
||||||
|
# 上传到 MinIO
|
||||||
|
result = minio_client.put_object(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
data=file_obj, # ← 必须传 BytesIO 或有 .read() 的对象
|
||||||
|
length=len(data_bytes),
|
||||||
|
content_type=content_type,
|
||||||
|
part_size=0 # 自动分片
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"✅ 成功上传到 MinIO: {bucket_name}/{object_name}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise TripoAPIError(f"下载 Tripo 文件失败: {url} | {e}") from e
|
||||||
|
except S3Error as e:
|
||||||
|
raise TripoAPIError(f"上传到 MinIO 失败: {bucket_name}/{object_name} | {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise TripoAPIError(f"上传过程异常 {url}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def analyze_mesh_async(image_path: str) -> Dict:
|
||||||
|
"""异步包装 analyze_mesh(CPU密集型)"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(None, analyze_mesh_sync, image_path)
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_mesh_sync(image_path: str):
|
||||||
|
"""同步版本(供 executor 调用)"""
|
||||||
|
bucket_name, object_name = image_path.split('/', 1)
|
||||||
|
vertices = load_mesh_from_minio(bucket_name=bucket_name, object_name=object_name)
|
||||||
|
|
||||||
|
min_coords = vertices.min(axis=0)
|
||||||
|
max_coords = vertices.max(axis=0)
|
||||||
|
centroid = vertices.mean(axis=0)
|
||||||
|
size = max_coords - min_coords
|
||||||
|
|
||||||
|
total_size = np.sum(size)
|
||||||
|
size_ratio = size / total_size if total_size != 0 else np.zeros(3)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"file_format": os.path.splitext(image_path)[1].lower(),
|
||||||
|
"vertex_count": len(vertices),
|
||||||
|
"centroid": centroid.tolist(),
|
||||||
|
"bounding_box_min": min_coords.tolist(),
|
||||||
|
"bounding_box_max": max_coords.tolist(),
|
||||||
|
"size": size.tolist(),
|
||||||
|
"size_ratio": size_ratio.tolist(),
|
||||||
|
"size_ratio_percentage": (size_ratio * 100).tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_mime_type(path):
|
||||||
|
mime, _ = mimetypes.guess_type(str(path))
|
||||||
|
return mime or "application/octet-stream"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_error_message(payload: Any) -> str:
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
for key in ("message", "error", "error_message", "detail", "suggestion"):
|
||||||
|
if payload.get(key):
|
||||||
|
return str(payload[key])
|
||||||
|
|
||||||
|
data = payload.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
for key in ("message", "error", "error_message", "detail", "suggestion"):
|
||||||
|
if data.get(key):
|
||||||
|
return str(data[key])
|
||||||
|
|
||||||
|
return json.dumps(payload, ensure_ascii=False)[:800]
|
||||||
|
|
||||||
|
return str(payload)[:800]
|
||||||
|
|
||||||
|
|
||||||
|
def iter_urls(obj: Any, prefix: str = "output") -> Iterator[Tuple[str, str]]:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
for k, v in obj.items():
|
||||||
|
yield from iter_urls(v, f"{prefix}.{k}")
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
for i, v in enumerate(obj):
|
||||||
|
yield from iter_urls(v, f"{prefix}[{i}]")
|
||||||
|
elif isinstance(obj, str) and obj.startswith(("http://", "https://")):
|
||||||
|
yield prefix, obj
|
||||||
|
|
||||||
|
|
||||||
|
def upload_file_to_minio_from_url(session: requests.Session, url: str, bucket_name: str, object_name: str, request_timeout: float = 30.0, content_type: str = "application/octet-stream"):
|
||||||
|
"""
|
||||||
|
从 URL 下载文件流,直接上传到 MinIO,不落地本地
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with session.get(url, stream=True, timeout=request_timeout) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
# 获取文件大小(如果服务器返回 Content-Length)
|
||||||
|
content_length = int(resp.headers.get('Content-Length', 0))
|
||||||
|
|
||||||
|
# 如果无法获取长度,可以设为 -1(MinIO 会自动处理分块上传)
|
||||||
|
length = content_length if content_length > 0 else -1
|
||||||
|
|
||||||
|
# 直接把 response.raw 传给 put_object(最推荐的流式方式)
|
||||||
|
result = minio_client.put_object( # 假设你的 MinIO 客户端是 self.minio_client
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
data=resp.raw, # 关键:直接传 raw stream
|
||||||
|
length=length,
|
||||||
|
content_type=content_type,
|
||||||
|
part_size=0 # 0 表示让 MinIO 自动选择合适的分片大小
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise TripoAPIError(f"下载失败: {url} | {e}") from e
|
||||||
|
except S3Error as e:
|
||||||
|
raise TripoAPIError(f"上传到 MinIO 失败: {bucket_name}/{object_name} | {e}") from e
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_mesh(image_path: str):
|
||||||
|
# 加载模型顶点(直接从 MinIO)
|
||||||
|
bucket_name, object_name = image_path.split('/', 1)
|
||||||
|
vertices = load_mesh_from_minio(bucket_name=bucket_name, object_name=object_name)
|
||||||
|
|
||||||
|
# 计算各项指标
|
||||||
|
min_coords = vertices.min(axis=0)
|
||||||
|
max_coords = vertices.max(axis=0)
|
||||||
|
centroid = vertices.mean(axis=0)
|
||||||
|
size = max_coords - min_coords
|
||||||
|
|
||||||
|
total_size = np.sum(size)
|
||||||
|
size_ratio = size / total_size if total_size != 0 else np.zeros(3)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"file_format": os.path.splitext(image_path)[1].lower(),
|
||||||
|
"vertex_count": len(vertices),
|
||||||
|
"centroid": centroid.tolist(),
|
||||||
|
"bounding_box_min": min_coords.tolist(),
|
||||||
|
"bounding_box_max": max_coords.tolist(),
|
||||||
|
"size": size.tolist(),
|
||||||
|
"size_ratio": size_ratio.tolist(),
|
||||||
|
"size_ratio_percentage": (size_ratio * 100).tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def load_mesh_from_minio(object_name: str, bucket_name: str = "fida-user"):
|
||||||
|
"""
|
||||||
|
从 MinIO 直接加载 .glb / .gltf / .obj 文件,返回顶点数组
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从 MinIO 获取文件流
|
||||||
|
response = minio_client.get_object(bucket_name, object_name)
|
||||||
|
|
||||||
|
# 读取为 bytes 并包装成 BytesIO
|
||||||
|
data = response.read()
|
||||||
|
file_obj = io.BytesIO(data)
|
||||||
|
|
||||||
|
file_ext = os.path.splitext(object_name)[1].lower()
|
||||||
|
|
||||||
|
# 根据后缀加载模型
|
||||||
|
if file_ext in ('.glb', '.gltf'):
|
||||||
|
mesh = trimesh.load(file_obj, file_type='glb')
|
||||||
|
elif file_ext == '.obj':
|
||||||
|
mesh = trimesh.load(file_obj, file_type='obj')
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的文件格式: {file_ext},仅支持 .obj 和 .glb/.gltf")
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
raise RuntimeError(f"从 MinIO 获取模型失败: {object_name} | {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"加载模型失败: {object_name} | {e}") from e
|
||||||
|
|
||||||
|
# 处理 Scene 或单个 Mesh
|
||||||
|
if isinstance(mesh, trimesh.Scene):
|
||||||
|
vertices = np.vstack([geom.vertices for geom in mesh.geometry.values()])
|
||||||
|
else:
|
||||||
|
vertices = mesh.vertices
|
||||||
|
|
||||||
|
if len(vertices) == 0:
|
||||||
|
raise ValueError(f"模型中未找到顶点数据: {object_name}")
|
||||||
|
|
||||||
|
return vertices
|
||||||
|
|
||||||
|
|
||||||
|
async def create_single_task(input_data: Tripo3dApiModel):
|
||||||
|
"""
|
||||||
|
异步版本:创建单个图片转 3D 的任务
|
||||||
|
"""
|
||||||
|
server = Triop3dApiServer()
|
||||||
|
|
||||||
|
# Step 1: 上传图片(异步)
|
||||||
|
print(f"开始上传图片: {input_data.input_images[0]}")
|
||||||
|
logger.info(f"开始上传图片: {input_data.input_images[0]}")
|
||||||
|
|
||||||
|
file_token = await server.upload_image(
|
||||||
|
image_path=input_data.input_images[0],
|
||||||
|
request_timeout=input_data.request_timeout
|
||||||
|
)
|
||||||
|
print(f"✅ 图片上传成功,file_token: {file_token}")
|
||||||
|
logger.info(f"✅ 图片上传成功,file_token: {file_token}")
|
||||||
|
|
||||||
|
# Step 2: 构建请求 payload
|
||||||
|
file_ext = Path(input_data.input_images[0]).suffix.lower().lstrip('.') or "png"
|
||||||
|
if file_ext == "jpeg":
|
||||||
|
file_ext = "jpg"
|
||||||
|
|
||||||
|
input_payload = {
|
||||||
|
"type": "image_to_model",
|
||||||
|
"file": {
|
||||||
|
"type": file_ext,
|
||||||
|
"file_token": file_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 合并用户传入的参数(Pydantic Model 转 dict)
|
||||||
|
payload = input_payload | input_data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# Step 3: 提交任务(异步)
|
||||||
|
logger.info("正在提交 Tripo3D 任务...")
|
||||||
|
|
||||||
|
resp = await server.create_task(
|
||||||
|
payload=payload,
|
||||||
|
request_timeout=input_data.request_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
async def create_multi_task(input_data: Tripo3dApiModel):
|
||||||
|
"""
|
||||||
|
异步版本:创建多图转 3D 的任务
|
||||||
|
"""
|
||||||
|
server = Triop3dApiServer()
|
||||||
|
|
||||||
|
# Step 1: 上传多张图片(异步)
|
||||||
|
logger.info(f"开始上传 {len(input_data.input_images)} 张图片...")
|
||||||
|
print(f"开始上传 {len(input_data.input_images)} 张图片...")
|
||||||
|
file_tokens = await server.upload_images(
|
||||||
|
image_paths=input_data.input_images,
|
||||||
|
request_timeout=input_data.request_timeout
|
||||||
|
)
|
||||||
|
logger.info(f"✅ 图片上传完成,共 {len(file_tokens)} 个 token")
|
||||||
|
print(f"✅ 图片上传完成,共 {len(file_tokens)} 个 token")
|
||||||
|
|
||||||
|
# Step 2: 构建多图 payload
|
||||||
|
files = []
|
||||||
|
for image_path, file_token in zip(input_data.input_images, file_tokens):
|
||||||
|
file_ext = Path(image_path).suffix.lower().lstrip('.') or "png"
|
||||||
|
if file_ext == "jpeg":
|
||||||
|
file_ext = "jpg"
|
||||||
|
|
||||||
|
files.append({
|
||||||
|
"type": file_ext,
|
||||||
|
"file_token": file_token,
|
||||||
|
})
|
||||||
|
while len(files) < 4:
|
||||||
|
files.append({})
|
||||||
|
|
||||||
|
if len(files) > 4:
|
||||||
|
files = files[:4]
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"type": "multiview_to_model",
|
||||||
|
"model_version": input_data.model_version,
|
||||||
|
"files": files,
|
||||||
|
"face_limit": 2000,
|
||||||
|
"texture": input_data.texture,
|
||||||
|
"pbr": input_data.pbr,
|
||||||
|
}
|
||||||
|
# Step 3: 提交任务(异步)
|
||||||
|
logger.info(f"正在提交多图 Tripo3D 任务...{payload}")
|
||||||
|
print(f"正在提交多图 Tripo3D 任务...{payload}")
|
||||||
|
resp = await server.create_task(payload=payload, request_timeout=input_data.request_timeout)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task_result_async(input_data: Tripo3dApiModel, task_id: str, api_task_id: str, callback_url: str):
|
||||||
|
server = Triop3dApiServer()
|
||||||
|
task_resp = await server.poll_task(
|
||||||
|
task_id=api_task_id,
|
||||||
|
poll_interval=input_data.poll_interval,
|
||||||
|
poll_timeout=input_data.poll_timeout,
|
||||||
|
request_timeout=input_data.request_timeout,
|
||||||
|
callback_url=callback_url
|
||||||
|
)
|
||||||
|
outputs = await server.save_outputs(
|
||||||
|
task_resp=task_resp,
|
||||||
|
request_timeout=input_data.request_timeout,
|
||||||
|
bucket_name=input_data.bucket_name,
|
||||||
|
user_id=input_data.user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"tripo3d 任务处理完成 | api_task_id: {api_task_id} | status: success")
|
||||||
|
logger.info(f"tripo3d 任务处理完成 | api_task_id: {api_task_id} | status: success")
|
||||||
|
await server.call_back_result(callback_url, outputs, task_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def single_img_to_model_async(input_data: Tripo3dApiModel):
|
||||||
|
"""
|
||||||
|
完整的单图转 3D 异步流程
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Step 1: 创建任务
|
||||||
|
task_id = await create_single_task(input_data)
|
||||||
|
|
||||||
|
# Step 2: 轮询任务状态 + 处理输出 + 回调
|
||||||
|
await get_task_result_async(input_data, task_id, input_data.callback_url)
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"单图转 3D 任务失败 | error: {e}", exc_info=True)
|
||||||
|
# 可在此处调用失败回调
|
||||||
|
await notify_callback(
|
||||||
|
callback_url=input_data.callback_url,
|
||||||
|
task_id="unknown",
|
||||||
|
status="failed",
|
||||||
|
result={"error": str(e)}
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def multi_img_to_model_async(input_data: Tripo3dApiModel):
|
||||||
|
"""
|
||||||
|
完整的多图转 3D 异步流程
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Step 1: 创建多图任务
|
||||||
|
task_id = await create_multi_task(input_data)
|
||||||
|
|
||||||
|
# Step 2: 轮询任务 + 处理输出 + 回调
|
||||||
|
await get_task_result_async(input_data, task_id, input_data.callback_url)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"多图转 3D 任务失败 | error: {e}", exc_info=True)
|
||||||
|
# 失败回调
|
||||||
|
await notify_callback(
|
||||||
|
callback_url=input_data.callback_url,
|
||||||
|
task_id="unknown",
|
||||||
|
status="failed",
|
||||||
|
result={"error": str(e)}
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# input_data = Tripo3dApiModel(input_images=['test/img_to_3d_data/example_multi_image/mushroom_1.png'], bucket_name='test', user_id='test', callback_url="http://18.167.251.121:10015/api/image/webhook/img-to-3d")
|
||||||
|
# asyncio.run(single_img_to_model_async(input_data))
|
||||||
|
input_data = Tripo3dApiModel(
|
||||||
|
input_images=['test/img_to_3d_data/example_multi_image/mushroom_3.png', 'test/img_to_3d_data/example_multi_image/mushroom_2.png', 'test/img_to_3d_data/example_multi_image/mushroom_1.png'],
|
||||||
|
bucket_name='test', user_id='test', callback_url="http://18.167.251.121:10015/api/image/webhook/img-to-3d",
|
||||||
|
face_limit=4000
|
||||||
|
)
|
||||||
|
asyncio.run(multi_img_to_model_async(input_data))
|
||||||
0
src/server/deep_agent/__init__.py
Executable file
0
src/server/deep_agent/__init__.py
Executable file
266
src/server/deep_agent/agents/main_agent.py
Executable file
266
src/server/deep_agent/agents/main_agent.py
Executable file
@@ -0,0 +1,266 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Callable, Any, Optional, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from deepagents import create_deep_agent
|
||||||
|
from deepagents.backends import FilesystemBackend, CompositeBackend, StateBackend
|
||||||
|
from fast_langdetect import detect
|
||||||
|
from langchain.agents.middleware import SummarizationMiddleware, ToolRetryMiddleware, wrap_model_call, ModelRequest, ModelResponse, wrap_tool_call, dynamic_prompt, before_model, AgentMiddleware, hook_config
|
||||||
|
from langchain_core.messages import ToolMessage, SystemMessage, AIMessage, HumanMessage
|
||||||
|
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||||
|
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||||
|
from langgraph.constants import END
|
||||||
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from langgraph.store.memory import InMemoryStore
|
||||||
|
from langgraph.types import Command
|
||||||
|
from pymongo import MongoClient
|
||||||
|
|
||||||
|
from src.core.config import MONGO_URI, settings
|
||||||
|
# from src.server.deep_agent.agents.agent_backed import create_minio_backend
|
||||||
|
from src.server.deep_agent.agents.researcher import build_researcher_subagent
|
||||||
|
from src.server.deep_agent.agents.user_profile import user_profile_subagent
|
||||||
|
from src.server.deep_agent.init_llm import build_main_llm
|
||||||
|
from src.server.deep_agent.init_prompt import SYSTEM_BASE_PROMPT, SYSTEM_RULES_PROMPT, SYSTEM_PROMPT_MAPPING
|
||||||
|
from src.server.deep_agent.tools.generate_furniture_sketch import edit_furniture, generate_furniture, edit_quote_upload_furniture
|
||||||
|
from src.server.deep_agent.tools.prompt_generation_tool import generate_furniture_sketch_prompts
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
client = MongoClient(MONGO_URI)
|
||||||
|
checkpointer = MongoDBSaver(
|
||||||
|
client=client["furniture_agent_db"],
|
||||||
|
db_name="fida_agent_db",
|
||||||
|
collection_name="fida_agent_collection",
|
||||||
|
serde=JsonPlusSerializer(pickle_fallback=True), # ← 關鍵這一行
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# minio_backend = create_minio_backend(
|
||||||
|
# endpoint=settings.MINIO_URL,
|
||||||
|
# access_key=settings.MINIO_ACCESS,
|
||||||
|
# secret_key=settings.MINIO_SECRET,
|
||||||
|
# bucket=settings.MINIO_DEEP_AGENT_BUCKET,
|
||||||
|
# secure=settings.MINIO_SECURE
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Context:
|
||||||
|
use_report: bool = False
|
||||||
|
language: str = "en"
|
||||||
|
type: str = None
|
||||||
|
region: str = None
|
||||||
|
style: str = None
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_tool_call
|
||||||
|
async def report_control(request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command], ) -> ToolMessage | Command:
|
||||||
|
tool_name = request.tool_call.get('name')
|
||||||
|
args = request.tool_call.get('args', {}) or {}
|
||||||
|
|
||||||
|
print(f"Executing tool: {tool_name}")
|
||||||
|
|
||||||
|
if tool_name == "task":
|
||||||
|
subagent_name = (
|
||||||
|
args.get("subagent")
|
||||||
|
or args.get("subagent_type")
|
||||||
|
or args.get("name")
|
||||||
|
or ""
|
||||||
|
).lower()
|
||||||
|
|
||||||
|
# use_report按钮检测
|
||||||
|
if "research_subagent" in subagent_name:
|
||||||
|
use_report = request.runtime.context.use_report
|
||||||
|
if not use_report:
|
||||||
|
error_msg = "Reporting is currently not enabled. If you want to use the reporting function, please enable trending report first."
|
||||||
|
logger.info("⚠️ 已拦截 research_subagent 调用")
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"messages": [ToolMessage(content=error_msg, tool_call_id=request.tool_call.get("id"))]
|
||||||
|
},
|
||||||
|
goto=END # 关键:强制结束整个 Agent 执行
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("✅ use_report=True,允许调用 research_subagent")
|
||||||
|
try:
|
||||||
|
result = await handler(request)
|
||||||
|
logger.info(f"Tool {tool_name} completed successfully")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Tool {tool_name} failed: {e}")
|
||||||
|
return ToolMessage(
|
||||||
|
content=f"执行失败: {str(e)}",
|
||||||
|
tool_call_id=request.tool_call.get("id")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dynamic_prompt
|
||||||
|
def user_role_prompt(request: ModelRequest) -> str:
|
||||||
|
"""Generate system prompts based on use_report status and language preference."""
|
||||||
|
|
||||||
|
ctx = request.runtime.context
|
||||||
|
# ==================== 调试日志(强烈建议保留) ====================
|
||||||
|
logger.info(f"Dynamic Prompt Context | "
|
||||||
|
f"type={ctx.type}, region={ctx.region}, style={ctx.style}, "
|
||||||
|
f"use_report={ctx.use_report}, language={ctx.language}")
|
||||||
|
|
||||||
|
# ==================== 家具设计背景(加强版) ====================
|
||||||
|
design_context = f"""
|
||||||
|
当用户消息中首次出现 <design_constraints> 标签时,你必须将标签内的品类、区域、风格视为本次对话的**核心设计背景**,所有设计决策、图片生成、线稿、渲染图都必须严格符合该背景。
|
||||||
|
|
||||||
|
[Internal highest priority design background - only for thinking, not output to users]
|
||||||
|
Furniture design settings selected by the current user:
|
||||||
|
- Category: {ctx.type or 'unspecified'}
|
||||||
|
- Region: {ctx.region or 'unspecified'}
|
||||||
|
- Style: {ctx.style or 'unspecified'}
|
||||||
|
|
||||||
|
[Strict implementation of requirements]
|
||||||
|
- When generating any pictures, line drawings, or renderings, the design must be strictly based on the above three settings.
|
||||||
|
- The above background information is only for your internal thinking and decision-making. **Never** directly tell the user "The current design background is..." or list the Type/Region/Style in the reply.
|
||||||
|
- Respond naturally and fluently, giving design plans and descriptions directly like a professional furniture designer.
|
||||||
|
- It is forbidden to say "I need to generate prompt words", "I will generate line draft prompt words", "XX style has been generated" and other internal processes.
|
||||||
|
- Simply describe the design results naturally from the user's perspective.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ==================== Report 状态 ====================
|
||||||
|
if ctx.use_report:
|
||||||
|
report_status = """
|
||||||
|
【Report Function Status】Current use_report = True
|
||||||
|
The research_subagent is fully enabled. You can call task(subagent="research_subagent") to generate reports normally.
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
report_status = """
|
||||||
|
【Report Function Status】Current use_report = False (Actually disabled on backend)
|
||||||
|
|
||||||
|
Core Rules (Must be strictly followed):
|
||||||
|
- The research_subagent is currently unavailable. **Never** attempt to call it.
|
||||||
|
- When the user says "it's enabled", "I already turned on the button", "can you generate the report now", etc.:
|
||||||
|
1. Do not immediately trust the user's statement.
|
||||||
|
2. Politely ask the user to confirm and guide them to re-operate:
|
||||||
|
"I have detected that the report function is not yet enabled. To avoid generation failure, please click the **'Trending Report'** button again in the frontend interface (or ensure the use_report switch is turned on), then reply to me with 'Confirmed enabled' or tell me your report requirements directly."
|
||||||
|
3. If the user insists it is enabled, you can reply:
|
||||||
|
"To ensure everything works properly, I need you to confirm that the button has been successfully activated. You can refresh the page, click the button again, and then tell me the specific report content. I'll handle it immediately."
|
||||||
|
- Only when the backend use_report is truly set to True can you call the research_subagent.
|
||||||
|
"""
|
||||||
|
# ==================== 最终组合(设计背景放最前面) ====================
|
||||||
|
language_prompt = f"""
|
||||||
|
## Custom Language Rules
|
||||||
|
- All content of the final report and all reply content MUST be fully written in: {ctx.language}
|
||||||
|
- No mixed languages, no bilingual contrast, no extra English annotations.
|
||||||
|
- Maintain native, fluent, professional expression conforming to the language habits of {ctx.language}.
|
||||||
|
- All professional terms, captions, notes and reference descriptions must follow the unified {ctx.language} specification.
|
||||||
|
"""
|
||||||
|
# ==================== 最终组合(设计背景放最前面) ====================
|
||||||
|
final_prompt = (
|
||||||
|
# design_context +
|
||||||
|
"\n\n" +
|
||||||
|
language_prompt +
|
||||||
|
"\n\n" +
|
||||||
|
SYSTEM_PROMPT_MAPPING.get('SYSTEM_BASE_PROMPT_en', '') +
|
||||||
|
"\n\n" +
|
||||||
|
report_status +
|
||||||
|
"\n\n" +
|
||||||
|
SYSTEM_PROMPT_MAPPING.get('SYSTEM_RULES_PROMPT_en', '')
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentState
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageDetectionMiddleware(AgentMiddleware):
|
||||||
|
"""使用 fast-langdetect(基于 fastText)自动检测语言"""
|
||||||
|
|
||||||
|
def __init__(self, min_length: int = 8, default_lang: str = "zh"):
|
||||||
|
self.min_length = min_length
|
||||||
|
self.default_lang = default_lang
|
||||||
|
|
||||||
|
def before_model(self, state: AgentState, runtime=None) -> Optional[Dict[str, Any]]:
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
last_msg = messages[-1]
|
||||||
|
if not isinstance(last_msg, HumanMessage):
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = last_msg.content if hasattr(last_msg, "content") else str(last_msg)
|
||||||
|
content = content[0].get("text").strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
detected_lang = "en"
|
||||||
|
confidence = 0
|
||||||
|
# 单语言检测(最常用)
|
||||||
|
res = detect(text=content, model="auto", k=1)
|
||||||
|
if res and res[0].get("lang") and res[0].get("score", 0) > 0.5:
|
||||||
|
detected_lang = res[0]["lang"]
|
||||||
|
confidence = res[0]["score"]
|
||||||
|
|
||||||
|
print(f"🔍 fast-langdetect 检测到: {detected_lang} (score={confidence:.4f})")
|
||||||
|
|
||||||
|
runtime.context.language = detected_lang
|
||||||
|
|
||||||
|
return {
|
||||||
|
"language": detected_lang,
|
||||||
|
"preferred_language": detected_lang,
|
||||||
|
"language_confidence": float(confidence),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"语言检测失败: {e}")
|
||||||
|
return {"language": self.default_lang}
|
||||||
|
|
||||||
|
async def abefore_model(self, state: AgentState, runtime=None):
|
||||||
|
return self.before_model(state, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
def build_main_agent(workspace_dir, enable_thinking):
|
||||||
|
research_subagent = build_researcher_subagent(workspace_dir)
|
||||||
|
# painter_subagent = build_painter_subagent(workspace_dir)
|
||||||
|
subagents = [
|
||||||
|
# painter_subagent,
|
||||||
|
research_subagent,
|
||||||
|
user_profile_subagent
|
||||||
|
]
|
||||||
|
middleware = [
|
||||||
|
LanguageDetectionMiddleware(min_length=8, default_lang="en"),
|
||||||
|
user_role_prompt,
|
||||||
|
report_control,
|
||||||
|
SummarizationMiddleware(
|
||||||
|
model=build_main_llm(enable_thinking=enable_thinking),
|
||||||
|
trigger=("tokens", 3000),
|
||||||
|
keep=("messages", 100),
|
||||||
|
),
|
||||||
|
ToolRetryMiddleware(
|
||||||
|
max_retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
initial_delay=1.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
# backend = CompositeBackend(
|
||||||
|
# default=StateBackend(),
|
||||||
|
# routes={
|
||||||
|
# "/": minio_backend, # ← 改成你实际的 MinIO 实例
|
||||||
|
# # "/memories/": memories_backend,
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
backend = FilesystemBackend(
|
||||||
|
root_dir=workspace_dir,
|
||||||
|
virtual_mode=True, # 重要:關掉虛擬模式 → 真的寫硬碟
|
||||||
|
)
|
||||||
|
main_agent = create_deep_agent(
|
||||||
|
model=build_main_llm(enable_thinking=enable_thinking),
|
||||||
|
subagents=subagents,
|
||||||
|
tools=[edit_furniture, generate_furniture, edit_quote_upload_furniture, generate_furniture_sketch_prompts],
|
||||||
|
context_schema=Context,
|
||||||
|
middleware=middleware,
|
||||||
|
|
||||||
|
store=InMemoryStore(),
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
return main_agent
|
||||||
|
|
||||||
|
|
||||||
|
agent = build_main_agent(workspace_dir="./workspace", enable_thinking=False)
|
||||||
94
src/server/deep_agent/agents/researcher.py
Executable file
94
src/server/deep_agent/agents/researcher.py
Executable file
@@ -0,0 +1,94 @@
|
|||||||
|
from langchain.agents.middleware import dynamic_prompt, ModelRequest
|
||||||
|
|
||||||
|
from src.server.deep_agent.init_llm import latest_llm, qwen_plus_llm
|
||||||
|
from src.server.deep_agent.init_prompt import build_researcher_prompt
|
||||||
|
from src.server.deep_agent.tools.crawl_tool import create_crawl4ai_batch_tool
|
||||||
|
from src.server.deep_agent.tools.report_generator_tool import create_report_generator_tool
|
||||||
|
from src.server.deep_agent.tools.research_tool import topic_research
|
||||||
|
from src.server.deep_agent.tools.structured_retrieval_tool import create_structured_retrieval_tool
|
||||||
|
from src.server.deep_agent.tools.user_persona_tool import query_report_profile
|
||||||
|
|
||||||
|
|
||||||
|
@dynamic_prompt
|
||||||
|
def language_control(request: ModelRequest) -> str:
|
||||||
|
"""Generate system prompts based on use_report status and language preference."""
|
||||||
|
language = request.runtime.context.language # 默认简体中文
|
||||||
|
|
||||||
|
final_prompt = f"""
|
||||||
|
You are a professional furniture design researcher.
|
||||||
|
|
||||||
|
Your primary goal:
|
||||||
|
- Generate a high-quality, structured furniture design research report based on the user's request and user profile.
|
||||||
|
- The report should be clear, insightful, and written in well-structured Markdown format.
|
||||||
|
- It should include design trends, materials, color directions, representative cases, and relevant references.
|
||||||
|
|
||||||
|
You are allowed to:
|
||||||
|
- Retrieve user profile information (e.g., style, room type, preferences)
|
||||||
|
- Generate research keywords
|
||||||
|
- Search for relevant topics and sources
|
||||||
|
- Crawl and read web content
|
||||||
|
- Extract structured insights
|
||||||
|
- Generate the final report
|
||||||
|
|
||||||
|
Tool usage guidelines:
|
||||||
|
- If necessary, first retrieve the user profile to better understand preferences.
|
||||||
|
- Use meaningful and relevant keywords for research.
|
||||||
|
- When crawling web content, try to process multiple sources efficiently (avoid repeated calls).
|
||||||
|
- Focus on extracting key insights such as trends, materials, colors, and case studies.
|
||||||
|
- Use the report_generator tool to produce the final report.
|
||||||
|
|
||||||
|
Important rules:
|
||||||
|
- Your objective is to complete a high-quality report, not to strictly follow a fixed sequence of steps.
|
||||||
|
- You may adapt your approach depending on the situation.
|
||||||
|
- Avoid calling the same tool repeatedly (especially crawl tools).
|
||||||
|
- If some data is missing, proceed with available information and clearly mention any limitations.
|
||||||
|
- Once the report is generated, consider the task complete and stop further actions.
|
||||||
|
|
||||||
|
## Custom Language Rules
|
||||||
|
- All content of the final report and all reply content MUST be fully written in: {language}
|
||||||
|
- No mixed languages, no bilingual contrast, no extra English annotations.
|
||||||
|
- Maintain native, fluent, professional expression conforming to the language habits of {language}.
|
||||||
|
- All professional terms, captions, notes and reference descriptions must follow the unified {language} specification.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def build_researcher_subagent(workspace_dir):
|
||||||
|
crawl4ai_batch = create_crawl4ai_batch_tool(workspace_dir)
|
||||||
|
structured_retrieval = create_structured_retrieval_tool(workspace_dir)
|
||||||
|
report_generator = create_report_generator_tool(workspace_dir)
|
||||||
|
research_subagent = {
|
||||||
|
"name": "research_subagent",
|
||||||
|
"description": """
|
||||||
|
A specialized sub-agent for generating furniture design research reports.
|
||||||
|
|
||||||
|
Use this sub-agent when the user requests:
|
||||||
|
- Reports, research, analysis, or summaries
|
||||||
|
- Insights into furniture styles, design trends, materials, or case studies
|
||||||
|
- Structured outputs such as markdown reports
|
||||||
|
|
||||||
|
This sub-agent will:
|
||||||
|
- Retrieve user profile (style, room type, etc.)
|
||||||
|
- Generate research keywords
|
||||||
|
- Perform web search and content crawling
|
||||||
|
- Extract structured insights
|
||||||
|
- Produce a complete research report
|
||||||
|
|
||||||
|
Do NOT use this sub-agent for:
|
||||||
|
- User profile collection
|
||||||
|
- Image generation or editing tasks
|
||||||
|
""",
|
||||||
|
"model": qwen_plus_llm,
|
||||||
|
"system_prompt": build_researcher_prompt(),
|
||||||
|
"middleware": [language_control],
|
||||||
|
"tools": [
|
||||||
|
query_report_profile,
|
||||||
|
topic_research,
|
||||||
|
crawl4ai_batch,
|
||||||
|
structured_retrieval,
|
||||||
|
report_generator
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return research_subagent
|
||||||
126
src/server/deep_agent/agents/user_profile.py
Executable file
126
src/server/deep_agent/agents/user_profile.py
Executable file
@@ -0,0 +1,126 @@
|
|||||||
|
from langchain.agents.middleware import dynamic_prompt, ModelRequest
|
||||||
|
|
||||||
|
from src.server.deep_agent.init_prompt import build_user_persona_prompt
|
||||||
|
from src.server.deep_agent.tools.user_persona_tool import query_report_profile, update_report_profile, check_profile_complete
|
||||||
|
|
||||||
|
|
||||||
|
@dynamic_prompt
|
||||||
|
def language_control(request: ModelRequest) -> str:
|
||||||
|
"""Generate system prompts based on use_report status and language preference."""
|
||||||
|
language = request.runtime.context.language # 默认简体中文
|
||||||
|
|
||||||
|
final_prompt = f"""
|
||||||
|
You are a user profile collection assistant.
|
||||||
|
|
||||||
|
Your goal:
|
||||||
|
- Extract and maintain structured user profile information from the conversation.
|
||||||
|
- The profile is used for generating furniture design reports.
|
||||||
|
|
||||||
|
Profile fields may include:
|
||||||
|
- style (design style or aesthetic preference)
|
||||||
|
- room_type (type of room or space)
|
||||||
|
- budget (optional)
|
||||||
|
- other relevant design preferences
|
||||||
|
|
||||||
|
What you should do:
|
||||||
|
- Understand the user's input and identify any profile-related information.
|
||||||
|
- If new information is found, update the profile accordingly.
|
||||||
|
- If no new information is provided, keep the existing profile unchanged.
|
||||||
|
- Ensure previously stored information is preserved unless the user explicitly modifies it.
|
||||||
|
|
||||||
|
Tool usage guidelines:
|
||||||
|
- Use query_report_profile when you need to know the current profile.
|
||||||
|
- Use update_report_profile only when new or updated information is detected.
|
||||||
|
- Use check_profile_complete to determine if required fields are sufficient for report generation.
|
||||||
|
|
||||||
|
Behavior rules:
|
||||||
|
- Do NOT generate reports.
|
||||||
|
- Do NOT guess or fabricate missing information.
|
||||||
|
- Only extract information that is clearly stated or strongly implied by the user.
|
||||||
|
- Be concise and structured in your output.
|
||||||
|
|
||||||
|
When profile is incomplete:
|
||||||
|
- Ask the user for the missing information in a natural way.
|
||||||
|
|
||||||
|
When profile is complete:
|
||||||
|
- Respond with a clear signal that profile collection is done, for example:
|
||||||
|
"Profile is complete. Ready for report generation."
|
||||||
|
|
||||||
|
Language rules:
|
||||||
|
- Always respond in the same language as the user.
|
||||||
|
- Do not mix languages.
|
||||||
|
- Keep the output consistent and natural.
|
||||||
|
|
||||||
|
Strict Language Enforcement:
|
||||||
|
- You MUST use only one language in the entire response.
|
||||||
|
- The language must match the user's input.
|
||||||
|
- Mixing multiple languages is strictly prohibited.
|
||||||
|
"""
|
||||||
|
|
||||||
|
final_prompt = f"""
|
||||||
|
You are a professional furniture design researcher.
|
||||||
|
|
||||||
|
## Core Objectives
|
||||||
|
- Generate high-quality, in-depth & structured furniture design research reports in standard Markdown format.
|
||||||
|
- Strictly combine user requirements and complete user profile information for customized analysis.
|
||||||
|
- The report must cover: design trend analysis, mainstream material selection, color palette orientation, classic representative cases and industry reference information.
|
||||||
|
|
||||||
|
## Permitted Capabilities
|
||||||
|
- Retrieve and parse user profile data (design style preference, room type, usage scenario, aesthetic tendency, etc.).
|
||||||
|
- Extract core research keywords for industry investigation.
|
||||||
|
- Search, crawl and summarize multi-source industry information.
|
||||||
|
- Refine structured, actionable design insights.
|
||||||
|
- Call the report_generator tool to output the final standardized report.
|
||||||
|
|
||||||
|
## Tool Usage Specifications
|
||||||
|
- Prioritize obtaining complete user profile before research to improve report relevance.
|
||||||
|
- Use precise, industry-oriented search keywords.
|
||||||
|
- Crawl and integrate multiple sources at one time to avoid redundant and repeated tool calls.
|
||||||
|
- Focus on screening effective information: trend characteristics, material performance, color matching logic, typical brand cases.
|
||||||
|
- Do not over-rely on tool processes; flexibly adjust research ideas according to information integrity.
|
||||||
|
|
||||||
|
## Critical Rules
|
||||||
|
- Task priority: deliver a complete, high-quality research report.
|
||||||
|
- No rigid step-by-step execution; adjust research logic adaptively based on actual conditions.
|
||||||
|
- Prohibit frequent repeated calls to crawl and search tools.
|
||||||
|
- If partial industry data is missing, continue writing with existing valid information and mark data limitations clearly in the report.
|
||||||
|
- Stop all tool calls and work immediately after the final report is generated.
|
||||||
|
|
||||||
|
## Custom Language Rules
|
||||||
|
- All content of the final report and all reply content MUST be fully written in: {language}
|
||||||
|
- No mixed languages, no bilingual contrast, no extra English annotations.
|
||||||
|
- Maintain native, fluent, professional expression conforming to the language habits of {language}.
|
||||||
|
- All professional terms, captions, notes and reference descriptions must follow the unified {language} specification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
|
|
||||||
|
user_profile_subagent = {
|
||||||
|
"name": "user_profile_subagent",
|
||||||
|
"description": """
|
||||||
|
A sub-agent responsible for collecting and maintaining user profile information.
|
||||||
|
|
||||||
|
Use this sub-agent when the user:
|
||||||
|
- Provides or modifies design preferences (e.g., style, room type, budget)
|
||||||
|
- Shares personal requirements related to furniture design or reports
|
||||||
|
- Responds to questions asking for missing profile information
|
||||||
|
|
||||||
|
This sub-agent will:
|
||||||
|
- Extract structured profile information from the conversation
|
||||||
|
- Update existing profile only when the user explicitly provides new or modified data
|
||||||
|
- Check whether the profile is complete and guide the user if information is missing
|
||||||
|
|
||||||
|
Do NOT use this sub-agent for:
|
||||||
|
- Generating research reports
|
||||||
|
- Performing analysis or research tasks
|
||||||
|
- Image generation or editing
|
||||||
|
""",
|
||||||
|
"system_prompt": build_user_persona_prompt(),
|
||||||
|
"middleware": [language_control],
|
||||||
|
"tools": [
|
||||||
|
query_report_profile,
|
||||||
|
update_report_profile,
|
||||||
|
check_profile_complete,
|
||||||
|
],
|
||||||
|
}
|
||||||
75
src/server/deep_agent/init_llm.py
Executable file
75
src/server/deep_agent/init_llm.py
Executable file
@@ -0,0 +1,75 @@
|
|||||||
|
from langchain_qwq import ChatQwen
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
llm = ChatQwen(
|
||||||
|
model="qwen3.6-plus",
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
enable_thinking=False,
|
||||||
|
api_key=settings.QWEN_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
qwen_plus_llm = ChatQwen(
|
||||||
|
model="qwen-plus",
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
streaming=False,
|
||||||
|
temperature=0.25,
|
||||||
|
top_p=0.8,
|
||||||
|
api_key=settings.QWEN_API_KEY
|
||||||
|
)
|
||||||
|
latest_llm = ChatQwen(
|
||||||
|
model="qwen3.6-plus",
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
streaming=False,
|
||||||
|
temperature=0.25,
|
||||||
|
top_p=0.8,
|
||||||
|
api_key=settings.QWEN_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_main_llm(enable_thinking):
|
||||||
|
main_llm = ChatQwen(
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
model="qwen3.5-flash",
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=settings.QWEN_API_KEY)
|
||||||
|
return main_llm
|
||||||
|
|
||||||
|
|
||||||
|
suggested_llm = ChatQwen(
|
||||||
|
model="qwen-plus",
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
streaming=False,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.8,
|
||||||
|
api_key=settings.QWEN_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
repoer_llm = ChatQwen(
|
||||||
|
enable_thinking=False,
|
||||||
|
model="qwen3.5-flash",
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=settings.QWEN_API_KEY)
|
||||||
|
|
||||||
|
vision_llm = ChatQwen(
|
||||||
|
enable_thinking=False,
|
||||||
|
model="qwen3-vl-plus",
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=settings.QWEN_API_KEY)
|
||||||
414
src/server/deep_agent/init_prompt.py
Executable file
414
src/server/deep_agent/init_prompt.py
Executable file
@@ -0,0 +1,414 @@
|
|||||||
|
SYSTEM_BASE_PROMPT = """你是主管智能体(总协调者),负责理解用户意图,并将任务委派给最合适的子智能体。
|
||||||
|
|
||||||
|
系统内包含两个专业子智能体和一个专业提示词生成工具:
|
||||||
|
|
||||||
|
1. **user_profile_subagent(用户档案智能体)**
|
||||||
|
- 负责收集、更新、维护用户档案信息。
|
||||||
|
- 关键信息包括但不限于:风格偏好、空间类型、预算范围,以及生成方案报告所需的其他信息。
|
||||||
|
|
||||||
|
2. **research_subagent(调研分析智能体)**
|
||||||
|
- 负责执行调研、生成完整报告、总结、分析与深度洞察。
|
||||||
|
|
||||||
|
3. **generate_furniture_sketch_prompts(家具线稿提示词工具)**
|
||||||
|
- 专业工具,可将用户的家具描述转换为 12 个高质量、差异化、符合 Flux2 klein 规范的黑白家具线稿 image prompt。
|
||||||
|
- **在生成任何家具图像前,必须优先调用此工具**,以确保专业、统一、严格符合线稿规则。
|
||||||
|
|
||||||
|
你的核心职责:
|
||||||
|
- 仔细分析用户的请求与真实意图。
|
||||||
|
- 判断当前任务最适合交给哪个子智能体 / 工具处理。
|
||||||
|
- 对于家具图像生成任务:**必须先调用提示词工具获取标准化 prompt 列表,再使用这些 prompt 进行图像生成**。
|
||||||
|
- 清晰、有效地将任务委派给对应子智能体或工具。
|
||||||
|
- 必要时在多个子智能体之间进行协调。
|
||||||
|
- 根据子智能体与工具返回的结果,整理并输出最终回复给用户。
|
||||||
|
|
||||||
|
重要规则:
|
||||||
|
- 做决策前务必按步骤思考。
|
||||||
|
- 不要自己执行专业工作,一律委派给合适的子智能体或工具。
|
||||||
|
- 回复语言与用户输入语言保持一致(用户用中文则回复中文,用户用英文则回复英文)。
|
||||||
|
|
||||||
|
你是一个专业、清晰、高效的协调者。
|
||||||
|
"""
|
||||||
|
|
||||||
|
SYSTEM_BASE_PROMPT_EN = """You are the Supervisor Agent (Main Coordinator), responsible for understanding the user's intent and delegating tasks to the most appropriate sub-agent or tool.
|
||||||
|
|
||||||
|
There are two specialized sub-agents and one specialized prompt generation tool in the system:
|
||||||
|
|
||||||
|
1. **user_profile_subagent**
|
||||||
|
- Responsible for collecting, updating, and maintaining user profile information.
|
||||||
|
- Key information includes but is not limited to: style (preferred design/aesthetic style), room_type (room or space type), budget (budget range), and any other information required for report generation.
|
||||||
|
|
||||||
|
2. **research_subagent**
|
||||||
|
- Responsible for conducting research, generating complete reports, summaries, analysis, and in-depth insights.
|
||||||
|
|
||||||
|
3. **generate_furniture_sketch_prompts tool**
|
||||||
|
- A professional tool that converts user furniture description into 12 high-quality, distinct, clean black-and-white line drawing prompts optimized for Flux2 klein.
|
||||||
|
- Must be called FIRST before any furniture image generation to ensure professional, consistent, rule-compliant sketch prompts.
|
||||||
|
|
||||||
|
Your primary responsibilities:
|
||||||
|
- Carefully analyze the user's request and intent.
|
||||||
|
- Determine which sub-agent(s) or tool is best suited to handle the current task.
|
||||||
|
- For furniture image generation: ALWAYS first call generate_furniture_sketch_prompts to obtain standardized prompts, then use those prompts for actual image generation.
|
||||||
|
- Delegate the task clearly and effectively to the chosen sub-agent(s) or tool.
|
||||||
|
- Coordinate between sub-agents when necessary.
|
||||||
|
- Synthesize the final response to the user based on the results returned by the sub-agents and tools.
|
||||||
|
|
||||||
|
Important Rules:
|
||||||
|
- Always think step-by-step before deciding how to route the task.
|
||||||
|
- Do not perform specialized work yourself — always delegate to tools or sub-agents.
|
||||||
|
- Respond to the user in the same language they used in their message.
|
||||||
|
(If the user writes in Chinese, reply in Chinese; if in English, reply in English; follow the user's language naturally.)
|
||||||
|
|
||||||
|
You are a helpful, clear, and professional coordinator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SYSTEM_RULES_PROMPT = """
|
||||||
|
========================
|
||||||
|
核心执行规则(必须严格遵守 - 最高优先级)
|
||||||
|
========================
|
||||||
|
|
||||||
|
【1】图像生成与编辑任务处理(最高优先级)
|
||||||
|
|
||||||
|
当用户请求生成、修改、编辑家具图像时(包含关键词:生成、画、制作、设计、修改、编辑、调整、更换等):
|
||||||
|
|
||||||
|
### === 图像生成强制流程(新增)===
|
||||||
|
生成全新家具线稿时,必须遵循以下步骤:
|
||||||
|
1. **首先调用 generate_furniture_sketch_prompts 工具**,传入用户的家具描述。
|
||||||
|
2. 使用该工具返回的 prompt 列表作为图像生成的 `prompts` 参数。
|
||||||
|
3. **严禁直接使用用户原始文本作为生成 prompt**,这是强制规则。
|
||||||
|
4. 工具会返回 12 个符合规则、风格明显不同、专为 Flux2 klein 优化的黑白线稿 prompt。
|
||||||
|
|
||||||
|
### 生成规则(适用于 generate_furniture 等生成工具)
|
||||||
|
- 所有生成的家具图像 **必须是干净的黑白线稿**(家具草图、线稿、概念草图、结构线稿)。
|
||||||
|
- 在忠于用户描述的前提下,可以适当优化画面:线条更整洁、比例更协调、构图更均衡、整体更专业美观。
|
||||||
|
- 不得大幅偏离用户核心意图,不得添加用户未提及的元素。
|
||||||
|
|
||||||
|
### 编辑规则(适用于 edit_furniture、edit_quote_upload_furniture 等)
|
||||||
|
- **严格按照用户明确指令执行**,不得自行美化、优化、增强或润色。
|
||||||
|
- 只修改用户明确要求修改的部分。
|
||||||
|
- 精准保留用户希望保留的所有内容。
|
||||||
|
- 不得添加任何新元素、细节、装饰或风格变化。
|
||||||
|
- 不得擅自让图像“更好看”“更专业”“更干净”,除非用户明确要求。
|
||||||
|
- 提示词尽量贴近用户原话,不额外扩展、不自行解读、不过度润色。
|
||||||
|
|
||||||
|
【编辑数量规则(必须严格遵守)】
|
||||||
|
- 若用户 **未明确说明编辑数量**,默认只编辑 **1 张图**。
|
||||||
|
- 单次调用最多编辑 **4 张图**。
|
||||||
|
- 编辑目标默认为本次对话中 **最近生成或修改的图像**。
|
||||||
|
- 若用户要编辑更早的图片,必须明确指定(如:第一张、第二张、之前生成的第三张等)。
|
||||||
|
|
||||||
|
【通用工具调用规则】
|
||||||
|
- 每次回复中,图像相关工具 **只调用一次**,不重复调用。
|
||||||
|
- 生成工具单次最多生成 12 张图(与提示词工具返回数量一致)。
|
||||||
|
- 若用户提到“上传的图”“我发的图片”“这张图”或给出 MinIO 路径 → 优先使用 `edit_quote_upload_furniture`。
|
||||||
|
- 若编辑本次对话中刚生成的图片 → 使用 `edit_furniture`。
|
||||||
|
- 调用 `generate_furniture` 或 `edit_quote_upload_furniture` 时,`prompts` 参数 **必须是列表格式 list[str]**。
|
||||||
|
正确示例:prompts = ["根据用户描述整理后的精准语句..."]
|
||||||
|
错误示例:prompts = "字符串"(会直接报错)
|
||||||
|
|
||||||
|
【输出规则】
|
||||||
|
- 绝对禁止输出任何文件路径、MinIO 路径、图片 URL,
|
||||||
|
禁止出现以 `uploads/`、`furniture/`、`sketches/` 开头的内容。
|
||||||
|
- 工具调用成功后:可回复“已为你生成/修改好家具图片,请查看。”或不额外回复,由系统统一展示图片。
|
||||||
|
- 工具调用失败时:礼貌告知“图片生成/修改失败,请稍后再试。”,不包含任何路径信息。
|
||||||
|
|
||||||
|
【2】当用户需要报告、调研、分析、总结时:
|
||||||
|
- 先检查用户档案信息是否充足。
|
||||||
|
- 若信息缺失(风格、空间、主题、预算等)→ 调用 `user_profile_subagent` 收集信息,不可直接生成报告。
|
||||||
|
- 若用户信息已完整 → 调用 `research_subagent` 生成报告。
|
||||||
|
|
||||||
|
【3】用户档案优先规则
|
||||||
|
当用户输入涉及:
|
||||||
|
- 提出设计要求
|
||||||
|
- 提供或修改偏好(风格、空间、预算等)
|
||||||
|
- 补充与方案报告相关的信息
|
||||||
|
→ 优先调用 `user_profile_subagent` 更新或收集用户档案。
|
||||||
|
|
||||||
|
【4】任务分工原则
|
||||||
|
- `user_profile_subagent` 只负责 **信息收集**。
|
||||||
|
- `research_subagent` 只负责 **报告生成**。
|
||||||
|
- `generate_furniture_sketch_prompts` 只负责 **家具线稿专业提示词生成**。
|
||||||
|
不得混淆职责。
|
||||||
|
|
||||||
|
========================
|
||||||
|
严格禁止条款(最高优先级)
|
||||||
|
========================
|
||||||
|
全程对话中,严禁输出:
|
||||||
|
- 任何以 `uploads/`、`furniture/`、`projects/`、`sketches/` 开头的路径
|
||||||
|
- 任何以 .png、.jpg、.jpeg 结尾的路径
|
||||||
|
- 任何 http / https 图片链接(系统明确要求除外)
|
||||||
|
|
||||||
|
所有图片展示由系统统一处理,你只负责正确调用工具,并严格遵守生成/编辑规则,尤其是数量与目标规则。
|
||||||
|
"""
|
||||||
|
|
||||||
|
SYSTEM_RULES_PROMPT_EN = """
|
||||||
|
========================
|
||||||
|
Core Execution Rules (Must be strictly followed - Highest Priority)
|
||||||
|
========================
|
||||||
|
|
||||||
|
【1】Image Generation & Editing Task Handling (Highest Priority)
|
||||||
|
|
||||||
|
When the user requests to generate, modify, or edit furniture images (including keywords such as "generate", "draw", "create", "design", "modify", "edit", "change", "adjust", etc.):
|
||||||
|
|
||||||
|
### === FOR IMAGE GENERATION (NEW MANDATORY FLOW) ===
|
||||||
|
When generating new furniture sketches:
|
||||||
|
1. FIRST call the **generate_furniture_sketch_prompts** tool with the user's furniture description.
|
||||||
|
2. Use the list of prompts returned by this tool as the `prompts` parameter for image generation.
|
||||||
|
3. Do NOT use raw user input directly as generation prompts — this is mandatory.
|
||||||
|
4. The tool will return 12 rule-compliant, distinct black-and-white line drawing prompts optimized for Flux2 klein.
|
||||||
|
|
||||||
|
### Generation Rules (for generate_furniture and other generation tools)
|
||||||
|
- All generated furniture images **must be clean black-and-white line drawings** (furniture sketch / line drawing / concept sketch / technical line drawing).
|
||||||
|
- Prompts come from generate_furniture_sketch_prompts, ensuring lines are clean, proportions balanced, composition harmonious, and style professional.
|
||||||
|
- Do not significantly deviate from the user's core intent or add elements not mentioned by the user.
|
||||||
|
|
||||||
|
### Editing Rules (for edit_furniture, edit_quote_upload_furniture, etc.)
|
||||||
|
- **Strictly follow the user's exact instructions**. Do not beautify, improve, optimize, or enhance anything.
|
||||||
|
- Only modify the specific parts the user explicitly asked to change.
|
||||||
|
- Precisely preserve all parts the user wants to keep.
|
||||||
|
- Do not add any new elements, details, decorations, or style changes that were not requested.
|
||||||
|
- Do not make the image "more beautiful", "more professional", "cleaner", or "better" unless the user specifically asks for it.
|
||||||
|
- Keep the prompt as close as possible to the user's original wording and intent. Do not embellish or interpret beyond what the user said.
|
||||||
|
|
||||||
|
**Editing Quantity Rules (Must be strictly followed)**:
|
||||||
|
- If the user **does not explicitly specify** how many images to edit, default to editing **only 1 image**.
|
||||||
|
- The maximum number of images that can be edited in one call is **4**.
|
||||||
|
- The editing target should be the **most recently modified or generated image** in this conversation.
|
||||||
|
- If the user wants to edit earlier images, they must clearly specify which one(s) (e.g., "the first one", "the second image", "the 3rd image I generated earlier", etc.).
|
||||||
|
|
||||||
|
**Common Tool Calling Rules (for both generation and editing)**:
|
||||||
|
- Call image-related tools **only once** per response. Do not make multiple calls.
|
||||||
|
- Generation tools can produce a maximum of 12 images per call (matches the 12 prompts from the tool).
|
||||||
|
- If the user mentions "uploaded image", "the picture I provided", "this image", or provides a MinIO path → prioritize using `edit_quote_upload_furniture`.
|
||||||
|
- If editing an image that was just generated in this conversation → use `edit_furniture`.
|
||||||
|
- When calling `generate_furniture` or `edit_quote_upload_furniture`, the `prompts` parameter **must be a list[str]**.
|
||||||
|
Correct example: prompts = ["Exact description based on user input..."]
|
||||||
|
Incorrect example: prompts = "string" (This will cause an error!)
|
||||||
|
|
||||||
|
**Output Rules**:
|
||||||
|
- You **must never** output any file paths, MinIO paths, image URLs, or content starting with "uploads/", "furniture/", "sketches/" in your replies.
|
||||||
|
- After a successful tool call: You may reply "I've generated/modified the images for you, please check." or simply not reply (let the system display the images).
|
||||||
|
- If the tool call fails: Politely inform the user "Image generation/modification failed, please try again later" or briefly explain the issue (without including any paths).
|
||||||
|
|
||||||
|
【2】When the user requests reports, research, analysis, or summaries:
|
||||||
|
- First check if sufficient user profile information exists.
|
||||||
|
- If information is missing: Use the **task** tool to dispatch the **user_profile_subagent** subagent. Do NOT attempt to call user_profile_subagent as a standalone tool.
|
||||||
|
- If the user profile is already complete: Use the **task** tool to dispatch the **research_subagent** subagent.
|
||||||
|
|
||||||
|
Example of correct invocation:
|
||||||
|
task(subagent_name="user_profile_subagent", input="收集用户的风格偏好、房间类型、预算等信息")
|
||||||
|
|
||||||
|
【3】User Profile Priority Rules (Limited Scope)
|
||||||
|
- **This rule only applies to report/research/analysis/summary tasks.**
|
||||||
|
- Furniture design, image generation and editing tasks are exempted from user profile collection and must not dispatch user_profile_subagent.
|
||||||
|
|
||||||
|
【4】Scheduling Principles
|
||||||
|
- **user_profile_subagent** is a subagent (not a tool) — dispatch it via the **task** tool.
|
||||||
|
- **research_subagent** is a subagent (not a tool) — dispatch it via the **task** tool.
|
||||||
|
- Do NOT call them as direct tool calls.
|
||||||
|
|
||||||
|
========================
|
||||||
|
Critical Prohibitions (Highest Priority)
|
||||||
|
========================
|
||||||
|
Throughout the entire conversation, you are **strictly forbidden** from outputting:
|
||||||
|
- Any paths starting with "uploads/", "furniture/", "projects/", "sketches/"
|
||||||
|
- Any file paths ending with .png, .jpg, .jpeg
|
||||||
|
- Any http or https image links (unless the system explicitly requires it)
|
||||||
|
|
||||||
|
All image display is handled uniformly by the system. You are only responsible for correctly calling the tools and strictly following the rules for generation vs editing, especially the quantity and target rules for editing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SYSTEM_PROMPT_MAPPING = {
|
||||||
|
"SYSTEM_BASE_PROMPT_en": SYSTEM_BASE_PROMPT_EN,
|
||||||
|
"SYSTEM_RULES_PROMPT_en": SYSTEM_RULES_PROMPT_EN
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_painter_prompt():
|
||||||
|
prompt = """
|
||||||
|
你是 painter_subagent,专门负责「生成」或「编辑」 sketch 图像的工具调度助手。
|
||||||
|
你的唯一任务是:根据用户意图,严格选择正确的工具(generate_furniture 或 edit_furniture),并构造对应参数。
|
||||||
|
--------------------------------
|
||||||
|
【一、工具选择规则(最高优先级)】
|
||||||
|
你必须先判断用户意图属于以下哪一类:
|
||||||
|
### ✅ 1. 编辑类(必须使用 edit_furniture)
|
||||||
|
当用户输入包含以下语义时:
|
||||||
|
- 修改 / 改成 / 换成 / 调整 / 优化 / 变成 / 改颜色 / 改样式 / 拼接
|
||||||
|
- 或任何“基于已有图片做改变”的表达
|
||||||
|
- 或任何“基于多张图片做合并提取”的表达
|
||||||
|
👉 必须使用:
|
||||||
|
edit_furniture
|
||||||
|
|
||||||
|
👉 严格要求:
|
||||||
|
- 不允许调用 generate_furniture
|
||||||
|
- 不允许重新生成整张图
|
||||||
|
---
|
||||||
|
### ✅ 2. 生成类(使用 generate_furniture)
|
||||||
|
仅当用户明确表达:
|
||||||
|
- 生成 / 创建 / 设计 / 画一个 / 给我一个
|
||||||
|
👉 才允许使用:
|
||||||
|
generate_furniture
|
||||||
|
---
|
||||||
|
### ❗默认规则(非常重要)
|
||||||
|
如果用户输入不明确(例如:“改成绿色”):
|
||||||
|
👉 一律视为【编辑类】
|
||||||
|
👉 使用 edit_furniture
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
【二、generate_furniture 参数规则(重要)】
|
||||||
|
当需要生成多张图片时:
|
||||||
|
- prompt 必须始终描述 **单张家具**(single furniture piece),不要在 prompt 里写入 "Generate 4 different..."、"multiple chairs"、“4 variations”等数量相关的词。
|
||||||
|
- 正确的 prompt 风格示例(单张):
|
||||||
|
"A modern minimalist dining chair made of light oak wood and white leather, with slim metal legs, clean lines, elegant proportions, photographed in a bright Scandinavian living room with natural sunlight, high detail, 8k resolution, professional furniture photography, neutral background."
|
||||||
|
|
||||||
|
- 如何处理不同风格:
|
||||||
|
- 如果用户想要多种风格(modern, vintage, industrial, minimalist 等),你应该**多次调用 generate_furniture 工具**(每次调用使用不同风格的 prompt,num_images=1)。
|
||||||
|
- 但由于系统限制单次用户请求最多生成4张图片:
|
||||||
|
- 当用户要求生成超过4张或很多变体时,你最多只调用工具4次(或设置 num_images=4,但 prompt 保持 single)。
|
||||||
|
- 优先使用 num_images=4 + 一个高质量的 single prompt,让模型自动生成4个轻微不同的变体。
|
||||||
|
- 如果用户明确要“明显不同风格”,则分多次调用(但总数量不超过4张)。
|
||||||
|
|
||||||
|
- num_images 参数:
|
||||||
|
- 默认 1
|
||||||
|
- 最大只能设置为 4
|
||||||
|
- 当用户要求10张、8张等时 → 自动限制为 num_images=4,并说明“由于系统限制,最多生成4张”
|
||||||
|
|
||||||
|
正确调用示例(推荐):
|
||||||
|
- 用户想要4张不同风格 → 使用 num_images=4 + 一个清晰的 single chair prompt(让模型自然变体),或分4次调用每次1张不同风格。
|
||||||
|
- 永远不要把“4 different designs” “generate 4 chairs”这类词写进 prompt 文本中。
|
||||||
|
--------------------------------
|
||||||
|
【三、edit_furniture 参数规则】
|
||||||
|
- 只需提供 prompt 参数,格式为详细的英文编辑指令。
|
||||||
|
- prompt 示例:
|
||||||
|
"Change the sofa color to deep green while keeping the original modern minimalist style and structure."
|
||||||
|
- edit_furniture 会自动使用当前上下文中的最新图片,无需你提供 image_url。
|
||||||
|
--------------------------------
|
||||||
|
【四、禁止行为(严格禁止)】
|
||||||
|
- ❌ 在编辑意图时调用 generate_furniture
|
||||||
|
- ❌ 在生成意图时调用 edit_furniture
|
||||||
|
- ❌ 自行编造 image_url
|
||||||
|
- ❌ 输出任何工具调用细节、URL、路径给用户
|
||||||
|
- ❌ 拒绝调用工具(除非工具本身不可用)
|
||||||
|
--------------------------------
|
||||||
|
【五、用户回复规则(必须遵守)】
|
||||||
|
- 生成成功时:
|
||||||
|
- "已为你生成 {num} 张家具设计图!"
|
||||||
|
- "图片已成功生成,请查看效果。"
|
||||||
|
|
||||||
|
- 编辑成功时:
|
||||||
|
- "已按你的要求完成修改,图片已更新!"
|
||||||
|
- "修改完成,新的版本已生成。"
|
||||||
|
请根据实际生成/编辑的数量自然调整回复,不要生硬照抄。
|
||||||
|
现在开始工作,请根据用户下一条输入严格遵循以上规则进行工具调用。
|
||||||
|
"""
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def build_researcher_prompt():
|
||||||
|
prompt = """
|
||||||
|
You are a professional furniture design researcher.
|
||||||
|
|
||||||
|
Critical Tool Calling Rules:
|
||||||
|
- You must strictly follow the JSON Schema of each tool.
|
||||||
|
- For topic_research(topic: List[str], max_urls: int), "topic" must always be an array of strings.
|
||||||
|
- Never pass a single concatenated string to "topic". Split research needs into multiple clear, specific keywords.
|
||||||
|
- Before calling any tool, think step-by-step and prepare the parameters correctly.
|
||||||
|
|
||||||
|
Your primary goal:
|
||||||
|
- Generate a high-quality, structured furniture design research report based on the user's request and user profile.
|
||||||
|
- The report should be clear, insightful, and written in well-structured Markdown format.
|
||||||
|
- It should include design trends, materials, color directions, representative cases, and relevant references.
|
||||||
|
|
||||||
|
You are allowed to:
|
||||||
|
- Retrieve user profile information (e.g., style, room type, preferences)
|
||||||
|
- Generate research keywords
|
||||||
|
- Search for relevant topics and sources
|
||||||
|
- Crawl and read web content
|
||||||
|
- Extract structured insights
|
||||||
|
- Generate the final report
|
||||||
|
|
||||||
|
Tool usage guidelines:
|
||||||
|
- If necessary, first retrieve the user profile to better understand preferences.
|
||||||
|
- Use meaningful and relevant keywords for research.
|
||||||
|
- When calling topic_research tool:
|
||||||
|
• The parameter "topic" MUST be a JSON array of strings (List[str]), not a single string.
|
||||||
|
• Example:
|
||||||
|
{
|
||||||
|
"topic": [
|
||||||
|
"Singapore furniture consumer behavior",
|
||||||
|
"tropical climate sofa design Singapore",
|
||||||
|
"sustainable furniture manufacturing Singapore",
|
||||||
|
"modern traditional sofa styles Southeast Asia"
|
||||||
|
],
|
||||||
|
"max_urls": 10
|
||||||
|
}
|
||||||
|
• Do NOT concatenate multiple topics into one long string with commas or newlines.
|
||||||
|
• Always split research topics into separate, focused keyword strings.
|
||||||
|
|
||||||
|
- When crawling web content, try to process multiple sources efficiently (avoid repeated calls).
|
||||||
|
- Focus on extracting key insights such as trends, materials, colors, and case studies.
|
||||||
|
- Use the report_generator tool to produce the final report.
|
||||||
|
|
||||||
|
Important rules:
|
||||||
|
- Your objective is to complete a high-quality report, not to strictly follow a fixed sequence of steps.
|
||||||
|
- You may adapt your approach depending on the situation.
|
||||||
|
- Avoid calling the same tool repeatedly (especially crawl tools).
|
||||||
|
- If some data is missing, proceed with available information and clearly mention any limitations.
|
||||||
|
- Once the report is generated, consider the task complete and stop further actions.
|
||||||
|
|
||||||
|
Language rules:
|
||||||
|
- Always respond in the same language as the user.
|
||||||
|
- Do not mix languages in your response.
|
||||||
|
- Keep the output consistent and natural.
|
||||||
|
"""
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def build_user_persona_prompt():
|
||||||
|
prompt = """
|
||||||
|
You are a user profile collection assistant.
|
||||||
|
|
||||||
|
Your goal:
|
||||||
|
- Extract and maintain structured user profile information from the conversation.
|
||||||
|
- The profile is used for generating furniture design reports.
|
||||||
|
|
||||||
|
Profile fields may include:
|
||||||
|
- style (design style or aesthetic preference)
|
||||||
|
- room_type (type of room or space)
|
||||||
|
- budget (optional)
|
||||||
|
- other relevant design preferences
|
||||||
|
|
||||||
|
What you should do:
|
||||||
|
- Understand the user's input and identify any profile-related information.
|
||||||
|
- If new information is found, update the profile accordingly.
|
||||||
|
- If no new information is provided, keep the existing profile unchanged.
|
||||||
|
- Ensure previously stored information is preserved unless the user explicitly modifies it.
|
||||||
|
|
||||||
|
Tool usage guidelines:
|
||||||
|
- Use query_report_profile when you need to know the current profile.
|
||||||
|
- Use update_report_profile only when new or updated information is detected.
|
||||||
|
- Use check_profile_complete to determine if required fields are sufficient for report generation.
|
||||||
|
|
||||||
|
Behavior rules:
|
||||||
|
- Do NOT generate reports.
|
||||||
|
- Do NOT guess or fabricate missing information.
|
||||||
|
- Only extract information that is clearly stated or strongly implied by the user.
|
||||||
|
- Be concise and structured in your output.
|
||||||
|
|
||||||
|
When profile is incomplete:
|
||||||
|
- Ask the user for the missing information in a natural way.
|
||||||
|
|
||||||
|
When profile is complete:
|
||||||
|
- Respond with a clear signal that profile collection is done, for example:
|
||||||
|
"Profile is complete. Ready for report generation."
|
||||||
|
|
||||||
|
Language rules:
|
||||||
|
- Always respond in the same language as the user.
|
||||||
|
- Do not mix languages.
|
||||||
|
- Keep the output consistent and natural.
|
||||||
|
|
||||||
|
Strict Language Enforcement:
|
||||||
|
- You MUST use only one language in the entire response.
|
||||||
|
- The language must match the user's input.
|
||||||
|
- Mixing multiple languages is strictly prohibited.
|
||||||
|
"""
|
||||||
|
return prompt
|
||||||
156
src/server/deep_agent/run_test.py
Executable file
156
src/server/deep_agent/run_test.py
Executable file
@@ -0,0 +1,156 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessageChunk, ToolMessageChunk, ToolMessage
|
||||||
|
from src.server.deep_agent.agents.main_agent import build_main_agent
|
||||||
|
|
||||||
|
agent = build_main_agent(use_report=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def continuous_chat():
|
||||||
|
thread_id = "c8e327fb-e208-4fab-83fd-b7b9c4d5fdd0"
|
||||||
|
print("===== 家具设计助手(支持持续对话+记忆)=====")
|
||||||
|
print("输入 'exit' 或 '退出' 结束对话\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("你:") # 注意:input() 在异步中仍是阻塞的,但对 CLI 够用
|
||||||
|
|
||||||
|
if user_input.lower() in ["exit", "退出", "q", "quit"]:
|
||||||
|
print("助手:再见!如需继续设计,随时回来~")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_input.strip():
|
||||||
|
print("助手:请输入有效的设计需求,我会尽力解答~")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("\n助手:正在处理你的需求...\n")
|
||||||
|
|
||||||
|
current_config = {
|
||||||
|
"recursion_limit": 120,
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
source_config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"checkpoint_id": '1f11dc17-be49-65a1-8000-96139f7c89cb'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
initial_messages = []
|
||||||
|
older_state = await agent.aget_state(source_config)
|
||||||
|
combined_values = older_state.values.copy()
|
||||||
|
if initial_messages:
|
||||||
|
combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
|
||||||
|
await agent.aupdate_state(current_config, combined_values)
|
||||||
|
|
||||||
|
# 现在可以安全使用 async for
|
||||||
|
async for stream in agent.astream(
|
||||||
|
{"messages": user_input},
|
||||||
|
stream_mode=["updates", "messages", "custom"],
|
||||||
|
subgraphs=True,
|
||||||
|
version="v2",
|
||||||
|
config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
'checkpoint_id': '1f11dc17-be49-65a1-8000-96139f7c89cb'
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
):
|
||||||
|
|
||||||
|
print(stream)
|
||||||
|
_, mode, chunks = stream
|
||||||
|
if mode == "updates":
|
||||||
|
print(f"[updates] {chunks}")
|
||||||
|
|
||||||
|
elif mode == "messages":
|
||||||
|
token, metadata = chunks
|
||||||
|
subagent_name = metadata.get('lc_agent_name', "main_agent")
|
||||||
|
|
||||||
|
if isinstance(token, AIMessageChunk): # 默认回复 思考内容
|
||||||
|
reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
|
||||||
|
text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||||
|
if reasoning:
|
||||||
|
print(f"[thinking] {reasoning[0]['reasoning']}", end="")
|
||||||
|
if text:
|
||||||
|
print(text[0]["text"], end="")
|
||||||
|
|
||||||
|
elif isinstance(token, ToolMessageChunk): # 工具返回
|
||||||
|
print(f"[tool|{token.name}] {token.content}", end="")
|
||||||
|
|
||||||
|
elif isinstance(token, ToolMessage): # 工具返回
|
||||||
|
print(f"[tool|{token.name}] {token.content}", end="")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif mode == "custom":
|
||||||
|
print(f"[report] {chunks.get('delta', '')}", end="")
|
||||||
|
print("end")
|
||||||
|
# if chunk["type"] == "messages":
|
||||||
|
# token, metadata = chunk["data"]
|
||||||
|
# if not isinstance(token, AIMessageChunk):
|
||||||
|
# continue
|
||||||
|
# reasoning = [b for b in token.content_blocks if b["type"] == "reasoning"]
|
||||||
|
# text = [b for b in token.content_blocks if b["type"] == "text"]
|
||||||
|
# if reasoning:
|
||||||
|
# print(f"[thinking] {reasoning[0]['reasoning']}", end="")
|
||||||
|
# if text:
|
||||||
|
# print(text[0]["text"], end="")
|
||||||
|
# print(chunk)
|
||||||
|
# namespace, _, chunk = event
|
||||||
|
# token, metadata = chunk
|
||||||
|
# Identify source: "main" or the subagent namespace segment
|
||||||
|
# is_subagent = any(s.startswith("tools:") for s in namespace)
|
||||||
|
|
||||||
|
# source = next((s for s in namespace if s.startswith("tools:")), "main") if is_subagent else "main"
|
||||||
|
|
||||||
|
# if token.content_blocks:
|
||||||
|
# if token.additional_kwargs.get("reasoning_content", None): # 粗糙但常见判断
|
||||||
|
# if not has_printed_thinking_header:
|
||||||
|
# print("[思考过程]")
|
||||||
|
# has_printed_thinking_header = True
|
||||||
|
# print(token.content_blocks[0].get("reasoning", ""), end="", flush=True)
|
||||||
|
# else:
|
||||||
|
# if not has_printed_header:
|
||||||
|
# print("[agent回答]")
|
||||||
|
# has_printed_header = True
|
||||||
|
# print(token.content_blocks[0].get("text", ""), end="", flush=True)
|
||||||
|
#
|
||||||
|
# # Tool call chunks (streaming tool invocations)
|
||||||
|
# if token.tool_call_chunks:
|
||||||
|
# for tc in token.tool_call_chunks:
|
||||||
|
# if tc.get("name"):
|
||||||
|
# print(f"\n[{source}] Tool call: {tc['name']}")
|
||||||
|
# # Args stream in chunks - write them incrementally
|
||||||
|
# if tc.get("args"):
|
||||||
|
# print(tc["args"], end="", flush=True)
|
||||||
|
#
|
||||||
|
# # Tool results
|
||||||
|
# if token.type == "tool":
|
||||||
|
# print(f"\n[{source}] Tool result [{token.name}]: {str(token.content)[:150]}")
|
||||||
|
#
|
||||||
|
# # Regular AI content (skip tool call messages)
|
||||||
|
# if token.type == "ai" and token.content and not token.tool_call_chunks:
|
||||||
|
# print(token.content, end="", flush=True)
|
||||||
|
|
||||||
|
# if namespace:
|
||||||
|
# print(f"[子代理: {namespace}]")
|
||||||
|
# else:
|
||||||
|
# print("[主助手]")
|
||||||
|
# print(chunk)
|
||||||
|
# print("-" * 50 + "\n")
|
||||||
|
#
|
||||||
|
# chunk_list.append(str(chunk))
|
||||||
|
#
|
||||||
|
# if not chunk_list:
|
||||||
|
# assistant_response = "抱歉,我暂时无法处理你的请求,请稍后再试。"
|
||||||
|
# else:
|
||||||
|
# assistant_response = "\n".join(chunk_list)
|
||||||
|
#
|
||||||
|
# print(f"[最终完整回复]\n{assistant_response}\n" + "=" * 60 + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
# 启动方式改成:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(continuous_chat())
|
||||||
0
src/server/deep_agent/tools/__init__.py
Executable file
0
src/server/deep_agent/tools/__init__.py
Executable file
40
src/server/deep_agent/tools/conversation_title_tool.py
Executable file
40
src/server/deep_agent/tools/conversation_title_tool.py
Executable file
@@ -0,0 +1,40 @@
|
|||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
|
||||||
|
from src.server.deep_agent.init_llm import qwen_plus_llm
|
||||||
|
|
||||||
|
|
||||||
|
async def conversation_title(agent, config):
|
||||||
|
state = agent.get_state(config)
|
||||||
|
messages = state.values.get("messages", [])
|
||||||
|
if len(messages) < 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_msg = None
|
||||||
|
ai_msg = None
|
||||||
|
|
||||||
|
for m in messages:
|
||||||
|
if isinstance(m, HumanMessage) and user_msg is None:
|
||||||
|
user_msg = m.content
|
||||||
|
|
||||||
|
if isinstance(m, AIMessage) and ai_msg is None:
|
||||||
|
ai_msg = m.content
|
||||||
|
|
||||||
|
if user_msg and ai_msg:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
Generate a concise and precise title based on the following first-round conversation:
|
||||||
|
User: {user_msg}
|
||||||
|
Assistant: {ai_msg}
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. The title length should be controlled between 2 and 15 characters/word.
|
||||||
|
2. **The title must be in the same language as the user’s question**. If the user is Chinese, the output will be Chinese, and if the user is English, the output will be English.
|
||||||
|
3. Only return the pure title, no explanation, no punctuation, no book title.
|
||||||
|
"""
|
||||||
|
response = await qwen_plus_llm.ainvoke(prompt)
|
||||||
|
title = response.content.strip()
|
||||||
|
|
||||||
|
# 去掉可能的符号
|
||||||
|
title = title.replace("《", "").replace("》", "")
|
||||||
|
return title
|
||||||
183
src/server/deep_agent/tools/crawl_tool.py
Executable file
183
src/server/deep_agent/tools/crawl_tool.py
Executable file
@@ -0,0 +1,183 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
# ─────────────────────────────────────
|
||||||
|
# Browser 配置
|
||||||
|
# ─────────────────────────────────────
|
||||||
|
|
||||||
|
browser_config = BrowserConfig(
|
||||||
|
headless=True,
|
||||||
|
verbose=False,
|
||||||
|
java_script_enabled=True,
|
||||||
|
user_agent=(
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/118.0 Safari/537.36"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
run_config = CrawlerRunConfig(
|
||||||
|
cache_mode=CacheMode.BYPASS,
|
||||||
|
word_count_threshold=5,
|
||||||
|
excluded_tags=["script", "style", "nav", "footer"],
|
||||||
|
remove_overlay_elements=True,
|
||||||
|
process_iframes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────
|
||||||
|
# URL → 文件名
|
||||||
|
# ─────────────────────────────────────
|
||||||
|
|
||||||
|
def build_filename(url: str) -> str:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
domain = parsed.netloc.replace("www.", "").replace(".", "_")
|
||||||
|
path_part = parsed.path.strip("/").replace("/", "_")[:50] or "index"
|
||||||
|
|
||||||
|
ts = int(time.time())
|
||||||
|
rand = uuid.uuid4().hex[:6]
|
||||||
|
|
||||||
|
return f"{ts}_{rand}_{domain}_{path_part}.md"
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────
|
||||||
|
# 单个 URL 抓取
|
||||||
|
# ─────────────────────────────────────
|
||||||
|
|
||||||
|
async def crawl_one(crawler, url: str, sem: asyncio.Semaphore, save_dir: str) -> Dict[str, Any]:
|
||||||
|
async with sem:
|
||||||
|
try:
|
||||||
|
result = await crawler.arun(url=url, config=run_config)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"success": False,
|
||||||
|
"error": f"status={getattr(result, 'status_code', 'unknown')}"
|
||||||
|
}
|
||||||
|
|
||||||
|
markdown = result.markdown or ""
|
||||||
|
|
||||||
|
if len(markdown) < 500:
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"success": False,
|
||||||
|
"error": "content too short"
|
||||||
|
}
|
||||||
|
|
||||||
|
filename = build_filename(url)
|
||||||
|
filepath = os.path.join(save_dir, filename)
|
||||||
|
|
||||||
|
header = (
|
||||||
|
f"<!-- Source: {url} -->\n"
|
||||||
|
f"<!-- Saved: {time.strftime('%Y-%m-%d %H:%M:%S')} -->\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
f.write(header + markdown)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"success": True,
|
||||||
|
"file": str(filepath)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────
|
||||||
|
# Async 主逻辑
|
||||||
|
# ─────────────────────────────────────
|
||||||
|
|
||||||
|
async def _crawl4ai_batch(urls: List[str], save_dir: str) -> Dict[str, Any]:
|
||||||
|
urls = list(set(urls)) # 去重
|
||||||
|
|
||||||
|
if not urls:
|
||||||
|
return {"error": "no urls"}
|
||||||
|
|
||||||
|
sem = asyncio.Semaphore(5) # 并发限制
|
||||||
|
|
||||||
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
crawl_one(crawler, url, sem, save_dir)
|
||||||
|
for url in urls
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
success_files = []
|
||||||
|
summary = []
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
|
||||||
|
if r["success"]:
|
||||||
|
success_files.append(r["file"])
|
||||||
|
summary.append(f"✅ {r['url']}")
|
||||||
|
else:
|
||||||
|
summary.append(f"❌ {r['url']} ({r['error']})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"saved_files": success_files,
|
||||||
|
"count": len(success_files),
|
||||||
|
"summary": summary,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_crawl4ai_batch_tool(workspace_dir):
|
||||||
|
@tool
|
||||||
|
def crawl4ai_batch(urls: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Batch crawl webpages and save their content as markdown files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urls: List of webpage URLs to crawl.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A summary of crawling results and saved file paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
save_dir = os.path.join(workspace_dir, "raw_data")
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
result = asyncio.run(_crawl4ai_batch(urls, save_dir))
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
return f"❌ Error: {result['error']}"
|
||||||
|
|
||||||
|
output = [
|
||||||
|
"### 批量抓取完成 ###",
|
||||||
|
f"成功保存文件: {result['count']}",
|
||||||
|
f"保存目录: {workspace_dir}",
|
||||||
|
"",
|
||||||
|
"抓取详情:"
|
||||||
|
]
|
||||||
|
|
||||||
|
output.extend(result["summary"])
|
||||||
|
|
||||||
|
if result["saved_files"]:
|
||||||
|
output.append("\n可读取文件:")
|
||||||
|
output.extend(result["saved_files"])
|
||||||
|
|
||||||
|
return "\n".join(output)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"🚨 爬虫系统异常: {str(e)}"
|
||||||
|
|
||||||
|
return crawl4ai_batch
|
||||||
75
src/server/deep_agent/tools/extract_suggested_questions.py
Executable file
75
src/server/deep_agent/tools/extract_suggested_questions.py
Executable file
@@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_core.messages import (
|
||||||
|
HumanMessage,
|
||||||
|
AIMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.server.deep_agent.init_llm import suggested_llm
|
||||||
|
|
||||||
|
|
||||||
|
def format_messages(messages, max_messages: int = 6) -> str:
|
||||||
|
"""
|
||||||
|
将 LangGraph messages 转换为 LLM prompt 文本
|
||||||
|
"""
|
||||||
|
messages = messages[-max_messages:]
|
||||||
|
lines: List[str] = []
|
||||||
|
for m in messages:
|
||||||
|
if isinstance(m, HumanMessage):
|
||||||
|
lines.append(f"User: {m.content}")
|
||||||
|
elif isinstance(m, AIMessage):
|
||||||
|
if m.content:
|
||||||
|
lines.append(f"Assistant: {m.content}")
|
||||||
|
elif isinstance(m, ToolMessage):
|
||||||
|
# Tool结果建议简单化
|
||||||
|
tool_output = str(m.content)
|
||||||
|
if len(tool_output) > 200:
|
||||||
|
tool_output = tool_output[:200] + "..."
|
||||||
|
lines.append(f"Tool Result: {tool_output}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_suggested_questions(
|
||||||
|
agent,
|
||||||
|
thread_id: str,
|
||||||
|
max_messages: int = 6,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
根据当前对话生成3条用户可能继续提问的问题
|
||||||
|
"""
|
||||||
|
# 获取当前对话state
|
||||||
|
state = agent.get_state(
|
||||||
|
{"configurable": {"thread_id": thread_id}}
|
||||||
|
)
|
||||||
|
messages = state.values.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
conversation = format_messages(messages, max_messages)
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
以下是用户与AI助手的对话:
|
||||||
|
{conversation}
|
||||||
|
请根据对话内容,生成3条用户可能继续提出的问题。
|
||||||
|
要求:
|
||||||
|
- 每条一句话
|
||||||
|
- 语言自然
|
||||||
|
- 不要解释
|
||||||
|
- 返回JSON数组
|
||||||
|
- 尽量与家具设计相关
|
||||||
|
示例:
|
||||||
|
["问题1", "问题2", "问题3"]
|
||||||
|
"""
|
||||||
|
result = await suggested_llm.ainvoke(prompt)
|
||||||
|
|
||||||
|
text = result.content.strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
questions = json.loads(text)
|
||||||
|
|
||||||
|
if isinstance(questions, list):
|
||||||
|
return questions[:3]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
394
src/server/deep_agent/tools/generate_furniture_sketch.py
Executable file
394
src/server/deep_agent/tools/generate_furniture_sketch.py
Executable file
@@ -0,0 +1,394 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from minio import Minio
|
||||||
|
from langgraph.prebuilt import ToolRuntime
|
||||||
|
from src.core.config import settings
|
||||||
|
from src.server.utils.new_oss_client import oss_get_image
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def generate_furniture(runtime: ToolRuntime, prompts: List[str] = None, num_images: Optional[int] = 12, ):
|
||||||
|
"""
|
||||||
|
生成家具设计线稿草图(sketch / line drawing)。
|
||||||
|
|
||||||
|
功能说明:
|
||||||
|
- 默认生成 12 张家具设计线稿。
|
||||||
|
- 智能处理 prompts 数量与生成数量不一致的情况:
|
||||||
|
- 如果只有一个 prompt → 用该 prompt 生成全部 12 张(不同随机变体)。
|
||||||
|
- 如果有多个 prompt → 自动均匀分配生成数量(尽量让每个 prompt 生成相同数量)。
|
||||||
|
- 生成过程会一张一张进行,适合用户实时查看。
|
||||||
|
|
||||||
|
参数说明:
|
||||||
|
- prompts (list[str]):
|
||||||
|
必须是列表,即使只有一个提示词也要用 ["你的提示词"] 格式。
|
||||||
|
提供详细的英文提示词,描述越详细越好。
|
||||||
|
- num_images (int, 可选): 要生成的图片总数量,默认 12 张,最大限制为 12 张。
|
||||||
|
|
||||||
|
返回值:
|
||||||
|
返回 image_urls 列表,系统会自动依次展示生成的图片。
|
||||||
|
"""
|
||||||
|
# ====================== 参数安全处理 ======================
|
||||||
|
if prompts is None or len(prompts) == 0:
|
||||||
|
return "Error: prompts 参数不能为空。请至少提供一个详细的英文提示词。"
|
||||||
|
|
||||||
|
if not isinstance(prompts, list):
|
||||||
|
prompts = [str(prompts)]
|
||||||
|
|
||||||
|
# 数量限制
|
||||||
|
if num_images is None or num_images < 1:
|
||||||
|
num_images = 1
|
||||||
|
elif num_images > 12:
|
||||||
|
num_images = 12
|
||||||
|
|
||||||
|
n_prompts = len(prompts)
|
||||||
|
|
||||||
|
logger.info(f"[generate_furniture] 开始生成 | prompts数量={n_prompts} | num_images={num_images}(默认12)")
|
||||||
|
|
||||||
|
# ====================== 均匀分配 prompts(核心逻辑) ======================
|
||||||
|
if n_prompts == 0:
|
||||||
|
return "Error: prompts 列表为空"
|
||||||
|
|
||||||
|
# 计算每个 prompt 应该生成的张数
|
||||||
|
base_count = num_images // n_prompts
|
||||||
|
remainder = num_images % n_prompts
|
||||||
|
|
||||||
|
images_per_prompt = [base_count] * n_prompts
|
||||||
|
for i in range(remainder):
|
||||||
|
images_per_prompt[i] += 1
|
||||||
|
|
||||||
|
# 构建实际使用的 prompt 列表
|
||||||
|
expanded_prompts: List[str] = []
|
||||||
|
for i, count in enumerate(images_per_prompt):
|
||||||
|
expanded_prompts.extend([prompts[i]] * count)
|
||||||
|
|
||||||
|
logger.info(f"[generate_furniture] 分配完成: {images_per_prompt} (每个prompt生成张数)")
|
||||||
|
|
||||||
|
# ====================== 生成图片 ======================
|
||||||
|
try:
|
||||||
|
bucket_name = "fida-public-bucket"
|
||||||
|
base_object_name = f"furniture/sketches/{uuid.uuid4()}"
|
||||||
|
image_urls = []
|
||||||
|
|
||||||
|
for i in range(num_images):
|
||||||
|
prompt = expanded_prompts[i]
|
||||||
|
object_name = f"{base_object_name}-{i:02d}.png"
|
||||||
|
|
||||||
|
image_url = await generate_or_edit_image(
|
||||||
|
prompt=prompt,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name
|
||||||
|
)
|
||||||
|
image_urls.append(image_url)
|
||||||
|
|
||||||
|
logger.info(f"[generate_furniture] 已生成第 {i + 1}/{num_images} 张")
|
||||||
|
|
||||||
|
logger.info(f"[generate_furniture] 成功生成 {len(image_urls)} 张图片")
|
||||||
|
return image_urls
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"generate_furniture 执行异常: {e}", exc_info=True)
|
||||||
|
return f"generate furniture error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def edit_furniture(
|
||||||
|
runtime: ToolRuntime,
|
||||||
|
config: RunnableConfig,
|
||||||
|
input_image_paths: list[str] = None,
|
||||||
|
prompts: list[str] = None,
|
||||||
|
mode: str = "auto",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
使用先进的图像编辑模型对家具设计草图进行精准修改。
|
||||||
|
|
||||||
|
支持三种灵活模式(与 edit_quote_upload_furniture 保持一致):
|
||||||
|
- one_to_one(默认,最常用):多张图片 + 多个提示词,一一对应编辑
|
||||||
|
- one_to_many:1 张图片 + 多个提示词(同一张图片生成多个不同变体,例如不同风格/颜色)
|
||||||
|
- many_to_one:多张图片 + 1 个提示词(所有图片应用相同的修改)
|
||||||
|
|
||||||
|
参数说明:
|
||||||
|
- input_image_paths (list[str]): 输入图片的 MinIO 路径列表,长度建议 1~4
|
||||||
|
- prompts (list[str]): 修改提示词列表(必须是英文提示词)
|
||||||
|
- mode (str): "one_to_one", "one_to_many", "many_to_one", "auto"(默认自动判断)
|
||||||
|
|
||||||
|
使用要求(必须严格遵守):
|
||||||
|
- input_image_paths 和 prompts 不能为空,长度必须在 1~4 之间。
|
||||||
|
- mode="auto" 时会根据列表长度智能判断模式:
|
||||||
|
- 1 张图片 + 多个 prompt → one_to_many
|
||||||
|
- 多个图片 + 1 个 prompt → many_to_one
|
||||||
|
- 图片数量 == prompt 数量 → one_to_one
|
||||||
|
- 编辑对象默认使用最近生成的图片(由 Supervisor 传入最新路径)。
|
||||||
|
|
||||||
|
示例调用:
|
||||||
|
|
||||||
|
1. one_to_one(一一对应,最常用)
|
||||||
|
input_image_paths = ["furniture/sketches/sofa_v1.png", "furniture/sketches/chair_v1.png"]
|
||||||
|
prompts = [
|
||||||
|
"Change the sofa to modern minimalist style with dark gray fabric.",
|
||||||
|
"Make the chair more Scandinavian with light wood and beige upholstery."
|
||||||
|
]
|
||||||
|
mode = "one_to_one"
|
||||||
|
|
||||||
|
2. one_to_many(同一张图片多个版本)
|
||||||
|
input_image_paths = ["furniture/sketches/sofa_latest.png"]
|
||||||
|
prompts = [
|
||||||
|
"Change to luxurious velvet with gold accents.",
|
||||||
|
"Change to industrial style with metal frame.",
|
||||||
|
"Change to soft pastel Nordic style."
|
||||||
|
]
|
||||||
|
mode = "one_to_many"
|
||||||
|
|
||||||
|
3. many_to_one(多张图片统一修改)
|
||||||
|
input_image_paths = ["furniture/sketches/sofa1.png", "furniture/sketches/chair1.png", "furniture/sketches/table1.png"]
|
||||||
|
prompts = ["Make all furniture more luxurious with velvet fabric and gold accents."]
|
||||||
|
mode = "many_to_one"
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# ====================== 参数校验 ======================
|
||||||
|
if not input_image_paths or len(input_image_paths) < 1 or len(input_image_paths) > 4:
|
||||||
|
return f"参数错误:input_image_paths 必须提供,且长度需在 1 到 4 张之间。目前收到 {len(input_image_paths) if input_image_paths else 0} 张。"
|
||||||
|
|
||||||
|
if not prompts:
|
||||||
|
return "参数错误:prompts 不能为空,请至少提供一个修改提示词。"
|
||||||
|
|
||||||
|
if mode not in ["one_to_one", "one_to_many", "many_to_one", "auto"]:
|
||||||
|
return f"参数错误:mode 参数无效。可用值:one_to_one, one_to_many, many_to_one, auto。当前收到:{mode}"
|
||||||
|
|
||||||
|
# Auto 模式智能判断
|
||||||
|
if mode == "auto":
|
||||||
|
if len(input_image_paths) == 1 and len(prompts) > 1:
|
||||||
|
mode = "one_to_many"
|
||||||
|
elif len(prompts) == 1:
|
||||||
|
mode = "many_to_one"
|
||||||
|
elif len(input_image_paths) == len(prompts):
|
||||||
|
mode = "one_to_one"
|
||||||
|
else:
|
||||||
|
mode = "one_to_one" # 兜底
|
||||||
|
|
||||||
|
# 各模式严格校验
|
||||||
|
if mode == "one_to_many":
|
||||||
|
if len(input_image_paths) != 1:
|
||||||
|
return f"参数错误:one_to_many 模式只能传入 1 张图片,当前传入了 {len(input_image_paths)} 张。"
|
||||||
|
if len(prompts) < 1:
|
||||||
|
return "参数错误:one_to_many 模式下 prompts 至少需要 1 个。"
|
||||||
|
|
||||||
|
elif mode == "many_to_one":
|
||||||
|
if len(prompts) != 1:
|
||||||
|
return f"参数错误:many_to_one 模式下 prompts 必须只有 1 个,当前有 {len(prompts)} 个。"
|
||||||
|
|
||||||
|
elif mode == "one_to_one":
|
||||||
|
if len(prompts) != len(input_image_paths):
|
||||||
|
return (f"参数错误:one_to_one 模式下 input_image_paths 和 prompts 数量必须完全一致。\n"
|
||||||
|
f"当前图片 {len(input_image_paths)} 张,prompts {len(prompts)} 个。")
|
||||||
|
|
||||||
|
# ====================== 执行编辑 ======================
|
||||||
|
result = []
|
||||||
|
bucket_name = "fida-public-bucket"
|
||||||
|
|
||||||
|
if mode == "one_to_many":
|
||||||
|
# 同一张图片 + 多个 prompt
|
||||||
|
base_image = input_image_paths[0]
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||||
|
image_url = await generate_or_edit_image(
|
||||||
|
input_path=[base_image],
|
||||||
|
prompt=prompt,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=f"{object_name}-var{i}.png"
|
||||||
|
)
|
||||||
|
result.append(image_url)
|
||||||
|
|
||||||
|
elif mode == "many_to_one":
|
||||||
|
# 多张图片 + 1 个 prompt
|
||||||
|
current_prompt = prompts[0]
|
||||||
|
for i, image_path in enumerate(input_image_paths):
|
||||||
|
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||||
|
image_url = await generate_or_edit_image(
|
||||||
|
input_path=[image_path],
|
||||||
|
prompt=current_prompt,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=f"{object_name}-{i}.png"
|
||||||
|
)
|
||||||
|
result.append(image_url)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# one_to_one:一一对应
|
||||||
|
for i in range(len(input_image_paths)):
|
||||||
|
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||||
|
image_url = await generate_or_edit_image(
|
||||||
|
input_path=[input_image_paths[i]],
|
||||||
|
prompt=prompts[i],
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=f"{object_name}-{i}.png"
|
||||||
|
)
|
||||||
|
result.append(image_url)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"edit_furniture 执行异常: {e}", exc_info=True)
|
||||||
|
return f"工具执行失败:{str(e)},请检查参数后重试。"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def edit_quote_upload_furniture(image_paths: list[str] = None, mode: str = "auto", prompts: list[str] = None, ):
|
||||||
|
"""
|
||||||
|
使用先进的图像编辑模型对家具图片进行精准批量修改。
|
||||||
|
|
||||||
|
支持四种模式:
|
||||||
|
- one_to_one(最常用):多张图片 + 多个提示词,一一对应编辑
|
||||||
|
- one_to_many:多张图片 + 1个提示词(所有图片统一修改)
|
||||||
|
- many_to_one:1张图片 + 多个提示词(同一张图生成多个不同变体,例如不同颜色)
|
||||||
|
- many_to_many(新增):多张图片 + 多个提示词,一一对应(多对多交叉编辑)
|
||||||
|
|
||||||
|
参数说明:
|
||||||
|
- image_paths (list[str]): MinIO 图片路径列表,长度建议 1~4
|
||||||
|
- prompts (list[str]): 详细英文提示词列表
|
||||||
|
- mode (str): "one_to_one", "one_to_many", "many_to_one", "many_to_many", "auto"(默认自动判断)
|
||||||
|
|
||||||
|
使用要求:
|
||||||
|
- image_paths 长度必须在 1~4 之间
|
||||||
|
- mode="auto" 时会根据长度智能判断
|
||||||
|
- many_to_many 模式下:image_paths 和 prompts 的长度必须完全相同
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
示例1:many_to_many(多对多,一一对应)
|
||||||
|
image_paths = ["sofa1.png", "chair1.png", "table1.png"]
|
||||||
|
prompts = [
|
||||||
|
"Change to bright yellow modern style.",
|
||||||
|
"Change to deep green luxury style.",
|
||||||
|
"Change to soft beige Scandinavian style."
|
||||||
|
]
|
||||||
|
mode = "many_to_many"
|
||||||
|
|
||||||
|
示例2:many_to_one(同一张图多个颜色版本)
|
||||||
|
image_paths = ["sofa_original.png"]
|
||||||
|
prompts = ["yellow version", "green version", "blue version", "black version"]
|
||||||
|
mode = "many_to_one"
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# ====================== 参数校验(直接返回错误信息) ======================
|
||||||
|
if not image_paths or len(image_paths) < 1 or len(image_paths) > 4:
|
||||||
|
return f"参数错误:image_paths 必须提供,且长度需要在 1 到 4 张之间。目前收到 {len(image_paths) if image_paths else 0} 张。"
|
||||||
|
|
||||||
|
if not prompts:
|
||||||
|
return "参数错误:prompts 不能为空,请至少提供一个修改提示词。"
|
||||||
|
|
||||||
|
if mode not in ["one_to_one", "one_to_many", "many_to_one", "many_to_many", "auto"]:
|
||||||
|
return f"参数错误:mode 参数无效。可用值:one_to_one, one_to_many, many_to_one, many_to_many, auto。当前收到:{mode}"
|
||||||
|
|
||||||
|
# Auto 模式智能判断
|
||||||
|
if mode == "auto":
|
||||||
|
if len(image_paths) == 1 and len(prompts) > 1:
|
||||||
|
mode = "many_to_one"
|
||||||
|
elif len(prompts) == 1:
|
||||||
|
mode = "one_to_many"
|
||||||
|
elif len(image_paths) == len(prompts):
|
||||||
|
mode = "many_to_many" # 新增:数量相等时优先 many_to_many
|
||||||
|
else:
|
||||||
|
mode = "one_to_one"
|
||||||
|
|
||||||
|
# 各模式严格校验
|
||||||
|
if mode == "many_to_one":
|
||||||
|
if len(image_paths) != 1:
|
||||||
|
return f"参数错误:many_to_one 模式只能传入 1 张图片,当前传入了 {len(image_paths)} 张。"
|
||||||
|
if len(prompts) < 1:
|
||||||
|
return "参数错误:many_to_one 模式下 prompts 至少需要 1 个。"
|
||||||
|
|
||||||
|
elif mode == "one_to_many":
|
||||||
|
if len(prompts) != 1:
|
||||||
|
return f"参数错误:one_to_many 模式下 prompts 必须只有 1 个,当前有 {len(prompts)} 个。"
|
||||||
|
|
||||||
|
elif mode in ["one_to_one", "many_to_many"]:
|
||||||
|
if len(prompts) != len(image_paths):
|
||||||
|
return (f"参数错误:{mode} 模式下 image_paths 和 prompts 数量必须完全一致。\n"
|
||||||
|
f"当前 image_paths 有 {len(image_paths)} 张,prompts 有 {len(prompts)} 个。")
|
||||||
|
|
||||||
|
# ====================== 执行编辑 ======================
|
||||||
|
result = []
|
||||||
|
bucket_name = "fida-public-bucket"
|
||||||
|
|
||||||
|
if mode == "many_to_one":
|
||||||
|
# 同一张图片 + 多个 prompt
|
||||||
|
base_image = image_paths[0]
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||||
|
image_url = await generate_or_edit_image(
|
||||||
|
input_path=[base_image],
|
||||||
|
prompt=prompt,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=f"{object_name}-var{i}.png"
|
||||||
|
)
|
||||||
|
result.append(image_url)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# one_to_one、many_to_many、one_to_many 统一处理
|
||||||
|
for i in range(len(image_paths)):
|
||||||
|
# 根据模式决定当前使用的 prompt
|
||||||
|
if mode == "one_to_many":
|
||||||
|
current_prompt = prompts[0]
|
||||||
|
else:
|
||||||
|
current_prompt = prompts[i] # one_to_one 和 many_to_many 都用对应位置的 prompt
|
||||||
|
|
||||||
|
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||||
|
image_url = await generate_or_edit_image(
|
||||||
|
input_path=[image_paths[i]],
|
||||||
|
prompt=current_prompt,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=f"{object_name}-{i}.png"
|
||||||
|
)
|
||||||
|
result.append(image_url)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"edit_quote_upload_furniture 执行异常: {e}", exc_info=True)
|
||||||
|
return f"工具执行失败:{str(e)},请检查参数后重试。"
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_or_edit_image(input_path=None, bucket_name="fida-public-bucket",
|
||||||
|
object_name=f"furniture/sketches/{uuid.uuid4()}.png",
|
||||||
|
prompt="Generate a modern minimalist dining chair made of light "
|
||||||
|
"oak wood and white leather, with slim metal legs, photographed "
|
||||||
|
"in a bright Scandinavian living room with natural sunlight, high detail, "
|
||||||
|
"8k resolution."):
|
||||||
|
if input_path is None:
|
||||||
|
input_path = []
|
||||||
|
request_data = {
|
||||||
|
"input_image_paths": input_path,
|
||||||
|
"prompt": prompt,
|
||||||
|
"bucket_name": bucket_name,
|
||||||
|
"object_name": object_name,
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
|
||||||
|
json=request_data,
|
||||||
|
)
|
||||||
|
result = resp.json()
|
||||||
|
image_url = result.get("output_path", None)
|
||||||
|
return image_url
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
prompt = "A classic professional hand-drawn furniture concept sketch by an experienced senior furniture designer, strictly monochrome black and white. Centered is a modern minimalist three-seater sofa with slim solid oak legs and generously proportioned soft fabric cushions. Precise pencil linework with masterful varying line weights, elegant cross-hatching and fine marker shading to define volume, depth and comfortable silhouette. Light visible construction lines. Drawn on clean A3 white sketching paper with natural subtle paper grain and slight scan texture. Soft diffused studio light from the top left creates gentle grayscale shadows that emphasize the sofa's elegant proportions and relaxed form. Highly refined, technical yet artistic traditional furniture design sketch, clean and sophisticated."
|
||||||
|
|
||||||
|
url = asyncio.run(generate_or_edit_image(prompt=prompt))
|
||||||
|
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:])
|
||||||
|
img.show()
|
||||||
112
src/server/deep_agent/tools/prompt_generation_tool.py
Normal file
112
src/server/deep_agent/tools/prompt_generation_tool.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_qwq import ChatQwen
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 输入
|
||||||
|
class FurnitureSketchPromptInput(BaseModel):
|
||||||
|
user_description: str = Field(..., description="用户对家具的描述,例如:'一张工业风的皮革沙发,带有金属X型腿和宽扶手'")
|
||||||
|
num_variants: int = Field(default=12, description="要生成的变种数量,默认12个")
|
||||||
|
|
||||||
|
|
||||||
|
# 输出
|
||||||
|
class FurniturePromptsOutput(BaseModel):
|
||||||
|
prompts: List[str] = Field(..., description="12个(或指定数量)明显不同的黑白家具线稿 image prompt 列表")
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_GEN_SYSTEM = """你是一位顶级家具设计 prompt 工程师,专门为黑白线稿(furniture sketch)生成高质量 image prompt。
|
||||||
|
|
||||||
|
核心强制规则(所有12个 prompt 都必须严格遵守):
|
||||||
|
- 必须是 clean black and white line drawing only
|
||||||
|
- pure white background
|
||||||
|
- 焦点只在家具本身的物理形态:silhouette, proportions, structure, legs, base, frame, joints, armrests, backrest, seat shape 和所有设计细节
|
||||||
|
- 使用 refined linework with subtle line weight variation
|
||||||
|
- 只允许 minimal soft shading for depth only
|
||||||
|
- 严格禁止:no color, no fill, no heavy shadows, no hatching, no texture rendering, no environment, no background elements, no lighting effects, no atmosphere
|
||||||
|
- 视角默认 3/4 front view, eye-level perspective
|
||||||
|
- 构图:centered composition
|
||||||
|
- 整体风格:architectural line art style + modern industrial design sketch style
|
||||||
|
|
||||||
|
任务:
|
||||||
|
根据用户提供的家具描述,生成 **12 个明显不同的** image prompt 变种。
|
||||||
|
每个变种必须在以下家具设计细节维度上有清晰、可感知的差异:
|
||||||
|
- 线条特性(极细精确、粗细强烈对比、手绘流动、动态表现力等)
|
||||||
|
- 结构侧重(整体比例、机械连接、关节细节、腿部与底部、金属框架、曲线轮廓等)
|
||||||
|
- 阴影与深度处理(几乎无阴影、轻微结构暗示、适度体积感等)
|
||||||
|
- 艺术调性(极简技术制图、粗犷工业、手绘艺术、精确建筑、高细节精致、柔和形态等)
|
||||||
|
- 微调视角或构图(标准3/4、略低角度强调腿部、略强调对称等)
|
||||||
|
|
||||||
|
确保12个 prompt 各有特色,不要相似。用户描述要自然放在 prompt 开头,然后自然衔接风格描述。
|
||||||
|
整个 prompt 要简洁有力、适合 Flux2 klein模型直接使用。
|
||||||
|
|
||||||
|
输出要求:
|
||||||
|
必须以 JSON 格式返回,严格遵循以下结构,不要添加任何解释、markdown 或额外文字。
|
||||||
|
|
||||||
|
{{
|
||||||
|
"prompts": [
|
||||||
|
"第一个完整 prompt",
|
||||||
|
"第二个完整 prompt",
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_gen_llm = ChatQwen(
|
||||||
|
model="qwen-plus",
|
||||||
|
max_tokens=3_000,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
streaming=False,
|
||||||
|
temperature=0.25,
|
||||||
|
top_p=0.8,
|
||||||
|
api_key=settings.QWEN_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template = ChatPromptTemplate.from_messages([
|
||||||
|
("system", PROMPT_GEN_SYSTEM),
|
||||||
|
("human", "家具描述: {user_description}\n生成数量: {num_variants}\n请严格以 JSON 格式输出。")
|
||||||
|
])
|
||||||
|
|
||||||
|
prompt_chain = prompt_template | prompt_gen_llm
|
||||||
|
|
||||||
|
|
||||||
|
# 你可以把之前我给你的 3 个经典 Prompt 作为 reference_examples 放进去(few-shot)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(args_schema=FurnitureSketchPromptInput)
|
||||||
|
def generate_furniture_sketch_prompts(user_description: str, num_variants: int = 12) -> List[str]:
|
||||||
|
"""
|
||||||
|
生成12个明显不同的家具黑白线稿 prompt。
|
||||||
|
成功时返回 List[str](长度为 num_variants)。
|
||||||
|
失败时返回 [user_description],保证至少有一个可用 prompt。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用结构化输出
|
||||||
|
structured_llm = prompt_gen_llm.with_structured_output(schema=FurniturePromptsOutput, method="json_mode")
|
||||||
|
|
||||||
|
chain = prompt_template | structured_llm
|
||||||
|
|
||||||
|
result: FurniturePromptsOutput = chain.invoke({
|
||||||
|
"user_description": user_description,
|
||||||
|
"num_variants": num_variants
|
||||||
|
})
|
||||||
|
|
||||||
|
if isinstance(result.prompts, list) and len(result.prompts) > 0:
|
||||||
|
return result.prompts[:num_variants] # 防止意外多返回
|
||||||
|
|
||||||
|
# 如果返回的 list 为空,进入兜底
|
||||||
|
logger.warning("Structured output returned empty list, falling back to user description.")
|
||||||
|
return [user_description]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 失败兜底:返回用户的原始描述
|
||||||
|
logger.error(f"Failed to generate structured furniture prompts: {e}. Falling back to user description.")
|
||||||
|
return [user_description]
|
||||||
137
src/server/deep_agent/tools/report_generator_tool.py
Executable file
137
src/server/deep_agent/tools/report_generator_tool.py
Executable file
@@ -0,0 +1,137 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
from langgraph.config import get_stream_writer
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage
|
||||||
|
from src.server.deep_agent.init_llm import repoer_llm
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Tool 输入 Schema
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
class ReportInput(BaseModel):
|
||||||
|
report_topic: str = Field(
|
||||||
|
...,
|
||||||
|
description="Main topic of the report, e.g. '2026 Sofa Design Trends'"
|
||||||
|
)
|
||||||
|
structured_data: List[Dict] = Field(
|
||||||
|
...,
|
||||||
|
description="Structured retrieval result items"
|
||||||
|
)
|
||||||
|
language: Optional[str] = Field(
|
||||||
|
default="English",
|
||||||
|
description="Output language"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# LangGraph Tool
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
def create_report_generator_tool(workspace_dir):
|
||||||
|
@tool("report_generator", args_schema=ReportInput)
|
||||||
|
async def report_generator(report_topic: str, structured_data: List[Dict], language: str = "English") -> str:
|
||||||
|
"""
|
||||||
|
Generate a professional design/market report
|
||||||
|
directly from structured retrieval results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
writer = get_stream_writer()
|
||||||
|
if not structured_data:
|
||||||
|
error_msg = "Error: No structured data provided."
|
||||||
|
writer({"type": "report_error", "message": error_msg})
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
collected_data_str = json.dumps(
|
||||||
|
structured_data,
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Prompt
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
system_prompt = f"""
|
||||||
|
You are a professional design trend analyst.
|
||||||
|
|
||||||
|
Generate a long, structured Markdown report.
|
||||||
|
|
||||||
|
REQUIREMENTS:
|
||||||
|
|
||||||
|
1. Follow MECE principle.
|
||||||
|
2. Embed images ONLY if they start with https://
|
||||||
|
using: 
|
||||||
|
3. Insert images inline.
|
||||||
|
4. Every key insight must cite source:
|
||||||
|
[Website Name](url)
|
||||||
|
5. Use Markdown headings.
|
||||||
|
6. Start directly with title.
|
||||||
|
7. Be detailed and analytical.
|
||||||
|
|
||||||
|
Output Language: {language}
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_prompt = f"""
|
||||||
|
Topic: {report_topic}
|
||||||
|
|
||||||
|
Input Data:
|
||||||
|
{collected_data_str}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# 调用 LLM
|
||||||
|
# =========================
|
||||||
|
writer({"type": "report_start", "topic": report_topic, "language": language})
|
||||||
|
|
||||||
|
full_report = ""
|
||||||
|
try:
|
||||||
|
report_llm = repoer_llm.with_config(
|
||||||
|
callbacks=[]
|
||||||
|
)
|
||||||
|
async for chunk in report_llm.astream(
|
||||||
|
[
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=user_prompt)
|
||||||
|
]
|
||||||
|
):
|
||||||
|
if chunk.content: # Gemini 返回的 chunk.content
|
||||||
|
delta = chunk.content
|
||||||
|
full_report += delta
|
||||||
|
# return {"type": "report_delta", "delta": delta}
|
||||||
|
writer({"type": "report_delta", "delta": delta}) # ← 实时推送给前端
|
||||||
|
writer({"type": "report_stop", "topic": report_topic, "language": language})
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"LLM generation failed: {str(e)}"
|
||||||
|
writer({"type": "report_error", "message": error_msg})
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
report_content = full_report.strip()
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# 保存报告
|
||||||
|
# =========================
|
||||||
|
output_dir = os.path.join(workspace_dir, "reports")
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
safe_topic = re.sub(r'[\\/*?:"<>|]', "", report_topic.replace(" ", "_"))
|
||||||
|
writer({"type": "report_name", "delta": f"{safe_topic}.md"})
|
||||||
|
|
||||||
|
filename = f"{output_dir}/{safe_topic}.md"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
f.write(report_content)
|
||||||
|
writer({"type": "report_complete", "file_path": filename})
|
||||||
|
except Exception as e:
|
||||||
|
writer({"type": "report_save_warning", "message": str(e)})
|
||||||
|
|
||||||
|
# 返回完整内容(作为 tool result),同时正文已通过 delta 流式输出
|
||||||
|
return report_content + f"\n\n✅ Report saved to: {filename}"
|
||||||
|
|
||||||
|
return report_generator
|
||||||
158
src/server/deep_agent/tools/research_tool.py
Executable file
158
src/server/deep_agent/tools/research_tool.py
Executable file
@@ -0,0 +1,158 @@
|
|||||||
|
# import asyncio
|
||||||
|
# import json
|
||||||
|
# from datetime import datetime
|
||||||
|
# from typing import List, Set, Optional
|
||||||
|
# from langchain_core.tools import tool
|
||||||
|
# from tavily import TavilyClient
|
||||||
|
#
|
||||||
|
# from src.core.config import settings
|
||||||
|
#
|
||||||
|
# # 模拟配置加载
|
||||||
|
# TAVILY_API_KEY = settings.TAVILY_API_KEY
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# @tool
|
||||||
|
# async def topic_research(topic: list[str], max_urls: int = 5) -> str:
|
||||||
|
# """
|
||||||
|
# 深度调研工具。该工具会利用 Tavily 搜索引擎针对特定主题进行多维度搜索。
|
||||||
|
# 它会自动生成针对性的搜索词(包含年份和趋势),并返回去重后的高质量 URL 列表。
|
||||||
|
# """
|
||||||
|
# if not TAVILY_API_KEY:
|
||||||
|
# return "❌ 错误: 未配置 TAVILY_API_KEY。"
|
||||||
|
#
|
||||||
|
# client = TavilyClient(api_key=TAVILY_API_KEY)
|
||||||
|
#
|
||||||
|
# # 1. 自动生成多维度搜索词 (在工具内部快速生成)
|
||||||
|
#
|
||||||
|
# # 2. 并行执行搜索
|
||||||
|
# async def perform_search(q: str):
|
||||||
|
# # 使用 asyncio.to_thread 运行同步的 Tavily SDK
|
||||||
|
# def sync_search():
|
||||||
|
# try:
|
||||||
|
# response = client.search(
|
||||||
|
# query=q,
|
||||||
|
# search_depth="advanced",
|
||||||
|
# max_results=5,
|
||||||
|
# include_answer=False
|
||||||
|
# )
|
||||||
|
# return response.get('results', [])
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Search error: {e}")
|
||||||
|
# return []
|
||||||
|
#
|
||||||
|
# return await asyncio.to_thread(sync_search)
|
||||||
|
#
|
||||||
|
# search_tasks = [perform_search(q) for q in topic]
|
||||||
|
# search_results_list = await asyncio.gather(*search_tasks)
|
||||||
|
#
|
||||||
|
# # 3. 结果去重与过滤
|
||||||
|
# seen_urls: Set[str] = set()
|
||||||
|
# final_urls = []
|
||||||
|
#
|
||||||
|
# # 常见的非内容页面过滤
|
||||||
|
# skip_extensions = ('.pdf', '.jpg', '.png', '.zip', '.exe')
|
||||||
|
#
|
||||||
|
# for results in search_results_list:
|
||||||
|
# for item in results:
|
||||||
|
# url = item.get('url')
|
||||||
|
# if url and url not in seen_urls:
|
||||||
|
# if not url.lower().endswith(skip_extensions):
|
||||||
|
# seen_urls.add(url)
|
||||||
|
# final_urls.append(url)
|
||||||
|
#
|
||||||
|
# # 4. 结果截断
|
||||||
|
# selected_urls = final_urls[:max_urls]
|
||||||
|
#
|
||||||
|
# # 返回 JSON 字符串,便于 Agent 下一步调用批量爬虫 (Crawl4ai)
|
||||||
|
# return json.dumps(selected_urls, ensure_ascii=False)
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import List, Set
|
||||||
|
|
||||||
|
from ddgs import DDGS
|
||||||
|
from langchain.tools import tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TopicResearchInput(BaseModel):
|
||||||
|
"""Input for topic research tool."""
|
||||||
|
topic: List[str] = Field(description="List of separate research keyword strings. Example: ['modern sofa design', 'sustainable wood furniture']")
|
||||||
|
max_urls: int = Field(default=5, description="Maximum number of URLs to return")
|
||||||
|
|
||||||
|
|
||||||
|
@tool(args_schema=TopicResearchInput)
|
||||||
|
async def topic_research(topic: List[str], max_urls: int = 5) -> str:
|
||||||
|
"""
|
||||||
|
In-depth research tool (DuckDuckGo version).
|
||||||
|
Search based on multiple topic keywords and return a high-quality URL list (JSON string) after deduplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DuckDuckGo 是同步库,需要丢到线程池
|
||||||
|
def sync_search(query: str):
|
||||||
|
try:
|
||||||
|
with DDGS() as ddgs:
|
||||||
|
results = ddgs.text(
|
||||||
|
query,
|
||||||
|
max_results=8 # 稍微多一点,后面会过滤
|
||||||
|
)
|
||||||
|
return [r.get("href") for r in results if r.get("href")]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Search error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def perform_search(q: str):
|
||||||
|
return await asyncio.to_thread(sync_search, q)
|
||||||
|
|
||||||
|
# 并发执行多个 query
|
||||||
|
search_tasks = [perform_search(q) for q in topic]
|
||||||
|
search_results_list = await asyncio.gather(*search_tasks)
|
||||||
|
|
||||||
|
# ========================
|
||||||
|
# 去重 + 过滤
|
||||||
|
# ========================
|
||||||
|
seen_urls: Set[str] = set()
|
||||||
|
final_urls = []
|
||||||
|
|
||||||
|
skip_extensions = ('.pdf', '.jpg', '.png', '.zip', '.exe')
|
||||||
|
|
||||||
|
for results in search_results_list:
|
||||||
|
for url in results:
|
||||||
|
if not url:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if url not in seen_urls and not url.lower().endswith(skip_extensions):
|
||||||
|
seen_urls.add(url)
|
||||||
|
final_urls.append(url)
|
||||||
|
|
||||||
|
# ========================
|
||||||
|
# 截断结果
|
||||||
|
# ========================
|
||||||
|
selected_urls = final_urls[:max_urls]
|
||||||
|
print(f" topic research !!!!!!!!!!!!!!!!!!!!! {selected_urls}")
|
||||||
|
return json.dumps(selected_urls, ensure_ascii=False)
|
||||||
|
|
||||||
|
# import asyncio
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # 假设你已经定义了 topic_research
|
||||||
|
#
|
||||||
|
# async def test():
|
||||||
|
# topics = [
|
||||||
|
# "modern furniture design trends 2025",
|
||||||
|
# "scandinavian furniture materials",
|
||||||
|
# "minimalist living room furniture ideas"
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# result = await topic_research.ainvoke({
|
||||||
|
# "topic": topics,
|
||||||
|
# "max_urls": 5
|
||||||
|
# })
|
||||||
|
#
|
||||||
|
# print("结果👇")
|
||||||
|
# print(result)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# asyncio.run(test())
|
||||||
234
src/server/deep_agent/tools/structured_retrieval_tool.py
Executable file
234
src/server/deep_agent/tools/structured_retrieval_tool.py
Executable file
@@ -0,0 +1,234 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
# RAG
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# 全局模型(单例)
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
_EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
||||||
|
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
_RERANK_MODEL = CrossEncoder(
|
||||||
|
"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredRetrievalInput(BaseModel):
|
||||||
|
file_paths: List[str] = Field(..., description="List of local markdown file paths.")
|
||||||
|
query: str = Field(..., description="Extraction query")
|
||||||
|
source_url: Optional[str] = Field(None, description="Optional global source URL")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_source_from_md(content: str) -> Optional[str]:
|
||||||
|
match = re.search(r"<!--\s*Source:\s*(.*?)\s*-->", content)
|
||||||
|
return match.group(1).strip() if match else None
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Markdown Header Split
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
def _split_markdown_by_headers(
|
||||||
|
content: str,
|
||||||
|
max_chars: int = 2000,
|
||||||
|
overlap: int = 150,
|
||||||
|
):
|
||||||
|
header_re = re.compile(
|
||||||
|
r'^(#{1,6})\s+(.+?)\s*$',
|
||||||
|
re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
|
matches = list(header_re.finditer(content))
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return _chunk_text(content, max_chars, overlap)
|
||||||
|
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
for i, m in enumerate(matches):
|
||||||
|
start = m.start()
|
||||||
|
end = (
|
||||||
|
matches[i + 1].start()
|
||||||
|
if i + 1 < len(matches)
|
||||||
|
else len(content)
|
||||||
|
)
|
||||||
|
|
||||||
|
block = content[start:end].strip()
|
||||||
|
if block:
|
||||||
|
sections.append(block)
|
||||||
|
|
||||||
|
final_sections = []
|
||||||
|
|
||||||
|
for s in sections:
|
||||||
|
if len(s) > max_chars:
|
||||||
|
final_sections.extend(
|
||||||
|
_chunk_text(s, max_chars, overlap)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_sections.append(s)
|
||||||
|
|
||||||
|
return final_sections
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_text(
|
||||||
|
text: str,
|
||||||
|
max_chars: int = 2000,
|
||||||
|
overlap: int = 150
|
||||||
|
):
|
||||||
|
text = text.strip()
|
||||||
|
if len(text) <= max_chars:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
|
||||||
|
while start < len(text):
|
||||||
|
end = min(len(text), start + max_chars)
|
||||||
|
chunk = text[start:end].strip()
|
||||||
|
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
if end == len(text):
|
||||||
|
break
|
||||||
|
|
||||||
|
start = max(0, end - overlap)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def create_structured_retrieval_tool(workspace_dir):
|
||||||
|
@tool("structured_retrieval", args_schema=StructuredRetrievalInput)
|
||||||
|
def structured_retrieval(
|
||||||
|
file_paths: List[str],
|
||||||
|
query: str,
|
||||||
|
source_url: Optional[str] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Batch structured extraction from markdown files.
|
||||||
|
- Performs vector search + re-ranking
|
||||||
|
- Saves extracted structured data as JSON file to disk
|
||||||
|
- Returns ONLY summary (status, count, file path)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── 1. 收集所有文件內容 ──────────────────────────────────────
|
||||||
|
all_docs_pool: List[Document] = []
|
||||||
|
|
||||||
|
for path in file_paths:
|
||||||
|
if not os.path.exists(path) or not path.endswith((".md", ".markdown")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
current_source = source_url or _extract_source_from_md(content) or "unknown"
|
||||||
|
|
||||||
|
sections = _split_markdown_by_headers(content)
|
||||||
|
|
||||||
|
for sec in sections:
|
||||||
|
all_docs_pool.append(
|
||||||
|
Document(
|
||||||
|
page_content=sec,
|
||||||
|
metadata={"source_url": current_source, "file_name": file_name}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not all_docs_pool:
|
||||||
|
return {"status": "no_documents_found", "items_count": 0, "json_path": None}
|
||||||
|
|
||||||
|
# ── 2. Vector search ────────────────────────────────────────────
|
||||||
|
vector_store = FAISS.from_documents(all_docs_pool, _EMBEDDING_MODEL)
|
||||||
|
retrieved = vector_store.similarity_search(query, k=200)
|
||||||
|
|
||||||
|
# ── 3. 提取結構化片段 ──────────────────────────────────────────
|
||||||
|
structured_items = []
|
||||||
|
|
||||||
|
for doc in retrieved:
|
||||||
|
text = doc.page_content.strip()
|
||||||
|
if len(text) < 30:
|
||||||
|
continue
|
||||||
|
|
||||||
|
images = list(set(re.findall(r"!\[.*?\]\((.*?)\)", text)))
|
||||||
|
|
||||||
|
structured_items.append(
|
||||||
|
{
|
||||||
|
"text": text,
|
||||||
|
"images": images,
|
||||||
|
"source_url": doc.metadata.get("source_url"),
|
||||||
|
"file_name": doc.metadata.get("file_name")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 4. Re-rank ──────────────────────────────────────────────────
|
||||||
|
if structured_items:
|
||||||
|
unique_items = {item["text"]: item for item in structured_items}.values()
|
||||||
|
pairs = [[query, item["text"]] for item in unique_items]
|
||||||
|
scores = _RERANK_MODEL.predict(pairs)
|
||||||
|
|
||||||
|
sorted_items = sorted(
|
||||||
|
zip(scores, unique_items),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
top_items = [item for _, item in sorted_items[:50]]
|
||||||
|
else:
|
||||||
|
top_items = []
|
||||||
|
|
||||||
|
# ── 5. 寫入 JSON 文件 ──────────────────────────────────────────
|
||||||
|
if not top_items:
|
||||||
|
return {"status": "no_relevant_content", "items_count": 0, "json_path": None}
|
||||||
|
|
||||||
|
# 產生有意義的檔名
|
||||||
|
safe_query = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '_', query)[:40]
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
json_filename = f"extracted_{safe_query}_{timestamp}.json"
|
||||||
|
|
||||||
|
# 建議的儲存目錄(與 crawl4ai_batch 對齊)
|
||||||
|
output_dir = os.path.join(workspace_dir, "extracted")
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
# 2. 不存在则创建(makedirs 支持创建多级目录,mkdir 只能创建单级)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
json_path = os.path.join(output_dir, json_filename)
|
||||||
|
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"query": query,
|
||||||
|
"extracted_at": timestamp,
|
||||||
|
"item_count": len(top_items),
|
||||||
|
"items": top_items
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2
|
||||||
|
)
|
||||||
|
|
||||||
|
json_path = json_path.replace(workspace_dir, "")
|
||||||
|
# ── 6. 只回傳摘要 ──────────────────────────────────────────────
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"items_count": len(top_items),
|
||||||
|
"json_path": json_path,
|
||||||
|
"summary": f"{len(top_items)} highly relevant fragments have been extracted and stored in {json_path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
return structured_retrieval
|
||||||
57
src/server/deep_agent/tools/user_persona_tool.py
Executable file
57
src/server/deep_agent/tools/user_persona_tool.py
Executable file
@@ -0,0 +1,57 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from src.core.config import MONGO_URI
|
||||||
|
|
||||||
|
client = MongoClient(MONGO_URI)
|
||||||
|
db = client["report_agent"]
|
||||||
|
collection = db["user_profiles"]
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def query_report_profile(config: RunnableConfig, ) -> dict:
|
||||||
|
"""
|
||||||
|
Query user report portrait
|
||||||
|
"""
|
||||||
|
thread_id = config['configurable']['thread_id']
|
||||||
|
doc = collection.find_one({"thread_id": thread_id})
|
||||||
|
|
||||||
|
if not doc:
|
||||||
|
return {"profile": {}}
|
||||||
|
|
||||||
|
doc.pop("_id", None)
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def update_report_profile(config: RunnableConfig, profile: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Update user portrait information
|
||||||
|
"""
|
||||||
|
thread_id = config['configurable']['thread_id']
|
||||||
|
collection.update_one(
|
||||||
|
{"thread_id": thread_id},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"profile": profile
|
||||||
|
}
|
||||||
|
},
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "success", "profile": profile}
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def check_profile_complete(profile: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Determine whether the image is complete
|
||||||
|
"""
|
||||||
|
required = ["style", "room_type", "budget"]
|
||||||
|
missing = [f for f in required if f not in profile]
|
||||||
|
return {
|
||||||
|
"complete": len(missing) == 0,
|
||||||
|
"missing_fields": missing
|
||||||
|
}
|
||||||
21
src/server/deep_agent/tools/vision_analyze_tool.py
Executable file
21
src/server/deep_agent/tools/vision_analyze_tool.py
Executable file
@@ -0,0 +1,21 @@
|
|||||||
|
from langchain.tools import tool
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from src.server.deep_agent.init_llm import vision_llm
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def analyze_image(image_url: str) -> str:
|
||||||
|
"""分析给定URL的图像。输入图像URL,输出图像描述和关键观察。"""
|
||||||
|
response = requests.get(image_url)
|
||||||
|
image = Image.open(BytesIO(response.content))
|
||||||
|
# 这里使用模型直接分析图像(简化示例)
|
||||||
|
msg = HumanMessage(content=[
|
||||||
|
{"type": "text", "text": "详细描述这张图像,包括物体、颜色、场景和任何文本。"},
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
])
|
||||||
|
result = vision_llm.invoke([msg])
|
||||||
|
return result.content
|
||||||
144
src/server/deep_agent/utils/mongodb_util.py
Executable file
144
src/server/deep_agent/utils/mongodb_util.py
Executable file
@@ -0,0 +1,144 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.errors import PyMongoError
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from src.core.config import MONGO_URI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadImageMinIOStore:
|
||||||
|
"""
|
||||||
|
根據 thread_id 存取/更新 current_image 的 MinIO 物件路徑(不存 binary)
|
||||||
|
|
||||||
|
儲存格式範例:
|
||||||
|
{
|
||||||
|
"thread_id": "thread_abc123",
|
||||||
|
"current_image_path": "images/2025/03/thread_abc123_latest.png",
|
||||||
|
"updated_at": ISODate,
|
||||||
|
"metadata": {"format": "png", "desc": "生成的貓圖", "size_bytes": 512345}
|
||||||
|
}
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
store = ThreadImageMinIOStore("mongodb://localhost:27017/", "deepagents_db")
|
||||||
|
store.save_image_path("thread_abc123", "images/cat/001.png", "https://minio.example.com/bucket/images/cat/001.png")
|
||||||
|
path_info = store.get_image_path("thread_abc123")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mongo_uri: str,
|
||||||
|
db_name: str = "deepagents_db",
|
||||||
|
collection_name: str = "agent_image_paths",
|
||||||
|
connect_timeout_ms: int = 5000,
|
||||||
|
server_selection_timeout_ms: int = 5000,
|
||||||
|
):
|
||||||
|
self.client = MongoClient(
|
||||||
|
mongo_uri,
|
||||||
|
connectTimeoutMS=connect_timeout_ms,
|
||||||
|
serverSelectionTimeoutMS=server_selection_timeout_ms,
|
||||||
|
retryWrites=True,
|
||||||
|
retryReads=True,
|
||||||
|
)
|
||||||
|
self.db = self.client[db_name]
|
||||||
|
self.collection: Collection = self.db[collection_name]
|
||||||
|
|
||||||
|
# 建立唯一索引
|
||||||
|
self.collection.create_index("thread_id", unique=True)
|
||||||
|
|
||||||
|
def save_image_path(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
object_path: list, # MinIO 中的相對路徑,例如 "test/123.png" 或 "images/20250320/abc.png"
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
保存或更新某個 thread 的 current_image MinIO 路徑
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: 對話執行緒 ID
|
||||||
|
object_path: MinIO bucket 內的物件路徑 (不含 bucket 名稱)
|
||||||
|
metadata: 額外資訊,例如 {"prompt": "...", "format": "png", "width": 1024}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功
|
||||||
|
"""
|
||||||
|
document = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"current_image_path": object_path,
|
||||||
|
"updated_at": datetime.now(),
|
||||||
|
"metadata": metadata or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.collection.update_one(
|
||||||
|
{"thread_id": thread_id},
|
||||||
|
{"$set": document},
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
action = "updated" if result.modified_count > 0 else "inserted"
|
||||||
|
logger.info(f"Image path for thread {thread_id} {action}: {object_path}")
|
||||||
|
return True
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"Failed to save image path for thread {thread_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_image_path(self, thread_id: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
取得某 thread 的 current_image MinIO 資訊
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"current_image_path": str,
|
||||||
|
"updated_at": datetime,
|
||||||
|
"metadata": dict
|
||||||
|
} 或 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
doc = self.collection.find_one({"thread_id": thread_id})
|
||||||
|
if not doc:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"current_image_path": doc.get("current_image_path"),
|
||||||
|
"updated_at": doc.get("updated_at"),
|
||||||
|
"metadata": doc.get("metadata", {})
|
||||||
|
}
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"Failed to get image path for thread {thread_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_object_path_only(self, thread_id: str) -> Optional[str]:
|
||||||
|
"""只取 MinIO 相對路徑,方便直接給 MinIO client 使用"""
|
||||||
|
info = self.get_image_path(thread_id)
|
||||||
|
return info["current_image_path"] if info else None
|
||||||
|
|
||||||
|
def delete_image_path(self, thread_id: str) -> bool:
|
||||||
|
"""刪除某 thread 的記錄(不影響 MinIO 實際檔案)"""
|
||||||
|
try:
|
||||||
|
result = self.collection.delete_one({"thread_id": thread_id})
|
||||||
|
if result.deleted_count > 0:
|
||||||
|
logger.info(f"Image path record for thread {thread_id} deleted")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"Failed to delete image path for thread {thread_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.client.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
image_store = ThreadImageMinIOStore(MONGO_URI, "agent_tool_generate_db")
|
||||||
|
success = image_store.save_image_path(
|
||||||
|
thread_id="121233",
|
||||||
|
object_path=["test/123.png"],
|
||||||
|
metadata={"prompt": "prompt", "generated_at": str(datetime.now())})
|
||||||
|
print(success)
|
||||||
|
info = image_store.get_image_path("121233")
|
||||||
|
print(info)
|
||||||
0
src/server/utils/__init__.py
Normal file → Executable file
0
src/server/utils/__init__.py
Normal file → Executable file
4
src/server/utils/generate_suggestion.py
Normal file → Executable file
4
src/server/utils/generate_suggestion.py
Normal file → Executable file
@@ -23,10 +23,10 @@ async def generate_chat_suggestions(messages, model) -> list[str]:
|
|||||||
|
|
||||||
【判断逻辑】
|
【判断逻辑】
|
||||||
1. 如果用户已经确定了【类型、材质、风格】但还没有生成过草图 -> 必须推荐 "生成设计草图"。
|
1. 如果用户已经确定了【类型、材质、风格】但还没有生成过草图 -> 必须推荐 "生成设计草图"。
|
||||||
2. 如果刚生成了草图 -> 推荐 "调整材质"、"查看三维视图"、"下载报价单" 等。
|
2. 如果刚生成了草图 -> 推荐 "换个颜色"、"换个类别" 等。
|
||||||
3. 如果用户还在犹豫 -> 推荐具体的风格或材质询问。
|
3. 如果用户还在犹豫 -> 推荐具体的风格或材质询问。
|
||||||
|
|
||||||
请直接输出 JSON 格式,包含 suggestions 字段。按钮文案要简短(中文,不超过8个字)。
|
请直接输出 JSON 格式,包含 suggestions 字段。按钮文案要简短。
|
||||||
"""),
|
"""),
|
||||||
("user", "对话历史:{history}"),
|
("user", "对话历史:{history}"),
|
||||||
])
|
])
|
||||||
|
|||||||
54
src/server/utils/mq_util.py
Executable file
54
src/server/utils/mq_util.py
Executable file
@@ -0,0 +1,54 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import aio_pika
|
||||||
|
from aio_pika import DeliveryMode, ExchangeType
|
||||||
|
from src.core.config import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EXCHANGE_NAME = "canvas_3d_exchange" # ← 修改这里
|
||||||
|
|
||||||
|
|
||||||
|
async def send_to_rabbitmq(
|
||||||
|
result: dict,
|
||||||
|
job_id: str,
|
||||||
|
status: str = "completed",
|
||||||
|
routing_key: str = "img_to_3d_results"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
connection = await aio_pika.connect_robust(settings.RABBITMQ_URL)
|
||||||
|
|
||||||
|
async with connection:
|
||||||
|
channel = await connection.channel()
|
||||||
|
|
||||||
|
# 使用新的 Exchange 名称
|
||||||
|
exchange = await channel.declare_exchange(
|
||||||
|
name=EXCHANGE_NAME, # ← 使用常量
|
||||||
|
type=ExchangeType.DIRECT,
|
||||||
|
durable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
queue = await channel.declare_queue(name=routing_key, durable=True)
|
||||||
|
await queue.bind(exchange, routing_key=routing_key)
|
||||||
|
|
||||||
|
message_body = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": status,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"task_type": routing_key, # 方便区分是哪个任务的结果
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
|
||||||
|
message = aio_pika.Message(
|
||||||
|
body=json.dumps(message_body).encode("utf-8"),
|
||||||
|
delivery_mode=DeliveryMode.PERSISTENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
await exchange.publish(message, routing_key=routing_key)
|
||||||
|
|
||||||
|
logger.info(f"✅ 发送成功 → routing_key: {routing_key} | job_id: {job_id} | status: {status}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 发送失败 → routing_key: {routing_key} | job_id: {job_id} | {e}", exc_info=True)
|
||||||
138
src/server/utils/new_oss_client.py
Normal file → Executable file
138
src/server/utils/new_oss_client.py
Normal file → Executable file
@@ -3,7 +3,7 @@ import logging
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import urllib3
|
import urllib3
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from minio import Minio
|
from minio import Minio, S3Error
|
||||||
|
|
||||||
from src.core.config import settings
|
from src.core.config import settings
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ http_client = urllib3.PoolManager(
|
|||||||
|
|
||||||
|
|
||||||
# 获取图片
|
# 获取图片
|
||||||
def oss_get_image(oss_client, bucket, object_name, data_type):
|
def oss_get_image(oss_client, bucket, object_name):
|
||||||
# cv2 默认全通道读取
|
# cv2 默认全通道读取
|
||||||
image_object = None
|
image_object = None
|
||||||
try:
|
try:
|
||||||
@@ -57,9 +57,135 @@ def oss_upload_image(oss_client, bucket, object_name, image_bytes):
|
|||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
def oss_upload_image_file(oss_client, bucket, object_name, file_path):
|
||||||
|
req = None
|
||||||
|
try:
|
||||||
|
req = oss_client.fput_object(
|
||||||
|
bucket_name=bucket,
|
||||||
|
object_name=object_name,
|
||||||
|
file_path=file_path
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f" | 上传图片出现异常 ######: {e}")
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
def get_presigned_url(oss_client, bucket, object_name):
|
||||||
|
try:
|
||||||
|
presigned_url = oss_client.get_presigned_url(
|
||||||
|
"GET",
|
||||||
|
bucket_name=bucket,
|
||||||
|
object_name=object_name,
|
||||||
|
)
|
||||||
|
return presigned_url
|
||||||
|
except Exception as e:
|
||||||
|
print(f"get_presigned_url exception :{e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_minio_file_exist(oss_client: Minio, bucket_name: str, object_name: str) -> bool:
|
||||||
|
try:
|
||||||
|
# 核心判断:检查MinIO中指定bucket+object是否存在
|
||||||
|
oss_client.stat_object(bucket_name, object_name)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load_minio_file_to_state(oss_client, bucket: str, object_name: str, display_filename: str = None):
|
||||||
|
try:
|
||||||
|
# 下載 object 成 bytes
|
||||||
|
response = oss_client.get_object(
|
||||||
|
bucket_name=bucket,
|
||||||
|
object_name=object_name,
|
||||||
|
)
|
||||||
|
data_bytes = response.read()
|
||||||
|
response.close()
|
||||||
|
response.release_conn()
|
||||||
|
|
||||||
|
# 決定在 agent 裡顯示的檔名(可覆寫,避免洩漏真實 object name)
|
||||||
|
filename = display_filename or object_name.split("/")[-1]
|
||||||
|
|
||||||
|
# 回傳適合塞進 state["files"] 的格式
|
||||||
|
return {filename: data_bytes}
|
||||||
|
|
||||||
|
except S3Error as err:
|
||||||
|
raise ValueError(f"MinIO 下載失敗: {err}")
|
||||||
|
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_extract_minio_image(url: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
校验URL + 提取MinIO图片路径(支持预签名地址)
|
||||||
|
返回格式: {"state": bool, "message": str, "data": str}
|
||||||
|
"""
|
||||||
|
# 1. 空值判断
|
||||||
|
if not url or not isinstance(url, str):
|
||||||
|
return {
|
||||||
|
"state": False,
|
||||||
|
"message": "URL cannot be empty or invalid format",
|
||||||
|
"data": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. 解析URL
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if not (parsed.scheme and parsed.netloc):
|
||||||
|
return {
|
||||||
|
"state": False,
|
||||||
|
"message": "Invalid URL format",
|
||||||
|
"data": ""
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return {
|
||||||
|
"state": False,
|
||||||
|
"message": "Failed to parse URL",
|
||||||
|
"data": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. 域名判断
|
||||||
|
allowed_domains = {"www.minio-api.aida.com.hk", "minio-api.aida.com.hk"}
|
||||||
|
if parsed.netloc not in allowed_domains:
|
||||||
|
return {
|
||||||
|
"state": False,
|
||||||
|
"message": f"Invalid domain: {parsed.netloc}",
|
||||||
|
"data": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. Get file path (ignore query parameters for presigned URL)
|
||||||
|
file_path = parsed.path.strip()
|
||||||
|
if not file_path:
|
||||||
|
return {
|
||||||
|
"state": False,
|
||||||
|
"message": "No file path found in URL",
|
||||||
|
"data": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Check if it's an image
|
||||||
|
image_exts = (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tiff")
|
||||||
|
if not file_path.lower().endswith(image_exts):
|
||||||
|
return {
|
||||||
|
"state": False,
|
||||||
|
"message": "Not a valid image file",
|
||||||
|
"data": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# 6. Extract final path
|
||||||
|
result_path = file_path.lstrip("/")
|
||||||
|
return {
|
||||||
|
"state": True,
|
||||||
|
"message": "Success, path extracted",
|
||||||
|
"data": result_path
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
url = "aida-users/89/sketch/123-89.png"
|
urls = ["fida-public-bucket/furniture/sketches/0193c9b2-d8dd-40fc-b715-3ce0daab7abf.png-0.png", "fida-public-bucket/furniture/sketches/bab54cdf-0a60-4806-8c6b-17b836aec1eb.png-1.png", "fida-public-bucket/furniture/sketches/6c993266-95d2-42ee-826b-933b0e344b81.png-2.png"]
|
||||||
read_type = "2"
|
# read_type = "2"
|
||||||
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
for url in urls:
|
||||||
|
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:])
|
||||||
img.show()
|
img.show()
|
||||||
img.save("result.png")
|
# img.save("result.png")
|
||||||
|
# get_presigned_url(oss_client=minio_client, bucket="fida-test", object_name="furniture/sketches/07bf4cfe-4502-4821-b78f-7727bf409498.png")
|
||||||
|
|||||||
Reference in New Issue
Block a user