Files
AiDA_Python/app/service/recommendation_system/precompute.py
litianxiang 29b4f43a27
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
debug:推荐接口
2026-01-12 13:34:56 +08:00

558 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
预计算模块
包含数据库表结构优化、Milvus集合创建、系统图向量预计算、初始用户偏好向量生成
"""
import logging
import math
import pymysql
import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from app.service.recommendation_system.config import (
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
)
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average
from app.service.recommendation_system.milvus_client import (
create_collection, insert_vectors, query_vectors_by_paths
)
from app.service.utils.redis_utils import Redis
import json
logger = logging.getLogger(__name__)
def optimize_database_table():
"""
优化 user_preference 表结构
添加冗余字段和索引
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 添加冗余字段
logger.info("添加冗余字段...")
alter_sqls = [
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别lower(level3_type + \"_\" + level2_type)'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图1-是0-用户图)'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'",
]
for sql in alter_sqls:
try:
cursor.execute(sql)
logger.info(f"执行成功: {sql[:50]}...")
except Exception as e:
if "Duplicate column name" in str(e):
logger.info(f"字段已存在,跳过: {sql[:50]}...")
else:
logger.warning(f"执行失败: {sql[:50]}... 错误: {e}")
# 2. 创建索引MySQL 不支持 IF NOT EXISTS需要先检查
logger.info("创建索引...")
index_definitions = [
("idx_account_category_time", ["account_id", "category", "data_time"]),
("idx_account_path", ["account_id", "path"]),
]
for index_name, columns in index_definitions:
try:
# 检查索引是否已存在
cursor.execute(f"""
SELECT COUNT(*)
FROM information_schema.statistics
WHERE table_schema = DATABASE()
AND table_name = '{TABLE_USER_PREFERENCE_LOG}'
AND index_name = '{index_name}'
""")
exists = cursor.fetchone()[0] > 0
if exists:
logger.info(f"索引已存在,跳过: {index_name}")
else:
# 创建索引
columns_str = ', '.join(columns)
create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})"
cursor.execute(create_sql)
logger.info(f"索引创建成功: {index_name}")
except Exception as e:
logger.warning(f"索引创建失败: {index_name} 错误: {e}")
conn.commit()
logger.info("数据库表结构优化完成")
except Exception as e:
logger.error(f"数据库表结构优化失败: {e}", exc_info=True)
if conn:
conn.rollback()
finally:
if conn:
conn.close()
def migrate_historical_data(batch_size: int = 1000):
"""
历史数据迁移:批量更新冗余字段
Args:
batch_size: 每批处理数量
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 查询需要更新的记录数
cursor.execute(f"""
SELECT COUNT(*)
FROM {TABLE_USER_PREFERENCE_LOG} u
WHERE u.category IS NULL
""")
total_count = cursor.fetchone()[0]
logger.info(f"需要迁移的记录数: {total_count}")
if total_count == 0:
logger.info("无需迁移数据")
return
# 分批处理
offset = 0
processed = 0
while offset < total_count:
# 查询一批记录
cursor.execute(f"""
SELECT u.id, u.path
FROM {TABLE_USER_PREFERENCE_LOG} u
WHERE u.category IS NULL
LIMIT {batch_size} OFFSET {offset}
""")
records = cursor.fetchall()
if not records:
break
# 批量更新
for record_id, path in records:
# 查询 t_sys_file 表
cursor.execute(f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE url = %s
LIMIT 1
""", (path,))
sys_file = cursor.fetchone()
if sys_file:
# 系统图
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
category = f"{level3_type.lower()}_{level2_type.lower()}"
cursor.execute(f"""
UPDATE {TABLE_USER_PREFERENCE_LOG}
SET category = %s,
style = %s,
is_system_sketch = 1,
sys_file_id = %s
WHERE id = %s
""", (category, style, sys_file_id, record_id))
else:
# 用户图
cursor.execute(f"""
UPDATE {TABLE_USER_PREFERENCE_LOG}
SET is_system_sketch = 0,
category = NULL,
style = NULL,
sys_file_id = NULL
WHERE id = %s
""", (record_id,))
conn.commit()
processed += len(records)
offset += batch_size
logger.info(f"已迁移 {processed}/{total_count} 条记录")
logger.info("历史数据迁移完成")
except Exception as e:
logger.error(f"历史数据迁移失败: {e}", exc_info=True)
if conn:
conn.rollback()
finally:
if conn:
conn.close()
def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3):
"""
系统图向量预计算与导入
Args:
batch_size: 每批处理数量
retry_times: 失败重试次数
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 数据筛选
logger.info("查询系统图数据...")
cursor.execute(f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE level1_type = 'Images'
AND style IS NOT NULL
AND style != ''
AND deprecated != 1
""")
records = cursor.fetchall()
logger.info(f"找到 {len(records)} 条系统图记录")
if not records:
logger.warning("没有找到系统图数据")
return
# 2. 批量处理
failed_records = []
batch_data = []
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1):
try:
# 计算 category
category = f"{level3_type.lower()}_{level2_type.lower()}"
# 提取特征向量
feature_vector = extract_feature_vector(url)
# 检查向量是否有效
if np.all(feature_vector == 0):
logger.warning(f"向量提取失败,跳过: {url}")
failed_records.append((sys_file_id, url))
continue
# 准备数据
data_item = {
"path": url,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": deprecated if deprecated else 0,
"feature_vector": feature_vector.tolist()
}
batch_data.append(data_item)
# 批量写入
if len(batch_data) >= batch_size:
try:
insert_vectors(batch_data)
batch_data = []
logger.info(f"已处理 {idx}/{len(records)} 条记录")
except Exception as e:
logger.error(f"批量写入失败: {e}")
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
batch_data = []
except Exception as e:
logger.error(f"处理记录失败 [{url}]: {e}")
failed_records.append((sys_file_id, url))
# 写入剩余数据
if batch_data:
try:
insert_vectors(batch_data)
except Exception as e:
logger.error(f"写入剩余数据失败: {e}")
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
# 3. 重试失败记录
if failed_records and retry_times > 0:
logger.info(f"重试 {len(failed_records)} 条失败记录...")
for retry in range(retry_times):
retry_failed = []
for sys_file_id, url in failed_records:
try:
category = f"{level3_type.lower()}_{level2_type.lower()}"
feature_vector = extract_feature_vector(url)
if not np.all(feature_vector == 0):
data_item = {
"path": url,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": 0,
"feature_vector": feature_vector.tolist()
}
insert_vectors([data_item])
else:
retry_failed.append((sys_file_id, url))
except Exception as e:
logger.error(f"重试失败 [{url}]: {e}")
retry_failed.append((sys_file_id, url))
failed_records = retry_failed
if not failed_records:
break
if failed_records:
logger.warning(f"仍有 {len(failed_records)} 条记录处理失败")
logger.info("系统图向量预计算完成")
except Exception as e:
logger.error(f"系统图向量预计算失败: {e}", exc_info=True)
finally:
if conn:
conn.close()
def compute_user_preference_vector(
account_id: int,
category: str,
conn: Optional[pymysql.connections.Connection] = None,
max_date: Optional[datetime] = None
) -> Optional[np.ndarray]:
"""
计算用户偏好向量
Args:
account_id: 用户ID
category: 类别
conn: 数据库连接(可选)
max_date: 最大日期(可选,用于评估时只使用训练集数据)
Returns:
用户偏好向量2048维失败返回 None
"""
from datetime import datetime
should_close = False
if conn is None:
conn = pymysql.connect(**MYSQL_CONFIG)
should_close = True
try:
cursor = conn.cursor()
# 1. 获取点赞记录如果指定了max_date只查询该日期之前的数据
if max_date:
cursor.execute(f"""
SELECT path, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND style is not null
AND data_time < %s
ORDER BY data_time DESC
""", (account_id, category, max_date))
else:
cursor.execute(f"""
SELECT path, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND style is not null
ORDER BY data_time DESC
""", (account_id, category))
like_records = cursor.fetchall()
if not like_records:
return None
# 2. 批量查询点赞次数如果指定了max_date只统计该日期之前的点赞
paths = [r[0] for r in like_records]
if not paths:
return None
placeholders = ','.join(['%s'] * len(paths))
if max_date:
cursor.execute(f"""
SELECT path, COUNT(*) as like_count
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
AND data_time < %s
GROUP BY path
""", (account_id, category) + tuple(paths) + (max_date,))
else:
cursor.execute(f"""
SELECT path, COUNT(*) as like_count
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
GROUP BY path
""", (account_id, category) + tuple(paths))
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
# 3. 批量获取向量
vectors_dict = query_vectors_by_paths(paths)
# 处理查询不到的 path用户图或异常情况
missing_paths = [p for p in paths if p not in vectors_dict]
if missing_paths:
logger.info(f"用户 {account_id} 类别 {category}{len(missing_paths)} 个 path 需要实时计算向量")
# 目前未有非系统图向量,跳过
# 这里可以实时计算并写入 Milvus但为了简化先跳过
# 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus
# 4. 计算权重并加权平均
vectors = []
weights = []
K_half = RECOMMENDATION_CONFIG["K_half"]
for k, (path, data_time) in enumerate(like_records, 1):
if path not in vectors_dict:
continue
vector_data = vectors_dict[path]
feature_vector = np.array(vector_data["feature_vector"])
# 时间衰减权重
d_k = 0.5 ** (k / K_half)
# 点赞次数权重
like_count = like_counts.get(path, 1)
p_i = 1 + math.log(1 + like_count)
# 综合权重
w_i = d_k * p_i
# w_i = p_i
vectors.append(feature_vector)
weights.append(w_i)
if not vectors:
return None
# 5. 计算加权平均并做 L2 归一化IP≈cosine
preference_vector = compute_weighted_average(vectors, weights)
preference_vector = normalize_vector(preference_vector)
return preference_vector
except Exception as e:
logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
return None
finally:
if should_close and conn:
conn.close()
def generate_initial_user_preference_vectors(batch_size: int = 100):
"""
初始用户偏好向量生成
Args:
batch_size: 每批处理用户数
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 扫描历史数据
logger.info("扫描用户和类别组合...")
cursor.execute(f"""
SELECT DISTINCT account_id, category
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE category IS NOT NULL
AND style IS NOT NULL
""")
user_categories = cursor.fetchall()
logger.info(f"找到 {len(user_categories)} 个用户-类别组合")
if not user_categories:
logger.warning("没有找到用户-类别组合")
return
# 2. 批量处理
processed = 0
failed = 0
for account_id, category in user_categories:
try:
# 计算偏好向量
preference_vector = compute_user_preference_vector(account_id, category, conn)
if preference_vector is not None:
# 写入 Redis
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
# 序列化向量(使用 JSON
vector_json = json.dumps(preference_vector.tolist())
Redis.write(
key=key,
value=vector_json,
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
)
processed += 1
else:
failed += 1
if (processed + failed) % batch_size == 0:
logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}")
except Exception as e:
logger.error(f"处理失败 [user={account_id}, category={category}]: {e}")
failed += 1
logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}")
except Exception as e:
logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True)
finally:
if conn:
conn.close()
def run_precompute():
"""
运行所有预计算任务
"""
logger.info("=" * 50)
logger.info("开始预计算任务")
logger.info("=" * 50)
# 1. 优化数据库表结构
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
# # 2. 创建 Milvus 集合
# logger.info("\n[2/5] 创建 Milvus 集合...")
# create_collection()
# 3. 历史数据迁移
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
# # 4. 系统图向量预计算
# logger.info("\n[4/5] 系统图向量预计算...")
# precompute_system_sketch_vectors()
# 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...")
generate_initial_user_preference_vectors()
logger.info("=" * 50)
logger.info("预计算任务完成")
logger.info("=" * 50)
if __name__ == "__main__":
# # 1. 优化数据库表结构
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
#
# # 3. 历史数据迁移
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
# 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...")
generate_initial_user_preference_vectors()