Merge remote-tracking branch 'origin/dev-ltx' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
# Conflicts: # app/api/api_recommendation.py # app/service/design_fast/utils/organize.py
This commit is contained in:
@@ -9,7 +9,6 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
import pymysql
|
||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||
from minio import Minio
|
||||
|
||||
116
app/api/api_import_sys_sketch.py
Normal file
116
app/api/api_import_sys_sketch.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# 使用线程池执行器来运行长时间任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
# 用于跟踪任务状态
|
||||
task_status = {"running": False}
|
||||
|
||||
|
||||
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
|
||||
"""在后台线程中运行导入任务"""
|
||||
original_argv = None
|
||||
try:
|
||||
task_status["running"] = True
|
||||
# 保存原始 sys.argv
|
||||
original_argv = sys.argv.copy()
|
||||
|
||||
# 模拟命令行参数
|
||||
sys.argv = [
|
||||
"import_sys_sketch_to_milvus.py",
|
||||
"--batch-size", str(batch_size),
|
||||
"--retry-times", str(retry_times),
|
||||
]
|
||||
if limit is not None:
|
||||
sys.argv.extend(["--limit", str(limit)])
|
||||
if offset > 0:
|
||||
sys.argv.extend(["--offset", str(offset)])
|
||||
if skip_create_collection:
|
||||
sys.argv.append("--skip-create-collection")
|
||||
|
||||
import_main()
|
||||
task_status["running"] = False
|
||||
logger.info("导入任务完成")
|
||||
except Exception as e:
|
||||
task_status["running"] = False
|
||||
logger.error(f"导入任务失败: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
# 恢复原始 sys.argv
|
||||
if original_argv is not None:
|
||||
sys.argv = original_argv
|
||||
|
||||
|
||||
@router.post("/import-sys-sketch", response_model=ResponseModel)
|
||||
async def import_sys_sketch(
|
||||
batch_size: int = Query(1000, description="批量处理大小(默认:1000)"),
|
||||
retry_times: int = Query(3, description="失败重试次数(默认:3)"),
|
||||
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
|
||||
offset: int = Query(0, description="起始偏移量(默认:0)"),
|
||||
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
|
||||
):
|
||||
"""
|
||||
从 t_sys_file 导入系统图向量到 Milvus
|
||||
|
||||
该接口会异步执行导入任务,任务在后台运行。
|
||||
"""
|
||||
try:
|
||||
# 检查是否有任务正在运行
|
||||
if task_status["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="已有导入任务正在运行,请等待完成后再试"
|
||||
)
|
||||
|
||||
# 在后台线程中执行任务
|
||||
executor.submit(
|
||||
run_import_task,
|
||||
batch_size,
|
||||
retry_times,
|
||||
limit,
|
||||
offset,
|
||||
skip_create_collection
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="导入任务已启动,正在后台执行",
|
||||
data={
|
||||
"status": "started",
|
||||
"batch_size": batch_size,
|
||||
"retry_times": retry_times,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"skip_create_collection": skip_create_collection
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"启动导入任务失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
|
||||
async def get_import_status():
|
||||
"""
|
||||
获取导入任务状态
|
||||
"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="OK",
|
||||
data={
|
||||
"running": task_status["running"]
|
||||
}
|
||||
)
|
||||
|
||||
85
app/api/api_precompute.py
Normal file
85
app/api/api_precompute.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommendation_system.precompute import run_precompute
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# 使用线程池执行器来运行长时间任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
# 用于跟踪任务状态
|
||||
task_status = {"running": False}
|
||||
|
||||
|
||||
def run_precompute_task():
|
||||
"""在后台线程中运行预计算任务"""
|
||||
try:
|
||||
task_status["running"] = True
|
||||
logger.info("开始执行预计算任务...")
|
||||
run_precompute()
|
||||
task_status["running"] = False
|
||||
logger.info("预计算任务完成")
|
||||
except Exception as e:
|
||||
task_status["running"] = False
|
||||
logger.error(f"预计算任务失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/precompute", response_model=ResponseModel)
|
||||
async def precompute():
|
||||
"""
|
||||
运行预计算任务
|
||||
|
||||
该接口会异步执行预计算任务,包括:
|
||||
1. 优化数据库表结构
|
||||
2. 历史数据迁移
|
||||
3. 初始用户偏好向量生成
|
||||
|
||||
任务在后台运行。
|
||||
"""
|
||||
try:
|
||||
# 检查是否有任务正在运行
|
||||
if task_status["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="已有预计算任务正在运行,请等待完成后再试"
|
||||
)
|
||||
|
||||
# 在后台线程中执行任务
|
||||
executor.submit(run_precompute_task)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="预计算任务已启动,正在后台执行",
|
||||
data={
|
||||
"status": "started",
|
||||
"tasks": [
|
||||
"优化数据库表结构",
|
||||
"历史数据迁移",
|
||||
"初始用户偏好向量生成"
|
||||
]
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/precompute/status", response_model=ResponseModel)
|
||||
async def get_precompute_status():
|
||||
"""
|
||||
获取预计算任务状态
|
||||
"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="OK",
|
||||
data={
|
||||
"running": task_status["running"]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,207 +1,175 @@
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException, APIRouter, Query
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
|
||||
from app.service.recommendation_system.incremental_listener import start_background_listener
|
||||
from app.service.recommendation_system.milvus_client import create_collection
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
# 初始加载
|
||||
load_resources()
|
||||
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_job(
|
||||
load_resources,
|
||||
trigger=CronTrigger(hour=0, minute=30),
|
||||
name="每日资源刷新"
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("定时任务已启动")
|
||||
|
||||
|
||||
def softmax(scores):
|
||||
max_score = max(scores)
|
||||
exp_scores = [math.exp(s - max_score) for s in scores]
|
||||
sum_exp = sum(exp_scores)
|
||||
return [s / sum_exp for s in exp_scores]
|
||||
|
||||
|
||||
# def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
# """根据预加载热度向量推荐(冷启动)"""
|
||||
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
|
||||
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
# """
|
||||
# :param user_id: 4
|
||||
# :param category: female_skirt
|
||||
# :param num_recommendations: 1
|
||||
# :return:
|
||||
# [
|
||||
# "aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
# ]
|
||||
# """
|
||||
# try:
|
||||
# heat_data = matrix_data.get("heat_data", {})
|
||||
# start_time = time.time()
|
||||
# cache_key = (user_id, category)
|
||||
# # === 新增:用户存在性检查 ===
|
||||
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
# user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
#
|
||||
# if category not in heat_data:
|
||||
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
|
||||
# # 任一矩阵不存在用户则返回随机推荐
|
||||
# if not (user_exists_inter and user_exists_feat):
|
||||
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
# return get_random_recommendations(category, num_recommendations)
|
||||
#
|
||||
# heat_dict = heat_data[category] # {url: score}
|
||||
# urls = list(heat_dict.keys())
|
||||
# scores = list(heat_dict.values())
|
||||
# # 检查缓存
|
||||
# if cache_key in matrix_data["cached_scores"]:
|
||||
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
# else:
|
||||
# # 实时计算逻辑(同原代码)
|
||||
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
#
|
||||
# if not urls:
|
||||
# raise ValueError("该类别下无热度记录,使用随机推荐")
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
# valid_sketch_idxs_inter = [
|
||||
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
# if iid in category_iids
|
||||
# ]
|
||||
#
|
||||
# probs = softmax(scores)
|
||||
# sample_size = min(num, len(urls))
|
||||
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
|
||||
# # 处理交互分数
|
||||
# raw_inter_scores = []
|
||||
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
# processed_inter = raw_inter_scores * 0.7
|
||||
#
|
||||
# return sampled_urls
|
||||
# # 处理特征分数
|
||||
# valid_sketch_idxs_feature = [
|
||||
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
# if iid in category_iids
|
||||
# ]
|
||||
# raw_feat_scores = []
|
||||
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
# processed_feat = raw_feat_scores
|
||||
# else:
|
||||
# processed_feat = np.array([])
|
||||
#
|
||||
# # 更新缓存
|
||||
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
#
|
||||
# # 合并分数
|
||||
# if brand_id is not None:
|
||||
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
#
|
||||
# brand_feat_valid = (
|
||||
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
# brand_idx_feature is not None and
|
||||
# valid_sketch_idxs_feature # 有可用索引
|
||||
# )
|
||||
#
|
||||
# if brand_feat_valid:
|
||||
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
# brand_idx_feature, valid_sketch_idxs_feature
|
||||
# ]
|
||||
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
# )
|
||||
# processed_brand_feat = raw_brand_feat_scores
|
||||
#
|
||||
# # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
# if processed_feat.size == 0:
|
||||
# processed_feat = np.zeros_like(processed_brand_feat)
|
||||
#
|
||||
# final_scores = processed_inter + 0.3 * (
|
||||
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
# )
|
||||
# else:
|
||||
# # brand 信息不可用
|
||||
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
# else:
|
||||
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
#
|
||||
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
#
|
||||
# # 概率采样
|
||||
# scores = np.array(final_scores)
|
||||
#
|
||||
# # 调整后的概率转换(带温度控制的softmax)
|
||||
# def calibrated_softmax(scores, temperature=1.0):
|
||||
# scores = scores / temperature
|
||||
# scale = scores - max(scores)
|
||||
# exps = np.exp(scale)
|
||||
# return exps / np.sum(exps)
|
||||
#
|
||||
# probs = calibrated_softmax(scores, 0.09)
|
||||
#
|
||||
# chosen_indices = np.random.choice(
|
||||
# len(valid_sketch_idxs),
|
||||
# size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
# p=probs,
|
||||
# replace=False
|
||||
# )
|
||||
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
#
|
||||
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
# return recommendations
|
||||
# except Exception as e:
|
||||
# # 回退:完全随机推荐
|
||||
# all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# sample_size = min(num, len(category_iids))
|
||||
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
# raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
"""全品类随机推荐"""
|
||||
all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# 优先从当前品类选择
|
||||
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# 确保不超出实际数量
|
||||
sample_size = min(num, len(category_iids))
|
||||
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
"""
|
||||
:param user_id: 4
|
||||
:param category: female_skirt
|
||||
:param num_recommendations: 1
|
||||
:return:
|
||||
[
|
||||
"aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
]
|
||||
"""
|
||||
# @router.on_event("startup")
|
||||
async def startup_event():
|
||||
"""启动时初始化增量监听任务"""
|
||||
try:
|
||||
logger.info(f"user_id:{user_id}-----category:{category}-----brand_id:{brand_id}-----brand_scale:{brand_scale}-----num_recommendations:{num_recommendations}")
|
||||
start_time = time.time()
|
||||
cache_key = (user_id, category)
|
||||
# === 新增:用户存在性检查 ===
|
||||
user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
|
||||
# 任一矩阵不存在用户则返回随机推荐
|
||||
if not (user_exists_inter and user_exists_feat):
|
||||
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
return get_random_recommendations(category, num_recommendations)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
else:
|
||||
# 实时计算逻辑(同原代码)
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
valid_sketch_idxs_inter = [
|
||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
|
||||
# 处理交互分数
|
||||
raw_inter_scores = []
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
|
||||
# 处理特征分数
|
||||
valid_sketch_idxs_feature = [
|
||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
raw_feat_scores = []
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores
|
||||
else:
|
||||
processed_feat = np.array([])
|
||||
|
||||
# 更新缓存
|
||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
|
||||
# 合并分数
|
||||
if brand_id is not None:
|
||||
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
|
||||
brand_feat_valid = (
|
||||
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
brand_idx_feature is not None and
|
||||
valid_sketch_idxs_feature # 有可用索引
|
||||
)
|
||||
|
||||
if brand_feat_valid:
|
||||
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
brand_idx_feature, valid_sketch_idxs_feature
|
||||
]
|
||||
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
)
|
||||
processed_brand_feat = raw_brand_feat_scores
|
||||
|
||||
# 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
if processed_feat.size == 0:
|
||||
processed_feat = np.zeros_like(processed_brand_feat)
|
||||
|
||||
final_scores = processed_inter + 0.3 * (
|
||||
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
)
|
||||
else:
|
||||
# brand 信息不可用
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
else:
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
|
||||
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
|
||||
# 概率采样
|
||||
scores = np.array(final_scores)
|
||||
|
||||
# 调整后的概率转换(带温度控制的softmax)
|
||||
def calibrated_softmax(scores, temperature=1.0):
|
||||
scores = scores / temperature
|
||||
scale = scores - max(scores)
|
||||
exps = np.exp(scale)
|
||||
return exps / np.sum(exps)
|
||||
|
||||
probs = calibrated_softmax(scores, 0.09)
|
||||
|
||||
chosen_indices = np.random.choice(
|
||||
len(valid_sketch_idxs),
|
||||
size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
p=probs,
|
||||
replace=False
|
||||
)
|
||||
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
|
||||
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
return recommendations
|
||||
|
||||
# 确保 Milvus 集合已创建(若已存在则直接返回)
|
||||
try:
|
||||
create_collection()
|
||||
except Exception as exc:
|
||||
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
|
||||
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
start_background_listener(scheduler)
|
||||
scheduler.start()
|
||||
logger.info("增量监听定时任务已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
|
||||
async def recommend(
|
||||
user_id: int,
|
||||
category: str,
|
||||
style: Optional[str] = Query(
|
||||
None,
|
||||
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
|
||||
),
|
||||
):
|
||||
"""新版推荐接口(Milvus + Redis 偏好向量)。"""
|
||||
try:
|
||||
results = get_new_recommendations(user_id, category, style)
|
||||
path = results[0] if results else ""
|
||||
return [path]
|
||||
except Exception as e:
|
||||
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -10,8 +10,10 @@ from app.api import api_design_pre_processing
|
||||
from app.api import api_extraction_project_info
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_image2sketch
|
||||
from app.api import api_import_sys_sketch
|
||||
from app.api import api_mannequins_edit
|
||||
from app.api import api_pose_transform
|
||||
from app.api import api_precompute
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_recommendation
|
||||
from app.api import api_super_resolution
|
||||
@@ -36,3 +38,5 @@ router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'],
|
||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
router.include_router(api_import_sys_sketch.router, tags=['api_import_sys_sketch'], prefix="/api")
|
||||
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
|
||||
|
||||
Reference in New Issue
Block a user