117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
|
|
import logging
|
|||
|
|
import sys
|
|||
|
|
from typing import Optional
|
|||
|
|
from fastapi import APIRouter, HTTPException, Query
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
|
import threading
|
|||
|
|
|
|||
|
|
from app.schemas.response_template import ResponseModel
|
|||
|
|
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
|
|||
|
|
|
|||
|
|
logger = logging.getLogger()
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
# 使用线程池执行器来运行长时间任务
|
|||
|
|
executor = ThreadPoolExecutor(max_workers=1)
|
|||
|
|
# 用于跟踪任务状态
|
|||
|
|
task_status = {"running": False}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
|
|||
|
|
"""在后台线程中运行导入任务"""
|
|||
|
|
original_argv = None
|
|||
|
|
try:
|
|||
|
|
task_status["running"] = True
|
|||
|
|
# 保存原始 sys.argv
|
|||
|
|
original_argv = sys.argv.copy()
|
|||
|
|
|
|||
|
|
# 模拟命令行参数
|
|||
|
|
sys.argv = [
|
|||
|
|
"import_sys_sketch_to_milvus.py",
|
|||
|
|
"--batch-size", str(batch_size),
|
|||
|
|
"--retry-times", str(retry_times),
|
|||
|
|
]
|
|||
|
|
if limit is not None:
|
|||
|
|
sys.argv.extend(["--limit", str(limit)])
|
|||
|
|
if offset > 0:
|
|||
|
|
sys.argv.extend(["--offset", str(offset)])
|
|||
|
|
if skip_create_collection:
|
|||
|
|
sys.argv.append("--skip-create-collection")
|
|||
|
|
|
|||
|
|
import_main()
|
|||
|
|
task_status["running"] = False
|
|||
|
|
logger.info("导入任务完成")
|
|||
|
|
except Exception as e:
|
|||
|
|
task_status["running"] = False
|
|||
|
|
logger.error(f"导入任务失败: {e}", exc_info=True)
|
|||
|
|
raise
|
|||
|
|
finally:
|
|||
|
|
# 恢复原始 sys.argv
|
|||
|
|
if original_argv is not None:
|
|||
|
|
sys.argv = original_argv
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/import-sys-sketch", response_model=ResponseModel)
|
|||
|
|
async def import_sys_sketch(
|
|||
|
|
batch_size: int = Query(1000, description="批量处理大小(默认:1000)"),
|
|||
|
|
retry_times: int = Query(3, description="失败重试次数(默认:3)"),
|
|||
|
|
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
|
|||
|
|
offset: int = Query(0, description="起始偏移量(默认:0)"),
|
|||
|
|
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
从 t_sys_file 导入系统图向量到 Milvus
|
|||
|
|
|
|||
|
|
该接口会异步执行导入任务,任务在后台运行。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 检查是否有任务正在运行
|
|||
|
|
if task_status["running"]:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=409,
|
|||
|
|
detail="已有导入任务正在运行,请等待完成后再试"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 在后台线程中执行任务
|
|||
|
|
executor.submit(
|
|||
|
|
run_import_task,
|
|||
|
|
batch_size,
|
|||
|
|
retry_times,
|
|||
|
|
limit,
|
|||
|
|
offset,
|
|||
|
|
skip_create_collection
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ResponseModel(
|
|||
|
|
code=200,
|
|||
|
|
msg="导入任务已启动,正在后台执行",
|
|||
|
|
data={
|
|||
|
|
"status": "started",
|
|||
|
|
"batch_size": batch_size,
|
|||
|
|
"retry_times": retry_times,
|
|||
|
|
"limit": limit,
|
|||
|
|
"offset": offset,
|
|||
|
|
"skip_create_collection": skip_create_collection
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"启动导入任务失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
|
|||
|
|
async def get_import_status():
|
|||
|
|
"""
|
|||
|
|
获取导入任务状态
|
|||
|
|
"""
|
|||
|
|
return ResponseModel(
|
|||
|
|
code=200,
|
|||
|
|
msg="OK",
|
|||
|
|
data={
|
|||
|
|
"running": task_status["running"]
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|