diff --git a/app/api/api_brand_dna_initialize.py b/app/api/api_brand_dna_initialize.py index b58b250..9388bdd 100644 --- a/app/api/api_brand_dna_initialize.py +++ b/app/api/api_brand_dna_initialize.py @@ -1,25 +1,34 @@ import io import logging -import os import sys +import time +from typing import List from collections import defaultdict - import numpy as np -import pymysql -import torch -from PIL import Image +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.cron import CronTrigger from fastapi import HTTPException, APIRouter -from fastapi.responses import JSONResponse -from minio import Minio -from torchvision import models, transforms -from app.core.mysql_config import DB_CONFIG -from app.core.new_config import settings +import pymysql +from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX +from minio import Minio +import torch +from torchvision import models, transforms +from PIL import Image +import os +from fastapi.responses import JSONResponse sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') logger = logging.getLogger() router = APIRouter() -minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) + +# MinIO 配置 +minio_client = Minio( + "www.minio.aida.com.hk:12024", + access_key="admin", + secret_key="Aidlab123123!", + secure=True +) transform = transforms.Compose([ transforms.Resize((224, 224)), @@ -58,8 +67,8 @@ def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray: # 预加载 -BRAND_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item() -SYSTEM_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item() +BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item() +SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item() def save_sketch_to_iid(): @@ -67,11 +76,11 @@ def save_sketch_to_iid(): sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1) } - np.save(f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid) + np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid) def load_sketch_to_iid(): - path = f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy" + path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy" if os.path.exists(path): return np.load(path, allow_pickle=True).item() save_sketch_to_iid() @@ -81,7 +90,7 @@ def load_sketch_to_iid(): sketch_to_iid = load_sketch_to_iid() -def get_new_category(gender: str, sketch_category: str) -> str: +def getNewCategory(gender: str, sketch_category: str) -> str: return f"{gender.lower()}_{sketch_category.lower()}" @@ -94,8 +103,8 @@ def get_category_from_path(path: str) -> str: def load_brand_matrix(): """单独加载 brand_matrix 和 brand_index_map""" - mat_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_matrix.npy" - idx_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy" + mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy" + idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy" try: matrix = np.load(mat_path) index_map = np.load(idx_path, allow_pickle=True).item() @@ -104,19 +113,11 @@ def load_brand_matrix(): index_map = {} return matrix, index_map - def cosine_similarity(vec1, vec2): """计算余弦相似度(增加零值处理)""" norm = np.linalg.norm(vec1) * np.linalg.norm(vec2) return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0 - -def getNewCategory(gender, sketch_category): - print(gender) - print(sketch_category) - return "None" - - def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray: # 1. 收集品牌-分类-特征 brand_feature = defaultdict(lambda: defaultdict(list)) @@ -163,11 +164,11 @@ def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray: brand_matrix[row_idx, sketch_index[iid]] = cos_sim # 7. 持久化 - np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix) - np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map) + np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix) + np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map) # 返回该品牌对应行 - return brand_matrix[row_idx:row_idx + 1] + return brand_matrix[row_idx:row_idx+1] @router.get("/brand_dna_initialize/{brand_id}") @@ -177,12 +178,14 @@ async def brand_dna_initialize(brand_id: int): conn = pymysql.connect(**DB_CONFIG) cursor = conn.cursor() cursor.execute(""" - SELECT id, img_url, gender, category - FROM product_image_attribute - WHERE library_id IN (SELECT library_id - FROM brand_rel_library - WHERE brand_id = %s) - """, (brand_id,)) + SELECT id, img_url, gender, category + FROM product_image_attribute + WHERE library_id IN ( + SELECT library_id + FROM brand_rel_library + WHERE brand_id = %s + ) + """, (brand_id,)) sketch_data = cursor.fetchall() # 触发计算并持久化,若内部出错会抛异常 diff --git a/app/api/api_import_sys_sketch.py b/app/api/api_import_sys_sketch.py new file mode 100644 index 0000000..3654124 --- /dev/null +++ b/app/api/api_import_sys_sketch.py @@ -0,0 +1,116 @@ +import logging +import sys +from typing import Optional +from fastapi import APIRouter, HTTPException, Query +from concurrent.futures import ThreadPoolExecutor +import threading + +from app.schemas.response_template import ResponseModel +from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main + +logger = logging.getLogger() +router = APIRouter() + +# 使用线程池执行器来运行长时间任务 +executor = ThreadPoolExecutor(max_workers=1) +# 用于跟踪任务状态 +task_status = {"running": False} + + +def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool): + """在后台线程中运行导入任务""" + original_argv = None + try: + task_status["running"] = True + # 保存原始 sys.argv + original_argv = sys.argv.copy() + + # 模拟命令行参数 + sys.argv = [ + "import_sys_sketch_to_milvus.py", + "--batch-size", str(batch_size), + "--retry-times", str(retry_times), + ] + if limit is not None: + sys.argv.extend(["--limit", str(limit)]) + if offset > 0: + sys.argv.extend(["--offset", str(offset)]) + if skip_create_collection: + sys.argv.append("--skip-create-collection") + + import_main() + task_status["running"] = False + logger.info("导入任务完成") + except Exception as e: + task_status["running"] = False + logger.error(f"导入任务失败: {e}", exc_info=True) + raise + finally: + # 恢复原始 sys.argv + if original_argv is not None: + sys.argv = original_argv + + +@router.post("/import-sys-sketch", response_model=ResponseModel) +async def import_sys_sketch( + batch_size: int = Query(1000, description="批量处理大小(默认:1000)"), + retry_times: int = Query(3, description="失败重试次数(默认:3)"), + limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"), + offset: int = Query(0, description="起始偏移量(默认:0)"), + skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"), +): + """ + 从 t_sys_file 导入系统图向量到 Milvus + + 该接口会异步执行导入任务,任务在后台运行。 + """ + try: + # 检查是否有任务正在运行 + if task_status["running"]: + raise HTTPException( + status_code=409, + detail="已有导入任务正在运行,请等待完成后再试" + ) + + # 在后台线程中执行任务 + executor.submit( + run_import_task, + batch_size, + retry_times, + limit, + offset, + skip_create_collection + ) + + return ResponseModel( + code=200, + msg="导入任务已启动,正在后台执行", + data={ + "status": "started", + "batch_size": batch_size, + "retry_times": retry_times, + "limit": limit, + "offset": offset, + "skip_create_collection": skip_create_collection + } + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"启动导入任务失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}") + + +@router.get("/import-sys-sketch/status", response_model=ResponseModel) +async def get_import_status(): + """ + 获取导入任务状态 + """ + return ResponseModel( + code=200, + msg="OK", + data={ + "running": task_status["running"] + } + ) + diff --git a/app/api/api_precompute.py b/app/api/api_precompute.py new file mode 100644 index 0000000..afebac7 --- /dev/null +++ b/app/api/api_precompute.py @@ -0,0 +1,85 @@ +import logging +from fastapi import APIRouter, HTTPException +from concurrent.futures import ThreadPoolExecutor + +from app.schemas.response_template import ResponseModel +from app.service.recommendation_system.precompute import run_precompute + +logger = logging.getLogger() +router = APIRouter() + +# 使用线程池执行器来运行长时间任务 +executor = ThreadPoolExecutor(max_workers=1) +# 用于跟踪任务状态 +task_status = {"running": False} + + +def run_precompute_task(): + """在后台线程中运行预计算任务""" + try: + task_status["running"] = True + logger.info("开始执行预计算任务...") + run_precompute() + task_status["running"] = False + logger.info("预计算任务完成") + except Exception as e: + task_status["running"] = False + logger.error(f"预计算任务失败: {e}", exc_info=True) + raise + + +@router.post("/precompute", response_model=ResponseModel) +async def precompute(): + """ + 运行预计算任务 + + 该接口会异步执行预计算任务,包括: + 1. 优化数据库表结构 + 2. 历史数据迁移 + 3. 初始用户偏好向量生成 + + 任务在后台运行。 + """ + try: + # 检查是否有任务正在运行 + if task_status["running"]: + raise HTTPException( + status_code=409, + detail="已有预计算任务正在运行,请等待完成后再试" + ) + + # 在后台线程中执行任务 + executor.submit(run_precompute_task) + + return ResponseModel( + code=200, + msg="预计算任务已启动,正在后台执行", + data={ + "status": "started", + "tasks": [ + "优化数据库表结构", + "历史数据迁移", + "初始用户偏好向量生成" + ] + } + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"启动预计算任务失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}") + + +@router.get("/precompute/status", response_model=ResponseModel) +async def get_precompute_status(): + """ + 获取预计算任务状态 + """ + return ResponseModel( + code=200, + msg="OK", + data={ + "running": task_status["running"] + } + ) + diff --git a/app/api/api_recommendation.py b/app/api/api_recommendation.py index b81e240..e5b86b1 100644 --- a/app/api/api_recommendation.py +++ b/app/api/api_recommendation.py @@ -1,206 +1,175 @@ import io import logging -import math import sys -import time -from typing import List - -import numpy as np +from typing import List, Optional +from fastapi import HTTPException, APIRouter, Query from apscheduler.schedulers.background import BackgroundScheduler -from apscheduler.triggers.cron import CronTrigger -from fastapi import HTTPException, APIRouter -from app.service.recommend.service import load_resources, matrix_data +from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations +from app.service.recommendation_system.incremental_listener import start_background_listener +from app.service.recommendation_system.milvus_client import create_collection sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') logger = logging.getLogger() router = APIRouter() -@router.on_event("startup") -async def startup_event(): - # 初始加载 - load_resources() - - # 配置定时任务 - scheduler = BackgroundScheduler() - scheduler.add_job( - load_resources, - trigger=CronTrigger(hour=0, minute=30), - name="每日资源刷新" - ) - scheduler.start() - logger.info("定时任务已启动") - - -def softmax(scores): - max_score = max(scores) - exp_scores = [math.exp(s - max_score) for s in scores] - sum_exp = sum(exp_scores) - return [s / sum_exp for s in exp_scores] - - -# def get_random_recommendations(category: str, num: int) -> List[str]: -# """根据预加载热度向量推荐(冷启动)""" +# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)========== +# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str]) +# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10): +# """ +# :param user_id: 4 +# :param category: female_skirt +# :param num_recommendations: 1 +# :return: +# [ +# "aida-sys-image/images/female/skirt/903000017.jpg" +# ] +# """ # try: -# heat_data = matrix_data.get("heat_data", {}) +# start_time = time.time() +# cache_key = (user_id, category) +# # === 新增:用户存在性检查 === +# user_exists_inter = user_id in matrix_data["user_index_interaction"] +# user_exists_feat = user_id in matrix_data["user_index_feature"] # -# if category not in heat_data: -# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐") +# # 任一矩阵不存在用户则返回随机推荐 +# if not (user_exists_inter and user_exists_feat): +# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐") +# return get_random_recommendations(category, num_recommendations) # -# heat_dict = heat_data[category] # {url: score} -# urls = list(heat_dict.keys()) -# scores = list(heat_dict.values()) +# # 检查缓存 +# if cache_key in matrix_data["cached_scores"]: +# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key] +# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key] +# else: +# # 实时计算逻辑(同原代码) +# user_idx_inter = matrix_data["user_index_interaction"].get(user_id) +# user_idx_feature = matrix_data["user_index_feature"].get(user_id) # -# if not urls: -# raise ValueError("该类别下无热度记录,使用随机推荐") +# category_iids = matrix_data["category_to_iids"].get(category, []) +# valid_sketch_idxs_inter = [ +# idx for iid, idx in matrix_data["sketch_index_interaction"].items() +# if iid in category_iids +# ] # -# probs = softmax(scores) -# sample_size = min(num, len(urls)) -# sampled_urls = random.choices(urls, weights=probs, k=sample_size) +# # 处理交互分数 +# raw_inter_scores = [] +# if user_idx_inter is not None and valid_sketch_idxs_inter: +# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] +# processed_inter = raw_inter_scores * 0.7 # -# return sampled_urls +# # 处理特征分数 +# valid_sketch_idxs_feature = [ +# idx for iid, idx in matrix_data["sketch_index_feature"].items() +# if iid in category_iids +# ] +# raw_feat_scores = [] +# if user_idx_feature is not None and valid_sketch_idxs_feature: +# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] +# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( +# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) +# processed_feat = raw_feat_scores +# else: +# processed_feat = np.array([]) # +# # 更新缓存 +# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat) +# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter +# +# # 合并分数 +# if brand_id is not None: +# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id) +# +# brand_feat_valid = ( +# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空 +# brand_idx_feature is not None and +# valid_sketch_idxs_feature # 有可用索引 +# ) +# +# if brand_feat_valid: +# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][ +# brand_idx_feature, valid_sketch_idxs_feature +# ] +# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / ( +# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8 +# ) +# processed_brand_feat = raw_brand_feat_scores +# +# # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致 +# if processed_feat.size == 0: +# processed_feat = np.zeros_like(processed_brand_feat) +# +# final_scores = processed_inter + 0.3 * ( +# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat +# ) +# else: +# # brand 信息不可用 +# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter +# else: +# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter +# +# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key] +# +# # 概率采样 +# scores = np.array(final_scores) +# +# # 调整后的概率转换(带温度控制的softmax) +# def calibrated_softmax(scores, temperature=1.0): +# scores = scores / temperature +# scale = scores - max(scores) +# exps = np.exp(scale) +# return exps / np.sum(exps) +# +# probs = calibrated_softmax(scores, 0.09) +# +# chosen_indices = np.random.choice( +# len(valid_sketch_idxs), +# size=min(num_recommendations, len(valid_sketch_idxs)), +# p=probs, +# replace=False +# ) +# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices] +# +# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒") +# return recommendations # except Exception as e: -# # 回退:完全随机推荐 -# all_iids = list(matrix_data["iid_to_sketch"].keys()) -# category_iids = matrix_data["category_to_iids"].get(category, all_iids) -# sample_size = min(num, len(category_iids)) -# sampled = np.random.choice(category_iids, size=sample_size, replace=False) -# return [matrix_data["iid_to_sketch"][iid] for iid in sampled] +# logger.error(f"推荐失败: {str(e)}", exc_info=True) +# raise HTTPException(status_code=500, detail=str(e)) -def get_random_recommendations(category: str, num: int) -> List[str]: - """全品类随机推荐""" - all_iids = list(matrix_data["iid_to_sketch"].keys()) - # 优先从当前品类选择 - category_iids = matrix_data["category_to_iids"].get(category, all_iids) - # 确保不超出实际数量 - sample_size = min(num, len(category_iids)) - sampled = np.random.choice(category_iids, size=sample_size, replace=False) - return [matrix_data["iid_to_sketch"][iid] for iid in sampled] - - -@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str]) -async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10): - """ - @param user_id: 4 - @param category: female_skirt - @param num_recommendations: 1 - @return: - [ - "aida-sys-image/images/female/skirt/903000017.jpg" - ] - - """ +# @router.on_event("startup") +async def startup_event(): + """启动时初始化增量监听任务""" try: - logger.info(f"user_id:{user_id}-----category:{category}-----brand_id:{brand_id}-----brand_scale:{brand_scale}-----num_recommendations:{num_recommendations}") - start_time = time.time() - cache_key = (user_id, category) - # === 新增:用户存在性检查 === - user_exists_inter = user_id in matrix_data["user_index_interaction"] - user_exists_feat = user_id in matrix_data["user_index_feature"] - - # 任一矩阵不存在用户则返回随机推荐 - if not (user_exists_inter and user_exists_feat): - logger.info(f"用户 {user_id} 数据不完整,触发随机推荐") - return get_random_recommendations(category, num_recommendations) - - # 检查缓存 - if cache_key in matrix_data["cached_scores"]: - processed_inter, processed_feat = matrix_data["cached_scores"][cache_key] - valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key] - else: - # 实时计算逻辑(同原代码) - user_idx_inter = matrix_data["user_index_interaction"].get(user_id) - user_idx_feature = matrix_data["user_index_feature"].get(user_id) - - category_iids = matrix_data["category_to_iids"].get(category, []) - valid_sketch_idxs_inter = [ - idx for iid, idx in matrix_data["sketch_index_interaction"].items() - if iid in category_iids - ] - - # 处理交互分数 - raw_inter_scores = [] - if user_idx_inter is not None and valid_sketch_idxs_inter: - raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] - processed_inter = raw_inter_scores * 0.7 - - # 处理特征分数 - valid_sketch_idxs_feature = [ - idx for iid, idx in matrix_data["sketch_index_feature"].items() - if iid in category_iids - ] - raw_feat_scores = [] - if user_idx_feature is not None and valid_sketch_idxs_feature: - raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] - raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( - np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) - processed_feat = raw_feat_scores - else: - processed_feat = np.array([]) - - # 更新缓存 - matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat) - matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter - - # 合并分数 - if brand_id is not None: - brand_idx_feature = matrix_data["brand_index_map"].get(brand_id) - - brand_feat_valid = ( - matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空 - brand_idx_feature is not None and - valid_sketch_idxs_feature # 有可用索引 - ) - - if brand_feat_valid: - raw_brand_feat_scores = matrix_data["brand_feature_matrix"][ - brand_idx_feature, valid_sketch_idxs_feature - ] - raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / ( - np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8 - ) - processed_brand_feat = raw_brand_feat_scores - - # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致 - if processed_feat.size == 0: - processed_feat = np.zeros_like(processed_brand_feat) - - final_scores = processed_inter + 0.3 * ( - (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat - ) - else: - # brand 信息不可用 - final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter - else: - final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter - - valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key] - - # 概率采样 - scores = np.array(final_scores) - - # 调整后的概率转换(带温度控制的softmax) - def calibrated_softmax(scores, temperature=1.0): - scores = scores / temperature - scale = scores - max(scores) - exps = np.exp(scale) - return exps / np.sum(exps) - - probs = calibrated_softmax(scores, 0.09) - - chosen_indices = np.random.choice( - len(valid_sketch_idxs), - size=min(num_recommendations, len(valid_sketch_idxs)), - p=probs, - replace=False - ) - recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices] - - logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒") - return recommendations - + # 确保 Milvus 集合已创建(若已存在则直接返回) + try: + create_collection() + except Exception as exc: + logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True) + + # 配置定时任务 + scheduler = BackgroundScheduler() + start_background_listener(scheduler) + scheduler.start() + logger.info("增量监听定时任务已启动") except Exception as e: - logger.error(f"推荐失败: {str(e)}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) + logger.error(f"启动增量监听任务失败: {e}", exc_info=True) + + +@router.get("/recommend/{user_id}/{category}", response_model=List[str]) +async def recommend( + user_id: int, + category: str, + style: Optional[str] = Query( + None, + description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分", + ), +): + """新版推荐接口(Milvus + Redis 偏好向量)。""" + try: + results = get_new_recommendations(user_id, category, style) + path = results[0] if results else "" + return [path] + except Exception as e: + logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/app/service/recommend/service.py b/app/service/recommend/service.py index 96f3704..6fcb464 100644 --- a/app/service/recommend/service.py +++ b/app/service/recommend/service.py @@ -1,241 +1,240 @@ -# 预加载资源 -import logging -import time -from collections import defaultdict -import os -import json -import numpy as np - -from app.core.config import settings -from app.core.mysql_config import DB_CONFIG - -logger = logging.getLogger() -import pymysql -from concurrent.futures import ThreadPoolExecutor - -HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置 - -matrix_data = { - "interaction_matrix": None, - "feature_matrix": None, - "user_index_interaction": None, - "sketch_index_interaction": None, - "user_index_feature": None, - "sketch_index_feature": None, - "iid_to_sketch": None, - "category_to_iids": None, - "cached_scores": {}, - "cached_valid_idxs": {}, - "category_sketch_idxs_inter": None, - "category_sketch_idxs_feature": None, - "user_inter_full": dict(), - "user_feat_full": dict(), - "brand_feature_matrix": None, - "brand_index_map": None, - "heat_data": {}, -} - - -def load_resources(): - """加载所有矩阵和映射关系,并触发预缓存""" - try: - start_time = time.time() - - # 清空缓存 - matrix_data["cached_scores"].clear() - matrix_data["cached_valid_idxs"].clear() - - # 加载数据 - sketch_to_iid = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item() - matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()} - - matrix_data["interaction_matrix"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True) - matrix_data["user_index_interaction"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item() - matrix_data["sketch_index_interaction"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", - allow_pickle=True).item() - - matrix_data["feature_matrix"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True) - - brand_feature_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy" - if os.path.exists(brand_feature_path): - matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True) - else: - logger.warning("brand_feature_matrix 文件不存在,使用空数组") - matrix_data["brand_feature_matrix"] = np.array([]) - - # brand_index_map - brand_index_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy" - if os.path.exists(brand_index_path): - matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item() - else: - logger.warning("brand_index_map 文件不存在,使用空字典") - matrix_data["brand_index_map"] = {} - - matrix_data["user_index_feature"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item() - - matrix_data["sketch_index_feature"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item() - - category_to_iid_map = np.load(f"{settings.RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item() - matrix_data["category_to_iids"] = defaultdict(list) - for iid, cat in category_to_iid_map.items(): - matrix_data["category_to_iids"][cat].append(iid) - - logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒") - - # 触发预缓存 - precache_user_category() - - if os.path.exists(HEAT_VECTOR_FILE): - with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f: - heat_json = json.load(f) - matrix_data["heat_data"] = heat_json.get("data", {}) - logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别") - else: - matrix_data["heat_data"] = {} - - except Exception as e: - logger.error(f"资源加载失败: {str(e)}") - raise RuntimeError("初始化失败") - - -def precache_user_category(): - """优化后的用户分类预缓存(添加耗时统计)""" - if not all([ - matrix_data["interaction_matrix"] is not None, - matrix_data["feature_matrix"] is not None, - matrix_data["user_index_interaction"] is not None - ]): - logger.warning("资源未加载完成,跳过预缓存") - return - - start_time = time.perf_counter() - time_stats = { - "get_all_user_categories": 0, - "process_user_category": 0, - "thread_execution": 0, - "cache_update": 0, - "total": 0, - } - - # 统计用户类别获取时间 - t1 = time.perf_counter() - user_categories = get_all_user_categories() - time_stats["get_all_user_categories"] = time.perf_counter() - t1 - - precached_count = 0 - - def process_user_category(user_id, categories): - """单用户类别缓存计算(统计耗时)""" - local_cache = {} - local_valid_idxs = {} - time.perf_counter() - - for category in categories: - cache_key = (user_id, category) - if cache_key in matrix_data["cached_scores"]: - continue - - try: - user_idx_inter = matrix_data["user_index_interaction"].get(user_id) - user_idx_feature = matrix_data["user_index_feature"].get(user_id) - - # 统计获取类别 IID 耗时 - t_iid = time.perf_counter() - category_iids = matrix_data["category_to_iids"].get(category, []) - valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid] - for iid in category_iids if iid in matrix_data["sketch_index_interaction"]] - valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid] - for iid in category_iids if iid in matrix_data["sketch_index_feature"]] - time_stats["process_user_category"] += time.perf_counter() - t_iid - - # 统计矩阵计算耗时 - t_matrix = time.perf_counter() - processed_inter = np.zeros(len(valid_sketch_idxs_inter)) - if user_idx_inter is not None and valid_sketch_idxs_inter: - raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] - processed_inter = raw_inter_scores * 0.7 - - processed_feat = np.zeros(len(valid_sketch_idxs_feature)) - if user_idx_feature is not None and valid_sketch_idxs_feature: - raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] - raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( - np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) - processed_feat = raw_feat_scores * 0.3 - time_stats["process_user_category"] += time.perf_counter() - t_matrix - - if len(processed_inter) == len(processed_feat): - local_cache[cache_key] = (processed_inter, processed_feat) - local_valid_idxs[cache_key] = valid_sketch_idxs_inter - - except Exception as e: - logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}") - - return local_cache, local_valid_idxs - - # 统计线程执行时间 - t2 = time.perf_counter() - with ThreadPoolExecutor(max_workers=8) as executor: - futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()} - for future in futures: - try: - t_cache = time.perf_counter() - cache_part, valid_idxs_part = future.result() - matrix_data["cached_scores"].update(cache_part) - matrix_data["cached_valid_idxs"].update(valid_idxs_part) - time_stats["cache_update"] += time.perf_counter() - t_cache - precached_count += len(cache_part) - except Exception as e: - logger.error(f"线程执行错误: {str(e)}") - time_stats["thread_execution"] = time.perf_counter() - t2 - - time_stats["total"] = time.perf_counter() - start_time - - # 输出统计信息 - logger.info(f""" - 预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下: - - 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s - - 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s - - 线程任务执行: {time_stats["thread_execution"]:.2f}s - - 更新缓存数据: {time_stats["cache_update"]:.2f}s - - 总耗时: {time_stats["total"]:.2f}s - """) - - -def get_all_user_categories(): - """获取所有用户及其对应的分类""" - conn = None - try: - conn = pymysql.connect(**DB_CONFIG) - cursor = conn.cursor() - - query = """ - SELECT DISTINCT account_id, path - FROM user_preference_log_prediction \ - """ - cursor.execute(query) - results = cursor.fetchall() - - user_categories = defaultdict(set) - for account_id, path in results: - category = get_category_from_path(path) - user_categories[account_id].add(category) - - return dict(user_categories) - - except Exception as e: - logger.error(f"数据库查询失败: {str(e)}") - return {} - finally: - if conn: - conn.close() - - -def get_category_from_path(path: str) -> str: - """从路径解析类别""" - try: - parts = path.split('/') - if len(parts) >= 4: - return f"{parts[2]}_{parts[3]}" - return "unknown" - except: - return "unknown" +# # 预加载资源 +# import logging +# import time +# from collections import defaultdict +# import os +# import json +# import numpy as np +# +# from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX +# +# logger = logging.getLogger() +# import pymysql +# from concurrent.futures import ThreadPoolExecutor +# +# HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置 +# +# matrix_data = { +# "interaction_matrix": None, +# "feature_matrix": None, +# "user_index_interaction": None, +# "sketch_index_interaction": None, +# "user_index_feature": None, +# "sketch_index_feature": None, +# "iid_to_sketch": None, +# "category_to_iids": None, +# "cached_scores": {}, +# "cached_valid_idxs": {}, +# "category_sketch_idxs_inter": None, +# "category_sketch_idxs_feature": None, +# "user_inter_full": dict(), +# "user_feat_full": dict(), +# "brand_feature_matrix": None, +# "brand_index_map": None, +# "heat_data": {}, +# } +# +# +# def load_resources(): +# """加载所有矩阵和映射关系,并触发预缓存""" +# try: +# start_time = time.time() +# +# # 清空缓存 +# matrix_data["cached_scores"].clear() +# matrix_data["cached_valid_idxs"].clear() +# +# # 加载数据 +# sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item() +# matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()} +# +# matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True) +# matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item() +# matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", +# allow_pickle=True).item() +# +# matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True) +# +# brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy" +# if os.path.exists(brand_feature_path): +# matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True) +# else: +# logger.warning("brand_feature_matrix 文件不存在,使用空数组") +# matrix_data["brand_feature_matrix"] = np.array([]) +# +# # brand_index_map +# brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy" +# if os.path.exists(brand_index_path): +# matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item() +# else: +# logger.warning("brand_index_map 文件不存在,使用空字典") +# matrix_data["brand_index_map"] = {} +# +# matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item() +# +# matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item() +# +# category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item() +# matrix_data["category_to_iids"] = defaultdict(list) +# for iid, cat in category_to_iid_map.items(): +# matrix_data["category_to_iids"][cat].append(iid) +# +# logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒") +# +# # 触发预缓存 +# precache_user_category() +# +# if os.path.exists(HEAT_VECTOR_FILE): +# with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f: +# heat_json = json.load(f) +# matrix_data["heat_data"] = heat_json.get("data", {}) +# logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别") +# else: +# matrix_data["heat_data"] = {} +# +# except Exception as e: +# logger.error(f"资源加载失败: {str(e)}") +# raise RuntimeError("初始化失败") +# +# +# def precache_user_category(): +# """优化后的用户分类预缓存(添加耗时统计)""" +# if not all([ +# matrix_data["interaction_matrix"] is not None, +# matrix_data["feature_matrix"] is not None, +# matrix_data["user_index_interaction"] is not None +# ]): +# logger.warning("资源未加载完成,跳过预缓存") +# return +# +# start_time = time.perf_counter() +# time_stats = { +# "get_all_user_categories": 0, +# "process_user_category": 0, +# "thread_execution": 0, +# "cache_update": 0, +# "total": 0, +# } +# +# # 统计用户类别获取时间 +# t1 = time.perf_counter() +# user_categories = get_all_user_categories() +# time_stats["get_all_user_categories"] = time.perf_counter() - t1 +# +# precached_count = 0 +# +# def process_user_category(user_id, categories): +# """单用户类别缓存计算(统计耗时)""" +# local_cache = {} +# local_valid_idxs = {} +# t_start = time.perf_counter() +# +# for category in categories: +# cache_key = (user_id, category) +# if cache_key in matrix_data["cached_scores"]: +# continue +# +# try: +# user_idx_inter = matrix_data["user_index_interaction"].get(user_id) +# user_idx_feature = matrix_data["user_index_feature"].get(user_id) +# +# # 统计获取类别 IID 耗时 +# t_iid = time.perf_counter() +# category_iids = matrix_data["category_to_iids"].get(category, []) +# valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid] +# for iid in category_iids if iid in matrix_data["sketch_index_interaction"]] +# valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid] +# for iid in category_iids if iid in matrix_data["sketch_index_feature"]] +# time_stats["process_user_category"] += time.perf_counter() - t_iid +# +# # 统计矩阵计算耗时 +# t_matrix = time.perf_counter() +# processed_inter = np.zeros(len(valid_sketch_idxs_inter)) +# if user_idx_inter is not None and valid_sketch_idxs_inter: +# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter] +# processed_inter = raw_inter_scores * 0.7 +# +# processed_feat = np.zeros(len(valid_sketch_idxs_feature)) +# if user_idx_feature is not None and valid_sketch_idxs_feature: +# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature] +# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / ( +# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8) +# processed_feat = raw_feat_scores * 0.3 +# time_stats["process_user_category"] += time.perf_counter() - t_matrix +# +# if len(processed_inter) == len(processed_feat): +# local_cache[cache_key] = (processed_inter, processed_feat) +# local_valid_idxs[cache_key] = valid_sketch_idxs_inter +# +# except Exception as e: +# logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}") +# +# return local_cache, local_valid_idxs +# +# # 统计线程执行时间 +# t2 = time.perf_counter() +# with ThreadPoolExecutor(max_workers=8) as executor: +# futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()} +# for future in futures: +# try: +# t_cache = time.perf_counter() +# cache_part, valid_idxs_part = future.result() +# matrix_data["cached_scores"].update(cache_part) +# matrix_data["cached_valid_idxs"].update(valid_idxs_part) +# time_stats["cache_update"] += time.perf_counter() - t_cache +# precached_count += len(cache_part) +# except Exception as e: +# logger.error(f"线程执行错误: {str(e)}") +# time_stats["thread_execution"] = time.perf_counter() - t2 +# +# time_stats["total"] = time.perf_counter() - start_time +# +# # 输出统计信息 +# logger.info(f""" +# 预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下: +# - 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s +# - 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s +# - 线程任务执行: {time_stats["thread_execution"]:.2f}s +# - 更新缓存数据: {time_stats["cache_update"]:.2f}s +# - 总耗时: {time_stats["total"]:.2f}s +# """) +# +# +# def get_all_user_categories(): +# """获取所有用户及其对应的分类""" +# conn = None +# try: +# conn = pymysql.connect(**DB_CONFIG) +# cursor = conn.cursor() +# +# query = """ +# SELECT DISTINCT account_id, path +# FROM user_preference_log_prediction +# """ +# cursor.execute(query) +# results = cursor.fetchall() +# +# user_categories = defaultdict(set) +# for account_id, path in results: +# category = get_category_from_path(path) +# user_categories[account_id].add(category) +# +# return dict(user_categories) +# +# except Exception as e: +# logger.error(f"数据库查询失败: {str(e)}") +# return {} +# finally: +# if conn: +# conn.close() +# +# +# def get_category_from_path(path: str) -> str: +# """从路径解析类别""" +# try: +# parts = path.split('/') +# if len(parts) >= 4: +# return f"{parts[2]}_{parts[3]}" +# return "unknown" +# except: +# return "unknown" diff --git a/app/service/recommendation_system/__init__.py b/app/service/recommendation_system/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/app/service/recommendation_system/__init__.py @@ -0,0 +1 @@ + diff --git a/app/service/recommendation_system/config.py b/app/service/recommendation_system/config.py new file mode 100644 index 0000000..9e6f40b --- /dev/null +++ b/app/service/recommendation_system/config.py @@ -0,0 +1,73 @@ +""" +推荐系统配置 +""" +import os +from app.core.config import ( + DB_CONFIG, DB_HOST, DB_PORT, DB_USERNAME, DB_PASSWORD, DB_NAME, + REDIS_HOST, REDIS_PORT, REDIS_DB, + MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS, + MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE +) + +# Milvus 集合名称 +MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm" + +# Redis key 前缀 +REDIS_KEY_USER_PREF_PREFIX = "user_pref" + +# 推荐系统配置参数 +RECOMMENDATION_CONFIG = { + # 时间衰减半衰期(用于计算时间衰减权重) + # 值越小,最近的行为权重越大 + "K_half": 20, + + # 探索与利用的比例 (0.0-1.0) + # - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机 + # - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准 + # - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8 + "explore_ratio": 0.5, + + # 向量检索返回的候选数量 + # 值越大,候选池越大,但计算成本也越高 + # 建议范围: 100-1000 + "topk": 1000, + + # Style 加分系数(同 style 的候选进行加分) + # 值越大,匹配 style 的候选被选中的概率越大 + # 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05) + "style_bonus": 0.2, + + # Softmax 抽样的温度参数 + # - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低 + # - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高 + # - 温度=1.0 为标准 Softmax + # - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0 + "softmax_temperature": 0.07, + + # 监听间隔(秒) + "listen_interval_sec": 30, + + # 批量处理大小 + "batch_size": 1000, + + # Redis 过期时间(秒,30天) + "redis_expire_seconds": 2592000, + + # 向量维度 + "vector_dim": 2048, +} + +# 数据库表名 +TABLE_USER_PREFERENCE_LOG = "user_preference_log_test" +TABLE_SYS_FILE = "t_sys_file" + +# MySQL 连接配置(用于推荐系统) +MYSQL_CONFIG = { + "host": DB_HOST, + "port": DB_PORT, + "user": DB_USERNAME, + "password": DB_PASSWORD, + "database": DB_NAME, + "charset": "utf8mb4" +} + diff --git a/app/service/recommendation_system/import_sys_sketch_to_milvus.py b/app/service/recommendation_system/import_sys_sketch_to_milvus.py new file mode 100644 index 0000000..b055089 --- /dev/null +++ b/app/service/recommendation_system/import_sys_sketch_to_milvus.py @@ -0,0 +1,331 @@ +""" +独立脚本:从 t_sys_file 导入系统图向量到 Milvus +可以单独运行,不依赖整个项目启动 + +使用方法: + python -m app.service.recommendation_system.import_sys_sketch_to_milvus + 或 + python app/service/recommendation_system/import_sys_sketch_to_milvus.py +""" +import sys +import os +import logging +import argparse +from pathlib import Path + +# 添加项目根目录到 Python 路径 +project_root = Path(__file__).parent.parent.parent.parent +sys.path.insert(0, str(project_root)) + +import numpy as np +import pymysql +from tqdm import tqdm + +from app.service.recommendation_system.config import ( + MYSQL_CONFIG, TABLE_SYS_FILE, + RECOMMENDATION_CONFIG, MILVUS_COLLECTION_SKETCH_VECTORS +) +from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector +from app.service.recommendation_system.milvus_client import create_collection, insert_vectors + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), + logging.FileHandler('import_sys_sketch.log', encoding='utf-8') + ] +) +logger = logging.getLogger(__name__) + + +def get_sys_file_records(conn, limit=None, offset=0): + """ + 从 t_sys_file 表获取系统图记录 + + Args: + conn: 数据库连接 + limit: 限制数量(None 表示不限制) + offset: 偏移量 + + Returns: + 记录列表,每个元素为 (id, url, style, level3_type, level2_type, deprecated) + """ + cursor = conn.cursor() + + query = f""" + SELECT id, url, style, level3_type, level2_type, deprecated + FROM {TABLE_SYS_FILE} + WHERE level1_type = 'Images' + AND style IS NOT NULL + AND style != '' + AND deprecated != 1 + ORDER BY id + """ + + if limit: + query += f" LIMIT {limit} OFFSET {offset}" + + cursor.execute(query) + records = cursor.fetchall() + cursor.close() + + return records + + +def get_total_count(conn): + """获取总记录数""" + cursor = conn.cursor() + cursor.execute(f""" + SELECT COUNT(*) + FROM {TABLE_SYS_FILE} + WHERE level1_type = 'Images' + AND style IS NOT NULL + AND style != '' + AND deprecated != 1 + """) + count = cursor.fetchone()[0] + cursor.close() + return count + + +def process_and_insert_batch(records, batch_size=1000, retry_times=3): + """ + 处理并批量插入向量 + + Args: + records: 记录列表 + batch_size: 批量大小 + retry_times: 失败重试次数 + + Returns: + (成功数量, 失败数量) + """ + success_count = 0 + failed_count = 0 + failed_records = [] + batch_data = [] + + # 使用 tqdm 显示进度 + with tqdm(total=len(records), desc="处理记录", unit="条") as pbar: + for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records): + try: + # 计算 category + category = f"{level3_type.lower()}_{level2_type.lower()}" + + # 提取特征向量 + feature_vector = extract_feature_vector(url) + # 归一化,便于 IP≈cosine 度量 + feature_vector = normalize_vector(feature_vector) + + # 检查向量是否有效 + if np.all(feature_vector == 0): + logger.warning(f"向量提取失败,跳过: {url} (id={sys_file_id})") + failed_count += 1 + failed_records.append((sys_file_id, url)) + pbar.update(1) + continue + + # 准备数据 + data_item = { + "path": url, + "sys_file_id": sys_file_id, + "style": style, + "category": category, + "is_system_sketch": 1, + "deprecated": deprecated if deprecated else 0, + "feature_vector": feature_vector.tolist() + } + + batch_data.append(data_item) + + # 批量写入 + if len(batch_data) >= batch_size: + try: + insert_vectors(batch_data) + success_count += len(batch_data) + batch_data = [] + logger.info(f"已成功插入 {success_count} 条记录") + except Exception as e: + logger.error(f"批量写入失败: {e}") + failed_count += len(batch_data) + failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data]) + batch_data = [] + + pbar.update(1) + + except Exception as e: + logger.error(f"处理记录失败 [id={sys_file_id}, url={url}]: {e}") + failed_count += 1 + failed_records.append((sys_file_id, url)) + pbar.update(1) + + # 写入剩余数据 + if batch_data: + try: + insert_vectors(batch_data) + success_count += len(batch_data) + logger.info(f"写入剩余 {len(batch_data)} 条记录") + except Exception as e: + logger.error(f"写入剩余数据失败: {e}") + failed_count += len(batch_data) + failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data]) + + # 重试失败记录 + if failed_records and retry_times > 0: + logger.info(f"开始重试 {len(failed_records)} 条失败记录,最多重试 {retry_times} 次...") + + for retry in range(retry_times): + if not failed_records: + break + + retry_failed = [] + with tqdm(total=len(failed_records), desc=f"重试第 {retry + 1} 次", unit="条") as pbar: + for sys_file_id, url in failed_records: + try: + # 重新查询记录信息 + conn = pymysql.connect(**MYSQL_CONFIG) + cursor = conn.cursor() + cursor.execute(f""" + SELECT id, url, style, level3_type, level2_type, deprecated + FROM {TABLE_SYS_FILE} + WHERE id = %s + """, (sys_file_id,)) + record = cursor.fetchone() + cursor.close() + conn.close() + + if not record: + retry_failed.append((sys_file_id, url)) + pbar.update(1) + continue + + sys_file_id, url, style, level3_type, level2_type, deprecated = record + category = f"{level3_type.lower()}_{level2_type.lower()}" + + feature_vector = extract_feature_vector(url) + feature_vector = normalize_vector(feature_vector) + if np.all(feature_vector == 0): + retry_failed.append((sys_file_id, url)) + pbar.update(1) + continue + + data_item = { + "path": url, + "sys_file_id": sys_file_id, + "style": style, + "category": category, + "is_system_sketch": 1, + "deprecated": deprecated if deprecated else 0, + "feature_vector": feature_vector.tolist() + } + + insert_vectors([data_item]) + success_count += 1 + failed_count -= 1 + pbar.update(1) + + except Exception as e: + logger.error(f"重试失败 [id={sys_file_id}, url={url}]: {e}") + retry_failed.append((sys_file_id, url)) + pbar.update(1) + + failed_records = retry_failed + if failed_records: + logger.warning(f"第 {retry + 1} 次重试后仍有 {len(failed_records)} 条记录失败") + + return success_count, failed_count, failed_records + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description='从 t_sys_file 导入系统图向量到 Milvus') + parser.add_argument('--batch-size', type=int, default=1000, help='批量处理大小(默认:1000)') + parser.add_argument('--retry-times', type=int, default=3, help='失败重试次数(默认:3)') + parser.add_argument('--limit', type=int, default=None, help='限制处理数量(用于测试,默认:不限制)') + parser.add_argument('--offset', type=int, default=0, help='起始偏移量(默认:0)') + parser.add_argument('--skip-create-collection', action='store_true', help='跳过创建集合(如果集合已存在)') + + args = parser.parse_args() + + logger.info("=" * 60) + logger.info("开始从 t_sys_file 导入系统图向量到 Milvus") + logger.info("=" * 60) + logger.info(f"配置参数:") + logger.info(f" - 批量大小: {args.batch_size}") + logger.info(f" - 重试次数: {args.retry_times}") + logger.info(f" - 限制数量: {args.limit if args.limit else '不限制'}") + logger.info(f" - 起始偏移: {args.offset}") + logger.info("=" * 60) + + # 1. 创建 Milvus 集合 + if not args.skip_create_collection: + logger.info("创建 Milvus 集合...") + try: + create_collection() + logger.info("Milvus 集合创建成功(或已存在)") + except Exception as e: + logger.error(f"创建 Milvus 集合失败: {e}") + return + else: + logger.info("跳过创建集合") + + # 2. 连接数据库 + logger.info("连接数据库...") + try: + conn = pymysql.connect(**MYSQL_CONFIG) + logger.info("数据库连接成功") + except Exception as e: + logger.error(f"数据库连接失败: {e}") + return + + try: + # 3. 获取总记录数 + total_count = get_total_count(conn) + logger.info(f"找到 {total_count} 条系统图记录") + + if total_count == 0: + logger.warning("没有找到系统图数据") + return + + # 4. 获取记录 + logger.info("获取记录...") + records = get_sys_file_records(conn, limit=args.limit, offset=args.offset) + logger.info(f"获取到 {len(records)} 条记录") + + if not records: + logger.warning("没有获取到记录") + return + + # 5. 处理并插入 + logger.info("开始处理记录...") + success_count, failed_count, failed_records = process_and_insert_batch( + records, + batch_size=args.batch_size, + retry_times=args.retry_times + ) + + # 6. 输出结果 + logger.info("=" * 60) + logger.info("导入完成!") + logger.info(f" - 成功: {success_count} 条") + logger.info(f" - 失败: {failed_count} 条") + if failed_records: + logger.warning(f" - 失败记录列表(前10条):") + for sys_file_id, url in failed_records[:10]: + logger.warning(f" ID={sys_file_id}, URL={url}") + if len(failed_records) > 10: + logger.warning(f" ... 还有 {len(failed_records) - 10} 条失败记录") + logger.info("=" * 60) + + except Exception as e: + logger.error(f"处理过程中发生错误: {e}", exc_info=True) + finally: + conn.close() + logger.info("数据库连接已关闭") + + +if __name__ == "__main__": + main() + diff --git a/app/service/recommendation_system/incremental_listener.py b/app/service/recommendation_system/incremental_listener.py new file mode 100644 index 0000000..08c3b21 --- /dev/null +++ b/app/service/recommendation_system/incremental_listener.py @@ -0,0 +1,343 @@ +""" +增量监听模块 +实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量 +""" +import logging +import math +import pymysql +import numpy as np +from typing import List, Dict, Set, Tuple, Optional +from datetime import datetime +from collections import defaultdict + +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.schedulers.blocking import BlockingScheduler + +from app.service.recommendation_system.config import ( + MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE, + RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX +) +from app.service.recommendation_system.vector_utils import extract_feature_vector, compute_weighted_average, normalize_vector +from app.service.recommendation_system.milvus_client import query_vectors_by_paths, insert_vectors +from app.service.utils.redis_utils import Redis +import json + +logger = logging.getLogger(__name__) + + +class IncrementalListener: + """增量监听器""" + + def __init__(self): + self.last_process_time = None + self.processed_combinations: Set[Tuple[int, str]] = set() # 已处理的 (account_id, category) 组合 + self.listen_interval = RECOMMENDATION_CONFIG["listen_interval_sec"] + + def get_new_like_records(self) -> List[Tuple]: + """ + 获取新增点赞记录 + + Returns: + 记录列表,每个元素为 (id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id) + """ + conn = None + try: + conn = pymysql.connect(**MYSQL_CONFIG) + cursor = conn.cursor() + + if self.last_process_time is None: + # 第一次运行,查询最近30分钟的数据 + cursor.execute(f""" + SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE) + ORDER BY data_time + """) + else: + # 基于上次处理时间查询 + cursor.execute(f""" + SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE data_time > %s + ORDER BY data_time + """, (self.last_process_time,)) + + records = cursor.fetchall() + return records + + except Exception as e: + logger.error(f"获取新增点赞记录失败: {e}", exc_info=True) + return [] + finally: + if conn: + conn.close() + + def process_new_records(self, records: List[Tuple]): + """ + 处理新增记录 + + Args: + records: 记录列表 + """ + if not records: + return + + # 按用户+类别分组 + user_category_records = defaultdict(list) + for record in records: + account_id = record[1] + category = record[3] + if category: # 只处理有类别的记录 + user_category_records[(account_id, category)].append(record) + + # 去重:只处理一次每个 (account_id, category) 组合 + to_process = [] + for (account_id, category), recs in user_category_records.items(): + if (account_id, category) not in self.processed_combinations: + to_process.append((account_id, category, recs)) + self.processed_combinations.add((account_id, category)) + + logger.info(f"需要处理 {len(to_process)} 个用户-类别组合") + + # 处理每个组合 + for account_id, category, recs in to_process: + try: + self.update_user_preference_vector(account_id, category) + except Exception as e: + logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True) + + # 更新最后处理时间 + if records: + self.last_process_time = records[-1][5] # data_time + # 重置去重集合,确保下次周期不会跳过同一用户-类别 + self.processed_combinations.clear() + + def update_user_preference_vector(self, account_id: int, category: str): + """ + 更新用户偏好向量 + + Args: + account_id: 用户ID + category: 类别 + """ + conn = None + try: + conn = pymysql.connect(**MYSQL_CONFIG) + cursor = conn.cursor() + + # 1. 获取该用户该类别的所有点赞记录 + cursor.execute(f""" + SELECT path, data_time + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE account_id = %s AND category = %s + ORDER BY data_time DESC + """, (account_id, category)) + + like_records = cursor.fetchall() + + if not like_records: + return + + # 2. 批量查询点赞次数 + paths = [r[0] for r in like_records] + placeholders = ','.join(['%s'] * len(paths)) + cursor.execute(f""" + SELECT path, COUNT(*) as like_count + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE account_id = %s AND category = %s AND path IN ({placeholders}) + GROUP BY path + """, (account_id, category) + tuple(paths)) + + like_counts = {row[0]: row[1] for row in cursor.fetchall()} + + # 3. 批量获取向量 + vectors_dict = query_vectors_by_paths(paths) + + # 处理查询不到的 path(新用户图或异常情况) + missing_paths = [p for p in paths if p not in vectors_dict] + if missing_paths: + logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量") + self._compute_and_insert_missing_vectors(missing_paths, conn) + # 重新查询 + vectors_dict = query_vectors_by_paths(paths) + + # 4. 计算权重并加权平均 + vectors = [] + weights = [] + K_half = RECOMMENDATION_CONFIG["K_half"] + + for k, (path, data_time) in enumerate(like_records, 1): + if path not in vectors_dict: + continue + + vector_data = vectors_dict[path] + feature_vector = np.array(vector_data["feature_vector"]) + + # 时间衰减权重 + d_k = 0.5 ** (k / K_half) + + # 点赞次数权重 + like_count = like_counts.get(path, 1) + p_i = 1 + math.log(1 + like_count) + + # 综合权重 + w_i = d_k * p_i + + vectors.append(feature_vector) + weights.append(w_i) + + if not vectors: + logger.warning(f"用户 {account_id} 类别 {category} 没有有效向量") + return + + # 5. 计算加权平均并做 L2 归一化,IP≈cosine + preference_vector = compute_weighted_average(vectors, weights) + preference_vector = normalize_vector(preference_vector) + + # 6. 写入 Redis + key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}" + vector_json = json.dumps(preference_vector.tolist()) + Redis.write( + key=key, + value=vector_json, + expire=RECOMMENDATION_CONFIG["redis_expire_seconds"] + ) + + logger.debug(f"用户偏好向量更新成功 [user={account_id}, category={category}]") + + except Exception as e: + logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True) + raise + finally: + if conn: + conn.close() + + def _compute_and_insert_missing_vectors(self, paths: List[str], conn: pymysql.connections.Connection): + """ + 计算并插入缺失的向量 + + Args: + paths: 缺失的 path 列表 + conn: 数据库连接 + """ + cursor = conn.cursor() + data_to_insert = [] + + for path in paths: + try: + # 判断数据来源(查询 t_sys_file 表) + cursor.execute(f""" + SELECT id, url, style, level3_type, level2_type, deprecated + FROM {TABLE_SYS_FILE} + WHERE url = %s + LIMIT 1 + """, (path,)) + + sys_file = cursor.fetchone() + + # 提取特征向量 + feature_vector = extract_feature_vector(path) + + if np.all(feature_vector == 0): + logger.warning(f"向量提取失败,跳过: {path}") + continue + + if sys_file: + # 系统图 + sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file + category = f"{level3_type.lower()}_{level2_type.lower()}" + + data_item = { + "path": path, + "sys_file_id": sys_file_id, + "style": style, + "category": category, + "is_system_sketch": 1, + "deprecated": deprecated if deprecated else 0, + "feature_vector": feature_vector.tolist() + } + else: + # 用户图 + # 从 user_preference_log_test 获取 category(如果有) + cursor.execute(f""" + SELECT category + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE path = %s AND category IS NOT NULL + LIMIT 1 + """, (path,)) + + category_result = cursor.fetchone() + category = category_result[0] if category_result else None + + data_item = { + "path": path, + "sys_file_id": None, + "style": None, + "category": category, + "is_system_sketch": 0, + "deprecated": 0, + "feature_vector": feature_vector.tolist() + } + + data_to_insert.append(data_item) + + except Exception as e: + logger.error(f"处理缺失向量失败 [{path}]: {e}") + + # 批量插入 + if data_to_insert: + try: + insert_vectors(data_to_insert) + logger.info(f"成功插入 {len(data_to_insert)} 个缺失向量") + except Exception as e: + logger.error(f"插入缺失向量失败: {e}") + + def process_once(self): + """单次轮询任务,供调度器调用""" + try: + records = self.get_new_like_records() + + if records: + logger.info(f"发现 {len(records)} 条新增记录") + self.process_new_records(records) + else: + logger.debug("没有新增记录") + except Exception as e: + logger.error(f"监听轮询异常: {e}", exc_info=True) + + +def start_background_listener(scheduler: BackgroundScheduler): + """将增量监听任务注册到后台调度器""" + listener = IncrementalListener() + scheduler.add_job( + listener.process_once, + "interval", + seconds=listener.listen_interval, + max_instances=1, + coalesce=True, + id="recommendation_incremental_listener", + replace_existing=True, + ) + logger.info("增量监听任务已注册到调度器") + + +def start_blocking_listener(): + """以阻塞方式启动调度器(用于独立脚本运行)""" + listener = IncrementalListener() + scheduler = BlockingScheduler() + scheduler.add_job( + listener.process_once, + "interval", + seconds=listener.listen_interval, + max_instances=1, + coalesce=True, + id="recommendation_incremental_listener", + replace_existing=True, + ) + logger.info("增量监听调度器已启动(BlockingScheduler)") + scheduler.start() + + +if __name__ == "__main__": + start_blocking_listener() + diff --git a/app/service/recommendation_system/milvus_client.py b/app/service/recommendation_system/milvus_client.py new file mode 100644 index 0000000..b17cf2c --- /dev/null +++ b/app/service/recommendation_system/milvus_client.py @@ -0,0 +1,295 @@ +""" +Milvus 客户端封装 +""" +import logging +from typing import List, Dict, Optional, Any +import numpy as np +from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection + +from app.core.config import MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS +from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG + +logger = logging.getLogger(__name__) + +# Milvus 客户端(单例) +_milvus_client = None + + +def get_milvus_client() -> MilvusClient: + """获取 Milvus 客户端(单例模式)""" + global _milvus_client + if _milvus_client is None: + try: + _milvus_client = MilvusClient( + uri=MILVUS_URL, + token=MILVUS_TOKEN, + db_name=MILVUS_ALIAS + ) + logger.info("Milvus 客户端连接成功") + except Exception as e: + logger.error(f"Milvus 客户端连接失败: {e}") + raise + return _milvus_client + + +def create_collection(): + """ + 创建 Milvus 集合 sketch_vectors + + 集合结构: + - path (PK, varchar(512)) - 主键,MinIO 逻辑 URL + - sys_file_id (int64, 可为NULL) - 系统文件ID + - style (varchar(50), 可为NULL) - 风格样式 + - category (varchar(100), 可为NULL) - 类别 + - is_system_sketch (int8, 默认 1) - 标记字段:1-系统图,0-用户图 + - deprecated (int8, 默认 0) - 是否废弃 + - feature_vector (FloatVector(2048)) - 2048维特征向量 + """ + client = get_milvus_client() + + # 检查集合是否已存在 + collections = client.list_collections() + if MILVUS_COLLECTION_SKETCH_VECTORS in collections: + logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在") + return + + try: + # 解析 Milvus URL + # 处理 http://host.docker.internal:19530 格式 + url_clean = MILVUS_URL.replace("http://", "").replace("https://", "") + if ":" in url_clean: + host, port_str = url_clean.split(":", 1) + port = int(port_str) + else: + host = url_clean + port = 19530 + + # 使用传统 API 创建集合(更可靠) + # 连接到 Milvus(如果未连接) + try: + connections.connect( + alias=MILVUS_ALIAS, + host=host, + port=port, + token=MILVUS_TOKEN if MILVUS_TOKEN else None + ) + logger.info(f"已连接到 Milvus: {host}:{port}") + except Exception as conn_e: + # 如果连接已存在,忽略错误 + if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e): + logger.info("Milvus 连接已存在") + else: + logger.warning(f"连接 Milvus 时出现警告: {conn_e}") + + # 定义字段 + fields = [ + FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512), + FieldSchema(name="sys_file_id", dtype=DataType.INT64), + FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50), + FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50), + FieldSchema(name="is_system_sketch", dtype=DataType.INT8), + FieldSchema(name="deprecated", dtype=DataType.INT8), + FieldSchema( + name="feature_vector", + dtype=DataType.FLOAT_VECTOR, + dim=RECOMMENDATION_CONFIG["vector_dim"] + ) + ] + + # 创建 schema + schema = CollectionSchema( + fields=fields, + description="Sketch vectors collection for recommendation system" + ) + + # 创建集合 + collection = Collection( + name=MILVUS_COLLECTION_SKETCH_VECTORS, + schema=schema, + using=MILVUS_ALIAS + ) + + # 创建索引 + # 注意:使用 IP(内积)作为度量类型,与搜索时保持一致 + # 如果向量已归一化,IP 等价于 COSINE + index_params = { + "metric_type": "IP", # 内积(Inner Product) + "index_type": "IVF_FLAT", + "params": {"nlist": 1024} + } + + collection.create_index( + field_name="feature_vector", + index_params=index_params + ) + + logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功") + + except Exception as e: + logger.error(f"创建集合失败: {e}", exc_info=True) + raise + + +def insert_vectors(data: List[Dict[str, Any]]): + """ + 批量插入向量到 Milvus + + Args: + data: 数据列表,每个元素包含: + - path: str + - sys_file_id: int (可选) + - style: str (可选) + - category: str (可选) + - is_system_sketch: int (默认 1) + - deprecated: int (默认 0) + - feature_vector: List[float] (2048维) + """ + if not data: + return + + client = get_milvus_client() + + try: + client.insert( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + data=data + ) + logger.info(f"成功插入 {len(data)} 条向量数据") + except Exception as e: + logger.error(f"插入向量失败: {e}", exc_info=True) + raise + + +def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]: + """ + 根据 path 列表批量查询向量 + + Args: + paths: path 列表 + + Returns: + {path: {feature_vector: [...], ...}} 字典 + """ + if not paths: + return {} + + client = get_milvus_client() + + try: + # 构建查询表达式 + # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) + # 对于字符串列表,使用单引号包裹每个值 + path_list = ", ".join([f"'{p}'" for p in paths]) + filter_expr = f"path in [{path_list}]" + + results = client.query( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + filter=filter_expr, + output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"] + ) + + # 转换为字典 + result_dict = {} + for r in results: + result_dict[r["path"]] = r + + return result_dict + except Exception as e: + logger.error(f"查询向量失败: {e}", exc_info=True) + return {} + + +def search_similar_vectors( + query_vector: np.ndarray, + category: str, + topk: int = 500, + style: Optional[str] = None +) -> List[Dict]: + """ + 向量相似度检索 + + Args: + query_vector: 查询向量(2048维) + category: 类别过滤 + topk: 返回数量 + style: 风格过滤(可选) + + Returns: + 检索结果列表,每个元素包含 path, score, style, category 等字段 + """ + client = get_milvus_client() + + try: + # 构建过滤表达式 + # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) + filter_expr = f"category == '{category}' && deprecated == 0" + if style: + filter_expr += f" && style == '{style}'" + + # 搜索 + results = client.search( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + data=[query_vector.tolist()], + anns_field="feature_vector", + search_params={"metric_type": "IP", "params": {"nprobe": 10}}, + limit=topk, + filter=filter_expr, + output_fields=["path", "style", "category", "sys_file_id"] + ) + + # 格式化结果 + formatted_results = [] + if results and len(results) > 0: + for hit in results[0]: + formatted_results.append({ + "path": hit.get("entity", {}).get("path", ""), + "score": hit.get("distance", 0.0), + "style": hit.get("entity", {}).get("style", ""), + "category": hit.get("entity", {}).get("category", ""), + "sys_file_id": hit.get("entity", {}).get("sys_file_id") + }) + + return formatted_results + except Exception as e: + logger.error(f"向量检索失败: {e}", exc_info=True) + return [] + + +def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]: + """ + 随机查询候选(用于探索分支) + + Args: + category: 类别 + style: 风格(可选) + limit: 返回数量 + + Returns: + 候选列表 + """ + client = get_milvus_client() + + try: + # 构建过滤表达式 + filter_expr = f"category == '{category}' && deprecated == 0" + if style: + filter_expr += f" && style == '{style}'" + + # 查询所有符合条件的记录 + results = client.query( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + filter=filter_expr, + output_fields=["path", "style", "category"], + limit=10000 # 先查询大量数据,然后随机选择 + ) + + # 随机选择 + if len(results) > limit: + import random + results = random.sample(results, limit) + + return results + except Exception as e: + logger.error(f"随机查询候选失败: {e}", exc_info=True) + return [] + diff --git a/app/service/recommendation_system/precompute.py b/app/service/recommendation_system/precompute.py new file mode 100644 index 0000000..c4797d1 --- /dev/null +++ b/app/service/recommendation_system/precompute.py @@ -0,0 +1,556 @@ +""" +预计算模块 +包含:数据库表结构优化、Milvus集合创建、系统图向量预计算、初始用户偏好向量生成 +""" +import logging +import math +import pymysql +import numpy as np +from typing import List, Dict, Tuple, Optional +from collections import defaultdict + +from app.service.recommendation_system.config import ( + MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE, + RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX +) +from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average +from app.service.recommendation_system.milvus_client import ( + create_collection, insert_vectors, query_vectors_by_paths +) +from app.service.utils.redis_utils import Redis +import json + +logger = logging.getLogger(__name__) + + +def optimize_database_table(): + """ + 优化 user_preference_log_test 表结构 + 添加冗余字段和索引 + """ + conn = None + try: + conn = pymysql.connect(**MYSQL_CONFIG) + cursor = conn.cursor() + + # 1. 添加冗余字段 + logger.info("添加冗余字段...") + alter_sqls = [ + f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别:lower(level3_type + \"_\" + level2_type)'", + f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'", + f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图(1-是,0-用户图)'", + f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'", + ] + + for sql in alter_sqls: + try: + cursor.execute(sql) + logger.info(f"执行成功: {sql[:50]}...") + except Exception as e: + if "Duplicate column name" in str(e): + logger.info(f"字段已存在,跳过: {sql[:50]}...") + else: + logger.warning(f"执行失败: {sql[:50]}... 错误: {e}") + + # 2. 创建索引(MySQL 不支持 IF NOT EXISTS,需要先检查) + logger.info("创建索引...") + index_definitions = [ + ("idx_account_category_time", ["account_id", "category", "data_time"]), + ("idx_account_path", ["account_id", "path"]), + ] + + for index_name, columns in index_definitions: + try: + # 检查索引是否已存在 + cursor.execute(f""" + SELECT COUNT(*) + FROM information_schema.statistics + WHERE table_schema = DATABASE() + AND table_name = '{TABLE_USER_PREFERENCE_LOG}' + AND index_name = '{index_name}' + """) + exists = cursor.fetchone()[0] > 0 + + if exists: + logger.info(f"索引已存在,跳过: {index_name}") + else: + # 创建索引 + columns_str = ', '.join(columns) + create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})" + cursor.execute(create_sql) + logger.info(f"索引创建成功: {index_name}") + except Exception as e: + logger.warning(f"索引创建失败: {index_name} 错误: {e}") + + conn.commit() + logger.info("数据库表结构优化完成") + + except Exception as e: + logger.error(f"数据库表结构优化失败: {e}", exc_info=True) + if conn: + conn.rollback() + finally: + if conn: + conn.close() + + +def migrate_historical_data(batch_size: int = 1000): + """ + 历史数据迁移:批量更新冗余字段 + + Args: + batch_size: 每批处理数量 + """ + conn = None + try: + conn = pymysql.connect(**MYSQL_CONFIG) + cursor = conn.cursor() + + # 查询需要更新的记录数 + cursor.execute(f""" + SELECT COUNT(*) + FROM {TABLE_USER_PREFERENCE_LOG} u + WHERE u.category IS NULL + """) + total_count = cursor.fetchone()[0] + logger.info(f"需要迁移的记录数: {total_count}") + + if total_count == 0: + logger.info("无需迁移数据") + return + + # 分批处理 + offset = 0 + processed = 0 + + while offset < total_count: + # 查询一批记录 + cursor.execute(f""" + SELECT u.id, u.path + FROM {TABLE_USER_PREFERENCE_LOG} u + WHERE u.category IS NULL + LIMIT {batch_size} OFFSET {offset} + """) + records = cursor.fetchall() + + if not records: + break + + # 批量更新 + for record_id, path in records: + # 查询 t_sys_file 表 + cursor.execute(f""" + SELECT id, url, style, level3_type, level2_type, deprecated + FROM {TABLE_SYS_FILE} + WHERE url = %s + LIMIT 1 + """, (path,)) + + sys_file = cursor.fetchone() + + if sys_file: + # 系统图 + sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file + category = f"{level3_type.lower()}_{level2_type.lower()}" + + cursor.execute(f""" + UPDATE {TABLE_USER_PREFERENCE_LOG} + SET category = %s, + style = %s, + is_system_sketch = 1, + sys_file_id = %s + WHERE id = %s + """, (category, style, sys_file_id, record_id)) + else: + # 用户图 + cursor.execute(f""" + UPDATE {TABLE_USER_PREFERENCE_LOG} + SET is_system_sketch = 0, + category = NULL, + style = NULL, + sys_file_id = NULL + WHERE id = %s + """, (record_id,)) + + conn.commit() + processed += len(records) + offset += batch_size + logger.info(f"已迁移 {processed}/{total_count} 条记录") + + logger.info("历史数据迁移完成") + + except Exception as e: + logger.error(f"历史数据迁移失败: {e}", exc_info=True) + if conn: + conn.rollback() + finally: + if conn: + conn.close() + + +def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3): + """ + 系统图向量预计算与导入 + + Args: + batch_size: 每批处理数量 + retry_times: 失败重试次数 + """ + conn = None + try: + conn = pymysql.connect(**MYSQL_CONFIG) + cursor = conn.cursor() + + # 1. 数据筛选 + logger.info("查询系统图数据...") + cursor.execute(f""" + SELECT id, url, style, level3_type, level2_type, deprecated + FROM {TABLE_SYS_FILE} + WHERE level1_type = 'Images' + AND style IS NOT NULL + AND style != '' + AND deprecated != 1 + """) + records = cursor.fetchall() + logger.info(f"找到 {len(records)} 条系统图记录") + + if not records: + logger.warning("没有找到系统图数据") + return + + # 2. 批量处理 + failed_records = [] + batch_data = [] + + for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1): + try: + # 计算 category + category = f"{level3_type.lower()}_{level2_type.lower()}" + + # 提取特征向量 + feature_vector = extract_feature_vector(url) + + # 检查向量是否有效 + if np.all(feature_vector == 0): + logger.warning(f"向量提取失败,跳过: {url}") + failed_records.append((sys_file_id, url)) + continue + + # 准备数据 + data_item = { + "path": url, + "sys_file_id": sys_file_id, + "style": style, + "category": category, + "is_system_sketch": 1, + "deprecated": deprecated if deprecated else 0, + "feature_vector": feature_vector.tolist() + } + + batch_data.append(data_item) + + # 批量写入 + if len(batch_data) >= batch_size: + try: + insert_vectors(batch_data) + batch_data = [] + logger.info(f"已处理 {idx}/{len(records)} 条记录") + except Exception as e: + logger.error(f"批量写入失败: {e}") + failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data]) + batch_data = [] + + except Exception as e: + logger.error(f"处理记录失败 [{url}]: {e}") + failed_records.append((sys_file_id, url)) + + # 写入剩余数据 + if batch_data: + try: + insert_vectors(batch_data) + except Exception as e: + logger.error(f"写入剩余数据失败: {e}") + failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data]) + + # 3. 重试失败记录 + if failed_records and retry_times > 0: + logger.info(f"重试 {len(failed_records)} 条失败记录...") + for retry in range(retry_times): + retry_failed = [] + for sys_file_id, url in failed_records: + try: + category = f"{level3_type.lower()}_{level2_type.lower()}" + feature_vector = extract_feature_vector(url) + if not np.all(feature_vector == 0): + data_item = { + "path": url, + "sys_file_id": sys_file_id, + "style": style, + "category": category, + "is_system_sketch": 1, + "deprecated": 0, + "feature_vector": feature_vector.tolist() + } + insert_vectors([data_item]) + else: + retry_failed.append((sys_file_id, url)) + except Exception as e: + logger.error(f"重试失败 [{url}]: {e}") + retry_failed.append((sys_file_id, url)) + + failed_records = retry_failed + if not failed_records: + break + + if failed_records: + logger.warning(f"仍有 {len(failed_records)} 条记录处理失败") + + logger.info("系统图向量预计算完成") + + except Exception as e: + logger.error(f"系统图向量预计算失败: {e}", exc_info=True) + finally: + if conn: + conn.close() + + +def compute_user_preference_vector( + account_id: int, + category: str, + conn: Optional[pymysql.connections.Connection] = None + # max_date: Optional[datetime] = None +) -> Optional[np.ndarray]: + """ + 计算用户偏好向量 + + Args: + account_id: 用户ID + category: 类别 + conn: 数据库连接(可选) + max_date: 最大日期(可选,用于评估时只使用训练集数据) + + Returns: + 用户偏好向量(2048维),失败返回 None + """ + from datetime import datetime + + should_close = False + if conn is None: + conn = pymysql.connect(**MYSQL_CONFIG) + should_close = True + + try: + cursor = conn.cursor() + + # 1. 获取点赞记录(如果指定了max_date,只查询该日期之前的数据) + if max_date: + cursor.execute(f""" + SELECT path, data_time + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE account_id = %s AND category = %s AND style is not null + AND data_time < %s + ORDER BY data_time DESC + """, (account_id, category, max_date)) + else: + cursor.execute(f""" + SELECT path, data_time + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE account_id = %s AND category = %s AND style is not null + ORDER BY data_time DESC + """, (account_id, category)) + + like_records = cursor.fetchall() + + if not like_records: + return None + + # 2. 批量查询点赞次数(如果指定了max_date,只统计该日期之前的点赞) + paths = [r[0] for r in like_records] + if not paths: + return None + + placeholders = ','.join(['%s'] * len(paths)) + if max_date: + cursor.execute(f""" + SELECT path, COUNT(*) as like_count + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE account_id = %s AND category = %s AND path IN ({placeholders}) + AND data_time < %s + GROUP BY path + """, (account_id, category) + tuple(paths) + (max_date,)) + else: + cursor.execute(f""" + SELECT path, COUNT(*) as like_count + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE account_id = %s AND category = %s AND path IN ({placeholders}) + GROUP BY path + """, (account_id, category) + tuple(paths)) + + like_counts = {row[0]: row[1] for row in cursor.fetchall()} + + # 3. 批量获取向量 + vectors_dict = query_vectors_by_paths(paths) + + # 处理查询不到的 path(用户图或异常情况) + missing_paths = [p for p in paths if p not in vectors_dict] + if missing_paths: + logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量") + # 目前未有非系统图向量,跳过 + # 这里可以实时计算并写入 Milvus,但为了简化,先跳过 + # 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus + + # 4. 计算权重并加权平均 + vectors = [] + weights = [] + K_half = RECOMMENDATION_CONFIG["K_half"] + + for k, (path, data_time) in enumerate(like_records, 1): + if path not in vectors_dict: + continue + + vector_data = vectors_dict[path] + feature_vector = np.array(vector_data["feature_vector"]) + + # 时间衰减权重 + d_k = 0.5 ** (k / K_half) + + # 点赞次数权重 + like_count = like_counts.get(path, 1) + p_i = 1 + math.log(1 + like_count) + + # 综合权重 + # w_i = d_k * p_i + w_i = p_i + + vectors.append(feature_vector) + weights.append(w_i) + + if not vectors: + return None + + # 5. 计算加权平均并做 L2 归一化,IP≈cosine + preference_vector = compute_weighted_average(vectors, weights) + preference_vector = normalize_vector(preference_vector) + + return preference_vector + + except Exception as e: + logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True) + return None + finally: + if should_close and conn: + conn.close() + + +def generate_initial_user_preference_vectors(batch_size: int = 100): + """ + 初始用户偏好向量生成 + + Args: + batch_size: 每批处理用户数 + """ + conn = None + try: + conn = pymysql.connect(**MYSQL_CONFIG) + cursor = conn.cursor() + + # 1. 扫描历史数据 + logger.info("扫描用户和类别组合...") + cursor.execute(f""" + SELECT DISTINCT account_id, category + FROM {TABLE_USER_PREFERENCE_LOG} + WHERE category IS NOT NULL + AND style IS NOT NULL + """) + + user_categories = cursor.fetchall() + logger.info(f"找到 {len(user_categories)} 个用户-类别组合") + + if not user_categories: + logger.warning("没有找到用户-类别组合") + return + + # 2. 批量处理 + processed = 0 + failed = 0 + + for account_id, category in user_categories: + try: + # 计算偏好向量 + preference_vector = compute_user_preference_vector(account_id, category, conn) + + if preference_vector is not None: + # 写入 Redis + key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}" + # 序列化向量(使用 JSON) + vector_json = json.dumps(preference_vector.tolist()) + Redis.write( + key=key, + value=vector_json, + expire=RECOMMENDATION_CONFIG["redis_expire_seconds"] + ) + processed += 1 + else: + failed += 1 + + if (processed + failed) % batch_size == 0: + logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}") + + except Exception as e: + logger.error(f"处理失败 [user={account_id}, category={category}]: {e}") + failed += 1 + + logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}") + + except Exception as e: + logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True) + finally: + if conn: + conn.close() + + +def run_precompute(): + """ + 运行所有预计算任务 + """ + logger.info("=" * 50) + logger.info("开始预计算任务") + logger.info("=" * 50) + + # 1. 优化数据库表结构 + logger.info("\n[1/5] 优化数据库表结构...") + optimize_database_table() + + # # 2. 创建 Milvus 集合 + # logger.info("\n[2/5] 创建 Milvus 集合...") + # create_collection() + + # 3. 历史数据迁移 + logger.info("\n[3/5] 历史数据迁移...") + migrate_historical_data() + + # # 4. 系统图向量预计算 + # logger.info("\n[4/5] 系统图向量预计算...") + # precompute_system_sketch_vectors() + + # 5. 初始用户偏好向量生成 + logger.info("\n[5/5] 初始用户偏好向量生成...") + generate_initial_user_preference_vectors() + + logger.info("=" * 50) + logger.info("预计算任务完成") + logger.info("=" * 50) + + +if __name__ == "__main__": + # 1. 优化数据库表结构 + logger.info("\n[1/5] 优化数据库表结构...") + optimize_database_table() + + # 3. 历史数据迁移 + logger.info("\n[3/5] 历史数据迁移...") + migrate_historical_data() + + # 5. 初始用户偏好向量生成 + logger.info("\n[5/5] 初始用户偏好向量生成...") + generate_initial_user_preference_vectors() diff --git a/app/service/recommendation_system/recommendation_api.py b/app/service/recommendation_system/recommendation_api.py new file mode 100644 index 0000000..7a856b8 --- /dev/null +++ b/app/service/recommendation_system/recommendation_api.py @@ -0,0 +1,214 @@ +""" +推荐接口实现 +实现探索/利用分支、向量检索、Softmax抽样等功能 +""" +import logging +import math +import random +import numpy as np +from typing import List, Dict, Optional + +from app.service.recommendation_system.config import RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX +from app.service.recommendation_system.milvus_client import search_similar_vectors, query_random_candidates +from app.service.recommendation_system.precompute import compute_user_preference_vector +from app.service.recommendation_system.vector_utils import normalize_vector +from app.service.utils.redis_utils import Redis +import json + +logger = logging.getLogger(__name__) + + +def get_user_preference_vector(user_id: int, category: str) -> Optional[np.ndarray]: + """ + 获取用户偏好向量 + + Args: + user_id: 用户ID + category: 类别 + + Returns: + 用户偏好向量(2048维),失败返回 None + """ + # 1. 从 Redis 获取 + key = f"{REDIS_KEY_USER_PREF_PREFIX}:{user_id}:{category}" + vector_json = Redis.read(key) + + if vector_json: + try: + vector_list = json.loads(vector_json) + return np.array(vector_list, dtype=np.float32) + except Exception as e: + logger.warning(f"解析 Redis 向量失败 [user={user_id}, category={category}]: {e}") + + # 2. 如果不存在,实时计算 + logger.info(f"Redis 中不存在用户偏好向量,实时计算 [user={user_id}, category={category}]") + preference_vector = compute_user_preference_vector(user_id, category) + + if preference_vector is not None: + # 写入 Redis + vector_json = json.dumps(preference_vector.tolist()) + Redis.write( + key=key, + value=vector_json, + expire=RECOMMENDATION_CONFIG["redis_expire_seconds"] + ) + + return preference_vector + + +def explore_branch(category: str, style: Optional[str] = None) -> List[str]: + """ + 探索分支(随机推荐) + + Args: + category: 类别 + style: 风格(可选) + + Returns: + 推荐结果列表,每个元素包含 path, style, category 等字段 + """ + # 查询候选(随机池) + pool_size = 10 # 固定查询10个,然后随机选择 + + candidates = query_random_candidates(category, style, limit=pool_size) + + if not candidates: + logger.warning(f"探索分支:类别 {category} 没有候选数据") + return [] + + # 随机选择 + if len(candidates) > 1: + import random + candidates = random.sample(candidates, 1) + + # 格式化返回结果 + return [candidate.get("path", "") for candidate in candidates[:1]] + + +def exploit_branch( + user_id: int, + category: str, + style: Optional[str] = None +) -> List[str]: + """ + 利用分支(基于向量相似度推荐) + + Args: + user_id: 用户ID + category: 类别 + num_recommendations: 返回数量 + style: 风格(可选,用于加分) + + Returns: + 推荐结果列表,每个元素包含 path, style, category, similarity, sample_score 等字段 + """ + # 1. 获取用户偏好向量 + embedding = get_user_preference_vector(user_id, category) + + if embedding is None: + logger.warning(f"利用分支:无法获取用户偏好向量,回退到探索分支 [user={user_id}, category={category}]") + return explore_branch(category, style) + + # 2. Milvus 相似度检索(内积 IP) + topk = RECOMMENDATION_CONFIG["topk"] + results = search_similar_vectors(embedding, category, topk) + + if not results: + logger.warning(f"利用分支:向量检索无结果,回退到探索分支 [user={user_id}, category={category}]") + return explore_branch(category, style) + + # 3. Style 加分(可选,需传入 style 参数) + style_bonus = RECOMMENDATION_CONFIG["style_bonus"] + if style: + for result in results: + similarity = result["score"] + if result.get("style") == style: + # 加分:相似度 * (1 + style_bonus) + similarity = similarity * (1 + style_bonus) + result["final_score"] = similarity + else: + for result in results: + result["final_score"] = result["score"] + + # 4. Softmax 抽样 + scores = [r["final_score"] for r in results] + probabilities = softmax_with_temperature(scores, RECOMMENDATION_CONFIG["softmax_temperature"]) + + # 根据概率抽样 + if not results: + return [] + + selected_index = np.random.choice(len(results), size=1, p=probabilities, replace=False) + selected_results = [results[int(selected_index[0])]] + + # 5. 返回结果 + return [result.get("path", "") for result in selected_results] + + +def softmax_with_temperature(scores: List[float], temperature: float = 1.0) -> List[float]: + """ + Softmax 函数(带温度参数) + + Args: + scores: 分数列表 + temperature: 温度参数 + + Returns: + 概率列表 + """ + if not scores: + return [] + + # 除以温度 + scaled_scores = [s / temperature for s in scores] + + # 减去最大值(数值稳定性) + max_score = max(scaled_scores) + exp_scores = [math.exp(s - max_score) for s in scaled_scores] + + # 归一化 + sum_exp = sum(exp_scores) + if sum_exp == 0: + # 如果所有分数都是负无穷或非常小,返回均匀分布 + return [1.0 / len(scores)] * len(scores) + + probabilities = [exp_s / sum_exp for exp_s in exp_scores] + return probabilities + + +def get_recommendations( + user_id: int, + category: str, + style: Optional[str] = None +) -> List[str]: + """ + 获取推荐结果(主函数) + + Args: + user_id: 用户ID + category: 类别(如 female_skirt) + num_recommendations: 返回推荐数量(默认 1) + style: 风格(可选):若传入,则在利用分支对同 style 的候选进行加分 + + Returns: + 推荐结果列表,每个元素包含 path 等字段 + """ + try: + # 1. 读取配置参数 + explore_ratio = RECOMMENDATION_CONFIG["explore_ratio"] + + # 2. 探索/利用决策 + r = random.random() # 生成随机数 (0-1) + + if r < explore_ratio: + logger.debug(f"探索分支 [user={user_id}, category={category}]") + return explore_branch(category, style) + + logger.debug(f"利用分支 [user={user_id}, category={category}]") + return exploit_branch(user_id, category, style) + + except Exception as e: + logger.error(f"获取推荐结果失败 [user={user_id}, category={category}]: {e}", exc_info=True) + # 容错:回退到探索分支 + return explore_branch(category, style) + diff --git a/app/service/recommendation_system/vector_utils.py b/app/service/recommendation_system/vector_utils.py new file mode 100644 index 0000000..05d8622 --- /dev/null +++ b/app/service/recommendation_system/vector_utils.py @@ -0,0 +1,189 @@ +""" +向量计算工具类 +包含 ResNet50 特征提取、向量归一化等功能 +""" +import io +import logging +import numpy as np +import torch +from torchvision import models, transforms +from PIL import Image +from minio import Minio + +from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE +from app.service.recommendation_system.config import RECOMMENDATION_CONFIG + +logger = logging.getLogger(__name__) + +# 图像预处理(与ResNet训练时的预处理一致) +transform = transforms.Compose([ + transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入 + transforms.ToTensor(), # 转换为 Tensor + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化 +]) + +# 加载预训练的 ResNet50 模型(去掉最后全连接层) +_resnet_model = None + + +def get_resnet_model(): + """获取 ResNet50 模型(单例模式)""" + global _resnet_model + if _resnet_model is None: + logger.info("加载 ResNet50 模型...") + _resnet_model = models.resnet50(pretrained=True) + modules = list(_resnet_model.children())[:-1] # 移除最后的全连接层 + _resnet_model = torch.nn.Sequential(*modules) + _resnet_model.eval() # 设置为评估模式 + logger.info("ResNet50 模型加载完成") + return _resnet_model + + +# MinIO 客户端(单例) +_minio_client = None + + +def get_minio_client(): + """获取 MinIO 客户端(单例模式)""" + global _minio_client + if _minio_client is None: + _minio_client = Minio( + MINIO_URL, + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE + ) + return _minio_client + + +def get_image_from_minio(path: str) -> Image.Image: + """ + 从 MinIO 获取图片 + + Args: + path: MinIO 逻辑 URL,格式如 "bucket_name/object_name" + + Returns: + PIL Image 对象,失败返回 None + """ + try: + # 分割路径,获取桶名和文件路径 + path_parts = path.split('/', 1) + if len(path_parts) != 2: + logger.error(f"路径格式错误: {path}") + return None + + bucket_name, file_name = path_parts + minio_client = get_minio_client() + + # 获取文件 + obj = minio_client.get_object(bucket_name, file_name) + img_data = obj.read() # 读取图像数据 + img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象 + + return img + except Exception as e: + logger.error(f"从 MinIO 获取图片失败 [{path}]: {e}") + return None + + +def extract_feature_vector(path: str) -> np.ndarray: + """ + 使用 ResNet50 提取图片特征向量(2048维) + + Args: + path: MinIO 逻辑 URL + + Returns: + 2048维特征向量(numpy array),失败返回零向量 + """ + try: + # 从 MinIO 获取图像 + img = get_image_from_minio(path) + if img is None: + logger.warning(f"无法获取图片,返回零向量: {path}") + return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32) + + # 预处理 + # 部分 MinIO 图片可能是 RGBA/CMYK,转换成 RGB 以匹配 3 通道标准化参数 + if img.mode != "RGB": + try: + img = img.convert("RGB") + except Exception: + logger.warning(f"无法转换图片为RGB,返回零向量: {path}") + return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32) + + img_tensor = transform(img).unsqueeze(0) # 扩展维度以适应批量处理 + + # 提取特征 + resnet_model = get_resnet_model() + with torch.no_grad(): # 在不需要计算梯度的情况下进行推断 + feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出 + feature_vector = feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度 + + # 确保是 2048 维 + if feature_vector.ndim > 1: + feature_vector = feature_vector.flatten() + + # 确保维度正确 + if len(feature_vector) != RECOMMENDATION_CONFIG["vector_dim"]: + logger.warning(f"向量维度不正确: {len(feature_vector)}, 期望: {RECOMMENDATION_CONFIG['vector_dim']}") + # 如果维度不对,尝试调整 + if len(feature_vector) > RECOMMENDATION_CONFIG["vector_dim"]: + feature_vector = feature_vector[:RECOMMENDATION_CONFIG["vector_dim"]] + else: + padded = np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32) + padded[:len(feature_vector)] = feature_vector + feature_vector = padded + + return feature_vector.astype(np.float32) + except Exception as e: + logger.error(f"提取特征向量失败 [{path}]: {e}", exc_info=True) + return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32) + + +def normalize_vector(vector: np.ndarray) -> np.ndarray: + """ + L2 归一化向量 + + Args: + vector: 输入向量 + + Returns: + 归一化后的向量 + """ + norm = np.linalg.norm(vector) + if norm == 0: + return vector + return vector / norm + + +def compute_weighted_average(vectors: list, weights: list) -> np.ndarray: + """ + 计算加权平均向量 + + Args: + vectors: 向量列表 + weights: 权重列表 + + Returns: + 加权平均向量(不做归一化,模长为加权平均后的尺度) + """ + if not vectors or not weights: + return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32) + + # 确保所有向量都是 numpy array + vectors = [np.array(v) for v in vectors] + weights = np.array(weights) + + # 计算加权和 + weighted_sum = np.zeros_like(vectors[0]) + for v, w in zip(vectors, weights): + weighted_sum += v * w + + # 返回加权平均(除以权重和,不做 L2 归一化,模长不会随条数线性暴涨) + weight_total = weights.sum() + if weight_total == 0: + return weighted_sum + return weighted_sum / weight_total +