TASK:冷启动热度推荐;

This commit is contained in:
shahaibo
2025-06-10 10:54:20 +08:00
parent a14e6051b1
commit d39dee851f
4 changed files with 400 additions and 16 deletions

View File

@@ -0,0 +1,212 @@
import io
import logging
import sys
import time
from typing import List
from collections import defaultdict
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
import pymysql
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
from minio import Minio
import torch
from torchvision import models, transforms
from PIL import Image
import os
from fastapi.responses import JSONResponse
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
# MinIO 配置
minio_client = Minio(
"www.minio.aida.com.hk:12024",
access_key="admin",
secret_key="Aidlab123123!",
secure=True
)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ResNet50去掉最后全连接层
resnet_model = models.resnet50(pretrained=True)
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1])
resnet_model.eval()
def get_sketch_image_from_minio(sketch_path: str):
path_parts = sketch_path.split('/', 1)
if len(path_parts) != 2:
return None
bucket_name, file_name = path_parts
try:
obj = minio_client.get_object(bucket_name, file_name)
img = Image.open(io.BytesIO(obj.read()))
return transform(img).unsqueeze(0)
except Exception as e:
logger.warning(f"Fetch image failed [{sketch_path}]: {e}")
return None
def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
img_tensor = get_sketch_image_from_minio(sketch_path)
if img_tensor is None:
return np.zeros(2048, dtype=np.float32)
with torch.no_grad():
vec = resnet_model(img_tensor) # [1, 2048, 1, 1]
return vec.squeeze().cpu().numpy()
# 预加载
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
def save_sketch_to_iid():
sketch_to_iid = {
sketch_path: iid
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
}
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
def load_sketch_to_iid():
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
if os.path.exists(path):
return np.load(path, allow_pickle=True).item()
save_sketch_to_iid()
return np.load(path, allow_pickle=True).item()
sketch_to_iid = load_sketch_to_iid()
def getNewCategory(gender: str, sketch_category: str) -> str:
return f"{gender.lower()}_{sketch_category.lower()}"
def get_category_from_path(path: str) -> str:
parts = path.split('/')
if len(parts) >= 4:
return f"{parts[2].lower()}_{parts[3].lower()}"
return "unknown_unknown"
def load_brand_matrix():
"""单独加载 brand_matrix 和 brand_index_map"""
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
try:
matrix = np.load(mat_path)
index_map = np.load(idx_path, allow_pickle=True).item()
except FileNotFoundError:
matrix = np.zeros((0, len(sketch_to_iid)), dtype=np.float32)
index_map = {}
return matrix, index_map
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
# 1. 收集品牌-分类-特征
brand_feature = defaultdict(lambda: defaultdict(list))
for _id, sketch_path, gender, sketch_category in sketch_data:
cat = getNewCategory(gender, sketch_category)
feat = BRAND_FEATURES.get(_id) or extract_feature_vector_from_resnet(sketch_path)
brand_feature[(brand_id, cat)][_id].append(feat)
# 2. 构建 sketch 索引
sketch_list = sorted(sketch_to_iid.values())
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
n_sketch = len(sketch_list)
# 3. 加载或初始化矩阵
brand_matrix, brand_index_map = load_brand_matrix()
# 4. 增加/更新 行
if brand_id in brand_index_map:
row_idx = brand_index_map[brand_id]
else:
row_idx = brand_matrix.shape[0]
brand_index_map[brand_id] = row_idx
brand_matrix = np.vstack([
brand_matrix,
np.zeros((1, n_sketch), dtype=np.float32)
])
# 5. 计算品牌-分类平均向量
brand_avg = {}
for key, id_dict in brand_feature.items():
all_feats = [v for feats in id_dict.values() for v in feats]
if all_feats:
brand_avg[key] = np.mean(all_feats, axis=0)
# 6. 填充相似度
for sketch_path, sys_vec in SYSTEM_FEATURES.items():
iid = sketch_to_iid.get(sketch_path)
if not iid or iid not in sketch_index:
continue
cat_key = (brand_id, get_category_from_path(sketch_path))
avg_vec = brand_avg.get(cat_key)
if avg_vec is not None:
cos_sim = cosine_similarity(avg_vec, sys_vec)
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
# 7. 持久化
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
# 返回该品牌对应行
return brand_matrix[row_idx:row_idx+1]
@router.get("/brand_dna_initialize/{brand_id}")
async def brand_dna_initialize(brand_id: int):
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
cursor.execute("""
SELECT id, img_url, gender, category
FROM product_image_attribute
WHERE library_id IN (
SELECT library_id
FROM brand_rel_library
WHERE brand_id = %s
)
""", (brand_id,))
sketch_data = cursor.fetchall()
# 触发计算并持久化,若内部出错会抛异常
_ = calculate_brand_matrix(sketch_data, brand_id)
# 返回成功
return {"success": True}
except HTTPException:
# 已经是明确的 HTTPException直接抛出
raise
except Exception as e:
logger.error(f"品牌初始化失败 [{brand_id}]: {e}", exc_info=True)
# 返回失败的 JSON同时设置 500 状态码
return JSONResponse(
status_code=500,
content={"success": False, "message": "品牌初始化失败"}
)
finally:
if conn:
conn.close()

View File

@@ -3,7 +3,10 @@ import logging
import sys
import time
from typing import List
import os
import json
import math
import random
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
@@ -31,18 +34,44 @@ async def startup_event():
scheduler.start()
logger.info("定时任务已启动")
def get_random_recommendations(category: str, num: int) -> List[str]:
"""全品类随机推荐"""
all_iids = list(matrix_data["iid_to_sketch"].keys())
# 优先从当前品类选择
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# 确保不超出实际数量
sample_size = min(num, len(category_iids))
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
def softmax(scores):
max_score = max(scores)
exp_scores = [math.exp(s - max_score) for s in scores]
sum_exp = sum(exp_scores)
return [s / sum_exp for s in exp_scores]
@router.get("/recommend/{user_id}/{category}/{num_recommendations}", response_model=List[str])
async def get_recommendations(user_id: int, category: str, num_recommendations: int = 10):
def get_random_recommendations(category: str, num: int) -> List[str]:
"""根据预加载热度向量推荐(冷启动)"""
try:
heat_data = matrix_data.get("heat_data", {})
if category not in heat_data:
raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
heat_dict = heat_data[category] # {url: score}
urls = list(heat_dict.keys())
scores = list(heat_dict.values())
if not urls:
raise ValueError("该类别下无热度记录,使用随机推荐")
probs = softmax(scores)
sample_size = min(num, len(urls))
sampled_urls = random.choices(urls, weights=probs, k=sample_size)
return sampled_urls
except Exception as e:
# 回退:完全随机推荐
all_iids = list(matrix_data["iid_to_sketch"].keys())
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
sample_size = min(num, len(category_iids))
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
"""
:param user_id: 4
:param category: female_skirt
@@ -95,7 +124,7 @@ async def get_recommendations(user_id: int, category: str, num_recommendations:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores * 0.3
processed_feat = raw_feat_scores
else:
processed_feat = np.array([])
@@ -104,7 +133,22 @@ async def get_recommendations(user_id: int, category: str, num_recommendations:
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
# 合并分数
final_scores = processed_inter + processed_feat
if brand_id is not None:
if brand_id is not None:
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
if brand_idx_feature is not None and valid_sketch_idxs_feature:
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
brand_idx_feature, valid_sketch_idxs_feature]
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8)
processed_brand_feat = raw_brand_feat_scores
final_scores = processed_inter + 0.3 * ((1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat)
else:
final_scores = processed_inter + 0.3 * processed_feat
else:
final_scores = processed_inter + 0.3 * processed_feat
else:
final_scores = processed_inter + 0.3 * processed_feat
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
# 概率采样