diff --git a/app/api/api_recommendation.py b/app/api/api_recommendation.py new file mode 100644 index 0000000..c533709 --- /dev/null +++ b/app/api/api_recommendation.py @@ -0,0 +1,118 @@ +import io +import logging +import sys +import time +from typing import List + +import numpy as np +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.cron import CronTrigger +from fastapi import HTTPException, APIRouter + +from app.service.recommend.service import load_resources, matrix_data + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +logger = logging.getLogger() +router = APIRouter() + + +@router.on_event("startup") +async def startup_event(): + # 初始加载 + load_resources() + + # 配置定时任务 + scheduler = BackgroundScheduler() + scheduler.add_job( + load_resources, + trigger=CronTrigger(hour=0, minute=30), + name="每日资源刷新" + ) + scheduler.start() + logger.info("定时任务已启动") + + +@router.get("/recommend/{user_id}/{category}/{num_recommendations}", response_model=List[str]) +async def get_recommendations(user_id: int, category: str, num_recommendations: int = 10): + """ + :param user_id: 4 + :param category: female_skirt + :param num_recommendations: 1 + :return: + [ + "aida-sys-image/images/female/skirt/903000017.jpg" + ] + """ + try: + start_time = time.time() + cache_key = (user_id, category) + + # 检查缓存 + if cache_key in matrix_data["cached_scores"]: + processed_inter, processed_feat = matrix_data["cached_scores"][cache_key] + valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key] + else: + # 实时计算逻辑(同原代码) + user_idx_inter = matrix_data["user_index_interaction"].get(user_id) + user_idx_feature = matrix_data["user_index_feature"].get(user_id) + + category_iids = matrix_data["category_to_iids"].get(category, []) + valid_sketch_idxs_inter = [ + idx for iid, idx in matrix_data["sketch_index_interaction"].items() + if iid in category_iids + ] + + # 处理交互分数 + raw_inter_scores = [] + if user_idx_inter is not None and valid_sketch_idxs_inter: + raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] + processed_inter = raw_inter_scores * 0.7 + + # 处理特征分数 + valid_sketch_idxs_feature = [ + idx for iid, idx in matrix_data["sketch_index_feature"].items() + if iid in category_iids + ] + raw_feat_scores = [] + if user_idx_feature is not None and valid_sketch_idxs_feature: + raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] + raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( + np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) + processed_feat = raw_feat_scores * 0.3 + else: + processed_feat = np.array([]) + + # 更新缓存 + matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat) + matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter + + # 合并分数 + final_scores = processed_inter + processed_feat + valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key] + + # 概率采样 + scores = np.array(final_scores) + + # 调整后的概率转换(带温度控制的softmax) + def calibrated_softmax(scores, temperature=1.0): + scores = scores / temperature + scale = scores - max(scores) + exps = np.exp(scale) + return exps / np.sum(exps) + + probs = calibrated_softmax(scores, 0.07) + + chosen_indices = np.random.choice( + len(valid_sketch_idxs), + size=min(num_recommendations, len(valid_sketch_idxs)), + p=probs, + replace=False + ) + recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices] + + logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒") + return recommendations + + except Exception as e: + logger.error(f"推荐失败: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/api/api_route.py b/app/api/api_route.py index 973a940..3890316 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -10,6 +10,7 @@ from app.api import api_generate_image from app.api import api_image2sketch from app.api import api_prompt_generation from app.api import api_super_resolution +from app.api import api_recommendation from app.api import api_test router = APIRouter() @@ -26,3 +27,4 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") +router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api") diff --git a/app/core/config.py b/app/core/config.py index d816407..df4702b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -25,10 +25,13 @@ if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" SEG_CACHE_PATH = "../seg_cache/" + RECOMMEND_PATH_PREFIX = "service/recommend/" else: LOGS_PATH = "app/logs/" CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" SEG_CACHE_PATH = "/seg_cache/" + RECOMMEND_PATH_PREFIX = "app/service/recommend/" + # RABBITMQ_ENV = "" # 生产环境 RABBITMQ_ENV = "-dev" # 开发环境 @@ -36,7 +39,6 @@ RABBITMQ_ENV = "-dev" # 开发环境 JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults") - settings = Settings() # minio 配置 @@ -114,7 +116,6 @@ GMV_MODEL_NAME = 'multi_view' GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}") - GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" @@ -191,3 +192,23 @@ PRIORITY_DICT = { } QWEN_API_KEY = "sk-a6bdf594e1f54a4aa3e9d4d48f8c661f" + +DB_CONFIG = { + "host": "18.167.251.121", + "port": 3306, + "user": "root", + "password": "QWa998345", + "database": "aida", + "charset": "utf8mb4" +} + +TABLE_CATEGORIES = { + "female_dress": "female/dress", + "female_outwear": "female/outwear", + "female_trousers": "female/trousers", + "female_skirt": "female/skirt", + "female_blouse": "female/blouse", + "male_tops": "male/tops", + "male_bottoms": "male/bottoms", + "male_outwear": "male/outwear" +} diff --git a/app/main.py b/app/main.py index 95c666a..cbdce4a 100644 --- a/app/main.py +++ b/app/main.py @@ -1,15 +1,17 @@ import logging.config -from http.client import HTTPException -from fastapi.responses import JSONResponse -from fastapi import FastAPI, HTTPException, Request import uvicorn +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.cron import CronTrigger from fastapi import FastAPI +from fastapi import HTTPException, Request +from fastapi.responses import JSONResponse from app.api.api_route import router from app.core.config import settings from app.core.record_api_count import count_api_calls from app.schemas.response_template import ResponseModel +from app.service.recommend.service import load_resources from logging_env import LOGGER_CONFIG_DICT logging.config.dictConfig(LOGGER_CONFIG_DICT) @@ -17,6 +19,8 @@ logging.getLogger("pika").setLevel(logging.WARNING) from starlette.middleware.cors import CORSMiddleware +logger = logging.getLogger(__name__) + def get_application() -> FastAPI: application = FastAPI( @@ -51,5 +55,7 @@ async def http_exception_handler(request: Request, exc: HTTPException): ) + + if __name__ == '__main__': uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/app/service/recommend/scheduled_task.py b/app/service/recommend/scheduled_task.py new file mode 100644 index 0000000..ec1e4aa --- /dev/null +++ b/app/service/recommend/scheduled_task.py @@ -0,0 +1,431 @@ +import pymysql +import numpy as np +from apscheduler.schedulers.blocking import BlockingScheduler +import os +import logging +from collections import defaultdict +import torch +from torchvision import models, transforms +from minio import Minio +from PIL import Image +import io +import seaborn as sns +import matplotlib.pyplot as plt +from scipy.sparse import csr_matrix +import matplotlib.font_manager as fm +from scipy import sparse + +from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX + +# 自动选择可用字体 +try: + # 尝试加载常见中文字体 + font_path = fm.findfont(fm.FontProperties(family=['Microsoft YaHei', 'SimHei', 'WenQuanYi Zen Hei'])) + plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name() +except: + # 回退到英文字体 + plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] +plt.rcParams['axes.unicode_minus'] = False + +# 检查系统中可用的字体并选择支持中文的字体 +font_path = fm.findfont(fm.FontProperties(family='Microsoft YaHei')) # 或其他支持中文的字体 +plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 设置为 Microsoft YaHei +plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 + +# 配置日志记录 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + filename='scheduler.log' +) + +# MinIO 配置信息 +minio_client = Minio( + "www.minio.aida.com.hk:12024", # MinIO Endpoint + access_key="admin", # Access Key + secret_key="Aidlab123123!", # Secret Key + secure=True # 使用https +) + +# 预加载系统sketch特征向量 +SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item() + + +# 保存sketch_to_iid到文件 +def save_sketch_to_iid(): + """保存sketch到iid的映射""" + sketch_to_iid = {sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)} + np.save('sketch_to_iid.npy', sketch_to_iid) + print("sketch_to_iid 已保存") + + +# 从文件加载sketch_to_iid +def load_sketch_to_iid(): + """加载保存的sketch到iid的映射""" + if os.path.exists('sketch_to_iid.npy'): + sketch_to_iid = np.load('sketch_to_iid.npy', allow_pickle=True).item() + print("sketch_to_iid 已加载") + return sketch_to_iid + else: + # 如果文件不存在,则生成并保存 + print("sketch_to_iid 文件不存在,正在生成并保存...") + save_sketch_to_iid() + return np.load('sketch_to_iid.npy', allow_pickle=True).item() + + +# 使用load_sketch_to_iid来获取映射 +sketch_to_iid = load_sketch_to_iid() + +# 在代码中其他地方使用sketch_to_iid +# print(f"Total sketches: {len(sketch_to_iid)}") + +# 定义图像预处理(与ResNet训练时的预处理一致) +transform = transforms.Compose([ + transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入 + transforms.ToTensor(), # 转换为 Tensor + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化 +]) + +# 加载预训练的 ResNet 模型 (ResNet50) +resnet_model = models.resnet50(pretrained=True) +modules = list(resnet_model.children())[:-1] # 移除最后的全连接层 +resnet_model = torch.nn.Sequential(*modules) +resnet_model.eval() # 设置为评估模式 + + +# 从 MinIO 获取图片并进行预处理 +def get_sketch_image_from_minio(sketch_path): + """ + 从 MinIO 获取 sketch 图像并预处理 + """ + # 分割路径,获取桶名和文件路径 + path_parts = sketch_path.split('/', 1) # 根据第一个斜杠分割,得到桶名和路径 + bucket_name = path_parts[0] # 桶名 + file_name = path_parts[1] # 文件路径(从第二部分开始) + + try: + # 获取文件 + obj = minio_client.get_object(bucket_name, file_name) + img_data = obj.read() # 读取图像数据 + img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象 + img = transform(img) # 对图像进行预处理 + return img.unsqueeze(0) # 扩展维度以适应批量处理 + except Exception as e: + print(f"Error fetching image for {sketch_path}: {e}") + return None + + +def extract_feature_vector_from_resnet(sketch_path): + """ + 提取 sketch 图像的特征向量 + """ + # 从 MinIO 获取图像并预处理 + img_tensor = get_sketch_image_from_minio(sketch_path) + if img_tensor is None: + return np.zeros(2048) # 如果获取失败,返回零向量 + + with torch.no_grad(): # 在不需要计算梯度的情况下进行推断 + feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出 + return feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度 + + +def update_user_matrices(): + """每天更新用户交互次数矩阵和特征向量矩阵""" + conn = None + try: + conn = pymysql.connect(**DB_CONFIG) + cursor = conn.cursor() + + # 修改后的查询语句(移除category过滤) + cursor.execute(""" + SELECT account_id, path, COUNT(*) as like_count + FROM user_preference_log_test + GROUP BY account_id, path + """) + user_data = cursor.fetchall() + logging.info(f"成功读取{len(user_data)}条用户偏好记录") + + # 计算矩阵 + interaction_matrix, raw_counts_sparse, user_index_interaction_matrix, sketch_index_interaction_matrix, iid_to_category_interaction_matrix = calculate_interaction_matrix(user_data) + # visualize_sparse_matrix(raw_counts_sparse,'交互次数矩阵', 'interaction_frequency_matrix.png') + # visualize_sparse_matrix(interaction_matrix, '交互次数得分矩阵', 'interaction_score_matrix.png') + # plot_interaction_count_matrix(raw_counts_sparse) + # feature_matrix, iid_to_category_feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix = calculate_feature_matrix(user_data) + feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix, iid_to_category_feature_matrix = calculate_feature_matrix(user_data) + # visualize_sparse_matrix(feature_matrix, '系统sketch与用户category平均特征向量关联度矩阵', 'correlation_matrix.png') + # 存储矩阵 + np.save(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", interaction_matrix) + np.save(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", feature_matrix) + # + np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", iid_to_category_interaction_matrix) + np.save(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", user_index_interaction_matrix) + # + np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_feature_matrix.npy", iid_to_category_feature_matrix) + np.save(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", user_index_feature_matrix) + # + np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", sketch_index_interaction_matrix) + np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", sketch_index_feature_matrix) + # logging.info("矩阵更新完成") + + except Exception as e: + logging.error(f"定时任务执行失败: {str(e)}", exc_info=True) + finally: + if conn: + conn.close() + + +def plot_interaction_count_matrix(interaction_count_matrix): + """绘制交互次数矩阵的分布图(热图),不隐藏零值""" + try: + if not isinstance(interaction_count_matrix, csr_matrix): + sparse_matrix = csr_matrix(interaction_count_matrix) + else: + sparse_matrix = interaction_count_matrix + + # 转换为密集矩阵 + try: + dense_matrix = sparse_matrix.toarray() + except MemoryError: + logging.error("内存不足,无法转换为密集矩阵") + return + + # 自动检测可用中文字体 + try: + font_path = fm.findfont(fm.FontProperties(family='sans-serif', style='normal')) + plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name() + except: + plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] # 回退到英文字体 + plt.rcParams['axes.unicode_minus'] = False + + # 处理大矩阵的显示,限制显示范围 + if dense_matrix.shape[0] > 1000 or dense_matrix.shape[1] > 1000: + dense_matrix = dense_matrix[:1000, :1000] # 只绘制前1000行列 + + plt.figure(figsize=(15, 10)) + + # 使用 `cmap` 来设置颜色,零值可以使用特定颜色,调整 `vmin` 和 `vmax` 让热图更具对比 + sns.heatmap( + dense_matrix, + cmap="Blues", # 可以选择不同的颜色映射,"Blues" 或 "YlGnBu" + annot=False, # 关闭标注 + cbar_kws={"label": "Interaction Count"}, # 添加颜色条标签 + linewidths=0.5, + vmin=0, # 设置最小值,确保零值明显 + vmax=np.max(dense_matrix) # 设置最大值,保持颜色映射的合理性 + ) + + plt.title('User-Sketch Interaction Matrix (With Zero Entries)') + plt.xlabel('Sketch Index') + plt.ylabel('User Index') + plt.savefig('interaction_heatmap_with_zeros.png', dpi=150, bbox_inches='tight') + plt.close() + + logging.info("热图已保存为 interaction_heatmap_with_zeros.png") + + except Exception as e: + logging.error(f"绘图失败: {str(e)}", exc_info=True) + +def visualize_sparse_matrix(matrix, title='Non-zero Interactions (Scatter Plot)', filename="scatter_figure_interaction.png"): + if not sparse.issparse(matrix): + # 转换为稀疏矩阵 + matrix = sparse.csr_matrix(matrix) + + # 获取非零元素的坐标和值 + rows, cols = matrix.nonzero() + values = matrix.data + + # 绘制散点图 + plt.figure(figsize=(24, 20)) + plt.scatter(cols, rows, c=values, cmap='coolwarm', alpha=0.7, s=1) + plt.colorbar(label='Interaction Count') + plt.title(title) + plt.xlabel('Item Index') + plt.ylabel('Item Index') + plt.savefig(filename) + +def calculate_interaction_matrix(user_data): + """基于新表结构的交互次数矩阵计算(仅系统sketch)""" + # 获取所有用户ID + all_users = set() + for account_id, path, like_count in user_data: + category = get_category_from_path(path) + if category not in TABLE_CATEGORIES.keys(): + continue + all_users.add(account_id) + + # 获取所有系统sketch的iid + system_sketch_iids = {sketch_to_iid[path] for path in SYSTEM_FEATURES.keys() if path in sketch_to_iid} + + # 创建映射关系 + user_index = {uid: idx for idx, uid in enumerate(sorted(all_users))} + sketch_index = {iid: idx for idx, iid in enumerate(sorted(system_sketch_iids))} + + # 初始化双矩阵:归一化矩阵(密集)和原始计数矩阵(稀疏) + interaction_matrix = np.zeros((len(all_users), len(sketch_index))) # 归一化矩阵 + data, rows, cols = [], [], [] # 用于构建稀疏矩阵的COO格式数据 + + # 预计算用户最大交互次数 + user_max_likes = defaultdict(int) + for account_id, path, like_count in user_data: + if sketch_to_iid.get(path) in system_sketch_iids: + user_max_likes[account_id] = max(user_max_likes[account_id], like_count) + + # 填充矩阵 + for account_id, path, like_count in user_data: + sketch_iid = sketch_to_iid.get(path) + if sketch_iid not in system_sketch_iids: + continue + + user_idx = user_index[account_id] + sketch_idx = sketch_index[sketch_iid] + + # 填充稀疏矩阵数据 + data.append(like_count) + rows.append(user_idx) + cols.append(sketch_idx) + + # 归一化计算 + max_like = user_max_likes.get(account_id, 1) + interaction_matrix[user_idx, sketch_idx] = np.log1p(1 + like_count) / np.log1p(1 + max_like) + + # 构建稀疏矩阵(CSR格式适合快速行操作) + interaction_count_matrix = csr_matrix((data, (rows, cols)), shape=(len(all_users), len(sketch_index))) + + return interaction_matrix, interaction_count_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()} + + +def calculate_feature_matrix(user_data): + """基于新表结构的特征矩阵计算,返回用户与系统草图的相似度矩阵(加权平均)""" + + # 用户特征数据存储结构:{(account_id, category): {sketch_iid: [(feature_vector, weight)]}} + user_feature_weights = defaultdict(lambda: defaultdict(list)) + + # 初始化所有用户和系统草图集合 + all_users = set() + all_system_sketches = set(SYSTEM_FEATURES.keys()) + + # ==== 第一遍遍历:收集特征向量和权重 ==== + for account_id, path, like_count in user_data: + category = get_category_from_path(path) + if category not in TABLE_CATEGORIES.keys(): + continue + + sketch_iid = sketch_to_iid.get(path) + if not sketch_iid: + continue + + # 记录用户 + all_users.add(account_id) + + # 提取特征并记录权重(like_count) + if path in SYSTEM_FEATURES: # 系统草图 + feature = SYSTEM_FEATURES[path] + weight = like_count # 使用like_count作为权重 + user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight)) + else: # 用户草图 + feature = extract_feature_vector_from_resnet(path) + weight = like_count + user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight)) + + # ==== 第二遍遍历:收集所有系统草图iid ==== + system_sketch_iids = set() + for sketch_path in SYSTEM_FEATURES: + if iid := sketch_to_iid.get(sketch_path): + system_sketch_iids.add(iid) + + # ==== 创建索引映射 ==== + user_list = sorted(all_users) + sketch_list = sorted(system_sketch_iids) + + user_index = {uid: idx for idx, uid in enumerate(user_list)} + sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)} + + # ==== 初始化特征矩阵 ==== + feature_matrix = np.zeros((len(user_list), len(sketch_list))) + + # ==== 预计算加权平均特征向量 ==== + user_avg_features = {} + for (account_id, category), sketches in user_feature_weights.items(): + try: + # 展平所有特征向量和权重 + all_features_weights = [(vec, weight) for vec_list in sketches.values() for vec, weight in vec_list] + + if len(all_features_weights) == 0: + continue + + # 计算总权重 + total_weight = sum(weight for _, weight in all_features_weights) + if total_weight <= 0: # 防止除零错误 + total_weight = 1.0 + + # 加权平均计算 + weighted_sum = np.zeros_like(all_features_weights[0][0]) # 获取特征向量维度 + for vec, weight in all_features_weights: + weighted_sum += vec * weight + + avg_vec = weighted_sum / total_weight + user_avg_features[(account_id, category)] = avg_vec + + except Exception as e: + logging.warning(f"用户({account_id},{category})加权特征计算失败: {str(e)}") + continue + + # ==== 计算相似度并填充矩阵 ==== + for sketch_path, sys_vector in SYSTEM_FEATURES.items(): + sketch_iid = sketch_to_iid.get(sketch_path) + + system_sketch_category = get_category_from_path(sketch_path) + if not sketch_iid or sketch_iid not in sketch_index: + continue + + sketch_col = sketch_index[sketch_iid] + + # 遍历所有用户 + for account_id in all_users: + user_row = user_index.get(account_id) + if user_row is None: + continue + + # 获取用户加权平均特征向量 + try: + # 直接通过复合键获取用户特征向量 + user_vec = user_avg_features[(account_id, system_sketch_category)] + except KeyError: + # 该用户在此类别下无特征数据 + continue + + # 计算余弦相似度 + cos_sim = cosine_similarity(user_vec, sys_vector) + feature_matrix[user_row, sketch_col] = cos_sim + + return feature_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()} + + +def get_category_from_path(path): + """从path字段解析类别""" + try: + parts = path.split('/') + if len(parts) >= 2: + return f"{parts[2]}_{parts[3]}" + return "unknown" + except: + return "unknown" + + +def cosine_similarity(vec1, vec2): + """计算余弦相似度(增加零值处理)""" + norm = np.linalg.norm(vec1) * np.linalg.norm(vec2) + return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0 + + +if __name__ == "__main__": + try: + update_user_matrices() + # scheduler = BlockingScheduler() + # scheduler.add_job(update_user_matrices, 'cron', hour=12, timezone='Asia/Shanghai') + # logging.info("定时任务已启动,每天12:00执行") + # scheduler.start() + except KeyboardInterrupt: + logging.info("定时任务已停止") + except Exception as e: + logging.error(f"调度器启动失败: {str(e)}", exc_info=True) diff --git a/app/service/recommend/service.py b/app/service/recommend/service.py new file mode 100644 index 0000000..bbdc6c3 --- /dev/null +++ b/app/service/recommend/service.py @@ -0,0 +1,172 @@ +# 预加载资源 +import logging +import time +from collections import defaultdict + +import numpy as np + +from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX + +logger = logging.getLogger() +import pymysql + +matrix_data = { + "interaction_matrix": None, + "feature_matrix": None, + "user_index_interaction": None, + "sketch_index_interaction": None, + "user_index_feature": None, + "sketch_index_feature": None, + "iid_to_sketch": None, + "category_to_iids": None, + "cached_scores": {}, + "cached_valid_idxs": {}, + "category_sketch_idxs_inter": None, + "category_sketch_idxs_feature": None, + "user_inter_full": dict(), + "user_feat_full": dict(), +} + + +def load_resources(): + """加载所有矩阵和映射关系,并触发预缓存""" + try: + start_time = time.time() + + # 清空缓存 + matrix_data["cached_scores"].clear() + matrix_data["cached_valid_idxs"].clear() + + # 加载数据 + sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item() + matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()} + + matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True) + matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item() + matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", + allow_pickle=True).item() + + matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True) + matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item() + matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item() + + category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item() + matrix_data["category_to_iids"] = defaultdict(list) + for iid, cat in category_to_iid_map.items(): + matrix_data["category_to_iids"][cat].append(iid) + + logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒") + + # 触发预缓存 + precache_user_category() + + except Exception as e: + logger.error(f"资源加载失败: {str(e)}") + raise RuntimeError("初始化失败") + + +def precache_user_category(): + """预缓存用户-分类组合数据""" + if not all([ + matrix_data["interaction_matrix"] is not None, + matrix_data["feature_matrix"] is not None, + matrix_data["user_index_interaction"] is not None + ]): + logger.warning("资源未加载完成,跳过预缓存") + return + + start_time = time.time() + user_categories = get_all_user_categories() + + precached_count = 0 + for user_id, categories in user_categories.items(): + for category in categories: + cache_key = (user_id, category) + if cache_key in matrix_data["cached_scores"]: + continue + + try: + # 获取用户索引 + user_idx_inter = matrix_data["user_index_interaction"].get(user_id) + user_idx_feature = matrix_data["user_index_feature"].get(user_id) + + # 获取类别对应的iid列表 + category_iids = matrix_data["category_to_iids"].get(category, []) + + # 过滤有效草图索引 + valid_sketch_idxs_inter = [ + idx for iid, idx in matrix_data["sketch_index_interaction"].items() + if iid in category_iids + ] + + # 处理交互分数 + if user_idx_inter is not None and valid_sketch_idxs_inter: + raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] + processed_inter = raw_inter_scores * 0.7 + else: + processed_inter = np.array([]) + + # 处理特征分数 + valid_sketch_idxs_feature = [ + idx for iid, idx in matrix_data["sketch_index_feature"].items() + if iid in category_iids + ] + + if user_idx_feature is not None and valid_sketch_idxs_feature: + raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] + raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( + np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) + processed_feat = raw_feat_scores * 0.3 + else: + processed_feat = np.array([]) + + # 缓存结果 + if len(processed_inter) == len(processed_feat): + matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat) + matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter + precached_count += 1 + + except Exception as e: + logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}") + + logger.info(f"预缓存完成,共缓存 {precached_count} 个组合,耗时: {time.time() - start_time:.2f}秒") + + +def get_all_user_categories(): + """获取所有用户及其对应的分类""" + conn = None + try: + conn = pymysql.connect(**DB_CONFIG) + cursor = conn.cursor() + + query = """ + SELECT DISTINCT account_id, path + FROM user_preference_log_prediction + """ + cursor.execute(query) + results = cursor.fetchall() + + user_categories = defaultdict(set) + for account_id, path in results: + category = get_category_from_path(path) + user_categories[account_id].add(category) + + return dict(user_categories) + + except Exception as e: + logger.error(f"数据库查询失败: {str(e)}") + return {} + finally: + if conn: + conn.close() + + +def get_category_from_path(path: str) -> str: + """从路径解析类别""" + try: + parts = path.split('/') + if len(parts) >= 4: + return f"{parts[2]}_{parts[3]}" + return "unknown" + except: + return "unknown"