Files
AiDA_Python/app/api/api_recommendation.py

203 lines
8.5 KiB
Python
Raw Normal View History

import io
import logging
import sys
2025-12-29 10:52:33 +08:00
from typing import List, Optional
from fastapi import HTTPException, APIRouter, Query
from apscheduler.schedulers.background import BackgroundScheduler
2025-12-29 10:52:33 +08:00
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
from app.service.recommendation_system.incremental_listener import start_background_listener
from app.service.recommendation_system.milvus_client import create_collection
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
2025-12-29 10:52:33 +08:00
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
# """
# :param user_id: 4
# :param category: female_skirt
# :param num_recommendations: 1
# :return:
# [
# "aida-sys-image/images/female/skirt/903000017.jpg"
# ]
# """
2025-06-10 13:38:28 +08:00
# try:
2025-12-29 10:52:33 +08:00
# start_time = time.time()
# cache_key = (user_id, category)
# # === 新增:用户存在性检查 ===
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
# user_exists_feat = user_id in matrix_data["user_index_feature"]
#
# # 任一矩阵不存在用户则返回随机推荐
# if not (user_exists_inter and user_exists_feat):
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
# return get_random_recommendations(category, num_recommendations)
#
# # 检查缓存
# if cache_key in matrix_data["cached_scores"]:
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
# else:
# # 实时计算逻辑(同原代码)
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
#
# category_iids = matrix_data["category_to_iids"].get(category, [])
# valid_sketch_idxs_inter = [
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
# if iid in category_iids
# ]
#
# # 处理交互分数
# raw_inter_scores = []
# if user_idx_inter is not None and valid_sketch_idxs_inter:
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
# processed_inter = raw_inter_scores * 0.7
2025-12-29 10:52:33 +08:00
#
2025-12-29 10:52:33 +08:00
# # 处理特征分数
# valid_sketch_idxs_feature = [
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
# if iid in category_iids
# ]
# raw_feat_scores = []
# if user_idx_feature is not None and valid_sketch_idxs_feature:
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
# processed_feat = raw_feat_scores
# else:
# processed_feat = np.array([])
2025-12-29 10:52:33 +08:00
#
2025-12-29 10:52:33 +08:00
# # 更新缓存
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
2025-06-10 13:38:28 +08:00
#
2025-12-29 10:52:33 +08:00
# # 合并分数
# if brand_id is not None:
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
2025-06-10 13:38:28 +08:00
#
2025-12-29 10:52:33 +08:00
# brand_feat_valid = (
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
# brand_idx_feature is not None and
# valid_sketch_idxs_feature # 有可用索引
# )
2025-06-10 13:38:28 +08:00
#
2025-12-29 10:52:33 +08:00
# if brand_feat_valid:
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
# brand_idx_feature, valid_sketch_idxs_feature
# ]
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
# )
# processed_brand_feat = raw_brand_feat_scores
2025-06-10 13:38:28 +08:00
#
2025-12-29 10:52:33 +08:00
# # 如果 processed_feat 是空的,替换为全 0避免 shape 不一致
# if processed_feat.size == 0:
# processed_feat = np.zeros_like(processed_brand_feat)
#
# final_scores = processed_inter + 0.3 * (
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
# )
# else:
# # brand 信息不可用
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
# else:
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
#
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
#
# # 概率采样
# scores = np.array(final_scores)
#
# # 调整后的概率转换带温度控制的softmax
# def calibrated_softmax(scores, temperature=1.0):
# scores = scores / temperature
# scale = scores - max(scores)
# exps = np.exp(scale)
# return exps / np.sum(exps)
#
# probs = calibrated_softmax(scores, 0.09)
#
# chosen_indices = np.random.choice(
# len(valid_sketch_idxs),
# size=min(num_recommendations, len(valid_sketch_idxs)),
# p=probs,
# replace=False
# )
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
#
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
# return recommendations
2025-06-10 13:38:28 +08:00
# except Exception as e:
2025-12-29 10:52:33 +08:00
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
# raise HTTPException(status_code=500, detail=str(e))
2026-01-12 09:49:07 +08:00
@router.on_event("startup")
2025-12-29 10:52:33 +08:00
async def startup_event():
"""启动时初始化增量监听任务"""
2025-12-29 10:52:33 +08:00
try:
2025-12-29 10:52:33 +08:00
# 确保 Milvus 集合已创建(若已存在则直接返回)
try:
create_collection()
except Exception as exc:
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
# 配置定时任务
scheduler = BackgroundScheduler()
start_background_listener(scheduler)
scheduler.start()
logger.info("增量监听定时任务已启动")
except Exception as e:
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
async def recommend(
user_id: int,
category: str,
style: Optional[str] = Query(
None,
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
),
):
"""新版推荐接口Milvus + Redis 偏好向量)。"""
try:
results = get_new_recommendations(user_id, category, style)
path = results[0] if results else ""
return [path]
except Exception as e:
2025-12-29 10:52:33 +08:00
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
2026-01-12 11:51:37 +08:00
raise HTTPException(status_code=500, detail=str(e))
@router.get("/redis/user_pref")
async def get_all_user_preferences():
"""
2026-01-12 13:01:26 +08:00
获取所有以 user_pref 为前缀的 Redis key 信息
2026-01-12 11:51:37 +08:00
"""
try:
from app.service.utils.redis_utils import Redis
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
# 扫描所有匹配 user_pref:* 的 key
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
keys = Redis.scan_keys(pattern)
2026-01-12 13:01:26 +08:00
# 直接返回所有 key 和原始 value
2026-01-12 11:51:37 +08:00
result = {}
for key in keys:
2026-01-12 13:01:26 +08:00
# 读取对应的值
value = Redis.read(key)
if value:
result[key] = value
2026-01-12 11:51:37 +08:00
return result
except Exception as e:
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
2025-12-29 10:52:33 +08:00
raise HTTPException(status_code=500, detail=str(e))