feat(新功能): sketch 推荐算法

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zhouchengrong
2025-02-28 16:26:44 +08:00
parent 08f9f7ebf7
commit a2e78f3dd5
6 changed files with 755 additions and 5 deletions

View File

@@ -0,0 +1,118 @@
import io
import logging
import sys
import time
from typing import List
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
@router.on_event("startup")
async def startup_event():
# 初始加载
load_resources()
# 配置定时任务
scheduler = BackgroundScheduler()
scheduler.add_job(
load_resources,
trigger=CronTrigger(hour=0, minute=30),
name="每日资源刷新"
)
scheduler.start()
logger.info("定时任务已启动")
@router.get("/recommend/{user_id}/{category}/{num_recommendations}", response_model=List[str])
async def get_recommendations(user_id: int, category: str, num_recommendations: int = 10):
"""
:param user_id: 4
:param category: female_skirt
:param num_recommendations: 1
:return:
[
"aida-sys-image/images/female/skirt/903000017.jpg"
]
"""
try:
start_time = time.time()
cache_key = (user_id, category)
# 检查缓存
if cache_key in matrix_data["cached_scores"]:
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
else:
# 实时计算逻辑(同原代码)
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
category_iids = matrix_data["category_to_iids"].get(category, [])
valid_sketch_idxs_inter = [
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
if iid in category_iids
]
# 处理交互分数
raw_inter_scores = []
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
# 处理特征分数
valid_sketch_idxs_feature = [
idx for iid, idx in matrix_data["sketch_index_feature"].items()
if iid in category_iids
]
raw_feat_scores = []
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores * 0.3
else:
processed_feat = np.array([])
# 更新缓存
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
# 合并分数
final_scores = processed_inter + processed_feat
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
# 概率采样
scores = np.array(final_scores)
# 调整后的概率转换带温度控制的softmax
def calibrated_softmax(scores, temperature=1.0):
scores = scores / temperature
scale = scores - max(scores)
exps = np.exp(scale)
return exps / np.sum(exps)
probs = calibrated_softmax(scores, 0.07)
chosen_indices = np.random.choice(
len(valid_sketch_idxs),
size=min(num_recommendations, len(valid_sketch_idxs)),
p=probs,
replace=False
)
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}")
return recommendations
except Exception as e:
logger.error(f"推荐失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -10,6 +10,7 @@ from app.api import api_generate_image
from app.api import api_image2sketch
from app.api import api_prompt_generation
from app.api import api_super_resolution
from app.api import api_recommendation
from app.api import api_test
router = APIRouter()
@@ -26,3 +27,4 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")

View File

@@ -25,10 +25,13 @@ if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "../seg_cache/"
RECOMMEND_PATH_PREFIX = "service/recommend/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "/seg_cache/"
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = "-dev" # 开发环境
@@ -36,7 +39,6 @@ RABBITMQ_ENV = "-dev" # 开发环境
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
settings = Settings()
# minio 配置
@@ -114,7 +116,6 @@ GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
@@ -191,3 +192,23 @@ PRIORITY_DICT = {
}
QWEN_API_KEY = "sk-a6bdf594e1f54a4aa3e9d4d48f8c661f"
DB_CONFIG = {
"host": "18.167.251.121",
"port": 3306,
"user": "root",
"password": "QWa998345",
"database": "aida",
"charset": "utf8mb4"
}
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}

View File

@@ -1,15 +1,17 @@
import logging.config
from http.client import HTTPException
from fastapi.responses import JSONResponse
from fastapi import FastAPI, HTTPException, Request
import uvicorn
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from app.api.api_route import router
from app.core.config import settings
from app.core.record_api_count import count_api_calls
from app.schemas.response_template import ResponseModel
from app.service.recommend.service import load_resources
from logging_env import LOGGER_CONFIG_DICT
logging.config.dictConfig(LOGGER_CONFIG_DICT)
@@ -17,6 +19,8 @@ logging.getLogger("pika").setLevel(logging.WARNING)
from starlette.middleware.cors import CORSMiddleware
logger = logging.getLogger(__name__)
def get_application() -> FastAPI:
application = FastAPI(
@@ -51,5 +55,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,431 @@
import pymysql
import numpy as np
from apscheduler.schedulers.blocking import BlockingScheduler
import os
import logging
from collections import defaultdict
import torch
from torchvision import models, transforms
from minio import Minio
from PIL import Image
import io
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix
import matplotlib.font_manager as fm
from scipy import sparse
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
# 自动选择可用字体
try:
# 尝试加载常见中文字体
font_path = fm.findfont(fm.FontProperties(family=['Microsoft YaHei', 'SimHei', 'WenQuanYi Zen Hei']))
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
except:
# 回退到英文字体
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
# 检查系统中可用的字体并选择支持中文的字体
font_path = fm.findfont(fm.FontProperties(family='Microsoft YaHei')) # 或其他支持中文的字体
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 设置为 Microsoft YaHei
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 配置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
filename='scheduler.log'
)
# MinIO 配置信息
minio_client = Minio(
"www.minio.aida.com.hk:12024", # MinIO Endpoint
access_key="admin", # Access Key
secret_key="Aidlab123123!", # Secret Key
secure=True # 使用https
)
# 预加载系统sketch特征向量
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
# 保存sketch_to_iid到文件
def save_sketch_to_iid():
"""保存sketch到iid的映射"""
sketch_to_iid = {sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)}
np.save('sketch_to_iid.npy', sketch_to_iid)
print("sketch_to_iid 已保存")
# 从文件加载sketch_to_iid
def load_sketch_to_iid():
"""加载保存的sketch到iid的映射"""
if os.path.exists('sketch_to_iid.npy'):
sketch_to_iid = np.load('sketch_to_iid.npy', allow_pickle=True).item()
print("sketch_to_iid 已加载")
return sketch_to_iid
else:
# 如果文件不存在,则生成并保存
print("sketch_to_iid 文件不存在,正在生成并保存...")
save_sketch_to_iid()
return np.load('sketch_to_iid.npy', allow_pickle=True).item()
# 使用load_sketch_to_iid来获取映射
sketch_to_iid = load_sketch_to_iid()
# 在代码中其他地方使用sketch_to_iid
# print(f"Total sketches: {len(sketch_to_iid)}")
# 定义图像预处理与ResNet训练时的预处理一致
transform = transforms.Compose([
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
])
# 加载预训练的 ResNet 模型 (ResNet50)
resnet_model = models.resnet50(pretrained=True)
modules = list(resnet_model.children())[:-1] # 移除最后的全连接层
resnet_model = torch.nn.Sequential(*modules)
resnet_model.eval() # 设置为评估模式
# 从 MinIO 获取图片并进行预处理
def get_sketch_image_from_minio(sketch_path):
"""
从 MinIO 获取 sketch 图像并预处理
"""
# 分割路径,获取桶名和文件路径
path_parts = sketch_path.split('/', 1) # 根据第一个斜杠分割,得到桶名和路径
bucket_name = path_parts[0] # 桶名
file_name = path_parts[1] # 文件路径(从第二部分开始)
try:
# 获取文件
obj = minio_client.get_object(bucket_name, file_name)
img_data = obj.read() # 读取图像数据
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
img = transform(img) # 对图像进行预处理
return img.unsqueeze(0) # 扩展维度以适应批量处理
except Exception as e:
print(f"Error fetching image for {sketch_path}: {e}")
return None
def extract_feature_vector_from_resnet(sketch_path):
"""
提取 sketch 图像的特征向量
"""
# 从 MinIO 获取图像并预处理
img_tensor = get_sketch_image_from_minio(sketch_path)
if img_tensor is None:
return np.zeros(2048) # 如果获取失败,返回零向量
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
return feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
def update_user_matrices():
"""每天更新用户交互次数矩阵和特征向量矩阵"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
# 修改后的查询语句移除category过滤
cursor.execute("""
SELECT account_id, path, COUNT(*) as like_count
FROM user_preference_log_test
GROUP BY account_id, path
""")
user_data = cursor.fetchall()
logging.info(f"成功读取{len(user_data)}条用户偏好记录")
# 计算矩阵
interaction_matrix, raw_counts_sparse, user_index_interaction_matrix, sketch_index_interaction_matrix, iid_to_category_interaction_matrix = calculate_interaction_matrix(user_data)
# visualize_sparse_matrix(raw_counts_sparse,'交互次数矩阵', 'interaction_frequency_matrix.png')
# visualize_sparse_matrix(interaction_matrix, '交互次数得分矩阵', 'interaction_score_matrix.png')
# plot_interaction_count_matrix(raw_counts_sparse)
# feature_matrix, iid_to_category_feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix = calculate_feature_matrix(user_data)
feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix, iid_to_category_feature_matrix = calculate_feature_matrix(user_data)
# visualize_sparse_matrix(feature_matrix, '系统sketch与用户category平均特征向量关联度矩阵', 'correlation_matrix.png')
# 存储矩阵
np.save(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", feature_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", iid_to_category_interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", user_index_interaction_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_feature_matrix.npy", iid_to_category_feature_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", user_index_feature_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", sketch_index_interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", sketch_index_feature_matrix)
# logging.info("矩阵更新完成")
except Exception as e:
logging.error(f"定时任务执行失败: {str(e)}", exc_info=True)
finally:
if conn:
conn.close()
def plot_interaction_count_matrix(interaction_count_matrix):
"""绘制交互次数矩阵的分布图(热图),不隐藏零值"""
try:
if not isinstance(interaction_count_matrix, csr_matrix):
sparse_matrix = csr_matrix(interaction_count_matrix)
else:
sparse_matrix = interaction_count_matrix
# 转换为密集矩阵
try:
dense_matrix = sparse_matrix.toarray()
except MemoryError:
logging.error("内存不足,无法转换为密集矩阵")
return
# 自动检测可用中文字体
try:
font_path = fm.findfont(fm.FontProperties(family='sans-serif', style='normal'))
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
except:
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] # 回退到英文字体
plt.rcParams['axes.unicode_minus'] = False
# 处理大矩阵的显示,限制显示范围
if dense_matrix.shape[0] > 1000 or dense_matrix.shape[1] > 1000:
dense_matrix = dense_matrix[:1000, :1000] # 只绘制前1000行列
plt.figure(figsize=(15, 10))
# 使用 `cmap` 来设置颜色,零值可以使用特定颜色,调整 `vmin` 和 `vmax` 让热图更具对比
sns.heatmap(
dense_matrix,
cmap="Blues", # 可以选择不同的颜色映射,"Blues" 或 "YlGnBu"
annot=False, # 关闭标注
cbar_kws={"label": "Interaction Count"}, # 添加颜色条标签
linewidths=0.5,
vmin=0, # 设置最小值,确保零值明显
vmax=np.max(dense_matrix) # 设置最大值,保持颜色映射的合理性
)
plt.title('User-Sketch Interaction Matrix (With Zero Entries)')
plt.xlabel('Sketch Index')
plt.ylabel('User Index')
plt.savefig('interaction_heatmap_with_zeros.png', dpi=150, bbox_inches='tight')
plt.close()
logging.info("热图已保存为 interaction_heatmap_with_zeros.png")
except Exception as e:
logging.error(f"绘图失败: {str(e)}", exc_info=True)
def visualize_sparse_matrix(matrix, title='Non-zero Interactions (Scatter Plot)', filename="scatter_figure_interaction.png"):
if not sparse.issparse(matrix):
# 转换为稀疏矩阵
matrix = sparse.csr_matrix(matrix)
# 获取非零元素的坐标和值
rows, cols = matrix.nonzero()
values = matrix.data
# 绘制散点图
plt.figure(figsize=(24, 20))
plt.scatter(cols, rows, c=values, cmap='coolwarm', alpha=0.7, s=1)
plt.colorbar(label='Interaction Count')
plt.title(title)
plt.xlabel('Item Index')
plt.ylabel('Item Index')
plt.savefig(filename)
def calculate_interaction_matrix(user_data):
"""基于新表结构的交互次数矩阵计算仅系统sketch"""
# 获取所有用户ID
all_users = set()
for account_id, path, like_count in user_data:
category = get_category_from_path(path)
if category not in TABLE_CATEGORIES.keys():
continue
all_users.add(account_id)
# 获取所有系统sketch的iid
system_sketch_iids = {sketch_to_iid[path] for path in SYSTEM_FEATURES.keys() if path in sketch_to_iid}
# 创建映射关系
user_index = {uid: idx for idx, uid in enumerate(sorted(all_users))}
sketch_index = {iid: idx for idx, iid in enumerate(sorted(system_sketch_iids))}
# 初始化双矩阵:归一化矩阵(密集)和原始计数矩阵(稀疏)
interaction_matrix = np.zeros((len(all_users), len(sketch_index))) # 归一化矩阵
data, rows, cols = [], [], [] # 用于构建稀疏矩阵的COO格式数据
# 预计算用户最大交互次数
user_max_likes = defaultdict(int)
for account_id, path, like_count in user_data:
if sketch_to_iid.get(path) in system_sketch_iids:
user_max_likes[account_id] = max(user_max_likes[account_id], like_count)
# 填充矩阵
for account_id, path, like_count in user_data:
sketch_iid = sketch_to_iid.get(path)
if sketch_iid not in system_sketch_iids:
continue
user_idx = user_index[account_id]
sketch_idx = sketch_index[sketch_iid]
# 填充稀疏矩阵数据
data.append(like_count)
rows.append(user_idx)
cols.append(sketch_idx)
# 归一化计算
max_like = user_max_likes.get(account_id, 1)
interaction_matrix[user_idx, sketch_idx] = np.log1p(1 + like_count) / np.log1p(1 + max_like)
# 构建稀疏矩阵CSR格式适合快速行操作
interaction_count_matrix = csr_matrix((data, (rows, cols)), shape=(len(all_users), len(sketch_index)))
return interaction_matrix, interaction_count_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
def calculate_feature_matrix(user_data):
"""基于新表结构的特征矩阵计算,返回用户与系统草图的相似度矩阵(加权平均)"""
# 用户特征数据存储结构:{(account_id, category): {sketch_iid: [(feature_vector, weight)]}}
user_feature_weights = defaultdict(lambda: defaultdict(list))
# 初始化所有用户和系统草图集合
all_users = set()
all_system_sketches = set(SYSTEM_FEATURES.keys())
# ==== 第一遍遍历:收集特征向量和权重 ====
for account_id, path, like_count in user_data:
category = get_category_from_path(path)
if category not in TABLE_CATEGORIES.keys():
continue
sketch_iid = sketch_to_iid.get(path)
if not sketch_iid:
continue
# 记录用户
all_users.add(account_id)
# 提取特征并记录权重like_count
if path in SYSTEM_FEATURES: # 系统草图
feature = SYSTEM_FEATURES[path]
weight = like_count # 使用like_count作为权重
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
else: # 用户草图
feature = extract_feature_vector_from_resnet(path)
weight = like_count
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
# ==== 第二遍遍历收集所有系统草图iid ====
system_sketch_iids = set()
for sketch_path in SYSTEM_FEATURES:
if iid := sketch_to_iid.get(sketch_path):
system_sketch_iids.add(iid)
# ==== 创建索引映射 ====
user_list = sorted(all_users)
sketch_list = sorted(system_sketch_iids)
user_index = {uid: idx for idx, uid in enumerate(user_list)}
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
# ==== 初始化特征矩阵 ====
feature_matrix = np.zeros((len(user_list), len(sketch_list)))
# ==== 预计算加权平均特征向量 ====
user_avg_features = {}
for (account_id, category), sketches in user_feature_weights.items():
try:
# 展平所有特征向量和权重
all_features_weights = [(vec, weight) for vec_list in sketches.values() for vec, weight in vec_list]
if len(all_features_weights) == 0:
continue
# 计算总权重
total_weight = sum(weight for _, weight in all_features_weights)
if total_weight <= 0: # 防止除零错误
total_weight = 1.0
# 加权平均计算
weighted_sum = np.zeros_like(all_features_weights[0][0]) # 获取特征向量维度
for vec, weight in all_features_weights:
weighted_sum += vec * weight
avg_vec = weighted_sum / total_weight
user_avg_features[(account_id, category)] = avg_vec
except Exception as e:
logging.warning(f"用户({account_id},{category})加权特征计算失败: {str(e)}")
continue
# ==== 计算相似度并填充矩阵 ====
for sketch_path, sys_vector in SYSTEM_FEATURES.items():
sketch_iid = sketch_to_iid.get(sketch_path)
system_sketch_category = get_category_from_path(sketch_path)
if not sketch_iid or sketch_iid not in sketch_index:
continue
sketch_col = sketch_index[sketch_iid]
# 遍历所有用户
for account_id in all_users:
user_row = user_index.get(account_id)
if user_row is None:
continue
# 获取用户加权平均特征向量
try:
# 直接通过复合键获取用户特征向量
user_vec = user_avg_features[(account_id, system_sketch_category)]
except KeyError:
# 该用户在此类别下无特征数据
continue
# 计算余弦相似度
cos_sim = cosine_similarity(user_vec, sys_vector)
feature_matrix[user_row, sketch_col] = cos_sim
return feature_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
def get_category_from_path(path):
"""从path字段解析类别"""
try:
parts = path.split('/')
if len(parts) >= 2:
return f"{parts[2]}_{parts[3]}"
return "unknown"
except:
return "unknown"
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
if __name__ == "__main__":
try:
update_user_matrices()
# scheduler = BlockingScheduler()
# scheduler.add_job(update_user_matrices, 'cron', hour=12, timezone='Asia/Shanghai')
# logging.info("定时任务已启动每天12:00执行")
# scheduler.start()
except KeyboardInterrupt:
logging.info("定时任务已停止")
except Exception as e:
logging.error(f"调度器启动失败: {str(e)}", exc_info=True)

View File

@@ -0,0 +1,172 @@
# 预加载资源
import logging
import time
from collections import defaultdict
import numpy as np
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
logger = logging.getLogger()
import pymysql
matrix_data = {
"interaction_matrix": None,
"feature_matrix": None,
"user_index_interaction": None,
"sketch_index_interaction": None,
"user_index_feature": None,
"sketch_index_feature": None,
"iid_to_sketch": None,
"category_to_iids": None,
"cached_scores": {},
"cached_valid_idxs": {},
"category_sketch_idxs_inter": None,
"category_sketch_idxs_feature": None,
"user_inter_full": dict(),
"user_feat_full": dict(),
}
def load_resources():
"""加载所有矩阵和映射关系,并触发预缓存"""
try:
start_time = time.time()
# 清空缓存
matrix_data["cached_scores"].clear()
matrix_data["cached_valid_idxs"].clear()
# 加载数据
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
allow_pickle=True).item()
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
matrix_data["category_to_iids"] = defaultdict(list)
for iid, cat in category_to_iid_map.items():
matrix_data["category_to_iids"][cat].append(iid)
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}")
# 触发预缓存
precache_user_category()
except Exception as e:
logger.error(f"资源加载失败: {str(e)}")
raise RuntimeError("初始化失败")
def precache_user_category():
"""预缓存用户-分类组合数据"""
if not all([
matrix_data["interaction_matrix"] is not None,
matrix_data["feature_matrix"] is not None,
matrix_data["user_index_interaction"] is not None
]):
logger.warning("资源未加载完成,跳过预缓存")
return
start_time = time.time()
user_categories = get_all_user_categories()
precached_count = 0
for user_id, categories in user_categories.items():
for category in categories:
cache_key = (user_id, category)
if cache_key in matrix_data["cached_scores"]:
continue
try:
# 获取用户索引
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
# 获取类别对应的iid列表
category_iids = matrix_data["category_to_iids"].get(category, [])
# 过滤有效草图索引
valid_sketch_idxs_inter = [
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
if iid in category_iids
]
# 处理交互分数
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
else:
processed_inter = np.array([])
# 处理特征分数
valid_sketch_idxs_feature = [
idx for iid, idx in matrix_data["sketch_index_feature"].items()
if iid in category_iids
]
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores * 0.3
else:
processed_feat = np.array([])
# 缓存结果
if len(processed_inter) == len(processed_feat):
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
precached_count += 1
except Exception as e:
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
logger.info(f"预缓存完成,共缓存 {precached_count} 个组合,耗时: {time.time() - start_time:.2f}")
def get_all_user_categories():
"""获取所有用户及其对应的分类"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
query = """
SELECT DISTINCT account_id, path
FROM user_preference_log_prediction
"""
cursor.execute(query)
results = cursor.fetchall()
user_categories = defaultdict(set)
for account_id, path in results:
category = get_category_from_path(path)
user_categories[account_id].add(category)
return dict(user_categories)
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}")
return {}
finally:
if conn:
conn.close()
def get_category_from_path(path: str) -> str:
"""从路径解析类别"""
try:
parts = path.split('/')
if len(parts) >= 4:
return f"{parts[2]}_{parts[3]}"
return "unknown"
except:
return "unknown"