feat(新功能): sketch 推荐算法
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
118
app/api/api_recommendation.py
Normal file
118
app/api/api_recommendation.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
from fastapi import HTTPException, APIRouter
|
||||||
|
|
||||||
|
from app.service.recommend.service import load_resources, matrix_data
|
||||||
|
|
||||||
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||||
|
logger = logging.getLogger()
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
# 初始加载
|
||||||
|
load_resources()
|
||||||
|
|
||||||
|
# 配置定时任务
|
||||||
|
scheduler = BackgroundScheduler()
|
||||||
|
scheduler.add_job(
|
||||||
|
load_resources,
|
||||||
|
trigger=CronTrigger(hour=0, minute=30),
|
||||||
|
name="每日资源刷新"
|
||||||
|
)
|
||||||
|
scheduler.start()
|
||||||
|
logger.info("定时任务已启动")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/recommend/{user_id}/{category}/{num_recommendations}", response_model=List[str])
|
||||||
|
async def get_recommendations(user_id: int, category: str, num_recommendations: int = 10):
|
||||||
|
"""
|
||||||
|
:param user_id: 4
|
||||||
|
:param category: female_skirt
|
||||||
|
:param num_recommendations: 1
|
||||||
|
:return:
|
||||||
|
[
|
||||||
|
"aida-sys-image/images/female/skirt/903000017.jpg"
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
cache_key = (user_id, category)
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if cache_key in matrix_data["cached_scores"]:
|
||||||
|
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||||
|
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||||
|
else:
|
||||||
|
# 实时计算逻辑(同原代码)
|
||||||
|
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||||
|
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||||
|
|
||||||
|
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||||
|
valid_sketch_idxs_inter = [
|
||||||
|
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||||
|
if iid in category_iids
|
||||||
|
]
|
||||||
|
|
||||||
|
# 处理交互分数
|
||||||
|
raw_inter_scores = []
|
||||||
|
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||||
|
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||||
|
processed_inter = raw_inter_scores * 0.7
|
||||||
|
|
||||||
|
# 处理特征分数
|
||||||
|
valid_sketch_idxs_feature = [
|
||||||
|
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||||
|
if iid in category_iids
|
||||||
|
]
|
||||||
|
raw_feat_scores = []
|
||||||
|
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||||
|
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||||
|
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||||
|
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||||
|
processed_feat = raw_feat_scores * 0.3
|
||||||
|
else:
|
||||||
|
processed_feat = np.array([])
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||||
|
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||||
|
|
||||||
|
# 合并分数
|
||||||
|
final_scores = processed_inter + processed_feat
|
||||||
|
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||||
|
|
||||||
|
# 概率采样
|
||||||
|
scores = np.array(final_scores)
|
||||||
|
|
||||||
|
# 调整后的概率转换(带温度控制的softmax)
|
||||||
|
def calibrated_softmax(scores, temperature=1.0):
|
||||||
|
scores = scores / temperature
|
||||||
|
scale = scores - max(scores)
|
||||||
|
exps = np.exp(scale)
|
||||||
|
return exps / np.sum(exps)
|
||||||
|
|
||||||
|
probs = calibrated_softmax(scores, 0.07)
|
||||||
|
|
||||||
|
chosen_indices = np.random.choice(
|
||||||
|
len(valid_sketch_idxs),
|
||||||
|
size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||||
|
p=probs,
|
||||||
|
replace=False
|
||||||
|
)
|
||||||
|
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||||
|
|
||||||
|
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||||
|
return recommendations
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -10,6 +10,7 @@ from app.api import api_generate_image
|
|||||||
from app.api import api_image2sketch
|
from app.api import api_image2sketch
|
||||||
from app.api import api_prompt_generation
|
from app.api import api_prompt_generation
|
||||||
from app.api import api_super_resolution
|
from app.api import api_super_resolution
|
||||||
|
from app.api import api_recommendation
|
||||||
from app.api import api_test
|
from app.api import api_test
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -26,3 +27,4 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix
|
|||||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||||
|
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||||
|
|||||||
@@ -25,10 +25,13 @@ if DEBUG:
|
|||||||
LOGS_PATH = "logs/"
|
LOGS_PATH = "logs/"
|
||||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||||
SEG_CACHE_PATH = "../seg_cache/"
|
SEG_CACHE_PATH = "../seg_cache/"
|
||||||
|
RECOMMEND_PATH_PREFIX = "service/recommend/"
|
||||||
else:
|
else:
|
||||||
LOGS_PATH = "app/logs/"
|
LOGS_PATH = "app/logs/"
|
||||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||||
SEG_CACHE_PATH = "/seg_cache/"
|
SEG_CACHE_PATH = "/seg_cache/"
|
||||||
|
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
||||||
|
|
||||||
|
|
||||||
# RABBITMQ_ENV = "" # 生产环境
|
# RABBITMQ_ENV = "" # 生产环境
|
||||||
RABBITMQ_ENV = "-dev" # 开发环境
|
RABBITMQ_ENV = "-dev" # 开发环境
|
||||||
@@ -36,7 +39,6 @@ RABBITMQ_ENV = "-dev" # 开发环境
|
|||||||
|
|
||||||
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
|
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
# minio 配置
|
# minio 配置
|
||||||
@@ -114,7 +116,6 @@ GMV_MODEL_NAME = 'multi_view'
|
|||||||
|
|
||||||
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
|
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
|
||||||
|
|
||||||
|
|
||||||
GI_MINIO_BUCKET = "aida-users"
|
GI_MINIO_BUCKET = "aida-users"
|
||||||
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
||||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||||
@@ -191,3 +192,23 @@ PRIORITY_DICT = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
QWEN_API_KEY = "sk-a6bdf594e1f54a4aa3e9d4d48f8c661f"
|
QWEN_API_KEY = "sk-a6bdf594e1f54a4aa3e9d4d48f8c661f"
|
||||||
|
|
||||||
|
DB_CONFIG = {
|
||||||
|
"host": "18.167.251.121",
|
||||||
|
"port": 3306,
|
||||||
|
"user": "root",
|
||||||
|
"password": "QWa998345",
|
||||||
|
"database": "aida",
|
||||||
|
"charset": "utf8mb4"
|
||||||
|
}
|
||||||
|
|
||||||
|
TABLE_CATEGORIES = {
|
||||||
|
"female_dress": "female/dress",
|
||||||
|
"female_outwear": "female/outwear",
|
||||||
|
"female_trousers": "female/trousers",
|
||||||
|
"female_skirt": "female/skirt",
|
||||||
|
"female_blouse": "female/blouse",
|
||||||
|
"male_tops": "male/tops",
|
||||||
|
"male_bottoms": "male/bottoms",
|
||||||
|
"male_outwear": "male/outwear"
|
||||||
|
}
|
||||||
|
|||||||
12
app/main.py
12
app/main.py
@@ -1,15 +1,17 @@
|
|||||||
import logging.config
|
import logging.config
|
||||||
from http.client import HTTPException
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from app.api.api_route import router
|
from app.api.api_route import router
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.record_api_count import count_api_calls
|
from app.core.record_api_count import count_api_calls
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.recommend.service import load_resources
|
||||||
from logging_env import LOGGER_CONFIG_DICT
|
from logging_env import LOGGER_CONFIG_DICT
|
||||||
|
|
||||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||||
@@ -17,6 +19,8 @@ logging.getLogger("pika").setLevel(logging.WARNING)
|
|||||||
|
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_application() -> FastAPI:
|
def get_application() -> FastAPI:
|
||||||
application = FastAPI(
|
application = FastAPI(
|
||||||
@@ -51,5 +55,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|||||||
431
app/service/recommend/scheduled_task.py
Normal file
431
app/service/recommend/scheduled_task.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
import pymysql
|
||||||
|
import numpy as np
|
||||||
|
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
import torch
|
||||||
|
from torchvision import models, transforms
|
||||||
|
from minio import Minio
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from scipy.sparse import csr_matrix
|
||||||
|
import matplotlib.font_manager as fm
|
||||||
|
from scipy import sparse
|
||||||
|
|
||||||
|
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||||
|
|
||||||
|
# 自动选择可用字体
|
||||||
|
try:
|
||||||
|
# 尝试加载常见中文字体
|
||||||
|
font_path = fm.findfont(fm.FontProperties(family=['Microsoft YaHei', 'SimHei', 'WenQuanYi Zen Hei']))
|
||||||
|
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
|
||||||
|
except:
|
||||||
|
# 回退到英文字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
||||||
|
# 检查系统中可用的字体并选择支持中文的字体
|
||||||
|
font_path = fm.findfont(fm.FontProperties(family='Microsoft YaHei')) # 或其他支持中文的字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 设置为 Microsoft YaHei
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||||
|
|
||||||
|
# 配置日志记录
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
filename='scheduler.log'
|
||||||
|
)
|
||||||
|
|
||||||
|
# MinIO 配置信息
|
||||||
|
minio_client = Minio(
|
||||||
|
"www.minio.aida.com.hk:12024", # MinIO Endpoint
|
||||||
|
access_key="admin", # Access Key
|
||||||
|
secret_key="Aidlab123123!", # Secret Key
|
||||||
|
secure=True # 使用https
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预加载系统sketch特征向量
|
||||||
|
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||||
|
|
||||||
|
|
||||||
|
# 保存sketch_to_iid到文件
|
||||||
|
def save_sketch_to_iid():
|
||||||
|
"""保存sketch到iid的映射"""
|
||||||
|
sketch_to_iid = {sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)}
|
||||||
|
np.save('sketch_to_iid.npy', sketch_to_iid)
|
||||||
|
print("sketch_to_iid 已保存")
|
||||||
|
|
||||||
|
|
||||||
|
# 从文件加载sketch_to_iid
|
||||||
|
def load_sketch_to_iid():
|
||||||
|
"""加载保存的sketch到iid的映射"""
|
||||||
|
if os.path.exists('sketch_to_iid.npy'):
|
||||||
|
sketch_to_iid = np.load('sketch_to_iid.npy', allow_pickle=True).item()
|
||||||
|
print("sketch_to_iid 已加载")
|
||||||
|
return sketch_to_iid
|
||||||
|
else:
|
||||||
|
# 如果文件不存在,则生成并保存
|
||||||
|
print("sketch_to_iid 文件不存在,正在生成并保存...")
|
||||||
|
save_sketch_to_iid()
|
||||||
|
return np.load('sketch_to_iid.npy', allow_pickle=True).item()
|
||||||
|
|
||||||
|
|
||||||
|
# 使用load_sketch_to_iid来获取映射
|
||||||
|
sketch_to_iid = load_sketch_to_iid()
|
||||||
|
|
||||||
|
# 在代码中其他地方使用sketch_to_iid
|
||||||
|
# print(f"Total sketches: {len(sketch_to_iid)}")
|
||||||
|
|
||||||
|
# 定义图像预处理(与ResNet训练时的预处理一致)
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
|
||||||
|
transforms.ToTensor(), # 转换为 Tensor
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
|
||||||
|
])
|
||||||
|
|
||||||
|
# 加载预训练的 ResNet 模型 (ResNet50)
|
||||||
|
resnet_model = models.resnet50(pretrained=True)
|
||||||
|
modules = list(resnet_model.children())[:-1] # 移除最后的全连接层
|
||||||
|
resnet_model = torch.nn.Sequential(*modules)
|
||||||
|
resnet_model.eval() # 设置为评估模式
|
||||||
|
|
||||||
|
|
||||||
|
# 从 MinIO 获取图片并进行预处理
|
||||||
|
def get_sketch_image_from_minio(sketch_path):
|
||||||
|
"""
|
||||||
|
从 MinIO 获取 sketch 图像并预处理
|
||||||
|
"""
|
||||||
|
# 分割路径,获取桶名和文件路径
|
||||||
|
path_parts = sketch_path.split('/', 1) # 根据第一个斜杠分割,得到桶名和路径
|
||||||
|
bucket_name = path_parts[0] # 桶名
|
||||||
|
file_name = path_parts[1] # 文件路径(从第二部分开始)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取文件
|
||||||
|
obj = minio_client.get_object(bucket_name, file_name)
|
||||||
|
img_data = obj.read() # 读取图像数据
|
||||||
|
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
|
||||||
|
img = transform(img) # 对图像进行预处理
|
||||||
|
return img.unsqueeze(0) # 扩展维度以适应批量处理
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error fetching image for {sketch_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_feature_vector_from_resnet(sketch_path):
|
||||||
|
"""
|
||||||
|
提取 sketch 图像的特征向量
|
||||||
|
"""
|
||||||
|
# 从 MinIO 获取图像并预处理
|
||||||
|
img_tensor = get_sketch_image_from_minio(sketch_path)
|
||||||
|
if img_tensor is None:
|
||||||
|
return np.zeros(2048) # 如果获取失败,返回零向量
|
||||||
|
|
||||||
|
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
|
||||||
|
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
|
||||||
|
return feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
|
||||||
|
|
||||||
|
|
||||||
|
def update_user_matrices():
|
||||||
|
"""每天更新用户交互次数矩阵和特征向量矩阵"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**DB_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 修改后的查询语句(移除category过滤)
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT account_id, path, COUNT(*) as like_count
|
||||||
|
FROM user_preference_log_test
|
||||||
|
GROUP BY account_id, path
|
||||||
|
""")
|
||||||
|
user_data = cursor.fetchall()
|
||||||
|
logging.info(f"成功读取{len(user_data)}条用户偏好记录")
|
||||||
|
|
||||||
|
# 计算矩阵
|
||||||
|
interaction_matrix, raw_counts_sparse, user_index_interaction_matrix, sketch_index_interaction_matrix, iid_to_category_interaction_matrix = calculate_interaction_matrix(user_data)
|
||||||
|
# visualize_sparse_matrix(raw_counts_sparse,'交互次数矩阵', 'interaction_frequency_matrix.png')
|
||||||
|
# visualize_sparse_matrix(interaction_matrix, '交互次数得分矩阵', 'interaction_score_matrix.png')
|
||||||
|
# plot_interaction_count_matrix(raw_counts_sparse)
|
||||||
|
# feature_matrix, iid_to_category_feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix = calculate_feature_matrix(user_data)
|
||||||
|
feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix, iid_to_category_feature_matrix = calculate_feature_matrix(user_data)
|
||||||
|
# visualize_sparse_matrix(feature_matrix, '系统sketch与用户category平均特征向量关联度矩阵', 'correlation_matrix.png')
|
||||||
|
# 存储矩阵
|
||||||
|
np.save(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", interaction_matrix)
|
||||||
|
np.save(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", feature_matrix)
|
||||||
|
#
|
||||||
|
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", iid_to_category_interaction_matrix)
|
||||||
|
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", user_index_interaction_matrix)
|
||||||
|
#
|
||||||
|
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_feature_matrix.npy", iid_to_category_feature_matrix)
|
||||||
|
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", user_index_feature_matrix)
|
||||||
|
#
|
||||||
|
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", sketch_index_interaction_matrix)
|
||||||
|
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", sketch_index_feature_matrix)
|
||||||
|
# logging.info("矩阵更新完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"定时任务执行失败: {str(e)}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_interaction_count_matrix(interaction_count_matrix):
|
||||||
|
"""绘制交互次数矩阵的分布图(热图),不隐藏零值"""
|
||||||
|
try:
|
||||||
|
if not isinstance(interaction_count_matrix, csr_matrix):
|
||||||
|
sparse_matrix = csr_matrix(interaction_count_matrix)
|
||||||
|
else:
|
||||||
|
sparse_matrix = interaction_count_matrix
|
||||||
|
|
||||||
|
# 转换为密集矩阵
|
||||||
|
try:
|
||||||
|
dense_matrix = sparse_matrix.toarray()
|
||||||
|
except MemoryError:
|
||||||
|
logging.error("内存不足,无法转换为密集矩阵")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 自动检测可用中文字体
|
||||||
|
try:
|
||||||
|
font_path = fm.findfont(fm.FontProperties(family='sans-serif', style='normal'))
|
||||||
|
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
|
||||||
|
except:
|
||||||
|
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] # 回退到英文字体
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
||||||
|
# 处理大矩阵的显示,限制显示范围
|
||||||
|
if dense_matrix.shape[0] > 1000 or dense_matrix.shape[1] > 1000:
|
||||||
|
dense_matrix = dense_matrix[:1000, :1000] # 只绘制前1000行列
|
||||||
|
|
||||||
|
plt.figure(figsize=(15, 10))
|
||||||
|
|
||||||
|
# 使用 `cmap` 来设置颜色,零值可以使用特定颜色,调整 `vmin` 和 `vmax` 让热图更具对比
|
||||||
|
sns.heatmap(
|
||||||
|
dense_matrix,
|
||||||
|
cmap="Blues", # 可以选择不同的颜色映射,"Blues" 或 "YlGnBu"
|
||||||
|
annot=False, # 关闭标注
|
||||||
|
cbar_kws={"label": "Interaction Count"}, # 添加颜色条标签
|
||||||
|
linewidths=0.5,
|
||||||
|
vmin=0, # 设置最小值,确保零值明显
|
||||||
|
vmax=np.max(dense_matrix) # 设置最大值,保持颜色映射的合理性
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.title('User-Sketch Interaction Matrix (With Zero Entries)')
|
||||||
|
plt.xlabel('Sketch Index')
|
||||||
|
plt.ylabel('User Index')
|
||||||
|
plt.savefig('interaction_heatmap_with_zeros.png', dpi=150, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
logging.info("热图已保存为 interaction_heatmap_with_zeros.png")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"绘图失败: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
def visualize_sparse_matrix(matrix, title='Non-zero Interactions (Scatter Plot)', filename="scatter_figure_interaction.png"):
|
||||||
|
if not sparse.issparse(matrix):
|
||||||
|
# 转换为稀疏矩阵
|
||||||
|
matrix = sparse.csr_matrix(matrix)
|
||||||
|
|
||||||
|
# 获取非零元素的坐标和值
|
||||||
|
rows, cols = matrix.nonzero()
|
||||||
|
values = matrix.data
|
||||||
|
|
||||||
|
# 绘制散点图
|
||||||
|
plt.figure(figsize=(24, 20))
|
||||||
|
plt.scatter(cols, rows, c=values, cmap='coolwarm', alpha=0.7, s=1)
|
||||||
|
plt.colorbar(label='Interaction Count')
|
||||||
|
plt.title(title)
|
||||||
|
plt.xlabel('Item Index')
|
||||||
|
plt.ylabel('Item Index')
|
||||||
|
plt.savefig(filename)
|
||||||
|
|
||||||
|
def calculate_interaction_matrix(user_data):
|
||||||
|
"""基于新表结构的交互次数矩阵计算(仅系统sketch)"""
|
||||||
|
# 获取所有用户ID
|
||||||
|
all_users = set()
|
||||||
|
for account_id, path, like_count in user_data:
|
||||||
|
category = get_category_from_path(path)
|
||||||
|
if category not in TABLE_CATEGORIES.keys():
|
||||||
|
continue
|
||||||
|
all_users.add(account_id)
|
||||||
|
|
||||||
|
# 获取所有系统sketch的iid
|
||||||
|
system_sketch_iids = {sketch_to_iid[path] for path in SYSTEM_FEATURES.keys() if path in sketch_to_iid}
|
||||||
|
|
||||||
|
# 创建映射关系
|
||||||
|
user_index = {uid: idx for idx, uid in enumerate(sorted(all_users))}
|
||||||
|
sketch_index = {iid: idx for idx, iid in enumerate(sorted(system_sketch_iids))}
|
||||||
|
|
||||||
|
# 初始化双矩阵:归一化矩阵(密集)和原始计数矩阵(稀疏)
|
||||||
|
interaction_matrix = np.zeros((len(all_users), len(sketch_index))) # 归一化矩阵
|
||||||
|
data, rows, cols = [], [], [] # 用于构建稀疏矩阵的COO格式数据
|
||||||
|
|
||||||
|
# 预计算用户最大交互次数
|
||||||
|
user_max_likes = defaultdict(int)
|
||||||
|
for account_id, path, like_count in user_data:
|
||||||
|
if sketch_to_iid.get(path) in system_sketch_iids:
|
||||||
|
user_max_likes[account_id] = max(user_max_likes[account_id], like_count)
|
||||||
|
|
||||||
|
# 填充矩阵
|
||||||
|
for account_id, path, like_count in user_data:
|
||||||
|
sketch_iid = sketch_to_iid.get(path)
|
||||||
|
if sketch_iid not in system_sketch_iids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
user_idx = user_index[account_id]
|
||||||
|
sketch_idx = sketch_index[sketch_iid]
|
||||||
|
|
||||||
|
# 填充稀疏矩阵数据
|
||||||
|
data.append(like_count)
|
||||||
|
rows.append(user_idx)
|
||||||
|
cols.append(sketch_idx)
|
||||||
|
|
||||||
|
# 归一化计算
|
||||||
|
max_like = user_max_likes.get(account_id, 1)
|
||||||
|
interaction_matrix[user_idx, sketch_idx] = np.log1p(1 + like_count) / np.log1p(1 + max_like)
|
||||||
|
|
||||||
|
# 构建稀疏矩阵(CSR格式适合快速行操作)
|
||||||
|
interaction_count_matrix = csr_matrix((data, (rows, cols)), shape=(len(all_users), len(sketch_index)))
|
||||||
|
|
||||||
|
return interaction_matrix, interaction_count_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_feature_matrix(user_data):
|
||||||
|
"""基于新表结构的特征矩阵计算,返回用户与系统草图的相似度矩阵(加权平均)"""
|
||||||
|
|
||||||
|
# 用户特征数据存储结构:{(account_id, category): {sketch_iid: [(feature_vector, weight)]}}
|
||||||
|
user_feature_weights = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
# 初始化所有用户和系统草图集合
|
||||||
|
all_users = set()
|
||||||
|
all_system_sketches = set(SYSTEM_FEATURES.keys())
|
||||||
|
|
||||||
|
# ==== 第一遍遍历:收集特征向量和权重 ====
|
||||||
|
for account_id, path, like_count in user_data:
|
||||||
|
category = get_category_from_path(path)
|
||||||
|
if category not in TABLE_CATEGORIES.keys():
|
||||||
|
continue
|
||||||
|
|
||||||
|
sketch_iid = sketch_to_iid.get(path)
|
||||||
|
if not sketch_iid:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 记录用户
|
||||||
|
all_users.add(account_id)
|
||||||
|
|
||||||
|
# 提取特征并记录权重(like_count)
|
||||||
|
if path in SYSTEM_FEATURES: # 系统草图
|
||||||
|
feature = SYSTEM_FEATURES[path]
|
||||||
|
weight = like_count # 使用like_count作为权重
|
||||||
|
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
|
||||||
|
else: # 用户草图
|
||||||
|
feature = extract_feature_vector_from_resnet(path)
|
||||||
|
weight = like_count
|
||||||
|
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
|
||||||
|
|
||||||
|
# ==== 第二遍遍历:收集所有系统草图iid ====
|
||||||
|
system_sketch_iids = set()
|
||||||
|
for sketch_path in SYSTEM_FEATURES:
|
||||||
|
if iid := sketch_to_iid.get(sketch_path):
|
||||||
|
system_sketch_iids.add(iid)
|
||||||
|
|
||||||
|
# ==== 创建索引映射 ====
|
||||||
|
user_list = sorted(all_users)
|
||||||
|
sketch_list = sorted(system_sketch_iids)
|
||||||
|
|
||||||
|
user_index = {uid: idx for idx, uid in enumerate(user_list)}
|
||||||
|
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
|
||||||
|
|
||||||
|
# ==== 初始化特征矩阵 ====
|
||||||
|
feature_matrix = np.zeros((len(user_list), len(sketch_list)))
|
||||||
|
|
||||||
|
# ==== 预计算加权平均特征向量 ====
|
||||||
|
user_avg_features = {}
|
||||||
|
for (account_id, category), sketches in user_feature_weights.items():
|
||||||
|
try:
|
||||||
|
# 展平所有特征向量和权重
|
||||||
|
all_features_weights = [(vec, weight) for vec_list in sketches.values() for vec, weight in vec_list]
|
||||||
|
|
||||||
|
if len(all_features_weights) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算总权重
|
||||||
|
total_weight = sum(weight for _, weight in all_features_weights)
|
||||||
|
if total_weight <= 0: # 防止除零错误
|
||||||
|
total_weight = 1.0
|
||||||
|
|
||||||
|
# 加权平均计算
|
||||||
|
weighted_sum = np.zeros_like(all_features_weights[0][0]) # 获取特征向量维度
|
||||||
|
for vec, weight in all_features_weights:
|
||||||
|
weighted_sum += vec * weight
|
||||||
|
|
||||||
|
avg_vec = weighted_sum / total_weight
|
||||||
|
user_avg_features[(account_id, category)] = avg_vec
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"用户({account_id},{category})加权特征计算失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ==== 计算相似度并填充矩阵 ====
|
||||||
|
for sketch_path, sys_vector in SYSTEM_FEATURES.items():
|
||||||
|
sketch_iid = sketch_to_iid.get(sketch_path)
|
||||||
|
|
||||||
|
system_sketch_category = get_category_from_path(sketch_path)
|
||||||
|
if not sketch_iid or sketch_iid not in sketch_index:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sketch_col = sketch_index[sketch_iid]
|
||||||
|
|
||||||
|
# 遍历所有用户
|
||||||
|
for account_id in all_users:
|
||||||
|
user_row = user_index.get(account_id)
|
||||||
|
if user_row is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取用户加权平均特征向量
|
||||||
|
try:
|
||||||
|
# 直接通过复合键获取用户特征向量
|
||||||
|
user_vec = user_avg_features[(account_id, system_sketch_category)]
|
||||||
|
except KeyError:
|
||||||
|
# 该用户在此类别下无特征数据
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算余弦相似度
|
||||||
|
cos_sim = cosine_similarity(user_vec, sys_vector)
|
||||||
|
feature_matrix[user_row, sketch_col] = cos_sim
|
||||||
|
|
||||||
|
return feature_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def get_category_from_path(path):
|
||||||
|
"""从path字段解析类别"""
|
||||||
|
try:
|
||||||
|
parts = path.split('/')
|
||||||
|
if len(parts) >= 2:
|
||||||
|
return f"{parts[2]}_{parts[3]}"
|
||||||
|
return "unknown"
|
||||||
|
except:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(vec1, vec2):
|
||||||
|
"""计算余弦相似度(增加零值处理)"""
|
||||||
|
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
||||||
|
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
update_user_matrices()
|
||||||
|
# scheduler = BlockingScheduler()
|
||||||
|
# scheduler.add_job(update_user_matrices, 'cron', hour=12, timezone='Asia/Shanghai')
|
||||||
|
# logging.info("定时任务已启动,每天12:00执行")
|
||||||
|
# scheduler.start()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logging.info("定时任务已停止")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"调度器启动失败: {str(e)}", exc_info=True)
|
||||||
172
app/service/recommend/service.py
Normal file
172
app/service/recommend/service.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# 预加载资源
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
import pymysql
|
||||||
|
|
||||||
|
matrix_data = {
|
||||||
|
"interaction_matrix": None,
|
||||||
|
"feature_matrix": None,
|
||||||
|
"user_index_interaction": None,
|
||||||
|
"sketch_index_interaction": None,
|
||||||
|
"user_index_feature": None,
|
||||||
|
"sketch_index_feature": None,
|
||||||
|
"iid_to_sketch": None,
|
||||||
|
"category_to_iids": None,
|
||||||
|
"cached_scores": {},
|
||||||
|
"cached_valid_idxs": {},
|
||||||
|
"category_sketch_idxs_inter": None,
|
||||||
|
"category_sketch_idxs_feature": None,
|
||||||
|
"user_inter_full": dict(),
|
||||||
|
"user_feat_full": dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_resources():
|
||||||
|
"""加载所有矩阵和映射关系,并触发预缓存"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 清空缓存
|
||||||
|
matrix_data["cached_scores"].clear()
|
||||||
|
matrix_data["cached_valid_idxs"].clear()
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
||||||
|
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
||||||
|
|
||||||
|
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
||||||
|
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
||||||
|
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
||||||
|
allow_pickle=True).item()
|
||||||
|
|
||||||
|
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
||||||
|
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
||||||
|
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
||||||
|
|
||||||
|
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
||||||
|
matrix_data["category_to_iids"] = defaultdict(list)
|
||||||
|
for iid, cat in category_to_iid_map.items():
|
||||||
|
matrix_data["category_to_iids"][cat].append(iid)
|
||||||
|
|
||||||
|
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
||||||
|
|
||||||
|
# 触发预缓存
|
||||||
|
precache_user_category()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"资源加载失败: {str(e)}")
|
||||||
|
raise RuntimeError("初始化失败")
|
||||||
|
|
||||||
|
|
||||||
|
def precache_user_category():
|
||||||
|
"""预缓存用户-分类组合数据"""
|
||||||
|
if not all([
|
||||||
|
matrix_data["interaction_matrix"] is not None,
|
||||||
|
matrix_data["feature_matrix"] is not None,
|
||||||
|
matrix_data["user_index_interaction"] is not None
|
||||||
|
]):
|
||||||
|
logger.warning("资源未加载完成,跳过预缓存")
|
||||||
|
return
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
user_categories = get_all_user_categories()
|
||||||
|
|
||||||
|
precached_count = 0
|
||||||
|
for user_id, categories in user_categories.items():
|
||||||
|
for category in categories:
|
||||||
|
cache_key = (user_id, category)
|
||||||
|
if cache_key in matrix_data["cached_scores"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取用户索引
|
||||||
|
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||||
|
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||||
|
|
||||||
|
# 获取类别对应的iid列表
|
||||||
|
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||||
|
|
||||||
|
# 过滤有效草图索引
|
||||||
|
valid_sketch_idxs_inter = [
|
||||||
|
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||||
|
if iid in category_iids
|
||||||
|
]
|
||||||
|
|
||||||
|
# 处理交互分数
|
||||||
|
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||||
|
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||||
|
processed_inter = raw_inter_scores * 0.7
|
||||||
|
else:
|
||||||
|
processed_inter = np.array([])
|
||||||
|
|
||||||
|
# 处理特征分数
|
||||||
|
valid_sketch_idxs_feature = [
|
||||||
|
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||||
|
if iid in category_iids
|
||||||
|
]
|
||||||
|
|
||||||
|
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||||
|
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||||
|
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||||
|
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||||
|
processed_feat = raw_feat_scores * 0.3
|
||||||
|
else:
|
||||||
|
processed_feat = np.array([])
|
||||||
|
|
||||||
|
# 缓存结果
|
||||||
|
if len(processed_inter) == len(processed_feat):
|
||||||
|
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||||
|
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||||
|
precached_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"预缓存完成,共缓存 {precached_count} 个组合,耗时: {time.time() - start_time:.2f}秒")
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_user_categories():
|
||||||
|
"""获取所有用户及其对应的分类"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**DB_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
query = """
|
||||||
|
SELECT DISTINCT account_id, path
|
||||||
|
FROM user_preference_log_prediction
|
||||||
|
"""
|
||||||
|
cursor.execute(query)
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
user_categories = defaultdict(set)
|
||||||
|
for account_id, path in results:
|
||||||
|
category = get_category_from_path(path)
|
||||||
|
user_categories[account_id].add(category)
|
||||||
|
|
||||||
|
return dict(user_categories)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库查询失败: {str(e)}")
|
||||||
|
return {}
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_category_from_path(path: str) -> str:
|
||||||
|
"""从路径解析类别"""
|
||||||
|
try:
|
||||||
|
parts = path.split('/')
|
||||||
|
if len(parts) >= 4:
|
||||||
|
return f"{parts[2]}_{parts[3]}"
|
||||||
|
return "unknown"
|
||||||
|
except:
|
||||||
|
return "unknown"
|
||||||
Reference in New Issue
Block a user