117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
import logging
|
||
import sys
|
||
from typing import Optional
|
||
from fastapi import APIRouter, HTTPException, Query
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import threading
|
||
|
||
from app.schemas.response_template import ResponseModel
|
||
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
|
||
|
||
logger = logging.getLogger()
|
||
router = APIRouter()
|
||
|
||
# 使用线程池执行器来运行长时间任务
|
||
executor = ThreadPoolExecutor(max_workers=1)
|
||
# 用于跟踪任务状态
|
||
task_status = {"running": False}
|
||
|
||
|
||
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
|
||
"""在后台线程中运行导入任务"""
|
||
original_argv = None
|
||
try:
|
||
task_status["running"] = True
|
||
# 保存原始 sys.argv
|
||
original_argv = sys.argv.copy()
|
||
|
||
# 模拟命令行参数
|
||
sys.argv = [
|
||
"import_sys_sketch_to_milvus.py",
|
||
"--batch-size", str(batch_size),
|
||
"--retry-times", str(retry_times),
|
||
]
|
||
if limit is not None:
|
||
sys.argv.extend(["--limit", str(limit)])
|
||
if offset > 0:
|
||
sys.argv.extend(["--offset", str(offset)])
|
||
if skip_create_collection:
|
||
sys.argv.append("--skip-create-collection")
|
||
|
||
import_main()
|
||
task_status["running"] = False
|
||
logger.info("导入任务完成")
|
||
except Exception as e:
|
||
task_status["running"] = False
|
||
logger.error(f"导入任务失败: {e}", exc_info=True)
|
||
raise
|
||
finally:
|
||
# 恢复原始 sys.argv
|
||
if original_argv is not None:
|
||
sys.argv = original_argv
|
||
|
||
|
||
@router.post("/import-sys-sketch", response_model=ResponseModel)
|
||
async def import_sys_sketch(
|
||
batch_size: int = Query(1000, description="批量处理大小(默认:1000)"),
|
||
retry_times: int = Query(3, description="失败重试次数(默认:3)"),
|
||
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
|
||
offset: int = Query(0, description="起始偏移量(默认:0)"),
|
||
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
|
||
):
|
||
"""
|
||
从 t_sys_file 导入系统图向量到 Milvus
|
||
|
||
该接口会异步执行导入任务,任务在后台运行。
|
||
"""
|
||
try:
|
||
# 检查是否有任务正在运行
|
||
if task_status["running"]:
|
||
raise HTTPException(
|
||
status_code=409,
|
||
detail="已有导入任务正在运行,请等待完成后再试"
|
||
)
|
||
|
||
# 在后台线程中执行任务
|
||
executor.submit(
|
||
run_import_task,
|
||
batch_size,
|
||
retry_times,
|
||
limit,
|
||
offset,
|
||
skip_create_collection
|
||
)
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
msg="导入任务已启动,正在后台执行",
|
||
data={
|
||
"status": "started",
|
||
"batch_size": batch_size,
|
||
"retry_times": retry_times,
|
||
"limit": limit,
|
||
"offset": offset,
|
||
"skip_create_collection": skip_create_collection
|
||
}
|
||
)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"启动导入任务失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
|
||
|
||
|
||
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
|
||
async def get_import_status():
|
||
"""
|
||
获取导入任务状态
|
||
"""
|
||
return ResponseModel(
|
||
code=200,
|
||
msg="OK",
|
||
data={
|
||
"running": task_status["running"]
|
||
}
|
||
)
|
||
|