Merge remote-tracking branch 'origin/dev-ltx' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
# Conflicts: # app/api/api_recommendation.py # app/service/design_fast/utils/organize.py
This commit is contained in:
@@ -9,7 +9,6 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
import pymysql
|
||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||
from minio import Minio
|
||||
|
||||
116
app/api/api_import_sys_sketch.py
Normal file
116
app/api/api_import_sys_sketch.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# 使用线程池执行器来运行长时间任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
# 用于跟踪任务状态
|
||||
task_status = {"running": False}
|
||||
|
||||
|
||||
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
|
||||
"""在后台线程中运行导入任务"""
|
||||
original_argv = None
|
||||
try:
|
||||
task_status["running"] = True
|
||||
# 保存原始 sys.argv
|
||||
original_argv = sys.argv.copy()
|
||||
|
||||
# 模拟命令行参数
|
||||
sys.argv = [
|
||||
"import_sys_sketch_to_milvus.py",
|
||||
"--batch-size", str(batch_size),
|
||||
"--retry-times", str(retry_times),
|
||||
]
|
||||
if limit is not None:
|
||||
sys.argv.extend(["--limit", str(limit)])
|
||||
if offset > 0:
|
||||
sys.argv.extend(["--offset", str(offset)])
|
||||
if skip_create_collection:
|
||||
sys.argv.append("--skip-create-collection")
|
||||
|
||||
import_main()
|
||||
task_status["running"] = False
|
||||
logger.info("导入任务完成")
|
||||
except Exception as e:
|
||||
task_status["running"] = False
|
||||
logger.error(f"导入任务失败: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
# 恢复原始 sys.argv
|
||||
if original_argv is not None:
|
||||
sys.argv = original_argv
|
||||
|
||||
|
||||
@router.post("/import-sys-sketch", response_model=ResponseModel)
|
||||
async def import_sys_sketch(
|
||||
batch_size: int = Query(1000, description="批量处理大小(默认:1000)"),
|
||||
retry_times: int = Query(3, description="失败重试次数(默认:3)"),
|
||||
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
|
||||
offset: int = Query(0, description="起始偏移量(默认:0)"),
|
||||
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
|
||||
):
|
||||
"""
|
||||
从 t_sys_file 导入系统图向量到 Milvus
|
||||
|
||||
该接口会异步执行导入任务,任务在后台运行。
|
||||
"""
|
||||
try:
|
||||
# 检查是否有任务正在运行
|
||||
if task_status["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="已有导入任务正在运行,请等待完成后再试"
|
||||
)
|
||||
|
||||
# 在后台线程中执行任务
|
||||
executor.submit(
|
||||
run_import_task,
|
||||
batch_size,
|
||||
retry_times,
|
||||
limit,
|
||||
offset,
|
||||
skip_create_collection
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="导入任务已启动,正在后台执行",
|
||||
data={
|
||||
"status": "started",
|
||||
"batch_size": batch_size,
|
||||
"retry_times": retry_times,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"skip_create_collection": skip_create_collection
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"启动导入任务失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
|
||||
async def get_import_status():
|
||||
"""
|
||||
获取导入任务状态
|
||||
"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="OK",
|
||||
data={
|
||||
"running": task_status["running"]
|
||||
}
|
||||
)
|
||||
|
||||
85
app/api/api_precompute.py
Normal file
85
app/api/api_precompute.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommendation_system.precompute import run_precompute
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# 使用线程池执行器来运行长时间任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
# 用于跟踪任务状态
|
||||
task_status = {"running": False}
|
||||
|
||||
|
||||
def run_precompute_task():
|
||||
"""在后台线程中运行预计算任务"""
|
||||
try:
|
||||
task_status["running"] = True
|
||||
logger.info("开始执行预计算任务...")
|
||||
run_precompute()
|
||||
task_status["running"] = False
|
||||
logger.info("预计算任务完成")
|
||||
except Exception as e:
|
||||
task_status["running"] = False
|
||||
logger.error(f"预计算任务失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/precompute", response_model=ResponseModel)
|
||||
async def precompute():
|
||||
"""
|
||||
运行预计算任务
|
||||
|
||||
该接口会异步执行预计算任务,包括:
|
||||
1. 优化数据库表结构
|
||||
2. 历史数据迁移
|
||||
3. 初始用户偏好向量生成
|
||||
|
||||
任务在后台运行。
|
||||
"""
|
||||
try:
|
||||
# 检查是否有任务正在运行
|
||||
if task_status["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="已有预计算任务正在运行,请等待完成后再试"
|
||||
)
|
||||
|
||||
# 在后台线程中执行任务
|
||||
executor.submit(run_precompute_task)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="预计算任务已启动,正在后台执行",
|
||||
data={
|
||||
"status": "started",
|
||||
"tasks": [
|
||||
"优化数据库表结构",
|
||||
"历史数据迁移",
|
||||
"初始用户偏好向量生成"
|
||||
]
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/precompute/status", response_model=ResponseModel)
|
||||
async def get_precompute_status():
|
||||
"""
|
||||
获取预计算任务状态
|
||||
"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="OK",
|
||||
data={
|
||||
"running": task_status["running"]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,207 +1,175 @@
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException, APIRouter, Query
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
|
||||
from app.service.recommendation_system.incremental_listener import start_background_listener
|
||||
from app.service.recommendation_system.milvus_client import create_collection
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
|
||||
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
# """
|
||||
# :param user_id: 4
|
||||
# :param category: female_skirt
|
||||
# :param num_recommendations: 1
|
||||
# :return:
|
||||
# [
|
||||
# "aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
# ]
|
||||
# """
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# cache_key = (user_id, category)
|
||||
# # === 新增:用户存在性检查 ===
|
||||
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
# user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
#
|
||||
# # 任一矩阵不存在用户则返回随机推荐
|
||||
# if not (user_exists_inter and user_exists_feat):
|
||||
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
# return get_random_recommendations(category, num_recommendations)
|
||||
#
|
||||
# # 检查缓存
|
||||
# if cache_key in matrix_data["cached_scores"]:
|
||||
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
# else:
|
||||
# # 实时计算逻辑(同原代码)
|
||||
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
#
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
# valid_sketch_idxs_inter = [
|
||||
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
# if iid in category_iids
|
||||
# ]
|
||||
#
|
||||
# # 处理交互分数
|
||||
# raw_inter_scores = []
|
||||
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
# processed_inter = raw_inter_scores * 0.7
|
||||
#
|
||||
# # 处理特征分数
|
||||
# valid_sketch_idxs_feature = [
|
||||
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
# if iid in category_iids
|
||||
# ]
|
||||
# raw_feat_scores = []
|
||||
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
# processed_feat = raw_feat_scores
|
||||
# else:
|
||||
# processed_feat = np.array([])
|
||||
#
|
||||
# # 更新缓存
|
||||
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
#
|
||||
# # 合并分数
|
||||
# if brand_id is not None:
|
||||
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
#
|
||||
# brand_feat_valid = (
|
||||
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
# brand_idx_feature is not None and
|
||||
# valid_sketch_idxs_feature # 有可用索引
|
||||
# )
|
||||
#
|
||||
# if brand_feat_valid:
|
||||
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
# brand_idx_feature, valid_sketch_idxs_feature
|
||||
# ]
|
||||
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
# )
|
||||
# processed_brand_feat = raw_brand_feat_scores
|
||||
#
|
||||
# # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
# if processed_feat.size == 0:
|
||||
# processed_feat = np.zeros_like(processed_brand_feat)
|
||||
#
|
||||
# final_scores = processed_inter + 0.3 * (
|
||||
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
# )
|
||||
# else:
|
||||
# # brand 信息不可用
|
||||
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
# else:
|
||||
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
#
|
||||
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
#
|
||||
# # 概率采样
|
||||
# scores = np.array(final_scores)
|
||||
#
|
||||
# # 调整后的概率转换(带温度控制的softmax)
|
||||
# def calibrated_softmax(scores, temperature=1.0):
|
||||
# scores = scores / temperature
|
||||
# scale = scores - max(scores)
|
||||
# exps = np.exp(scale)
|
||||
# return exps / np.sum(exps)
|
||||
#
|
||||
# probs = calibrated_softmax(scores, 0.09)
|
||||
#
|
||||
# chosen_indices = np.random.choice(
|
||||
# len(valid_sketch_idxs),
|
||||
# size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
# p=probs,
|
||||
# replace=False
|
||||
# )
|
||||
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
#
|
||||
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
# return recommendations
|
||||
# except Exception as e:
|
||||
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
# raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# @router.on_event("startup")
|
||||
async def startup_event():
|
||||
# 初始加载
|
||||
load_resources()
|
||||
"""启动时初始化增量监听任务"""
|
||||
try:
|
||||
# 确保 Milvus 集合已创建(若已存在则直接返回)
|
||||
try:
|
||||
create_collection()
|
||||
except Exception as exc:
|
||||
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
|
||||
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_job(
|
||||
load_resources,
|
||||
trigger=CronTrigger(hour=0, minute=30),
|
||||
name="每日资源刷新"
|
||||
)
|
||||
start_background_listener(scheduler)
|
||||
scheduler.start()
|
||||
logger.info("定时任务已启动")
|
||||
|
||||
|
||||
def softmax(scores):
|
||||
max_score = max(scores)
|
||||
exp_scores = [math.exp(s - max_score) for s in scores]
|
||||
sum_exp = sum(exp_scores)
|
||||
return [s / sum_exp for s in exp_scores]
|
||||
|
||||
|
||||
# def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
# """根据预加载热度向量推荐(冷启动)"""
|
||||
# try:
|
||||
# heat_data = matrix_data.get("heat_data", {})
|
||||
#
|
||||
# if category not in heat_data:
|
||||
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
|
||||
#
|
||||
# heat_dict = heat_data[category] # {url: score}
|
||||
# urls = list(heat_dict.keys())
|
||||
# scores = list(heat_dict.values())
|
||||
#
|
||||
# if not urls:
|
||||
# raise ValueError("该类别下无热度记录,使用随机推荐")
|
||||
#
|
||||
# probs = softmax(scores)
|
||||
# sample_size = min(num, len(urls))
|
||||
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
|
||||
#
|
||||
# return sampled_urls
|
||||
#
|
||||
# except Exception as e:
|
||||
# # 回退:完全随机推荐
|
||||
# all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# sample_size = min(num, len(category_iids))
|
||||
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
|
||||
def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
"""全品类随机推荐"""
|
||||
all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# 优先从当前品类选择
|
||||
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# 确保不超出实际数量
|
||||
sample_size = min(num, len(category_iids))
|
||||
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
"""
|
||||
:param user_id: 4
|
||||
:param category: female_skirt
|
||||
:param num_recommendations: 1
|
||||
:return:
|
||||
[
|
||||
"aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
]
|
||||
"""
|
||||
try:
|
||||
logger.info(f"user_id:{user_id}-----category:{category}-----brand_id:{brand_id}-----brand_scale:{brand_scale}-----num_recommendations:{num_recommendations}")
|
||||
start_time = time.time()
|
||||
cache_key = (user_id, category)
|
||||
# === 新增:用户存在性检查 ===
|
||||
user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
|
||||
# 任一矩阵不存在用户则返回随机推荐
|
||||
if not (user_exists_inter and user_exists_feat):
|
||||
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
return get_random_recommendations(category, num_recommendations)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
else:
|
||||
# 实时计算逻辑(同原代码)
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
valid_sketch_idxs_inter = [
|
||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
|
||||
# 处理交互分数
|
||||
raw_inter_scores = []
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
|
||||
# 处理特征分数
|
||||
valid_sketch_idxs_feature = [
|
||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
raw_feat_scores = []
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores
|
||||
else:
|
||||
processed_feat = np.array([])
|
||||
|
||||
# 更新缓存
|
||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
|
||||
# 合并分数
|
||||
if brand_id is not None:
|
||||
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
|
||||
brand_feat_valid = (
|
||||
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
brand_idx_feature is not None and
|
||||
valid_sketch_idxs_feature # 有可用索引
|
||||
)
|
||||
|
||||
if brand_feat_valid:
|
||||
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
brand_idx_feature, valid_sketch_idxs_feature
|
||||
]
|
||||
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
)
|
||||
processed_brand_feat = raw_brand_feat_scores
|
||||
|
||||
# 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
if processed_feat.size == 0:
|
||||
processed_feat = np.zeros_like(processed_brand_feat)
|
||||
|
||||
final_scores = processed_inter + 0.3 * (
|
||||
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
)
|
||||
else:
|
||||
# brand 信息不可用
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
else:
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
|
||||
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
|
||||
# 概率采样
|
||||
scores = np.array(final_scores)
|
||||
|
||||
# 调整后的概率转换(带温度控制的softmax)
|
||||
def calibrated_softmax(scores, temperature=1.0):
|
||||
scores = scores / temperature
|
||||
scale = scores - max(scores)
|
||||
exps = np.exp(scale)
|
||||
return exps / np.sum(exps)
|
||||
|
||||
probs = calibrated_softmax(scores, 0.09)
|
||||
|
||||
chosen_indices = np.random.choice(
|
||||
len(valid_sketch_idxs),
|
||||
size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
p=probs,
|
||||
replace=False
|
||||
)
|
||||
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
|
||||
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
return recommendations
|
||||
|
||||
logger.info("增量监听定时任务已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
|
||||
async def recommend(
|
||||
user_id: int,
|
||||
category: str,
|
||||
style: Optional[str] = Query(
|
||||
None,
|
||||
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
|
||||
),
|
||||
):
|
||||
"""新版推荐接口(Milvus + Redis 偏好向量)。"""
|
||||
try:
|
||||
results = get_new_recommendations(user_id, category, style)
|
||||
path = results[0] if results else ""
|
||||
return [path]
|
||||
except Exception as e:
|
||||
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -10,8 +10,10 @@ from app.api import api_design_pre_processing
|
||||
from app.api import api_extraction_project_info
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_image2sketch
|
||||
from app.api import api_import_sys_sketch
|
||||
from app.api import api_mannequins_edit
|
||||
from app.api import api_pose_transform
|
||||
from app.api import api_precompute
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_recommendation
|
||||
from app.api import api_super_resolution
|
||||
@@ -36,3 +38,5 @@ router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'],
|
||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
router.include_router(api_import_sys_sketch.router, tags=['api_import_sys_sketch'], prefix="/api")
|
||||
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
|
||||
|
||||
@@ -82,9 +82,9 @@ MILVUS_TABLE_SEG = "seg_cache"
|
||||
DB_HOST = '18.167.251.121' # 数据库主机地址
|
||||
# DB_PORT = int( 33006)
|
||||
DB_PORT = 33008 # 数据库端口
|
||||
DB_USERNAME = 'aida_con_python' # 数据库用户名
|
||||
DB_USERNAME = 'aida_con' # 数据库用户名
|
||||
DB_PASSWORD = '123456' # 数据库密码
|
||||
DB_NAME = 'aida' # 数据库库名
|
||||
DB_NAME = 'aida_back' # 数据库库名
|
||||
|
||||
# openai
|
||||
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
|
||||
|
||||
@@ -11,7 +11,6 @@ from app.api.api_route import router
|
||||
from app.core.config import settings
|
||||
from app.core.record_api_count import count_api_calls
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommend.service import load_resources
|
||||
from logging_env import LOGGER_CONFIG_DICT
|
||||
|
||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||
|
||||
@@ -1,240 +1,240 @@
|
||||
# 预加载资源
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
||||
|
||||
logger = logging.getLogger()
|
||||
import pymysql
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
|
||||
|
||||
matrix_data = {
|
||||
"interaction_matrix": None,
|
||||
"feature_matrix": None,
|
||||
"user_index_interaction": None,
|
||||
"sketch_index_interaction": None,
|
||||
"user_index_feature": None,
|
||||
"sketch_index_feature": None,
|
||||
"iid_to_sketch": None,
|
||||
"category_to_iids": None,
|
||||
"cached_scores": {},
|
||||
"cached_valid_idxs": {},
|
||||
"category_sketch_idxs_inter": None,
|
||||
"category_sketch_idxs_feature": None,
|
||||
"user_inter_full": dict(),
|
||||
"user_feat_full": dict(),
|
||||
"brand_feature_matrix": None,
|
||||
"brand_index_map": None,
|
||||
"heat_data": {},
|
||||
}
|
||||
|
||||
|
||||
def load_resources():
|
||||
"""加载所有矩阵和映射关系,并触发预缓存"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 清空缓存
|
||||
matrix_data["cached_scores"].clear()
|
||||
matrix_data["cached_valid_idxs"].clear()
|
||||
|
||||
# 加载数据
|
||||
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
||||
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
||||
|
||||
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
||||
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
||||
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
||||
allow_pickle=True).item()
|
||||
|
||||
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
||||
|
||||
brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
|
||||
if os.path.exists(brand_feature_path):
|
||||
matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
|
||||
else:
|
||||
logger.warning("brand_feature_matrix 文件不存在,使用空数组")
|
||||
matrix_data["brand_feature_matrix"] = np.array([])
|
||||
|
||||
# brand_index_map
|
||||
brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||
if os.path.exists(brand_index_path):
|
||||
matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
|
||||
else:
|
||||
logger.warning("brand_index_map 文件不存在,使用空字典")
|
||||
matrix_data["brand_index_map"] = {}
|
||||
|
||||
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
||||
|
||||
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
||||
|
||||
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
||||
matrix_data["category_to_iids"] = defaultdict(list)
|
||||
for iid, cat in category_to_iid_map.items():
|
||||
matrix_data["category_to_iids"][cat].append(iid)
|
||||
|
||||
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 触发预缓存
|
||||
precache_user_category()
|
||||
|
||||
if os.path.exists(HEAT_VECTOR_FILE):
|
||||
with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
|
||||
heat_json = json.load(f)
|
||||
matrix_data["heat_data"] = heat_json.get("data", {})
|
||||
logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
|
||||
else:
|
||||
matrix_data["heat_data"] = {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"资源加载失败: {str(e)}")
|
||||
raise RuntimeError("初始化失败")
|
||||
|
||||
|
||||
def precache_user_category():
|
||||
"""优化后的用户分类预缓存(添加耗时统计)"""
|
||||
if not all([
|
||||
matrix_data["interaction_matrix"] is not None,
|
||||
matrix_data["feature_matrix"] is not None,
|
||||
matrix_data["user_index_interaction"] is not None
|
||||
]):
|
||||
logger.warning("资源未加载完成,跳过预缓存")
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
time_stats = {
|
||||
"get_all_user_categories": 0,
|
||||
"process_user_category": 0,
|
||||
"thread_execution": 0,
|
||||
"cache_update": 0,
|
||||
"total": 0,
|
||||
}
|
||||
|
||||
# 统计用户类别获取时间
|
||||
t1 = time.perf_counter()
|
||||
user_categories = get_all_user_categories()
|
||||
time_stats["get_all_user_categories"] = time.perf_counter() - t1
|
||||
|
||||
precached_count = 0
|
||||
|
||||
def process_user_category(user_id, categories):
|
||||
"""单用户类别缓存计算(统计耗时)"""
|
||||
local_cache = {}
|
||||
local_valid_idxs = {}
|
||||
t_start = time.perf_counter()
|
||||
|
||||
for category in categories:
|
||||
cache_key = (user_id, category)
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
continue
|
||||
|
||||
try:
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
|
||||
# 统计获取类别 IID 耗时
|
||||
t_iid = time.perf_counter()
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
|
||||
for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
|
||||
valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
|
||||
for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
|
||||
time_stats["process_user_category"] += time.perf_counter() - t_iid
|
||||
|
||||
# 统计矩阵计算耗时
|
||||
t_matrix = time.perf_counter()
|
||||
processed_inter = np.zeros(len(valid_sketch_idxs_inter))
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
|
||||
processed_feat = np.zeros(len(valid_sketch_idxs_feature))
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores * 0.3
|
||||
time_stats["process_user_category"] += time.perf_counter() - t_matrix
|
||||
|
||||
if len(processed_inter) == len(processed_feat):
|
||||
local_cache[cache_key] = (processed_inter, processed_feat)
|
||||
local_valid_idxs[cache_key] = valid_sketch_idxs_inter
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
||||
|
||||
return local_cache, local_valid_idxs
|
||||
|
||||
# 统计线程执行时间
|
||||
t2 = time.perf_counter()
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
|
||||
for future in futures:
|
||||
try:
|
||||
t_cache = time.perf_counter()
|
||||
cache_part, valid_idxs_part = future.result()
|
||||
matrix_data["cached_scores"].update(cache_part)
|
||||
matrix_data["cached_valid_idxs"].update(valid_idxs_part)
|
||||
time_stats["cache_update"] += time.perf_counter() - t_cache
|
||||
precached_count += len(cache_part)
|
||||
except Exception as e:
|
||||
logger.error(f"线程执行错误: {str(e)}")
|
||||
time_stats["thread_execution"] = time.perf_counter() - t2
|
||||
|
||||
time_stats["total"] = time.perf_counter() - start_time
|
||||
|
||||
# 输出统计信息
|
||||
logger.info(f"""
|
||||
预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
|
||||
- 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
|
||||
- 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
|
||||
- 线程任务执行: {time_stats["thread_execution"]:.2f}s
|
||||
- 更新缓存数据: {time_stats["cache_update"]:.2f}s
|
||||
- 总耗时: {time_stats["total"]:.2f}s
|
||||
""")
|
||||
|
||||
|
||||
def get_all_user_categories():
|
||||
"""获取所有用户及其对应的分类"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = """
|
||||
SELECT DISTINCT account_id, path
|
||||
FROM user_preference_log_prediction
|
||||
"""
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
|
||||
user_categories = defaultdict(set)
|
||||
for account_id, path in results:
|
||||
category = get_category_from_path(path)
|
||||
user_categories[account_id].add(category)
|
||||
|
||||
return dict(user_categories)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询失败: {str(e)}")
|
||||
return {}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_category_from_path(path: str) -> str:
|
||||
"""从路径解析类别"""
|
||||
try:
|
||||
parts = path.split('/')
|
||||
if len(parts) >= 4:
|
||||
return f"{parts[2]}_{parts[3]}"
|
||||
return "unknown"
|
||||
except:
|
||||
return "unknown"
|
||||
# # 预加载资源
|
||||
# import logging
|
||||
# import time
|
||||
# from collections import defaultdict
|
||||
# import os
|
||||
# import json
|
||||
# import numpy as np
|
||||
#
|
||||
# from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
||||
#
|
||||
# logger = logging.getLogger()
|
||||
# import pymysql
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
#
|
||||
# HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
|
||||
#
|
||||
# matrix_data = {
|
||||
# "interaction_matrix": None,
|
||||
# "feature_matrix": None,
|
||||
# "user_index_interaction": None,
|
||||
# "sketch_index_interaction": None,
|
||||
# "user_index_feature": None,
|
||||
# "sketch_index_feature": None,
|
||||
# "iid_to_sketch": None,
|
||||
# "category_to_iids": None,
|
||||
# "cached_scores": {},
|
||||
# "cached_valid_idxs": {},
|
||||
# "category_sketch_idxs_inter": None,
|
||||
# "category_sketch_idxs_feature": None,
|
||||
# "user_inter_full": dict(),
|
||||
# "user_feat_full": dict(),
|
||||
# "brand_feature_matrix": None,
|
||||
# "brand_index_map": None,
|
||||
# "heat_data": {},
|
||||
# }
|
||||
#
|
||||
#
|
||||
# def load_resources():
|
||||
# """加载所有矩阵和映射关系,并触发预缓存"""
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
#
|
||||
# # 清空缓存
|
||||
# matrix_data["cached_scores"].clear()
|
||||
# matrix_data["cached_valid_idxs"].clear()
|
||||
#
|
||||
# # 加载数据
|
||||
# sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
||||
# matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
||||
#
|
||||
# matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
||||
# matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
||||
# matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
||||
# allow_pickle=True).item()
|
||||
#
|
||||
# matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
||||
#
|
||||
# brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
|
||||
# if os.path.exists(brand_feature_path):
|
||||
# matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
|
||||
# else:
|
||||
# logger.warning("brand_feature_matrix 文件不存在,使用空数组")
|
||||
# matrix_data["brand_feature_matrix"] = np.array([])
|
||||
#
|
||||
# # brand_index_map
|
||||
# brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||
# if os.path.exists(brand_index_path):
|
||||
# matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
|
||||
# else:
|
||||
# logger.warning("brand_index_map 文件不存在,使用空字典")
|
||||
# matrix_data["brand_index_map"] = {}
|
||||
#
|
||||
# matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
||||
#
|
||||
# matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
||||
#
|
||||
# category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
||||
# matrix_data["category_to_iids"] = defaultdict(list)
|
||||
# for iid, cat in category_to_iid_map.items():
|
||||
# matrix_data["category_to_iids"][cat].append(iid)
|
||||
#
|
||||
# logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
#
|
||||
# # 触发预缓存
|
||||
# precache_user_category()
|
||||
#
|
||||
# if os.path.exists(HEAT_VECTOR_FILE):
|
||||
# with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
|
||||
# heat_json = json.load(f)
|
||||
# matrix_data["heat_data"] = heat_json.get("data", {})
|
||||
# logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
|
||||
# else:
|
||||
# matrix_data["heat_data"] = {}
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"资源加载失败: {str(e)}")
|
||||
# raise RuntimeError("初始化失败")
|
||||
#
|
||||
#
|
||||
# def precache_user_category():
|
||||
# """优化后的用户分类预缓存(添加耗时统计)"""
|
||||
# if not all([
|
||||
# matrix_data["interaction_matrix"] is not None,
|
||||
# matrix_data["feature_matrix"] is not None,
|
||||
# matrix_data["user_index_interaction"] is not None
|
||||
# ]):
|
||||
# logger.warning("资源未加载完成,跳过预缓存")
|
||||
# return
|
||||
#
|
||||
# start_time = time.perf_counter()
|
||||
# time_stats = {
|
||||
# "get_all_user_categories": 0,
|
||||
# "process_user_category": 0,
|
||||
# "thread_execution": 0,
|
||||
# "cache_update": 0,
|
||||
# "total": 0,
|
||||
# }
|
||||
#
|
||||
# # 统计用户类别获取时间
|
||||
# t1 = time.perf_counter()
|
||||
# user_categories = get_all_user_categories()
|
||||
# time_stats["get_all_user_categories"] = time.perf_counter() - t1
|
||||
#
|
||||
# precached_count = 0
|
||||
#
|
||||
# def process_user_category(user_id, categories):
|
||||
# """单用户类别缓存计算(统计耗时)"""
|
||||
# local_cache = {}
|
||||
# local_valid_idxs = {}
|
||||
# t_start = time.perf_counter()
|
||||
#
|
||||
# for category in categories:
|
||||
# cache_key = (user_id, category)
|
||||
# if cache_key in matrix_data["cached_scores"]:
|
||||
# continue
|
||||
#
|
||||
# try:
|
||||
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
#
|
||||
# # 统计获取类别 IID 耗时
|
||||
# t_iid = time.perf_counter()
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
# valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
|
||||
# for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
|
||||
# valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
|
||||
# for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
|
||||
# time_stats["process_user_category"] += time.perf_counter() - t_iid
|
||||
#
|
||||
# # 统计矩阵计算耗时
|
||||
# t_matrix = time.perf_counter()
|
||||
# processed_inter = np.zeros(len(valid_sketch_idxs_inter))
|
||||
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
# processed_inter = raw_inter_scores * 0.7
|
||||
#
|
||||
# processed_feat = np.zeros(len(valid_sketch_idxs_feature))
|
||||
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
# processed_feat = raw_feat_scores * 0.3
|
||||
# time_stats["process_user_category"] += time.perf_counter() - t_matrix
|
||||
#
|
||||
# if len(processed_inter) == len(processed_feat):
|
||||
# local_cache[cache_key] = (processed_inter, processed_feat)
|
||||
# local_valid_idxs[cache_key] = valid_sketch_idxs_inter
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
||||
#
|
||||
# return local_cache, local_valid_idxs
|
||||
#
|
||||
# # 统计线程执行时间
|
||||
# t2 = time.perf_counter()
|
||||
# with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
# futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
|
||||
# for future in futures:
|
||||
# try:
|
||||
# t_cache = time.perf_counter()
|
||||
# cache_part, valid_idxs_part = future.result()
|
||||
# matrix_data["cached_scores"].update(cache_part)
|
||||
# matrix_data["cached_valid_idxs"].update(valid_idxs_part)
|
||||
# time_stats["cache_update"] += time.perf_counter() - t_cache
|
||||
# precached_count += len(cache_part)
|
||||
# except Exception as e:
|
||||
# logger.error(f"线程执行错误: {str(e)}")
|
||||
# time_stats["thread_execution"] = time.perf_counter() - t2
|
||||
#
|
||||
# time_stats["total"] = time.perf_counter() - start_time
|
||||
#
|
||||
# # 输出统计信息
|
||||
# logger.info(f"""
|
||||
# 预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
|
||||
# - 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
|
||||
# - 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
|
||||
# - 线程任务执行: {time_stats["thread_execution"]:.2f}s
|
||||
# - 更新缓存数据: {time_stats["cache_update"]:.2f}s
|
||||
# - 总耗时: {time_stats["total"]:.2f}s
|
||||
# """)
|
||||
#
|
||||
#
|
||||
# def get_all_user_categories():
|
||||
# """获取所有用户及其对应的分类"""
|
||||
# conn = None
|
||||
# try:
|
||||
# conn = pymysql.connect(**DB_CONFIG)
|
||||
# cursor = conn.cursor()
|
||||
#
|
||||
# query = """
|
||||
# SELECT DISTINCT account_id, path
|
||||
# FROM user_preference_log_prediction
|
||||
# """
|
||||
# cursor.execute(query)
|
||||
# results = cursor.fetchall()
|
||||
#
|
||||
# user_categories = defaultdict(set)
|
||||
# for account_id, path in results:
|
||||
# category = get_category_from_path(path)
|
||||
# user_categories[account_id].add(category)
|
||||
#
|
||||
# return dict(user_categories)
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"数据库查询失败: {str(e)}")
|
||||
# return {}
|
||||
# finally:
|
||||
# if conn:
|
||||
# conn.close()
|
||||
#
|
||||
#
|
||||
# def get_category_from_path(path: str) -> str:
|
||||
# """从路径解析类别"""
|
||||
# try:
|
||||
# parts = path.split('/')
|
||||
# if len(parts) >= 4:
|
||||
# return f"{parts[2]}_{parts[3]}"
|
||||
# return "unknown"
|
||||
# except:
|
||||
# return "unknown"
|
||||
|
||||
1
app/service/recommendation_system/__init__.py
Normal file
1
app/service/recommendation_system/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
73
app/service/recommendation_system/config.py
Normal file
73
app/service/recommendation_system/config.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
推荐系统配置
|
||||
"""
|
||||
import os
|
||||
from app.core.config import (
|
||||
DB_CONFIG, DB_HOST, DB_PORT, DB_USERNAME, DB_PASSWORD, DB_NAME,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_DB,
|
||||
MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS,
|
||||
MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
||||
)
|
||||
|
||||
# Milvus 集合名称
|
||||
MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm"
|
||||
|
||||
# Redis key 前缀
|
||||
REDIS_KEY_USER_PREF_PREFIX = "user_pref"
|
||||
|
||||
# 推荐系统配置参数
|
||||
RECOMMENDATION_CONFIG = {
|
||||
# 时间衰减半衰期(用于计算时间衰减权重)
|
||||
# 值越小,最近的行为权重越大
|
||||
"K_half": 20,
|
||||
|
||||
# 探索与利用的比例 (0.0-1.0)
|
||||
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
|
||||
# - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准
|
||||
# - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8
|
||||
"explore_ratio": 0.5,
|
||||
|
||||
# 向量检索返回的候选数量
|
||||
# 值越大,候选池越大,但计算成本也越高
|
||||
# 建议范围: 100-1000
|
||||
"topk": 1000,
|
||||
|
||||
# Style 加分系数(同 style 的候选进行加分)
|
||||
# 值越大,匹配 style 的候选被选中的概率越大
|
||||
# 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05)
|
||||
"style_bonus": 0.2,
|
||||
|
||||
# Softmax 抽样的温度参数
|
||||
# - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低
|
||||
# - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高
|
||||
# - 温度=1.0 为标准 Softmax
|
||||
# - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0
|
||||
"softmax_temperature": 0.07,
|
||||
|
||||
# 监听间隔(秒)
|
||||
"listen_interval_sec": 30,
|
||||
|
||||
# 批量处理大小
|
||||
"batch_size": 1000,
|
||||
|
||||
# Redis 过期时间(秒,30天)
|
||||
"redis_expire_seconds": 2592000,
|
||||
|
||||
# 向量维度
|
||||
"vector_dim": 2048,
|
||||
}
|
||||
|
||||
# 数据库表名
|
||||
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test"
|
||||
TABLE_SYS_FILE = "t_sys_file"
|
||||
|
||||
# MySQL 连接配置(用于推荐系统)
|
||||
MYSQL_CONFIG = {
|
||||
"host": DB_HOST,
|
||||
"port": DB_PORT,
|
||||
"user": DB_USERNAME,
|
||||
"password": DB_PASSWORD,
|
||||
"database": DB_NAME,
|
||||
"charset": "utf8mb4"
|
||||
}
|
||||
|
||||
331
app/service/recommendation_system/import_sys_sketch_to_milvus.py
Normal file
331
app/service/recommendation_system/import_sys_sketch_to_milvus.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
独立脚本:从 t_sys_file 导入系统图向量到 Milvus
|
||||
可以单独运行,不依赖整个项目启动
|
||||
|
||||
使用方法:
|
||||
python -m app.service.recommendation_system.import_sys_sketch_to_milvus
|
||||
或
|
||||
python app/service/recommendation_system/import_sys_sketch_to_milvus.py
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import numpy as np
|
||||
import pymysql
|
||||
from tqdm import tqdm
|
||||
|
||||
from app.service.recommendation_system.config import (
|
||||
MYSQL_CONFIG, TABLE_SYS_FILE,
|
||||
RECOMMENDATION_CONFIG, MILVUS_COLLECTION_SKETCH_VECTORS
|
||||
)
|
||||
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector
|
||||
from app.service.recommendation_system.milvus_client import create_collection, insert_vectors
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('import_sys_sketch.log', encoding='utf-8')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_sys_file_records(conn, limit=None, offset=0):
|
||||
"""
|
||||
从 t_sys_file 表获取系统图记录
|
||||
|
||||
Args:
|
||||
conn: 数据库连接
|
||||
limit: 限制数量(None 表示不限制)
|
||||
offset: 偏移量
|
||||
|
||||
Returns:
|
||||
记录列表,每个元素为 (id, url, style, level3_type, level2_type, deprecated)
|
||||
"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = f"""
|
||||
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||
FROM {TABLE_SYS_FILE}
|
||||
WHERE level1_type = 'Images'
|
||||
AND style IS NOT NULL
|
||||
AND style != ''
|
||||
AND deprecated != 1
|
||||
ORDER BY id
|
||||
"""
|
||||
|
||||
if limit:
|
||||
query += f" LIMIT {limit} OFFSET {offset}"
|
||||
|
||||
cursor.execute(query)
|
||||
records = cursor.fetchall()
|
||||
cursor.close()
|
||||
|
||||
return records
|
||||
|
||||
|
||||
def get_total_count(conn):
|
||||
"""获取总记录数"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"""
|
||||
SELECT COUNT(*)
|
||||
FROM {TABLE_SYS_FILE}
|
||||
WHERE level1_type = 'Images'
|
||||
AND style IS NOT NULL
|
||||
AND style != ''
|
||||
AND deprecated != 1
|
||||
""")
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return count
|
||||
|
||||
|
||||
def process_and_insert_batch(records, batch_size=1000, retry_times=3):
|
||||
"""
|
||||
处理并批量插入向量
|
||||
|
||||
Args:
|
||||
records: 记录列表
|
||||
batch_size: 批量大小
|
||||
retry_times: 失败重试次数
|
||||
|
||||
Returns:
|
||||
(成功数量, 失败数量)
|
||||
"""
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
failed_records = []
|
||||
batch_data = []
|
||||
|
||||
# 使用 tqdm 显示进度
|
||||
with tqdm(total=len(records), desc="处理记录", unit="条") as pbar:
|
||||
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records):
|
||||
try:
|
||||
# 计算 category
|
||||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||
|
||||
# 提取特征向量
|
||||
feature_vector = extract_feature_vector(url)
|
||||
# 归一化,便于 IP≈cosine 度量
|
||||
feature_vector = normalize_vector(feature_vector)
|
||||
|
||||
# 检查向量是否有效
|
||||
if np.all(feature_vector == 0):
|
||||
logger.warning(f"向量提取失败,跳过: {url} (id={sys_file_id})")
|
||||
failed_count += 1
|
||||
failed_records.append((sys_file_id, url))
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# 准备数据
|
||||
data_item = {
|
||||
"path": url,
|
||||
"sys_file_id": sys_file_id,
|
||||
"style": style,
|
||||
"category": category,
|
||||
"is_system_sketch": 1,
|
||||
"deprecated": deprecated if deprecated else 0,
|
||||
"feature_vector": feature_vector.tolist()
|
||||
}
|
||||
|
||||
batch_data.append(data_item)
|
||||
|
||||
# 批量写入
|
||||
if len(batch_data) >= batch_size:
|
||||
try:
|
||||
insert_vectors(batch_data)
|
||||
success_count += len(batch_data)
|
||||
batch_data = []
|
||||
logger.info(f"已成功插入 {success_count} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"批量写入失败: {e}")
|
||||
failed_count += len(batch_data)
|
||||
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||
batch_data = []
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理记录失败 [id={sys_file_id}, url={url}]: {e}")
|
||||
failed_count += 1
|
||||
failed_records.append((sys_file_id, url))
|
||||
pbar.update(1)
|
||||
|
||||
# 写入剩余数据
|
||||
if batch_data:
|
||||
try:
|
||||
insert_vectors(batch_data)
|
||||
success_count += len(batch_data)
|
||||
logger.info(f"写入剩余 {len(batch_data)} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"写入剩余数据失败: {e}")
|
||||
failed_count += len(batch_data)
|
||||
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||
|
||||
# 重试失败记录
|
||||
if failed_records and retry_times > 0:
|
||||
logger.info(f"开始重试 {len(failed_records)} 条失败记录,最多重试 {retry_times} 次...")
|
||||
|
||||
for retry in range(retry_times):
|
||||
if not failed_records:
|
||||
break
|
||||
|
||||
retry_failed = []
|
||||
with tqdm(total=len(failed_records), desc=f"重试第 {retry + 1} 次", unit="条") as pbar:
|
||||
for sys_file_id, url in failed_records:
|
||||
try:
|
||||
# 重新查询记录信息
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"""
|
||||
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||
FROM {TABLE_SYS_FILE}
|
||||
WHERE id = %s
|
||||
""", (sys_file_id,))
|
||||
record = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if not record:
|
||||
retry_failed.append((sys_file_id, url))
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
sys_file_id, url, style, level3_type, level2_type, deprecated = record
|
||||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||
|
||||
feature_vector = extract_feature_vector(url)
|
||||
feature_vector = normalize_vector(feature_vector)
|
||||
if np.all(feature_vector == 0):
|
||||
retry_failed.append((sys_file_id, url))
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
data_item = {
|
||||
"path": url,
|
||||
"sys_file_id": sys_file_id,
|
||||
"style": style,
|
||||
"category": category,
|
||||
"is_system_sketch": 1,
|
||||
"deprecated": deprecated if deprecated else 0,
|
||||
"feature_vector": feature_vector.tolist()
|
||||
}
|
||||
|
||||
insert_vectors([data_item])
|
||||
success_count += 1
|
||||
failed_count -= 1
|
||||
pbar.update(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重试失败 [id={sys_file_id}, url={url}]: {e}")
|
||||
retry_failed.append((sys_file_id, url))
|
||||
pbar.update(1)
|
||||
|
||||
failed_records = retry_failed
|
||||
if failed_records:
|
||||
logger.warning(f"第 {retry + 1} 次重试后仍有 {len(failed_records)} 条记录失败")
|
||||
|
||||
return success_count, failed_count, failed_records
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='从 t_sys_file 导入系统图向量到 Milvus')
|
||||
parser.add_argument('--batch-size', type=int, default=1000, help='批量处理大小(默认:1000)')
|
||||
parser.add_argument('--retry-times', type=int, default=3, help='失败重试次数(默认:3)')
|
||||
parser.add_argument('--limit', type=int, default=None, help='限制处理数量(用于测试,默认:不限制)')
|
||||
parser.add_argument('--offset', type=int, default=0, help='起始偏移量(默认:0)')
|
||||
parser.add_argument('--skip-create-collection', action='store_true', help='跳过创建集合(如果集合已存在)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始从 t_sys_file 导入系统图向量到 Milvus")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"配置参数:")
|
||||
logger.info(f" - 批量大小: {args.batch_size}")
|
||||
logger.info(f" - 重试次数: {args.retry_times}")
|
||||
logger.info(f" - 限制数量: {args.limit if args.limit else '不限制'}")
|
||||
logger.info(f" - 起始偏移: {args.offset}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 1. 创建 Milvus 集合
|
||||
if not args.skip_create_collection:
|
||||
logger.info("创建 Milvus 集合...")
|
||||
try:
|
||||
create_collection()
|
||||
logger.info("Milvus 集合创建成功(或已存在)")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 Milvus 集合失败: {e}")
|
||||
return
|
||||
else:
|
||||
logger.info("跳过创建集合")
|
||||
|
||||
# 2. 连接数据库
|
||||
logger.info("连接数据库...")
|
||||
try:
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
logger.info("数据库连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
# 3. 获取总记录数
|
||||
total_count = get_total_count(conn)
|
||||
logger.info(f"找到 {total_count} 条系统图记录")
|
||||
|
||||
if total_count == 0:
|
||||
logger.warning("没有找到系统图数据")
|
||||
return
|
||||
|
||||
# 4. 获取记录
|
||||
logger.info("获取记录...")
|
||||
records = get_sys_file_records(conn, limit=args.limit, offset=args.offset)
|
||||
logger.info(f"获取到 {len(records)} 条记录")
|
||||
|
||||
if not records:
|
||||
logger.warning("没有获取到记录")
|
||||
return
|
||||
|
||||
# 5. 处理并插入
|
||||
logger.info("开始处理记录...")
|
||||
success_count, failed_count, failed_records = process_and_insert_batch(
|
||||
records,
|
||||
batch_size=args.batch_size,
|
||||
retry_times=args.retry_times
|
||||
)
|
||||
|
||||
# 6. 输出结果
|
||||
logger.info("=" * 60)
|
||||
logger.info("导入完成!")
|
||||
logger.info(f" - 成功: {success_count} 条")
|
||||
logger.info(f" - 失败: {failed_count} 条")
|
||||
if failed_records:
|
||||
logger.warning(f" - 失败记录列表(前10条):")
|
||||
for sys_file_id, url in failed_records[:10]:
|
||||
logger.warning(f" ID={sys_file_id}, URL={url}")
|
||||
if len(failed_records) > 10:
|
||||
logger.warning(f" ... 还有 {len(failed_records) - 10} 条失败记录")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理过程中发生错误: {e}", exc_info=True)
|
||||
finally:
|
||||
conn.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
343
app/service/recommendation_system/incremental_listener.py
Normal file
343
app/service/recommendation_system/incremental_listener.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
增量监听模块
|
||||
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import pymysql
|
||||
import numpy as np
|
||||
from typing import List, Dict, Set, Tuple, Optional
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
|
||||
from app.service.recommendation_system.config import (
|
||||
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
|
||||
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||
)
|
||||
from app.service.recommendation_system.vector_utils import extract_feature_vector, compute_weighted_average, normalize_vector
|
||||
from app.service.recommendation_system.milvus_client import query_vectors_by_paths, insert_vectors
|
||||
from app.service.utils.redis_utils import Redis
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IncrementalListener:
|
||||
"""增量监听器"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_process_time = None
|
||||
self.processed_combinations: Set[Tuple[int, str]] = set() # 已处理的 (account_id, category) 组合
|
||||
self.listen_interval = RECOMMENDATION_CONFIG["listen_interval_sec"]
|
||||
|
||||
def get_new_like_records(self) -> List[Tuple]:
|
||||
"""
|
||||
获取新增点赞记录
|
||||
|
||||
Returns:
|
||||
记录列表,每个元素为 (id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id)
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if self.last_process_time is None:
|
||||
# 第一次运行,查询最近30分钟的数据
|
||||
cursor.execute(f"""
|
||||
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
|
||||
ORDER BY data_time
|
||||
""")
|
||||
else:
|
||||
# 基于上次处理时间查询
|
||||
cursor.execute(f"""
|
||||
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE data_time > %s
|
||||
ORDER BY data_time
|
||||
""", (self.last_process_time,))
|
||||
|
||||
records = cursor.fetchall()
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取新增点赞记录失败: {e}", exc_info=True)
|
||||
return []
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def process_new_records(self, records: List[Tuple]):
|
||||
"""
|
||||
处理新增记录
|
||||
|
||||
Args:
|
||||
records: 记录列表
|
||||
"""
|
||||
if not records:
|
||||
return
|
||||
|
||||
# 按用户+类别分组
|
||||
user_category_records = defaultdict(list)
|
||||
for record in records:
|
||||
account_id = record[1]
|
||||
category = record[3]
|
||||
if category: # 只处理有类别的记录
|
||||
user_category_records[(account_id, category)].append(record)
|
||||
|
||||
# 去重:只处理一次每个 (account_id, category) 组合
|
||||
to_process = []
|
||||
for (account_id, category), recs in user_category_records.items():
|
||||
if (account_id, category) not in self.processed_combinations:
|
||||
to_process.append((account_id, category, recs))
|
||||
self.processed_combinations.add((account_id, category))
|
||||
|
||||
logger.info(f"需要处理 {len(to_process)} 个用户-类别组合")
|
||||
|
||||
# 处理每个组合
|
||||
for account_id, category, recs in to_process:
|
||||
try:
|
||||
self.update_user_preference_vector(account_id, category)
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||
|
||||
# 更新最后处理时间
|
||||
if records:
|
||||
self.last_process_time = records[-1][5] # data_time
|
||||
# 重置去重集合,确保下次周期不会跳过同一用户-类别
|
||||
self.processed_combinations.clear()
|
||||
|
||||
def update_user_preference_vector(self, account_id: int, category: str):
|
||||
"""
|
||||
更新用户偏好向量
|
||||
|
||||
Args:
|
||||
account_id: 用户ID
|
||||
category: 类别
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. 获取该用户该类别的所有点赞记录
|
||||
cursor.execute(f"""
|
||||
SELECT path, data_time
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE account_id = %s AND category = %s
|
||||
ORDER BY data_time DESC
|
||||
""", (account_id, category))
|
||||
|
||||
like_records = cursor.fetchall()
|
||||
|
||||
if not like_records:
|
||||
return
|
||||
|
||||
# 2. 批量查询点赞次数
|
||||
paths = [r[0] for r in like_records]
|
||||
placeholders = ','.join(['%s'] * len(paths))
|
||||
cursor.execute(f"""
|
||||
SELECT path, COUNT(*) as like_count
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||
GROUP BY path
|
||||
""", (account_id, category) + tuple(paths))
|
||||
|
||||
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# 3. 批量获取向量
|
||||
vectors_dict = query_vectors_by_paths(paths)
|
||||
|
||||
# 处理查询不到的 path(新用户图或异常情况)
|
||||
missing_paths = [p for p in paths if p not in vectors_dict]
|
||||
if missing_paths:
|
||||
logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量")
|
||||
self._compute_and_insert_missing_vectors(missing_paths, conn)
|
||||
# 重新查询
|
||||
vectors_dict = query_vectors_by_paths(paths)
|
||||
|
||||
# 4. 计算权重并加权平均
|
||||
vectors = []
|
||||
weights = []
|
||||
K_half = RECOMMENDATION_CONFIG["K_half"]
|
||||
|
||||
for k, (path, data_time) in enumerate(like_records, 1):
|
||||
if path not in vectors_dict:
|
||||
continue
|
||||
|
||||
vector_data = vectors_dict[path]
|
||||
feature_vector = np.array(vector_data["feature_vector"])
|
||||
|
||||
# 时间衰减权重
|
||||
d_k = 0.5 ** (k / K_half)
|
||||
|
||||
# 点赞次数权重
|
||||
like_count = like_counts.get(path, 1)
|
||||
p_i = 1 + math.log(1 + like_count)
|
||||
|
||||
# 综合权重
|
||||
w_i = d_k * p_i
|
||||
|
||||
vectors.append(feature_vector)
|
||||
weights.append(w_i)
|
||||
|
||||
if not vectors:
|
||||
logger.warning(f"用户 {account_id} 类别 {category} 没有有效向量")
|
||||
return
|
||||
|
||||
# 5. 计算加权平均并做 L2 归一化,IP≈cosine
|
||||
preference_vector = compute_weighted_average(vectors, weights)
|
||||
preference_vector = normalize_vector(preference_vector)
|
||||
|
||||
# 6. 写入 Redis
|
||||
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
|
||||
vector_json = json.dumps(preference_vector.tolist())
|
||||
Redis.write(
|
||||
key=key,
|
||||
value=vector_json,
|
||||
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||
)
|
||||
|
||||
logger.debug(f"用户偏好向量更新成功 [user={account_id}, category={category}]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def _compute_and_insert_missing_vectors(self, paths: List[str], conn: pymysql.connections.Connection):
|
||||
"""
|
||||
计算并插入缺失的向量
|
||||
|
||||
Args:
|
||||
paths: 缺失的 path 列表
|
||||
conn: 数据库连接
|
||||
"""
|
||||
cursor = conn.cursor()
|
||||
data_to_insert = []
|
||||
|
||||
for path in paths:
|
||||
try:
|
||||
# 判断数据来源(查询 t_sys_file 表)
|
||||
cursor.execute(f"""
|
||||
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||
FROM {TABLE_SYS_FILE}
|
||||
WHERE url = %s
|
||||
LIMIT 1
|
||||
""", (path,))
|
||||
|
||||
sys_file = cursor.fetchone()
|
||||
|
||||
# 提取特征向量
|
||||
feature_vector = extract_feature_vector(path)
|
||||
|
||||
if np.all(feature_vector == 0):
|
||||
logger.warning(f"向量提取失败,跳过: {path}")
|
||||
continue
|
||||
|
||||
if sys_file:
|
||||
# 系统图
|
||||
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
|
||||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||
|
||||
data_item = {
|
||||
"path": path,
|
||||
"sys_file_id": sys_file_id,
|
||||
"style": style,
|
||||
"category": category,
|
||||
"is_system_sketch": 1,
|
||||
"deprecated": deprecated if deprecated else 0,
|
||||
"feature_vector": feature_vector.tolist()
|
||||
}
|
||||
else:
|
||||
# 用户图
|
||||
# 从 user_preference_log_test 获取 category(如果有)
|
||||
cursor.execute(f"""
|
||||
SELECT category
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE path = %s AND category IS NOT NULL
|
||||
LIMIT 1
|
||||
""", (path,))
|
||||
|
||||
category_result = cursor.fetchone()
|
||||
category = category_result[0] if category_result else None
|
||||
|
||||
data_item = {
|
||||
"path": path,
|
||||
"sys_file_id": None,
|
||||
"style": None,
|
||||
"category": category,
|
||||
"is_system_sketch": 0,
|
||||
"deprecated": 0,
|
||||
"feature_vector": feature_vector.tolist()
|
||||
}
|
||||
|
||||
data_to_insert.append(data_item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理缺失向量失败 [{path}]: {e}")
|
||||
|
||||
# 批量插入
|
||||
if data_to_insert:
|
||||
try:
|
||||
insert_vectors(data_to_insert)
|
||||
logger.info(f"成功插入 {len(data_to_insert)} 个缺失向量")
|
||||
except Exception as e:
|
||||
logger.error(f"插入缺失向量失败: {e}")
|
||||
|
||||
def process_once(self):
|
||||
"""单次轮询任务,供调度器调用"""
|
||||
try:
|
||||
records = self.get_new_like_records()
|
||||
|
||||
if records:
|
||||
logger.info(f"发现 {len(records)} 条新增记录")
|
||||
self.process_new_records(records)
|
||||
else:
|
||||
logger.debug("没有新增记录")
|
||||
except Exception as e:
|
||||
logger.error(f"监听轮询异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
def start_background_listener(scheduler: BackgroundScheduler):
|
||||
"""将增量监听任务注册到后台调度器"""
|
||||
listener = IncrementalListener()
|
||||
scheduler.add_job(
|
||||
listener.process_once,
|
||||
"interval",
|
||||
seconds=listener.listen_interval,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
id="recommendation_incremental_listener",
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info("增量监听任务已注册到调度器")
|
||||
|
||||
|
||||
def start_blocking_listener():
|
||||
"""以阻塞方式启动调度器(用于独立脚本运行)"""
|
||||
listener = IncrementalListener()
|
||||
scheduler = BlockingScheduler()
|
||||
scheduler.add_job(
|
||||
listener.process_once,
|
||||
"interval",
|
||||
seconds=listener.listen_interval,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
id="recommendation_incremental_listener",
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info("增量监听调度器已启动(BlockingScheduler)")
|
||||
scheduler.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_blocking_listener()
|
||||
|
||||
295
app/service/recommendation_system/milvus_client.py
Normal file
295
app/service/recommendation_system/milvus_client.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Milvus 客户端封装
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
|
||||
|
||||
from app.core.config import MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS
|
||||
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Milvus 客户端(单例)
|
||||
_milvus_client = None
|
||||
|
||||
|
||||
def get_milvus_client() -> MilvusClient:
|
||||
"""获取 Milvus 客户端(单例模式)"""
|
||||
global _milvus_client
|
||||
if _milvus_client is None:
|
||||
try:
|
||||
_milvus_client = MilvusClient(
|
||||
uri=MILVUS_URL,
|
||||
token=MILVUS_TOKEN,
|
||||
db_name=MILVUS_ALIAS
|
||||
)
|
||||
logger.info("Milvus 客户端连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"Milvus 客户端连接失败: {e}")
|
||||
raise
|
||||
return _milvus_client
|
||||
|
||||
|
||||
def create_collection():
|
||||
"""
|
||||
创建 Milvus 集合 sketch_vectors
|
||||
|
||||
集合结构:
|
||||
- path (PK, varchar(512)) - 主键,MinIO 逻辑 URL
|
||||
- sys_file_id (int64, 可为NULL) - 系统文件ID
|
||||
- style (varchar(50), 可为NULL) - 风格样式
|
||||
- category (varchar(100), 可为NULL) - 类别
|
||||
- is_system_sketch (int8, 默认 1) - 标记字段:1-系统图,0-用户图
|
||||
- deprecated (int8, 默认 0) - 是否废弃
|
||||
- feature_vector (FloatVector(2048)) - 2048维特征向量
|
||||
"""
|
||||
client = get_milvus_client()
|
||||
|
||||
# 检查集合是否已存在
|
||||
collections = client.list_collections()
|
||||
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
|
||||
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
|
||||
return
|
||||
|
||||
try:
|
||||
# 解析 Milvus URL
|
||||
# 处理 http://host.docker.internal:19530 格式
|
||||
url_clean = MILVUS_URL.replace("http://", "").replace("https://", "")
|
||||
if ":" in url_clean:
|
||||
host, port_str = url_clean.split(":", 1)
|
||||
port = int(port_str)
|
||||
else:
|
||||
host = url_clean
|
||||
port = 19530
|
||||
|
||||
# 使用传统 API 创建集合(更可靠)
|
||||
# 连接到 Milvus(如果未连接)
|
||||
try:
|
||||
connections.connect(
|
||||
alias=MILVUS_ALIAS,
|
||||
host=host,
|
||||
port=port,
|
||||
token=MILVUS_TOKEN if MILVUS_TOKEN else None
|
||||
)
|
||||
logger.info(f"已连接到 Milvus: {host}:{port}")
|
||||
except Exception as conn_e:
|
||||
# 如果连接已存在,忽略错误
|
||||
if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e):
|
||||
logger.info("Milvus 连接已存在")
|
||||
else:
|
||||
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
|
||||
|
||||
# 定义字段
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
||||
FieldSchema(name="sys_file_id", dtype=DataType.INT64),
|
||||
FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50),
|
||||
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50),
|
||||
FieldSchema(name="is_system_sketch", dtype=DataType.INT8),
|
||||
FieldSchema(name="deprecated", dtype=DataType.INT8),
|
||||
FieldSchema(
|
||||
name="feature_vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=RECOMMENDATION_CONFIG["vector_dim"]
|
||||
)
|
||||
]
|
||||
|
||||
# 创建 schema
|
||||
schema = CollectionSchema(
|
||||
fields=fields,
|
||||
description="Sketch vectors collection for recommendation system"
|
||||
)
|
||||
|
||||
# 创建集合
|
||||
collection = Collection(
|
||||
name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
schema=schema,
|
||||
using=MILVUS_ALIAS
|
||||
)
|
||||
|
||||
# 创建索引
|
||||
# 注意:使用 IP(内积)作为度量类型,与搜索时保持一致
|
||||
# 如果向量已归一化,IP 等价于 COSINE
|
||||
index_params = {
|
||||
"metric_type": "IP", # 内积(Inner Product)
|
||||
"index_type": "IVF_FLAT",
|
||||
"params": {"nlist": 1024}
|
||||
}
|
||||
|
||||
collection.create_index(
|
||||
field_name="feature_vector",
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建集合失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def insert_vectors(data: List[Dict[str, Any]]):
|
||||
"""
|
||||
批量插入向量到 Milvus
|
||||
|
||||
Args:
|
||||
data: 数据列表,每个元素包含:
|
||||
- path: str
|
||||
- sys_file_id: int (可选)
|
||||
- style: str (可选)
|
||||
- category: str (可选)
|
||||
- is_system_sketch: int (默认 1)
|
||||
- deprecated: int (默认 0)
|
||||
- feature_vector: List[float] (2048维)
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
client = get_milvus_client()
|
||||
|
||||
try:
|
||||
client.insert(
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
data=data
|
||||
)
|
||||
logger.info(f"成功插入 {len(data)} 条向量数据")
|
||||
except Exception as e:
|
||||
logger.error(f"插入向量失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
|
||||
"""
|
||||
根据 path 列表批量查询向量
|
||||
|
||||
Args:
|
||||
paths: path 列表
|
||||
|
||||
Returns:
|
||||
{path: {feature_vector: [...], ...}} 字典
|
||||
"""
|
||||
if not paths:
|
||||
return {}
|
||||
|
||||
client = get_milvus_client()
|
||||
|
||||
try:
|
||||
# 构建查询表达式
|
||||
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
||||
# 对于字符串列表,使用单引号包裹每个值
|
||||
path_list = ", ".join([f"'{p}'" for p in paths])
|
||||
filter_expr = f"path in [{path_list}]"
|
||||
|
||||
results = client.query(
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
filter=filter_expr,
|
||||
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
|
||||
)
|
||||
|
||||
# 转换为字典
|
||||
result_dict = {}
|
||||
for r in results:
|
||||
result_dict[r["path"]] = r
|
||||
|
||||
return result_dict
|
||||
except Exception as e:
|
||||
logger.error(f"查询向量失败: {e}", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def search_similar_vectors(
|
||||
query_vector: np.ndarray,
|
||||
category: str,
|
||||
topk: int = 500,
|
||||
style: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
向量相似度检索
|
||||
|
||||
Args:
|
||||
query_vector: 查询向量(2048维)
|
||||
category: 类别过滤
|
||||
topk: 返回数量
|
||||
style: 风格过滤(可选)
|
||||
|
||||
Returns:
|
||||
检索结果列表,每个元素包含 path, score, style, category 等字段
|
||||
"""
|
||||
client = get_milvus_client()
|
||||
|
||||
try:
|
||||
# 构建过滤表达式
|
||||
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
||||
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||
if style:
|
||||
filter_expr += f" && style == '{style}'"
|
||||
|
||||
# 搜索
|
||||
results = client.search(
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
data=[query_vector.tolist()],
|
||||
anns_field="feature_vector",
|
||||
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||
limit=topk,
|
||||
filter=filter_expr,
|
||||
output_fields=["path", "style", "category", "sys_file_id"]
|
||||
)
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = []
|
||||
if results and len(results) > 0:
|
||||
for hit in results[0]:
|
||||
formatted_results.append({
|
||||
"path": hit.get("entity", {}).get("path", ""),
|
||||
"score": hit.get("distance", 0.0),
|
||||
"style": hit.get("entity", {}).get("style", ""),
|
||||
"category": hit.get("entity", {}).get("category", ""),
|
||||
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
||||
})
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"向量检索失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]:
|
||||
"""
|
||||
随机查询候选(用于探索分支)
|
||||
|
||||
Args:
|
||||
category: 类别
|
||||
style: 风格(可选)
|
||||
limit: 返回数量
|
||||
|
||||
Returns:
|
||||
候选列表
|
||||
"""
|
||||
client = get_milvus_client()
|
||||
|
||||
try:
|
||||
# 构建过滤表达式
|
||||
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||
if style:
|
||||
filter_expr += f" && style == '{style}'"
|
||||
|
||||
# 查询所有符合条件的记录
|
||||
results = client.query(
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
filter=filter_expr,
|
||||
output_fields=["path", "style", "category"],
|
||||
limit=10000 # 先查询大量数据,然后随机选择
|
||||
)
|
||||
|
||||
# 随机选择
|
||||
if len(results) > limit:
|
||||
import random
|
||||
results = random.sample(results, limit)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"随机查询候选失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
556
app/service/recommendation_system/precompute.py
Normal file
556
app/service/recommendation_system/precompute.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
预计算模块
|
||||
包含:数据库表结构优化、Milvus集合创建、系统图向量预计算、初始用户偏好向量生成
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import pymysql
|
||||
import numpy as np
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
from app.service.recommendation_system.config import (
|
||||
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
|
||||
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||
)
|
||||
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average
|
||||
from app.service.recommendation_system.milvus_client import (
|
||||
create_collection, insert_vectors, query_vectors_by_paths
|
||||
)
|
||||
from app.service.utils.redis_utils import Redis
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def optimize_database_table():
|
||||
"""
|
||||
优化 user_preference_log_test 表结构
|
||||
添加冗余字段和索引
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. 添加冗余字段
|
||||
logger.info("添加冗余字段...")
|
||||
alter_sqls = [
|
||||
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别:lower(level3_type + \"_\" + level2_type)'",
|
||||
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'",
|
||||
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图(1-是,0-用户图)'",
|
||||
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'",
|
||||
]
|
||||
|
||||
for sql in alter_sqls:
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
logger.info(f"执行成功: {sql[:50]}...")
|
||||
except Exception as e:
|
||||
if "Duplicate column name" in str(e):
|
||||
logger.info(f"字段已存在,跳过: {sql[:50]}...")
|
||||
else:
|
||||
logger.warning(f"执行失败: {sql[:50]}... 错误: {e}")
|
||||
|
||||
# 2. 创建索引(MySQL 不支持 IF NOT EXISTS,需要先检查)
|
||||
logger.info("创建索引...")
|
||||
index_definitions = [
|
||||
("idx_account_category_time", ["account_id", "category", "data_time"]),
|
||||
("idx_account_path", ["account_id", "path"]),
|
||||
]
|
||||
|
||||
for index_name, columns in index_definitions:
|
||||
try:
|
||||
# 检查索引是否已存在
|
||||
cursor.execute(f"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.statistics
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{TABLE_USER_PREFERENCE_LOG}'
|
||||
AND index_name = '{index_name}'
|
||||
""")
|
||||
exists = cursor.fetchone()[0] > 0
|
||||
|
||||
if exists:
|
||||
logger.info(f"索引已存在,跳过: {index_name}")
|
||||
else:
|
||||
# 创建索引
|
||||
columns_str = ', '.join(columns)
|
||||
create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})"
|
||||
cursor.execute(create_sql)
|
||||
logger.info(f"索引创建成功: {index_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"索引创建失败: {index_name} 错误: {e}")
|
||||
|
||||
conn.commit()
|
||||
logger.info("数据库表结构优化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库表结构优化失败: {e}", exc_info=True)
|
||||
if conn:
|
||||
conn.rollback()
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def migrate_historical_data(batch_size: int = 1000):
|
||||
"""
|
||||
历史数据迁移:批量更新冗余字段
|
||||
|
||||
Args:
|
||||
batch_size: 每批处理数量
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询需要更新的记录数
|
||||
cursor.execute(f"""
|
||||
SELECT COUNT(*)
|
||||
FROM {TABLE_USER_PREFERENCE_LOG} u
|
||||
WHERE u.category IS NULL
|
||||
""")
|
||||
total_count = cursor.fetchone()[0]
|
||||
logger.info(f"需要迁移的记录数: {total_count}")
|
||||
|
||||
if total_count == 0:
|
||||
logger.info("无需迁移数据")
|
||||
return
|
||||
|
||||
# 分批处理
|
||||
offset = 0
|
||||
processed = 0
|
||||
|
||||
while offset < total_count:
|
||||
# 查询一批记录
|
||||
cursor.execute(f"""
|
||||
SELECT u.id, u.path
|
||||
FROM {TABLE_USER_PREFERENCE_LOG} u
|
||||
WHERE u.category IS NULL
|
||||
LIMIT {batch_size} OFFSET {offset}
|
||||
""")
|
||||
records = cursor.fetchall()
|
||||
|
||||
if not records:
|
||||
break
|
||||
|
||||
# 批量更新
|
||||
for record_id, path in records:
|
||||
# 查询 t_sys_file 表
|
||||
cursor.execute(f"""
|
||||
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||
FROM {TABLE_SYS_FILE}
|
||||
WHERE url = %s
|
||||
LIMIT 1
|
||||
""", (path,))
|
||||
|
||||
sys_file = cursor.fetchone()
|
||||
|
||||
if sys_file:
|
||||
# 系统图
|
||||
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
|
||||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||
|
||||
cursor.execute(f"""
|
||||
UPDATE {TABLE_USER_PREFERENCE_LOG}
|
||||
SET category = %s,
|
||||
style = %s,
|
||||
is_system_sketch = 1,
|
||||
sys_file_id = %s
|
||||
WHERE id = %s
|
||||
""", (category, style, sys_file_id, record_id))
|
||||
else:
|
||||
# 用户图
|
||||
cursor.execute(f"""
|
||||
UPDATE {TABLE_USER_PREFERENCE_LOG}
|
||||
SET is_system_sketch = 0,
|
||||
category = NULL,
|
||||
style = NULL,
|
||||
sys_file_id = NULL
|
||||
WHERE id = %s
|
||||
""", (record_id,))
|
||||
|
||||
conn.commit()
|
||||
processed += len(records)
|
||||
offset += batch_size
|
||||
logger.info(f"已迁移 {processed}/{total_count} 条记录")
|
||||
|
||||
logger.info("历史数据迁移完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"历史数据迁移失败: {e}", exc_info=True)
|
||||
if conn:
|
||||
conn.rollback()
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3):
|
||||
"""
|
||||
系统图向量预计算与导入
|
||||
|
||||
Args:
|
||||
batch_size: 每批处理数量
|
||||
retry_times: 失败重试次数
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. 数据筛选
|
||||
logger.info("查询系统图数据...")
|
||||
cursor.execute(f"""
|
||||
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||
FROM {TABLE_SYS_FILE}
|
||||
WHERE level1_type = 'Images'
|
||||
AND style IS NOT NULL
|
||||
AND style != ''
|
||||
AND deprecated != 1
|
||||
""")
|
||||
records = cursor.fetchall()
|
||||
logger.info(f"找到 {len(records)} 条系统图记录")
|
||||
|
||||
if not records:
|
||||
logger.warning("没有找到系统图数据")
|
||||
return
|
||||
|
||||
# 2. 批量处理
|
||||
failed_records = []
|
||||
batch_data = []
|
||||
|
||||
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1):
|
||||
try:
|
||||
# 计算 category
|
||||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||
|
||||
# 提取特征向量
|
||||
feature_vector = extract_feature_vector(url)
|
||||
|
||||
# 检查向量是否有效
|
||||
if np.all(feature_vector == 0):
|
||||
logger.warning(f"向量提取失败,跳过: {url}")
|
||||
failed_records.append((sys_file_id, url))
|
||||
continue
|
||||
|
||||
# 准备数据
|
||||
data_item = {
|
||||
"path": url,
|
||||
"sys_file_id": sys_file_id,
|
||||
"style": style,
|
||||
"category": category,
|
||||
"is_system_sketch": 1,
|
||||
"deprecated": deprecated if deprecated else 0,
|
||||
"feature_vector": feature_vector.tolist()
|
||||
}
|
||||
|
||||
batch_data.append(data_item)
|
||||
|
||||
# 批量写入
|
||||
if len(batch_data) >= batch_size:
|
||||
try:
|
||||
insert_vectors(batch_data)
|
||||
batch_data = []
|
||||
logger.info(f"已处理 {idx}/{len(records)} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"批量写入失败: {e}")
|
||||
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||
batch_data = []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理记录失败 [{url}]: {e}")
|
||||
failed_records.append((sys_file_id, url))
|
||||
|
||||
# 写入剩余数据
|
||||
if batch_data:
|
||||
try:
|
||||
insert_vectors(batch_data)
|
||||
except Exception as e:
|
||||
logger.error(f"写入剩余数据失败: {e}")
|
||||
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||
|
||||
# 3. 重试失败记录
|
||||
if failed_records and retry_times > 0:
|
||||
logger.info(f"重试 {len(failed_records)} 条失败记录...")
|
||||
for retry in range(retry_times):
|
||||
retry_failed = []
|
||||
for sys_file_id, url in failed_records:
|
||||
try:
|
||||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||
feature_vector = extract_feature_vector(url)
|
||||
if not np.all(feature_vector == 0):
|
||||
data_item = {
|
||||
"path": url,
|
||||
"sys_file_id": sys_file_id,
|
||||
"style": style,
|
||||
"category": category,
|
||||
"is_system_sketch": 1,
|
||||
"deprecated": 0,
|
||||
"feature_vector": feature_vector.tolist()
|
||||
}
|
||||
insert_vectors([data_item])
|
||||
else:
|
||||
retry_failed.append((sys_file_id, url))
|
||||
except Exception as e:
|
||||
logger.error(f"重试失败 [{url}]: {e}")
|
||||
retry_failed.append((sys_file_id, url))
|
||||
|
||||
failed_records = retry_failed
|
||||
if not failed_records:
|
||||
break
|
||||
|
||||
if failed_records:
|
||||
logger.warning(f"仍有 {len(failed_records)} 条记录处理失败")
|
||||
|
||||
logger.info("系统图向量预计算完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"系统图向量预计算失败: {e}", exc_info=True)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def compute_user_preference_vector(
|
||||
account_id: int,
|
||||
category: str,
|
||||
conn: Optional[pymysql.connections.Connection] = None
|
||||
# max_date: Optional[datetime] = None
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
计算用户偏好向量
|
||||
|
||||
Args:
|
||||
account_id: 用户ID
|
||||
category: 类别
|
||||
conn: 数据库连接(可选)
|
||||
max_date: 最大日期(可选,用于评估时只使用训练集数据)
|
||||
|
||||
Returns:
|
||||
用户偏好向量(2048维),失败返回 None
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
should_close = False
|
||||
if conn is None:
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
should_close = True
|
||||
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. 获取点赞记录(如果指定了max_date,只查询该日期之前的数据)
|
||||
if max_date:
|
||||
cursor.execute(f"""
|
||||
SELECT path, data_time
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE account_id = %s AND category = %s AND style is not null
|
||||
AND data_time < %s
|
||||
ORDER BY data_time DESC
|
||||
""", (account_id, category, max_date))
|
||||
else:
|
||||
cursor.execute(f"""
|
||||
SELECT path, data_time
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE account_id = %s AND category = %s AND style is not null
|
||||
ORDER BY data_time DESC
|
||||
""", (account_id, category))
|
||||
|
||||
like_records = cursor.fetchall()
|
||||
|
||||
if not like_records:
|
||||
return None
|
||||
|
||||
# 2. 批量查询点赞次数(如果指定了max_date,只统计该日期之前的点赞)
|
||||
paths = [r[0] for r in like_records]
|
||||
if not paths:
|
||||
return None
|
||||
|
||||
placeholders = ','.join(['%s'] * len(paths))
|
||||
if max_date:
|
||||
cursor.execute(f"""
|
||||
SELECT path, COUNT(*) as like_count
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||
AND data_time < %s
|
||||
GROUP BY path
|
||||
""", (account_id, category) + tuple(paths) + (max_date,))
|
||||
else:
|
||||
cursor.execute(f"""
|
||||
SELECT path, COUNT(*) as like_count
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||
GROUP BY path
|
||||
""", (account_id, category) + tuple(paths))
|
||||
|
||||
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# 3. 批量获取向量
|
||||
vectors_dict = query_vectors_by_paths(paths)
|
||||
|
||||
# 处理查询不到的 path(用户图或异常情况)
|
||||
missing_paths = [p for p in paths if p not in vectors_dict]
|
||||
if missing_paths:
|
||||
logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量")
|
||||
# 目前未有非系统图向量,跳过
|
||||
# 这里可以实时计算并写入 Milvus,但为了简化,先跳过
|
||||
# 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus
|
||||
|
||||
# 4. 计算权重并加权平均
|
||||
vectors = []
|
||||
weights = []
|
||||
K_half = RECOMMENDATION_CONFIG["K_half"]
|
||||
|
||||
for k, (path, data_time) in enumerate(like_records, 1):
|
||||
if path not in vectors_dict:
|
||||
continue
|
||||
|
||||
vector_data = vectors_dict[path]
|
||||
feature_vector = np.array(vector_data["feature_vector"])
|
||||
|
||||
# 时间衰减权重
|
||||
d_k = 0.5 ** (k / K_half)
|
||||
|
||||
# 点赞次数权重
|
||||
like_count = like_counts.get(path, 1)
|
||||
p_i = 1 + math.log(1 + like_count)
|
||||
|
||||
# 综合权重
|
||||
# w_i = d_k * p_i
|
||||
w_i = p_i
|
||||
|
||||
vectors.append(feature_vector)
|
||||
weights.append(w_i)
|
||||
|
||||
if not vectors:
|
||||
return None
|
||||
|
||||
# 5. 计算加权平均并做 L2 归一化,IP≈cosine
|
||||
preference_vector = compute_weighted_average(vectors, weights)
|
||||
preference_vector = normalize_vector(preference_vector)
|
||||
|
||||
return preference_vector
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||
return None
|
||||
finally:
|
||||
if should_close and conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def generate_initial_user_preference_vectors(batch_size: int = 100):
|
||||
"""
|
||||
初始用户偏好向量生成
|
||||
|
||||
Args:
|
||||
batch_size: 每批处理用户数
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. 扫描历史数据
|
||||
logger.info("扫描用户和类别组合...")
|
||||
cursor.execute(f"""
|
||||
SELECT DISTINCT account_id, category
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
WHERE category IS NOT NULL
|
||||
AND style IS NOT NULL
|
||||
""")
|
||||
|
||||
user_categories = cursor.fetchall()
|
||||
logger.info(f"找到 {len(user_categories)} 个用户-类别组合")
|
||||
|
||||
if not user_categories:
|
||||
logger.warning("没有找到用户-类别组合")
|
||||
return
|
||||
|
||||
# 2. 批量处理
|
||||
processed = 0
|
||||
failed = 0
|
||||
|
||||
for account_id, category in user_categories:
|
||||
try:
|
||||
# 计算偏好向量
|
||||
preference_vector = compute_user_preference_vector(account_id, category, conn)
|
||||
|
||||
if preference_vector is not None:
|
||||
# 写入 Redis
|
||||
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
|
||||
# 序列化向量(使用 JSON)
|
||||
vector_json = json.dumps(preference_vector.tolist())
|
||||
Redis.write(
|
||||
key=key,
|
||||
value=vector_json,
|
||||
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||
)
|
||||
processed += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
if (processed + failed) % batch_size == 0:
|
||||
logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理失败 [user={account_id}, category={category}]: {e}")
|
||||
failed += 1
|
||||
|
||||
logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def run_precompute():
|
||||
"""
|
||||
运行所有预计算任务
|
||||
"""
|
||||
logger.info("=" * 50)
|
||||
logger.info("开始预计算任务")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# 1. 优化数据库表结构
|
||||
logger.info("\n[1/5] 优化数据库表结构...")
|
||||
optimize_database_table()
|
||||
|
||||
# # 2. 创建 Milvus 集合
|
||||
# logger.info("\n[2/5] 创建 Milvus 集合...")
|
||||
# create_collection()
|
||||
|
||||
# 3. 历史数据迁移
|
||||
logger.info("\n[3/5] 历史数据迁移...")
|
||||
migrate_historical_data()
|
||||
|
||||
# # 4. 系统图向量预计算
|
||||
# logger.info("\n[4/5] 系统图向量预计算...")
|
||||
# precompute_system_sketch_vectors()
|
||||
|
||||
# 5. 初始用户偏好向量生成
|
||||
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||||
generate_initial_user_preference_vectors()
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("预计算任务完成")
|
||||
logger.info("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1. 优化数据库表结构
|
||||
logger.info("\n[1/5] 优化数据库表结构...")
|
||||
optimize_database_table()
|
||||
|
||||
# 3. 历史数据迁移
|
||||
logger.info("\n[3/5] 历史数据迁移...")
|
||||
migrate_historical_data()
|
||||
|
||||
# 5. 初始用户偏好向量生成
|
||||
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||||
generate_initial_user_preference_vectors()
|
||||
214
app/service/recommendation_system/recommendation_api.py
Normal file
214
app/service/recommendation_system/recommendation_api.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
推荐接口实现
|
||||
实现探索/利用分支、向量检索、Softmax抽样等功能
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||
from app.service.recommendation_system.milvus_client import search_similar_vectors, query_random_candidates
|
||||
from app.service.recommendation_system.precompute import compute_user_preference_vector
|
||||
from app.service.recommendation_system.vector_utils import normalize_vector
|
||||
from app.service.utils.redis_utils import Redis
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_user_preference_vector(user_id: int, category: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
获取用户偏好向量
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
category: 类别
|
||||
|
||||
Returns:
|
||||
用户偏好向量(2048维),失败返回 None
|
||||
"""
|
||||
# 1. 从 Redis 获取
|
||||
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{user_id}:{category}"
|
||||
vector_json = Redis.read(key)
|
||||
|
||||
if vector_json:
|
||||
try:
|
||||
vector_list = json.loads(vector_json)
|
||||
return np.array(vector_list, dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析 Redis 向量失败 [user={user_id}, category={category}]: {e}")
|
||||
|
||||
# 2. 如果不存在,实时计算
|
||||
logger.info(f"Redis 中不存在用户偏好向量,实时计算 [user={user_id}, category={category}]")
|
||||
preference_vector = compute_user_preference_vector(user_id, category)
|
||||
|
||||
if preference_vector is not None:
|
||||
# 写入 Redis
|
||||
vector_json = json.dumps(preference_vector.tolist())
|
||||
Redis.write(
|
||||
key=key,
|
||||
value=vector_json,
|
||||
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||
)
|
||||
|
||||
return preference_vector
|
||||
|
||||
|
||||
def explore_branch(category: str, style: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
探索分支(随机推荐)
|
||||
|
||||
Args:
|
||||
category: 类别
|
||||
style: 风格(可选)
|
||||
|
||||
Returns:
|
||||
推荐结果列表,每个元素包含 path, style, category 等字段
|
||||
"""
|
||||
# 查询候选(随机池)
|
||||
pool_size = 10 # 固定查询10个,然后随机选择
|
||||
|
||||
candidates = query_random_candidates(category, style, limit=pool_size)
|
||||
|
||||
if not candidates:
|
||||
logger.warning(f"探索分支:类别 {category} 没有候选数据")
|
||||
return []
|
||||
|
||||
# 随机选择
|
||||
if len(candidates) > 1:
|
||||
import random
|
||||
candidates = random.sample(candidates, 1)
|
||||
|
||||
# 格式化返回结果
|
||||
return [candidate.get("path", "") for candidate in candidates[:1]]
|
||||
|
||||
|
||||
def exploit_branch(
|
||||
user_id: int,
|
||||
category: str,
|
||||
style: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
利用分支(基于向量相似度推荐)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
category: 类别
|
||||
num_recommendations: 返回数量
|
||||
style: 风格(可选,用于加分)
|
||||
|
||||
Returns:
|
||||
推荐结果列表,每个元素包含 path, style, category, similarity, sample_score 等字段
|
||||
"""
|
||||
# 1. 获取用户偏好向量
|
||||
embedding = get_user_preference_vector(user_id, category)
|
||||
|
||||
if embedding is None:
|
||||
logger.warning(f"利用分支:无法获取用户偏好向量,回退到探索分支 [user={user_id}, category={category}]")
|
||||
return explore_branch(category, style)
|
||||
|
||||
# 2. Milvus 相似度检索(内积 IP)
|
||||
topk = RECOMMENDATION_CONFIG["topk"]
|
||||
results = search_similar_vectors(embedding, category, topk)
|
||||
|
||||
if not results:
|
||||
logger.warning(f"利用分支:向量检索无结果,回退到探索分支 [user={user_id}, category={category}]")
|
||||
return explore_branch(category, style)
|
||||
|
||||
# 3. Style 加分(可选,需传入 style 参数)
|
||||
style_bonus = RECOMMENDATION_CONFIG["style_bonus"]
|
||||
if style:
|
||||
for result in results:
|
||||
similarity = result["score"]
|
||||
if result.get("style") == style:
|
||||
# 加分:相似度 * (1 + style_bonus)
|
||||
similarity = similarity * (1 + style_bonus)
|
||||
result["final_score"] = similarity
|
||||
else:
|
||||
for result in results:
|
||||
result["final_score"] = result["score"]
|
||||
|
||||
# 4. Softmax 抽样
|
||||
scores = [r["final_score"] for r in results]
|
||||
probabilities = softmax_with_temperature(scores, RECOMMENDATION_CONFIG["softmax_temperature"])
|
||||
|
||||
# 根据概率抽样
|
||||
if not results:
|
||||
return []
|
||||
|
||||
selected_index = np.random.choice(len(results), size=1, p=probabilities, replace=False)
|
||||
selected_results = [results[int(selected_index[0])]]
|
||||
|
||||
# 5. 返回结果
|
||||
return [result.get("path", "") for result in selected_results]
|
||||
|
||||
|
||||
def softmax_with_temperature(scores: List[float], temperature: float = 1.0) -> List[float]:
|
||||
"""
|
||||
Softmax 函数(带温度参数)
|
||||
|
||||
Args:
|
||||
scores: 分数列表
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
概率列表
|
||||
"""
|
||||
if not scores:
|
||||
return []
|
||||
|
||||
# 除以温度
|
||||
scaled_scores = [s / temperature for s in scores]
|
||||
|
||||
# 减去最大值(数值稳定性)
|
||||
max_score = max(scaled_scores)
|
||||
exp_scores = [math.exp(s - max_score) for s in scaled_scores]
|
||||
|
||||
# 归一化
|
||||
sum_exp = sum(exp_scores)
|
||||
if sum_exp == 0:
|
||||
# 如果所有分数都是负无穷或非常小,返回均匀分布
|
||||
return [1.0 / len(scores)] * len(scores)
|
||||
|
||||
probabilities = [exp_s / sum_exp for exp_s in exp_scores]
|
||||
return probabilities
|
||||
|
||||
|
||||
def get_recommendations(
|
||||
user_id: int,
|
||||
category: str,
|
||||
style: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
获取推荐结果(主函数)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
category: 类别(如 female_skirt)
|
||||
num_recommendations: 返回推荐数量(默认 1)
|
||||
style: 风格(可选):若传入,则在利用分支对同 style 的候选进行加分
|
||||
|
||||
Returns:
|
||||
推荐结果列表,每个元素包含 path 等字段
|
||||
"""
|
||||
try:
|
||||
# 1. 读取配置参数
|
||||
explore_ratio = RECOMMENDATION_CONFIG["explore_ratio"]
|
||||
|
||||
# 2. 探索/利用决策
|
||||
r = random.random() # 生成随机数 (0-1)
|
||||
|
||||
if r < explore_ratio:
|
||||
logger.debug(f"探索分支 [user={user_id}, category={category}]")
|
||||
return explore_branch(category, style)
|
||||
|
||||
logger.debug(f"利用分支 [user={user_id}, category={category}]")
|
||||
return exploit_branch(user_id, category, style)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取推荐结果失败 [user={user_id}, category={category}]: {e}", exc_info=True)
|
||||
# 容错:回退到探索分支
|
||||
return explore_branch(category, style)
|
||||
|
||||
189
app/service/recommendation_system/vector_utils.py
Normal file
189
app/service/recommendation_system/vector_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
向量计算工具类
|
||||
包含 ResNet50 特征提取、向量归一化等功能
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import models, transforms
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
||||
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 图像预处理(与ResNet训练时的预处理一致)
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
|
||||
transforms.ToTensor(), # 转换为 Tensor
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
|
||||
])
|
||||
|
||||
# 加载预训练的 ResNet50 模型(去掉最后全连接层)
|
||||
_resnet_model = None
|
||||
|
||||
|
||||
def get_resnet_model():
|
||||
"""获取 ResNet50 模型(单例模式)"""
|
||||
global _resnet_model
|
||||
if _resnet_model is None:
|
||||
logger.info("加载 ResNet50 模型...")
|
||||
_resnet_model = models.resnet50(pretrained=True)
|
||||
modules = list(_resnet_model.children())[:-1] # 移除最后的全连接层
|
||||
_resnet_model = torch.nn.Sequential(*modules)
|
||||
_resnet_model.eval() # 设置为评估模式
|
||||
logger.info("ResNet50 模型加载完成")
|
||||
return _resnet_model
|
||||
|
||||
|
||||
# MinIO 客户端(单例)
|
||||
_minio_client = None
|
||||
|
||||
|
||||
def get_minio_client():
|
||||
"""获取 MinIO 客户端(单例模式)"""
|
||||
global _minio_client
|
||||
if _minio_client is None:
|
||||
_minio_client = Minio(
|
||||
MINIO_URL,
|
||||
access_key=MINIO_ACCESS,
|
||||
secret_key=MINIO_SECRET,
|
||||
secure=MINIO_SECURE
|
||||
)
|
||||
return _minio_client
|
||||
|
||||
|
||||
def get_image_from_minio(path: str) -> Image.Image:
|
||||
"""
|
||||
从 MinIO 获取图片
|
||||
|
||||
Args:
|
||||
path: MinIO 逻辑 URL,格式如 "bucket_name/object_name"
|
||||
|
||||
Returns:
|
||||
PIL Image 对象,失败返回 None
|
||||
"""
|
||||
try:
|
||||
# 分割路径,获取桶名和文件路径
|
||||
path_parts = path.split('/', 1)
|
||||
if len(path_parts) != 2:
|
||||
logger.error(f"路径格式错误: {path}")
|
||||
return None
|
||||
|
||||
bucket_name, file_name = path_parts
|
||||
minio_client = get_minio_client()
|
||||
|
||||
# 获取文件
|
||||
obj = minio_client.get_object(bucket_name, file_name)
|
||||
img_data = obj.read() # 读取图像数据
|
||||
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
|
||||
|
||||
return img
|
||||
except Exception as e:
|
||||
logger.error(f"从 MinIO 获取图片失败 [{path}]: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_feature_vector(path: str) -> np.ndarray:
|
||||
"""
|
||||
使用 ResNet50 提取图片特征向量(2048维)
|
||||
|
||||
Args:
|
||||
path: MinIO 逻辑 URL
|
||||
|
||||
Returns:
|
||||
2048维特征向量(numpy array),失败返回零向量
|
||||
"""
|
||||
try:
|
||||
# 从 MinIO 获取图像
|
||||
img = get_image_from_minio(path)
|
||||
if img is None:
|
||||
logger.warning(f"无法获取图片,返回零向量: {path}")
|
||||
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||
|
||||
# 预处理
|
||||
# 部分 MinIO 图片可能是 RGBA/CMYK,转换成 RGB 以匹配 3 通道标准化参数
|
||||
if img.mode != "RGB":
|
||||
try:
|
||||
img = img.convert("RGB")
|
||||
except Exception:
|
||||
logger.warning(f"无法转换图片为RGB,返回零向量: {path}")
|
||||
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||
|
||||
img_tensor = transform(img).unsqueeze(0) # 扩展维度以适应批量处理
|
||||
|
||||
# 提取特征
|
||||
resnet_model = get_resnet_model()
|
||||
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
|
||||
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
|
||||
feature_vector = feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
|
||||
|
||||
# 确保是 2048 维
|
||||
if feature_vector.ndim > 1:
|
||||
feature_vector = feature_vector.flatten()
|
||||
|
||||
# 确保维度正确
|
||||
if len(feature_vector) != RECOMMENDATION_CONFIG["vector_dim"]:
|
||||
logger.warning(f"向量维度不正确: {len(feature_vector)}, 期望: {RECOMMENDATION_CONFIG['vector_dim']}")
|
||||
# 如果维度不对,尝试调整
|
||||
if len(feature_vector) > RECOMMENDATION_CONFIG["vector_dim"]:
|
||||
feature_vector = feature_vector[:RECOMMENDATION_CONFIG["vector_dim"]]
|
||||
else:
|
||||
padded = np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||
padded[:len(feature_vector)] = feature_vector
|
||||
feature_vector = padded
|
||||
|
||||
return feature_vector.astype(np.float32)
|
||||
except Exception as e:
|
||||
logger.error(f"提取特征向量失败 [{path}]: {e}", exc_info=True)
|
||||
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||
|
||||
|
||||
def normalize_vector(vector: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
L2 归一化向量
|
||||
|
||||
Args:
|
||||
vector: 输入向量
|
||||
|
||||
Returns:
|
||||
归一化后的向量
|
||||
"""
|
||||
norm = np.linalg.norm(vector)
|
||||
if norm == 0:
|
||||
return vector
|
||||
return vector / norm
|
||||
|
||||
|
||||
def compute_weighted_average(vectors: list, weights: list) -> np.ndarray:
|
||||
"""
|
||||
计算加权平均向量
|
||||
|
||||
Args:
|
||||
vectors: 向量列表
|
||||
weights: 权重列表
|
||||
|
||||
Returns:
|
||||
加权平均向量(不做归一化,模长为加权平均后的尺度)
|
||||
"""
|
||||
if not vectors or not weights:
|
||||
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||
|
||||
# 确保所有向量都是 numpy array
|
||||
vectors = [np.array(v) for v in vectors]
|
||||
weights = np.array(weights)
|
||||
|
||||
# 计算加权和
|
||||
weighted_sum = np.zeros_like(vectors[0])
|
||||
for v, w in zip(vectors, weights):
|
||||
weighted_sum += v * w
|
||||
|
||||
# 返回加权平均(除以权重和,不做 L2 归一化,模长不会随条数线性暴涨)
|
||||
weight_total = weights.sum()
|
||||
if weight_total == 0:
|
||||
return weighted_sum
|
||||
return weighted_sum / weight_total
|
||||
|
||||
Reference in New Issue
Block a user