173 lines
6.4 KiB
Python
173 lines
6.4 KiB
Python
|
|
# 预加载资源
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
from collections import defaultdict
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
||
|
|
|
||
|
|
logger = logging.getLogger()
|
||
|
|
import pymysql
|
||
|
|
|
||
|
|
matrix_data = {
|
||
|
|
"interaction_matrix": None,
|
||
|
|
"feature_matrix": None,
|
||
|
|
"user_index_interaction": None,
|
||
|
|
"sketch_index_interaction": None,
|
||
|
|
"user_index_feature": None,
|
||
|
|
"sketch_index_feature": None,
|
||
|
|
"iid_to_sketch": None,
|
||
|
|
"category_to_iids": None,
|
||
|
|
"cached_scores": {},
|
||
|
|
"cached_valid_idxs": {},
|
||
|
|
"category_sketch_idxs_inter": None,
|
||
|
|
"category_sketch_idxs_feature": None,
|
||
|
|
"user_inter_full": dict(),
|
||
|
|
"user_feat_full": dict(),
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def load_resources():
|
||
|
|
"""加载所有矩阵和映射关系,并触发预缓存"""
|
||
|
|
try:
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
# 清空缓存
|
||
|
|
matrix_data["cached_scores"].clear()
|
||
|
|
matrix_data["cached_valid_idxs"].clear()
|
||
|
|
|
||
|
|
# 加载数据
|
||
|
|
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
||
|
|
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
||
|
|
|
||
|
|
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
||
|
|
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
||
|
|
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
||
|
|
allow_pickle=True).item()
|
||
|
|
|
||
|
|
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
||
|
|
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
||
|
|
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
||
|
|
|
||
|
|
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
||
|
|
matrix_data["category_to_iids"] = defaultdict(list)
|
||
|
|
for iid, cat in category_to_iid_map.items():
|
||
|
|
matrix_data["category_to_iids"][cat].append(iid)
|
||
|
|
|
||
|
|
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
||
|
|
|
||
|
|
# 触发预缓存
|
||
|
|
precache_user_category()
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"资源加载失败: {str(e)}")
|
||
|
|
raise RuntimeError("初始化失败")
|
||
|
|
|
||
|
|
|
||
|
|
def precache_user_category():
|
||
|
|
"""预缓存用户-分类组合数据"""
|
||
|
|
if not all([
|
||
|
|
matrix_data["interaction_matrix"] is not None,
|
||
|
|
matrix_data["feature_matrix"] is not None,
|
||
|
|
matrix_data["user_index_interaction"] is not None
|
||
|
|
]):
|
||
|
|
logger.warning("资源未加载完成,跳过预缓存")
|
||
|
|
return
|
||
|
|
|
||
|
|
start_time = time.time()
|
||
|
|
user_categories = get_all_user_categories()
|
||
|
|
|
||
|
|
precached_count = 0
|
||
|
|
for user_id, categories in user_categories.items():
|
||
|
|
for category in categories:
|
||
|
|
cache_key = (user_id, category)
|
||
|
|
if cache_key in matrix_data["cached_scores"]:
|
||
|
|
continue
|
||
|
|
|
||
|
|
try:
|
||
|
|
# 获取用户索引
|
||
|
|
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||
|
|
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||
|
|
|
||
|
|
# 获取类别对应的iid列表
|
||
|
|
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||
|
|
|
||
|
|
# 过滤有效草图索引
|
||
|
|
valid_sketch_idxs_inter = [
|
||
|
|
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||
|
|
if iid in category_iids
|
||
|
|
]
|
||
|
|
|
||
|
|
# 处理交互分数
|
||
|
|
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||
|
|
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||
|
|
processed_inter = raw_inter_scores * 0.7
|
||
|
|
else:
|
||
|
|
processed_inter = np.array([])
|
||
|
|
|
||
|
|
# 处理特征分数
|
||
|
|
valid_sketch_idxs_feature = [
|
||
|
|
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||
|
|
if iid in category_iids
|
||
|
|
]
|
||
|
|
|
||
|
|
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||
|
|
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||
|
|
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||
|
|
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||
|
|
processed_feat = raw_feat_scores * 0.3
|
||
|
|
else:
|
||
|
|
processed_feat = np.array([])
|
||
|
|
|
||
|
|
# 缓存结果
|
||
|
|
if len(processed_inter) == len(processed_feat):
|
||
|
|
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||
|
|
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||
|
|
precached_count += 1
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
||
|
|
|
||
|
|
logger.info(f"预缓存完成,共缓存 {precached_count} 个组合,耗时: {time.time() - start_time:.2f}秒")
|
||
|
|
|
||
|
|
|
||
|
|
def get_all_user_categories():
|
||
|
|
"""获取所有用户及其对应的分类"""
|
||
|
|
conn = None
|
||
|
|
try:
|
||
|
|
conn = pymysql.connect(**DB_CONFIG)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
query = """
|
||
|
|
SELECT DISTINCT account_id, path
|
||
|
|
FROM user_preference_log_prediction
|
||
|
|
"""
|
||
|
|
cursor.execute(query)
|
||
|
|
results = cursor.fetchall()
|
||
|
|
|
||
|
|
user_categories = defaultdict(set)
|
||
|
|
for account_id, path in results:
|
||
|
|
category = get_category_from_path(path)
|
||
|
|
user_categories[account_id].add(category)
|
||
|
|
|
||
|
|
return dict(user_categories)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"数据库查询失败: {str(e)}")
|
||
|
|
return {}
|
||
|
|
finally:
|
||
|
|
if conn:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
|
||
|
|
def get_category_from_path(path: str) -> str:
|
||
|
|
"""从路径解析类别"""
|
||
|
|
try:
|
||
|
|
parts = path.split('/')
|
||
|
|
if len(parts) >= 4:
|
||
|
|
return f"{parts[2]}_{parts[3]}"
|
||
|
|
return "unknown"
|
||
|
|
except:
|
||
|
|
return "unknown"
|