feat(新功能): sketch 推荐算法
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
118
app/api/api_recommendation.py
Normal file
118
app/api/api_recommendation.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
# 初始加载
|
||||
load_resources()
|
||||
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_job(
|
||||
load_resources,
|
||||
trigger=CronTrigger(hour=0, minute=30),
|
||||
name="每日资源刷新"
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("定时任务已启动")
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}/{num_recommendations}", response_model=List[str])
|
||||
async def get_recommendations(user_id: int, category: str, num_recommendations: int = 10):
|
||||
"""
|
||||
:param user_id: 4
|
||||
:param category: female_skirt
|
||||
:param num_recommendations: 1
|
||||
:return:
|
||||
[
|
||||
"aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
]
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
cache_key = (user_id, category)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
else:
|
||||
# 实时计算逻辑(同原代码)
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
valid_sketch_idxs_inter = [
|
||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
|
||||
# 处理交互分数
|
||||
raw_inter_scores = []
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
|
||||
# 处理特征分数
|
||||
valid_sketch_idxs_feature = [
|
||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
raw_feat_scores = []
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores * 0.3
|
||||
else:
|
||||
processed_feat = np.array([])
|
||||
|
||||
# 更新缓存
|
||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
|
||||
# 合并分数
|
||||
final_scores = processed_inter + processed_feat
|
||||
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
|
||||
# 概率采样
|
||||
scores = np.array(final_scores)
|
||||
|
||||
# 调整后的概率转换(带温度控制的softmax)
|
||||
def calibrated_softmax(scores, temperature=1.0):
|
||||
scores = scores / temperature
|
||||
scale = scores - max(scores)
|
||||
exps = np.exp(scale)
|
||||
return exps / np.sum(exps)
|
||||
|
||||
probs = calibrated_softmax(scores, 0.07)
|
||||
|
||||
chosen_indices = np.random.choice(
|
||||
len(valid_sketch_idxs),
|
||||
size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
p=probs,
|
||||
replace=False
|
||||
)
|
||||
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
|
||||
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -10,6 +10,7 @@ from app.api import api_generate_image
|
||||
from app.api import api_image2sketch
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_recommendation
|
||||
from app.api import api_test
|
||||
|
||||
router = APIRouter()
|
||||
@@ -26,3 +27,4 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix
|
||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||
|
||||
@@ -25,10 +25,13 @@ if DEBUG:
|
||||
LOGS_PATH = "logs/"
|
||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "../seg_cache/"
|
||||
RECOMMEND_PATH_PREFIX = "service/recommend/"
|
||||
else:
|
||||
LOGS_PATH = "app/logs/"
|
||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "/seg_cache/"
|
||||
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
||||
|
||||
|
||||
# RABBITMQ_ENV = "" # 生产环境
|
||||
RABBITMQ_ENV = "-dev" # 开发环境
|
||||
@@ -36,7 +39,6 @@ RABBITMQ_ENV = "-dev" # 开发环境
|
||||
|
||||
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# minio 配置
|
||||
@@ -114,7 +116,6 @@ GMV_MODEL_NAME = 'multi_view'
|
||||
|
||||
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
|
||||
|
||||
|
||||
GI_MINIO_BUCKET = "aida-users"
|
||||
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||
@@ -191,3 +192,23 @@ PRIORITY_DICT = {
|
||||
}
|
||||
|
||||
QWEN_API_KEY = "sk-a6bdf594e1f54a4aa3e9d4d48f8c661f"
|
||||
|
||||
DB_CONFIG = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": "QWa998345",
|
||||
"database": "aida",
|
||||
"charset": "utf8mb4"
|
||||
}
|
||||
|
||||
TABLE_CATEGORIES = {
|
||||
"female_dress": "female/dress",
|
||||
"female_outwear": "female/outwear",
|
||||
"female_trousers": "female/trousers",
|
||||
"female_skirt": "female/skirt",
|
||||
"female_blouse": "female/blouse",
|
||||
"male_tops": "male/tops",
|
||||
"male_bottoms": "male/bottoms",
|
||||
"male_outwear": "male/outwear"
|
||||
}
|
||||
|
||||
12
app/main.py
12
app/main.py
@@ -1,15 +1,17 @@
|
||||
import logging.config
|
||||
from http.client import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
import uvicorn
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api.api_route import router
|
||||
from app.core.config import settings
|
||||
from app.core.record_api_count import count_api_calls
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommend.service import load_resources
|
||||
from logging_env import LOGGER_CONFIG_DICT
|
||||
|
||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||
@@ -17,6 +19,8 @@ logging.getLogger("pika").setLevel(logging.WARNING)
|
||||
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(
|
||||
@@ -51,5 +55,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
431
app/service/recommend/scheduled_task.py
Normal file
431
app/service/recommend/scheduled_task.py
Normal file
@@ -0,0 +1,431 @@
|
||||
import pymysql
|
||||
import numpy as np
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
import os
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
from torchvision import models, transforms
|
||||
from minio import Minio
|
||||
from PIL import Image
|
||||
import io
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.sparse import csr_matrix
|
||||
import matplotlib.font_manager as fm
|
||||
from scipy import sparse
|
||||
|
||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||
|
||||
# 自动选择可用字体
|
||||
try:
|
||||
# 尝试加载常见中文字体
|
||||
font_path = fm.findfont(fm.FontProperties(family=['Microsoft YaHei', 'SimHei', 'WenQuanYi Zen Hei']))
|
||||
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
|
||||
except:
|
||||
# 回退到英文字体
|
||||
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# 检查系统中可用的字体并选择支持中文的字体
|
||||
font_path = fm.findfont(fm.FontProperties(family='Microsoft YaHei')) # 或其他支持中文的字体
|
||||
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 设置为 Microsoft YaHei
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
filename='scheduler.log'
|
||||
)
|
||||
|
||||
# MinIO 配置信息
|
||||
minio_client = Minio(
|
||||
"www.minio.aida.com.hk:12024", # MinIO Endpoint
|
||||
access_key="admin", # Access Key
|
||||
secret_key="Aidlab123123!", # Secret Key
|
||||
secure=True # 使用https
|
||||
)
|
||||
|
||||
# 预加载系统sketch特征向量
|
||||
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||
|
||||
|
||||
# 保存sketch_to_iid到文件
|
||||
def save_sketch_to_iid():
|
||||
"""保存sketch到iid的映射"""
|
||||
sketch_to_iid = {sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)}
|
||||
np.save('sketch_to_iid.npy', sketch_to_iid)
|
||||
print("sketch_to_iid 已保存")
|
||||
|
||||
|
||||
# 从文件加载sketch_to_iid
|
||||
def load_sketch_to_iid():
|
||||
"""加载保存的sketch到iid的映射"""
|
||||
if os.path.exists('sketch_to_iid.npy'):
|
||||
sketch_to_iid = np.load('sketch_to_iid.npy', allow_pickle=True).item()
|
||||
print("sketch_to_iid 已加载")
|
||||
return sketch_to_iid
|
||||
else:
|
||||
# 如果文件不存在,则生成并保存
|
||||
print("sketch_to_iid 文件不存在,正在生成并保存...")
|
||||
save_sketch_to_iid()
|
||||
return np.load('sketch_to_iid.npy', allow_pickle=True).item()
|
||||
|
||||
|
||||
# 使用load_sketch_to_iid来获取映射
|
||||
sketch_to_iid = load_sketch_to_iid()
|
||||
|
||||
# 在代码中其他地方使用sketch_to_iid
|
||||
# print(f"Total sketches: {len(sketch_to_iid)}")
|
||||
|
||||
# 定义图像预处理(与ResNet训练时的预处理一致)
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
|
||||
transforms.ToTensor(), # 转换为 Tensor
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
|
||||
])
|
||||
|
||||
# 加载预训练的 ResNet 模型 (ResNet50)
|
||||
resnet_model = models.resnet50(pretrained=True)
|
||||
modules = list(resnet_model.children())[:-1] # 移除最后的全连接层
|
||||
resnet_model = torch.nn.Sequential(*modules)
|
||||
resnet_model.eval() # 设置为评估模式
|
||||
|
||||
|
||||
# 从 MinIO 获取图片并进行预处理
|
||||
def get_sketch_image_from_minio(sketch_path):
|
||||
"""
|
||||
从 MinIO 获取 sketch 图像并预处理
|
||||
"""
|
||||
# 分割路径,获取桶名和文件路径
|
||||
path_parts = sketch_path.split('/', 1) # 根据第一个斜杠分割,得到桶名和路径
|
||||
bucket_name = path_parts[0] # 桶名
|
||||
file_name = path_parts[1] # 文件路径(从第二部分开始)
|
||||
|
||||
try:
|
||||
# 获取文件
|
||||
obj = minio_client.get_object(bucket_name, file_name)
|
||||
img_data = obj.read() # 读取图像数据
|
||||
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
|
||||
img = transform(img) # 对图像进行预处理
|
||||
return img.unsqueeze(0) # 扩展维度以适应批量处理
|
||||
except Exception as e:
|
||||
print(f"Error fetching image for {sketch_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_feature_vector_from_resnet(sketch_path):
|
||||
"""
|
||||
提取 sketch 图像的特征向量
|
||||
"""
|
||||
# 从 MinIO 获取图像并预处理
|
||||
img_tensor = get_sketch_image_from_minio(sketch_path)
|
||||
if img_tensor is None:
|
||||
return np.zeros(2048) # 如果获取失败,返回零向量
|
||||
|
||||
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
|
||||
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
|
||||
return feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
|
||||
|
||||
|
||||
def update_user_matrices():
|
||||
"""每天更新用户交互次数矩阵和特征向量矩阵"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 修改后的查询语句(移除category过滤)
|
||||
cursor.execute("""
|
||||
SELECT account_id, path, COUNT(*) as like_count
|
||||
FROM user_preference_log_test
|
||||
GROUP BY account_id, path
|
||||
""")
|
||||
user_data = cursor.fetchall()
|
||||
logging.info(f"成功读取{len(user_data)}条用户偏好记录")
|
||||
|
||||
# 计算矩阵
|
||||
interaction_matrix, raw_counts_sparse, user_index_interaction_matrix, sketch_index_interaction_matrix, iid_to_category_interaction_matrix = calculate_interaction_matrix(user_data)
|
||||
# visualize_sparse_matrix(raw_counts_sparse,'交互次数矩阵', 'interaction_frequency_matrix.png')
|
||||
# visualize_sparse_matrix(interaction_matrix, '交互次数得分矩阵', 'interaction_score_matrix.png')
|
||||
# plot_interaction_count_matrix(raw_counts_sparse)
|
||||
# feature_matrix, iid_to_category_feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix = calculate_feature_matrix(user_data)
|
||||
feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix, iid_to_category_feature_matrix = calculate_feature_matrix(user_data)
|
||||
# visualize_sparse_matrix(feature_matrix, '系统sketch与用户category平均特征向量关联度矩阵', 'correlation_matrix.png')
|
||||
# 存储矩阵
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", interaction_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", feature_matrix)
|
||||
#
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", iid_to_category_interaction_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", user_index_interaction_matrix)
|
||||
#
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_feature_matrix.npy", iid_to_category_feature_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", user_index_feature_matrix)
|
||||
#
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", sketch_index_interaction_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", sketch_index_feature_matrix)
|
||||
# logging.info("矩阵更新完成")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"定时任务执行失败: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def plot_interaction_count_matrix(interaction_count_matrix):
|
||||
"""绘制交互次数矩阵的分布图(热图),不隐藏零值"""
|
||||
try:
|
||||
if not isinstance(interaction_count_matrix, csr_matrix):
|
||||
sparse_matrix = csr_matrix(interaction_count_matrix)
|
||||
else:
|
||||
sparse_matrix = interaction_count_matrix
|
||||
|
||||
# 转换为密集矩阵
|
||||
try:
|
||||
dense_matrix = sparse_matrix.toarray()
|
||||
except MemoryError:
|
||||
logging.error("内存不足,无法转换为密集矩阵")
|
||||
return
|
||||
|
||||
# 自动检测可用中文字体
|
||||
try:
|
||||
font_path = fm.findfont(fm.FontProperties(family='sans-serif', style='normal'))
|
||||
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
|
||||
except:
|
||||
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] # 回退到英文字体
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# 处理大矩阵的显示,限制显示范围
|
||||
if dense_matrix.shape[0] > 1000 or dense_matrix.shape[1] > 1000:
|
||||
dense_matrix = dense_matrix[:1000, :1000] # 只绘制前1000行列
|
||||
|
||||
plt.figure(figsize=(15, 10))
|
||||
|
||||
# 使用 `cmap` 来设置颜色,零值可以使用特定颜色,调整 `vmin` 和 `vmax` 让热图更具对比
|
||||
sns.heatmap(
|
||||
dense_matrix,
|
||||
cmap="Blues", # 可以选择不同的颜色映射,"Blues" 或 "YlGnBu"
|
||||
annot=False, # 关闭标注
|
||||
cbar_kws={"label": "Interaction Count"}, # 添加颜色条标签
|
||||
linewidths=0.5,
|
||||
vmin=0, # 设置最小值,确保零值明显
|
||||
vmax=np.max(dense_matrix) # 设置最大值,保持颜色映射的合理性
|
||||
)
|
||||
|
||||
plt.title('User-Sketch Interaction Matrix (With Zero Entries)')
|
||||
plt.xlabel('Sketch Index')
|
||||
plt.ylabel('User Index')
|
||||
plt.savefig('interaction_heatmap_with_zeros.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logging.info("热图已保存为 interaction_heatmap_with_zeros.png")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"绘图失败: {str(e)}", exc_info=True)
|
||||
|
||||
def visualize_sparse_matrix(matrix, title='Non-zero Interactions (Scatter Plot)', filename="scatter_figure_interaction.png"):
|
||||
if not sparse.issparse(matrix):
|
||||
# 转换为稀疏矩阵
|
||||
matrix = sparse.csr_matrix(matrix)
|
||||
|
||||
# 获取非零元素的坐标和值
|
||||
rows, cols = matrix.nonzero()
|
||||
values = matrix.data
|
||||
|
||||
# 绘制散点图
|
||||
plt.figure(figsize=(24, 20))
|
||||
plt.scatter(cols, rows, c=values, cmap='coolwarm', alpha=0.7, s=1)
|
||||
plt.colorbar(label='Interaction Count')
|
||||
plt.title(title)
|
||||
plt.xlabel('Item Index')
|
||||
plt.ylabel('Item Index')
|
||||
plt.savefig(filename)
|
||||
|
||||
def calculate_interaction_matrix(user_data):
|
||||
"""基于新表结构的交互次数矩阵计算(仅系统sketch)"""
|
||||
# 获取所有用户ID
|
||||
all_users = set()
|
||||
for account_id, path, like_count in user_data:
|
||||
category = get_category_from_path(path)
|
||||
if category not in TABLE_CATEGORIES.keys():
|
||||
continue
|
||||
all_users.add(account_id)
|
||||
|
||||
# 获取所有系统sketch的iid
|
||||
system_sketch_iids = {sketch_to_iid[path] for path in SYSTEM_FEATURES.keys() if path in sketch_to_iid}
|
||||
|
||||
# 创建映射关系
|
||||
user_index = {uid: idx for idx, uid in enumerate(sorted(all_users))}
|
||||
sketch_index = {iid: idx for idx, iid in enumerate(sorted(system_sketch_iids))}
|
||||
|
||||
# 初始化双矩阵:归一化矩阵(密集)和原始计数矩阵(稀疏)
|
||||
interaction_matrix = np.zeros((len(all_users), len(sketch_index))) # 归一化矩阵
|
||||
data, rows, cols = [], [], [] # 用于构建稀疏矩阵的COO格式数据
|
||||
|
||||
# 预计算用户最大交互次数
|
||||
user_max_likes = defaultdict(int)
|
||||
for account_id, path, like_count in user_data:
|
||||
if sketch_to_iid.get(path) in system_sketch_iids:
|
||||
user_max_likes[account_id] = max(user_max_likes[account_id], like_count)
|
||||
|
||||
# 填充矩阵
|
||||
for account_id, path, like_count in user_data:
|
||||
sketch_iid = sketch_to_iid.get(path)
|
||||
if sketch_iid not in system_sketch_iids:
|
||||
continue
|
||||
|
||||
user_idx = user_index[account_id]
|
||||
sketch_idx = sketch_index[sketch_iid]
|
||||
|
||||
# 填充稀疏矩阵数据
|
||||
data.append(like_count)
|
||||
rows.append(user_idx)
|
||||
cols.append(sketch_idx)
|
||||
|
||||
# 归一化计算
|
||||
max_like = user_max_likes.get(account_id, 1)
|
||||
interaction_matrix[user_idx, sketch_idx] = np.log1p(1 + like_count) / np.log1p(1 + max_like)
|
||||
|
||||
# 构建稀疏矩阵(CSR格式适合快速行操作)
|
||||
interaction_count_matrix = csr_matrix((data, (rows, cols)), shape=(len(all_users), len(sketch_index)))
|
||||
|
||||
return interaction_matrix, interaction_count_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
|
||||
|
||||
|
||||
def calculate_feature_matrix(user_data):
|
||||
"""基于新表结构的特征矩阵计算,返回用户与系统草图的相似度矩阵(加权平均)"""
|
||||
|
||||
# 用户特征数据存储结构:{(account_id, category): {sketch_iid: [(feature_vector, weight)]}}
|
||||
user_feature_weights = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# 初始化所有用户和系统草图集合
|
||||
all_users = set()
|
||||
all_system_sketches = set(SYSTEM_FEATURES.keys())
|
||||
|
||||
# ==== 第一遍遍历:收集特征向量和权重 ====
|
||||
for account_id, path, like_count in user_data:
|
||||
category = get_category_from_path(path)
|
||||
if category not in TABLE_CATEGORIES.keys():
|
||||
continue
|
||||
|
||||
sketch_iid = sketch_to_iid.get(path)
|
||||
if not sketch_iid:
|
||||
continue
|
||||
|
||||
# 记录用户
|
||||
all_users.add(account_id)
|
||||
|
||||
# 提取特征并记录权重(like_count)
|
||||
if path in SYSTEM_FEATURES: # 系统草图
|
||||
feature = SYSTEM_FEATURES[path]
|
||||
weight = like_count # 使用like_count作为权重
|
||||
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
|
||||
else: # 用户草图
|
||||
feature = extract_feature_vector_from_resnet(path)
|
||||
weight = like_count
|
||||
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
|
||||
|
||||
# ==== 第二遍遍历:收集所有系统草图iid ====
|
||||
system_sketch_iids = set()
|
||||
for sketch_path in SYSTEM_FEATURES:
|
||||
if iid := sketch_to_iid.get(sketch_path):
|
||||
system_sketch_iids.add(iid)
|
||||
|
||||
# ==== 创建索引映射 ====
|
||||
user_list = sorted(all_users)
|
||||
sketch_list = sorted(system_sketch_iids)
|
||||
|
||||
user_index = {uid: idx for idx, uid in enumerate(user_list)}
|
||||
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
|
||||
|
||||
# ==== 初始化特征矩阵 ====
|
||||
feature_matrix = np.zeros((len(user_list), len(sketch_list)))
|
||||
|
||||
# ==== 预计算加权平均特征向量 ====
|
||||
user_avg_features = {}
|
||||
for (account_id, category), sketches in user_feature_weights.items():
|
||||
try:
|
||||
# 展平所有特征向量和权重
|
||||
all_features_weights = [(vec, weight) for vec_list in sketches.values() for vec, weight in vec_list]
|
||||
|
||||
if len(all_features_weights) == 0:
|
||||
continue
|
||||
|
||||
# 计算总权重
|
||||
total_weight = sum(weight for _, weight in all_features_weights)
|
||||
if total_weight <= 0: # 防止除零错误
|
||||
total_weight = 1.0
|
||||
|
||||
# 加权平均计算
|
||||
weighted_sum = np.zeros_like(all_features_weights[0][0]) # 获取特征向量维度
|
||||
for vec, weight in all_features_weights:
|
||||
weighted_sum += vec * weight
|
||||
|
||||
avg_vec = weighted_sum / total_weight
|
||||
user_avg_features[(account_id, category)] = avg_vec
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"用户({account_id},{category})加权特征计算失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# ==== 计算相似度并填充矩阵 ====
|
||||
for sketch_path, sys_vector in SYSTEM_FEATURES.items():
|
||||
sketch_iid = sketch_to_iid.get(sketch_path)
|
||||
|
||||
system_sketch_category = get_category_from_path(sketch_path)
|
||||
if not sketch_iid or sketch_iid not in sketch_index:
|
||||
continue
|
||||
|
||||
sketch_col = sketch_index[sketch_iid]
|
||||
|
||||
# 遍历所有用户
|
||||
for account_id in all_users:
|
||||
user_row = user_index.get(account_id)
|
||||
if user_row is None:
|
||||
continue
|
||||
|
||||
# 获取用户加权平均特征向量
|
||||
try:
|
||||
# 直接通过复合键获取用户特征向量
|
||||
user_vec = user_avg_features[(account_id, system_sketch_category)]
|
||||
except KeyError:
|
||||
# 该用户在此类别下无特征数据
|
||||
continue
|
||||
|
||||
# 计算余弦相似度
|
||||
cos_sim = cosine_similarity(user_vec, sys_vector)
|
||||
feature_matrix[user_row, sketch_col] = cos_sim
|
||||
|
||||
return feature_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
|
||||
|
||||
|
||||
def get_category_from_path(path):
|
||||
"""从path字段解析类别"""
|
||||
try:
|
||||
parts = path.split('/')
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[2]}_{parts[3]}"
|
||||
return "unknown"
|
||||
except:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def cosine_similarity(vec1, vec2):
|
||||
"""计算余弦相似度(增加零值处理)"""
|
||||
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
||||
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
update_user_matrices()
|
||||
# scheduler = BlockingScheduler()
|
||||
# scheduler.add_job(update_user_matrices, 'cron', hour=12, timezone='Asia/Shanghai')
|
||||
# logging.info("定时任务已启动,每天12:00执行")
|
||||
# scheduler.start()
|
||||
except KeyboardInterrupt:
|
||||
logging.info("定时任务已停止")
|
||||
except Exception as e:
|
||||
logging.error(f"调度器启动失败: {str(e)}", exc_info=True)
|
||||
172
app/service/recommend/service.py
Normal file
172
app/service/recommend/service.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# 预加载资源
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
||||
|
||||
logger = logging.getLogger()
|
||||
import pymysql
|
||||
|
||||
matrix_data = {
|
||||
"interaction_matrix": None,
|
||||
"feature_matrix": None,
|
||||
"user_index_interaction": None,
|
||||
"sketch_index_interaction": None,
|
||||
"user_index_feature": None,
|
||||
"sketch_index_feature": None,
|
||||
"iid_to_sketch": None,
|
||||
"category_to_iids": None,
|
||||
"cached_scores": {},
|
||||
"cached_valid_idxs": {},
|
||||
"category_sketch_idxs_inter": None,
|
||||
"category_sketch_idxs_feature": None,
|
||||
"user_inter_full": dict(),
|
||||
"user_feat_full": dict(),
|
||||
}
|
||||
|
||||
|
||||
def load_resources():
|
||||
"""加载所有矩阵和映射关系,并触发预缓存"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 清空缓存
|
||||
matrix_data["cached_scores"].clear()
|
||||
matrix_data["cached_valid_idxs"].clear()
|
||||
|
||||
# 加载数据
|
||||
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
||||
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
||||
|
||||
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
||||
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
||||
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
||||
allow_pickle=True).item()
|
||||
|
||||
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
||||
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
||||
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
||||
|
||||
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
||||
matrix_data["category_to_iids"] = defaultdict(list)
|
||||
for iid, cat in category_to_iid_map.items():
|
||||
matrix_data["category_to_iids"][cat].append(iid)
|
||||
|
||||
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 触发预缓存
|
||||
precache_user_category()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"资源加载失败: {str(e)}")
|
||||
raise RuntimeError("初始化失败")
|
||||
|
||||
|
||||
def precache_user_category():
|
||||
"""预缓存用户-分类组合数据"""
|
||||
if not all([
|
||||
matrix_data["interaction_matrix"] is not None,
|
||||
matrix_data["feature_matrix"] is not None,
|
||||
matrix_data["user_index_interaction"] is not None
|
||||
]):
|
||||
logger.warning("资源未加载完成,跳过预缓存")
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
user_categories = get_all_user_categories()
|
||||
|
||||
precached_count = 0
|
||||
for user_id, categories in user_categories.items():
|
||||
for category in categories:
|
||||
cache_key = (user_id, category)
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 获取用户索引
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
|
||||
# 获取类别对应的iid列表
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
|
||||
# 过滤有效草图索引
|
||||
valid_sketch_idxs_inter = [
|
||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
|
||||
# 处理交互分数
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
else:
|
||||
processed_inter = np.array([])
|
||||
|
||||
# 处理特征分数
|
||||
valid_sketch_idxs_feature = [
|
||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores * 0.3
|
||||
else:
|
||||
processed_feat = np.array([])
|
||||
|
||||
# 缓存结果
|
||||
if len(processed_inter) == len(processed_feat):
|
||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
precached_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
||||
|
||||
logger.info(f"预缓存完成,共缓存 {precached_count} 个组合,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
|
||||
def get_all_user_categories():
|
||||
"""获取所有用户及其对应的分类"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = """
|
||||
SELECT DISTINCT account_id, path
|
||||
FROM user_preference_log_prediction
|
||||
"""
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
|
||||
user_categories = defaultdict(set)
|
||||
for account_id, path in results:
|
||||
category = get_category_from_path(path)
|
||||
user_categories[account_id].add(category)
|
||||
|
||||
return dict(user_categories)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询失败: {str(e)}")
|
||||
return {}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_category_from_path(path: str) -> str:
|
||||
"""从路径解析类别"""
|
||||
try:
|
||||
parts = path.split('/')
|
||||
if len(parts) >= 4:
|
||||
return f"{parts[2]}_{parts[3]}"
|
||||
return "unknown"
|
||||
except:
|
||||
return "unknown"
|
||||
Reference in New Issue
Block a user