import io import logging import sys import time from typing import List import os import json import math import random import numpy as np from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger from fastapi import HTTPException, APIRouter from app.service.recommend.service import load_resources, matrix_data sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') logger = logging.getLogger() router = APIRouter() @router.on_event("startup") async def startup_event(): # 初始加载 load_resources() # 配置定时任务 scheduler = BackgroundScheduler() scheduler.add_job( load_resources, trigger=CronTrigger(hour=0, minute=30), name="每日资源刷新" ) scheduler.start() logger.info("定时任务已启动") def softmax(scores): max_score = max(scores) exp_scores = [math.exp(s - max_score) for s in scores] sum_exp = sum(exp_scores) return [s / sum_exp for s in exp_scores] # def get_random_recommendations(category: str, num: int) -> List[str]: # """根据预加载热度向量推荐(冷启动)""" # try: # heat_data = matrix_data.get("heat_data", {}) # # if category not in heat_data: # raise ValueError(f"热度数据缺少类别 {category},使用随机推荐") # # heat_dict = heat_data[category] # {url: score} # urls = list(heat_dict.keys()) # scores = list(heat_dict.values()) # # if not urls: # raise ValueError("该类别下无热度记录,使用随机推荐") # # probs = softmax(scores) # sample_size = min(num, len(urls)) # sampled_urls = random.choices(urls, weights=probs, k=sample_size) # # return sampled_urls # # except Exception as e: # # 回退:完全随机推荐 # all_iids = list(matrix_data["iid_to_sketch"].keys()) # category_iids = matrix_data["category_to_iids"].get(category, all_iids) # sample_size = min(num, len(category_iids)) # sampled = np.random.choice(category_iids, size=sample_size, replace=False) # return [matrix_data["iid_to_sketch"][iid] for iid in sampled] def get_random_recommendations(category: str, num: int) -> List[str]: """全品类随机推荐""" all_iids = list(matrix_data["iid_to_sketch"].keys()) # 优先从当前品类选择 category_iids = matrix_data["category_to_iids"].get(category, all_iids) # 确保不超出实际数量 sample_size = min(num, len(category_iids)) sampled = np.random.choice(category_iids, size=sample_size, replace=False) return [matrix_data["iid_to_sketch"][iid] for iid in sampled] @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str]) async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10): """ :param user_id: 4 :param category: female_skirt :param num_recommendations: 1 :return: [ "aida-sys-image/images/female/skirt/903000017.jpg" ] """ try: start_time = time.time() cache_key = (user_id, category) # === 新增:用户存在性检查 === user_exists_inter = user_id in matrix_data["user_index_interaction"] user_exists_feat = user_id in matrix_data["user_index_feature"] # 任一矩阵不存在用户则返回随机推荐 if not (user_exists_inter and user_exists_feat): logger.info(f"用户 {user_id} 数据不完整,触发随机推荐") return get_random_recommendations(category, num_recommendations) # 检查缓存 if cache_key in matrix_data["cached_scores"]: processed_inter, processed_feat = matrix_data["cached_scores"][cache_key] valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key] else: # 实时计算逻辑(同原代码) user_idx_inter = matrix_data["user_index_interaction"].get(user_id) user_idx_feature = matrix_data["user_index_feature"].get(user_id) category_iids = matrix_data["category_to_iids"].get(category, []) valid_sketch_idxs_inter = [ idx for iid, idx in matrix_data["sketch_index_interaction"].items() if iid in category_iids ] # 处理交互分数 raw_inter_scores = [] if user_idx_inter is not None and valid_sketch_idxs_inter: raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] processed_inter = raw_inter_scores * 0.7 # 处理特征分数 valid_sketch_idxs_feature = [ idx for iid, idx in matrix_data["sketch_index_feature"].items() if iid in category_iids ] raw_feat_scores = [] if user_idx_feature is not None and valid_sketch_idxs_feature: raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) processed_feat = raw_feat_scores else: processed_feat = np.array([]) # 更新缓存 matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat) matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter # 合并分数 if brand_id is not None: brand_idx_feature = matrix_data["brand_index_map"].get(brand_id) brand_feat_valid = ( matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空 brand_idx_feature is not None and valid_sketch_idxs_feature # 有可用索引 ) if brand_feat_valid: raw_brand_feat_scores = matrix_data["brand_feature_matrix"][ brand_idx_feature, valid_sketch_idxs_feature ] raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / ( np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8 ) processed_brand_feat = raw_brand_feat_scores # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致 if processed_feat.size == 0: processed_feat = np.zeros_like(processed_brand_feat) final_scores = processed_inter + 0.3 * ( (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat ) else: # brand 信息不可用 final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter else: final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key] # 概率采样 scores = np.array(final_scores) # 调整后的概率转换(带温度控制的softmax) def calibrated_softmax(scores, temperature=1.0): scores = scores / temperature scale = scores - max(scores) exps = np.exp(scale) return exps / np.sum(exps) probs = calibrated_softmax(scores, 0.09) chosen_indices = np.random.choice( len(valid_sketch_idxs), size=min(num_recommendations, len(valid_sketch_idxs)), p=probs, replace=False ) recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices] logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒") return recommendations except Exception as e: logger.error(f"推荐失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e))