import io import logging import sys from typing import List, Optional from fastapi import HTTPException, APIRouter, Query from apscheduler.schedulers.background import BackgroundScheduler from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations from app.service.recommendation_system.incremental_listener import start_background_listener from app.service.recommendation_system.milvus_client import create_collection sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') logger = logging.getLogger() router = APIRouter() # ========== 旧版推荐接口(基于 npy 矩阵,已废弃)========== # @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str]) # async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10): # """ # :param user_id: 4 # :param category: female_skirt # :param num_recommendations: 1 # :return: # [ # "aida-sys-image/images/female/skirt/903000017.jpg" # ] # """ # try: # start_time = time.time() # cache_key = (user_id, category) # # === 新增:用户存在性检查 === # user_exists_inter = user_id in matrix_data["user_index_interaction"] # user_exists_feat = user_id in matrix_data["user_index_feature"] # # # 任一矩阵不存在用户则返回随机推荐 # if not (user_exists_inter and user_exists_feat): # logger.info(f"用户 {user_id} 数据不完整,触发随机推荐") # return get_random_recommendations(category, num_recommendations) # # # 检查缓存 # if cache_key in matrix_data["cached_scores"]: # processed_inter, processed_feat = matrix_data["cached_scores"][cache_key] # valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key] # else: # # 实时计算逻辑(同原代码) # user_idx_inter = matrix_data["user_index_interaction"].get(user_id) # user_idx_feature = matrix_data["user_index_feature"].get(user_id) # # category_iids = matrix_data["category_to_iids"].get(category, []) # valid_sketch_idxs_inter = [ # idx for iid, idx in matrix_data["sketch_index_interaction"].items() # if iid in category_iids # ] # # # 处理交互分数 # raw_inter_scores = [] # if user_idx_inter is not None and valid_sketch_idxs_inter: # raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] # processed_inter = raw_inter_scores * 0.7 # # # 处理特征分数 # valid_sketch_idxs_feature = [ # idx for iid, idx in matrix_data["sketch_index_feature"].items() # if iid in category_iids # ] # raw_feat_scores = [] # if user_idx_feature is not None and valid_sketch_idxs_feature: # raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] # raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( # np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) # processed_feat = raw_feat_scores # else: # processed_feat = np.array([]) # # # 更新缓存 # matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat) # matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter # # # 合并分数 # if brand_id is not None: # brand_idx_feature = matrix_data["brand_index_map"].get(brand_id) # # brand_feat_valid = ( # matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空 # brand_idx_feature is not None and # valid_sketch_idxs_feature # 有可用索引 # ) # # if brand_feat_valid: # raw_brand_feat_scores = matrix_data["brand_feature_matrix"][ # brand_idx_feature, valid_sketch_idxs_feature # ] # raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / ( # np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8 # ) # processed_brand_feat = raw_brand_feat_scores # # # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致 # if processed_feat.size == 0: # processed_feat = np.zeros_like(processed_brand_feat) # # final_scores = processed_inter + 0.3 * ( # (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat # ) # else: # # brand 信息不可用 # final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter # else: # final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter # # valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key] # # # 概率采样 # scores = np.array(final_scores) # # # 调整后的概率转换(带温度控制的softmax) # def calibrated_softmax(scores, temperature=1.0): # scores = scores / temperature # scale = scores - max(scores) # exps = np.exp(scale) # return exps / np.sum(exps) # # probs = calibrated_softmax(scores, 0.09) # # chosen_indices = np.random.choice( # len(valid_sketch_idxs), # size=min(num_recommendations, len(valid_sketch_idxs)), # p=probs, # replace=False # ) # recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices] # # logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒") # return recommendations # except Exception as e: # logger.error(f"推荐失败: {str(e)}", exc_info=True) # raise HTTPException(status_code=500, detail=str(e)) @router.on_event("startup") async def startup_event(): """启动时初始化增量监听任务""" try: # 屏蔽 apscheduler 的 INFO 日志 logging.getLogger("apscheduler").setLevel(logging.WARNING) # 确保 Milvus 集合已创建(若已存在则直接返回) try: create_collection() except Exception as exc: logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True) # 配置定时任务 scheduler = BackgroundScheduler() start_background_listener(scheduler) scheduler.start() logger.info("增量监听定时任务已启动") except Exception as e: logger.error(f"启动增量监听任务失败: {e}", exc_info=True) @router.get("/recommend/{user_id}/{category}", response_model=List[str]) async def recommend( user_id: int, category: str, style: Optional[str] = Query( None, description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分", ), ): """新版推荐接口(Milvus + Redis 偏好向量)。""" try: results = get_new_recommendations(user_id, category, style) path = results[0] if results else "" return [path] except Exception as e: logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/redis/user_pref") async def get_all_user_preferences(): """ 获取所有以 user_pref 为前缀的 Redis key 信息 """ try: from app.service.utils.redis_utils import Redis from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX # 扫描所有匹配 user_pref:* 的 key pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*" keys = Redis.scan_keys(pattern) # 直接返回所有 key 和原始 value result = {} for key in keys: # 读取对应的值 value = Redis.read(key) if value: result[key] = value return result except Exception as e: logger.error("获取用户偏好数据失败: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e))