# 预加载资源 import logging import time from collections import defaultdict import numpy as np from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX logger = logging.getLogger() import pymysql matrix_data = { "interaction_matrix": None, "feature_matrix": None, "user_index_interaction": None, "sketch_index_interaction": None, "user_index_feature": None, "sketch_index_feature": None, "iid_to_sketch": None, "category_to_iids": None, "cached_scores": {}, "cached_valid_idxs": {}, "category_sketch_idxs_inter": None, "category_sketch_idxs_feature": None, "user_inter_full": dict(), "user_feat_full": dict(), } def load_resources(): """加载所有矩阵和映射关系,并触发预缓存""" try: start_time = time.time() # 清空缓存 matrix_data["cached_scores"].clear() matrix_data["cached_valid_idxs"].clear() # 加载数据 sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item() matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()} matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True) matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item() matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", allow_pickle=True).item() matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True) matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item() matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item() category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item() matrix_data["category_to_iids"] = defaultdict(list) for iid, cat in category_to_iid_map.items(): matrix_data["category_to_iids"][cat].append(iid) logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒") # 触发预缓存 precache_user_category() except Exception as e: logger.error(f"资源加载失败: {str(e)}") raise RuntimeError("初始化失败") def precache_user_category(): """预缓存用户-分类组合数据""" if not all([ matrix_data["interaction_matrix"] is not None, matrix_data["feature_matrix"] is not None, matrix_data["user_index_interaction"] is not None ]): logger.warning("资源未加载完成,跳过预缓存") return start_time = time.time() user_categories = get_all_user_categories() precached_count = 0 for user_id, categories in user_categories.items(): for category in categories: cache_key = (user_id, category) if cache_key in matrix_data["cached_scores"]: continue try: # 获取用户索引 user_idx_inter = matrix_data["user_index_interaction"].get(user_id) user_idx_feature = matrix_data["user_index_feature"].get(user_id) # 获取类别对应的iid列表 category_iids = matrix_data["category_to_iids"].get(category, []) # 过滤有效草图索引 valid_sketch_idxs_inter = [ idx for iid, idx in matrix_data["sketch_index_interaction"].items() if iid in category_iids ] # 处理交互分数 if user_idx_inter is not None and valid_sketch_idxs_inter: raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] processed_inter = raw_inter_scores * 0.7 else: processed_inter = np.array([]) # 处理特征分数 valid_sketch_idxs_feature = [ idx for iid, idx in matrix_data["sketch_index_feature"].items() if iid in category_iids ] if user_idx_feature is not None and valid_sketch_idxs_feature: raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) processed_feat = raw_feat_scores * 0.3 else: processed_feat = np.array([]) # 缓存结果 if len(processed_inter) == len(processed_feat): matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat) matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter precached_count += 1 except Exception as e: logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}") logger.info(f"预缓存完成,共缓存 {precached_count} 个组合,耗时: {time.time() - start_time:.2f}秒") def get_all_user_categories(): """获取所有用户及其对应的分类""" conn = None try: conn = pymysql.connect(**DB_CONFIG) cursor = conn.cursor() query = """ SELECT DISTINCT account_id, path FROM user_preference_log_prediction """ cursor.execute(query) results = cursor.fetchall() user_categories = defaultdict(set) for account_id, path in results: category = get_category_from_path(path) user_categories[account_id].add(category) return dict(user_categories) except Exception as e: logger.error(f"数据库查询失败: {str(e)}") return {} finally: if conn: conn.close() def get_category_from_path(path: str) -> str: """从路径解析类别""" try: parts = path.split('/') if len(parts) >= 4: return f"{parts[2]}_{parts[3]}" return "unknown" except: return "unknown"