Files
AiDA_Python/app/service/recommendation_system/precompute.py

558 lines
20 KiB
Python
Raw Normal View History

2025-12-29 10:52:33 +08:00
"""
预计算模块
包含数据库表结构优化Milvus集合创建系统图向量预计算初始用户偏好向量生成
"""
import logging
import math
import pymysql
import numpy as np
2026-01-12 09:49:07 +08:00
from datetime import datetime
2025-12-29 10:52:33 +08:00
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from app.service.recommendation_system.config import (
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
)
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average
from app.service.recommendation_system.milvus_client import (
create_collection, insert_vectors, query_vectors_by_paths
)
from app.service.utils.redis_utils import Redis
import json
logger = logging.getLogger(__name__)
def optimize_database_table():
"""
2026-01-12 09:49:07 +08:00
优化 user_preference 表结构
2025-12-29 10:52:33 +08:00
添加冗余字段和索引
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 添加冗余字段
logger.info("添加冗余字段...")
alter_sqls = [
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别lower(level3_type + \"_\" + level2_type)'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图1-是0-用户图)'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'",
]
for sql in alter_sqls:
try:
cursor.execute(sql)
logger.info(f"执行成功: {sql[:50]}...")
except Exception as e:
if "Duplicate column name" in str(e):
logger.info(f"字段已存在,跳过: {sql[:50]}...")
else:
logger.warning(f"执行失败: {sql[:50]}... 错误: {e}")
# 2. 创建索引MySQL 不支持 IF NOT EXISTS需要先检查
logger.info("创建索引...")
index_definitions = [
("idx_account_category_time", ["account_id", "category", "data_time"]),
("idx_account_path", ["account_id", "path"]),
]
for index_name, columns in index_definitions:
try:
# 检查索引是否已存在
cursor.execute(f"""
SELECT COUNT(*)
FROM information_schema.statistics
WHERE table_schema = DATABASE()
AND table_name = '{TABLE_USER_PREFERENCE_LOG}'
AND index_name = '{index_name}'
""")
exists = cursor.fetchone()[0] > 0
if exists:
logger.info(f"索引已存在,跳过: {index_name}")
else:
# 创建索引
columns_str = ', '.join(columns)
create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})"
cursor.execute(create_sql)
logger.info(f"索引创建成功: {index_name}")
except Exception as e:
logger.warning(f"索引创建失败: {index_name} 错误: {e}")
conn.commit()
logger.info("数据库表结构优化完成")
except Exception as e:
logger.error(f"数据库表结构优化失败: {e}", exc_info=True)
if conn:
conn.rollback()
finally:
if conn:
conn.close()
def migrate_historical_data(batch_size: int = 1000):
"""
历史数据迁移批量更新冗余字段
Args:
batch_size: 每批处理数量
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 查询需要更新的记录数
cursor.execute(f"""
SELECT COUNT(*)
FROM {TABLE_USER_PREFERENCE_LOG} u
WHERE u.category IS NULL
""")
total_count = cursor.fetchone()[0]
logger.info(f"需要迁移的记录数: {total_count}")
if total_count == 0:
logger.info("无需迁移数据")
return
# 分批处理
offset = 0
processed = 0
while offset < total_count:
# 查询一批记录
cursor.execute(f"""
SELECT u.id, u.path
FROM {TABLE_USER_PREFERENCE_LOG} u
WHERE u.category IS NULL
LIMIT {batch_size} OFFSET {offset}
""")
records = cursor.fetchall()
if not records:
break
# 批量更新
for record_id, path in records:
# 查询 t_sys_file 表
cursor.execute(f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE url = %s
LIMIT 1
""", (path,))
sys_file = cursor.fetchone()
if sys_file:
# 系统图
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
category = f"{level3_type.lower()}_{level2_type.lower()}"
cursor.execute(f"""
UPDATE {TABLE_USER_PREFERENCE_LOG}
SET category = %s,
style = %s,
is_system_sketch = 1,
sys_file_id = %s
WHERE id = %s
""", (category, style, sys_file_id, record_id))
else:
# 用户图
cursor.execute(f"""
UPDATE {TABLE_USER_PREFERENCE_LOG}
SET is_system_sketch = 0,
category = NULL,
style = NULL,
sys_file_id = NULL
WHERE id = %s
""", (record_id,))
conn.commit()
processed += len(records)
offset += batch_size
logger.info(f"已迁移 {processed}/{total_count} 条记录")
logger.info("历史数据迁移完成")
except Exception as e:
logger.error(f"历史数据迁移失败: {e}", exc_info=True)
if conn:
conn.rollback()
finally:
if conn:
conn.close()
def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3):
"""
系统图向量预计算与导入
Args:
batch_size: 每批处理数量
retry_times: 失败重试次数
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 数据筛选
logger.info("查询系统图数据...")
cursor.execute(f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE level1_type = 'Images'
AND style IS NOT NULL
AND style != ''
AND deprecated != 1
""")
records = cursor.fetchall()
logger.info(f"找到 {len(records)} 条系统图记录")
if not records:
logger.warning("没有找到系统图数据")
return
# 2. 批量处理
failed_records = []
batch_data = []
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1):
try:
# 计算 category
category = f"{level3_type.lower()}_{level2_type.lower()}"
# 提取特征向量
feature_vector = extract_feature_vector(url)
# 检查向量是否有效
if np.all(feature_vector == 0):
logger.warning(f"向量提取失败,跳过: {url}")
failed_records.append((sys_file_id, url))
continue
# 准备数据
data_item = {
"path": url,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": deprecated if deprecated else 0,
"feature_vector": feature_vector.tolist()
}
batch_data.append(data_item)
# 批量写入
if len(batch_data) >= batch_size:
try:
insert_vectors(batch_data)
batch_data = []
logger.info(f"已处理 {idx}/{len(records)} 条记录")
except Exception as e:
logger.error(f"批量写入失败: {e}")
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
batch_data = []
except Exception as e:
logger.error(f"处理记录失败 [{url}]: {e}")
failed_records.append((sys_file_id, url))
# 写入剩余数据
if batch_data:
try:
insert_vectors(batch_data)
except Exception as e:
logger.error(f"写入剩余数据失败: {e}")
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
# 3. 重试失败记录
if failed_records and retry_times > 0:
logger.info(f"重试 {len(failed_records)} 条失败记录...")
for retry in range(retry_times):
retry_failed = []
for sys_file_id, url in failed_records:
try:
category = f"{level3_type.lower()}_{level2_type.lower()}"
feature_vector = extract_feature_vector(url)
if not np.all(feature_vector == 0):
data_item = {
"path": url,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": 0,
"feature_vector": feature_vector.tolist()
}
insert_vectors([data_item])
else:
retry_failed.append((sys_file_id, url))
except Exception as e:
logger.error(f"重试失败 [{url}]: {e}")
retry_failed.append((sys_file_id, url))
failed_records = retry_failed
if not failed_records:
break
if failed_records:
logger.warning(f"仍有 {len(failed_records)} 条记录处理失败")
logger.info("系统图向量预计算完成")
except Exception as e:
logger.error(f"系统图向量预计算失败: {e}", exc_info=True)
finally:
if conn:
conn.close()
def compute_user_preference_vector(
account_id: int,
category: str,
2026-01-12 13:34:56 +08:00
conn: Optional[pymysql.connections.Connection] = None,
max_date: Optional[datetime] = None
2025-12-29 10:52:33 +08:00
) -> Optional[np.ndarray]:
"""
计算用户偏好向量
Args:
account_id: 用户ID
category: 类别
conn: 数据库连接可选
max_date: 最大日期可选用于评估时只使用训练集数据
Returns:
用户偏好向量2048失败返回 None
"""
from datetime import datetime
should_close = False
if conn is None:
conn = pymysql.connect(**MYSQL_CONFIG)
should_close = True
try:
cursor = conn.cursor()
# 1. 获取点赞记录如果指定了max_date只查询该日期之前的数据
if max_date:
cursor.execute(f"""
SELECT path, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND style is not null
AND data_time < %s
ORDER BY data_time DESC
""", (account_id, category, max_date))
else:
cursor.execute(f"""
SELECT path, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND style is not null
ORDER BY data_time DESC
""", (account_id, category))
like_records = cursor.fetchall()
if not like_records:
return None
# 2. 批量查询点赞次数如果指定了max_date只统计该日期之前的点赞
paths = [r[0] for r in like_records]
if not paths:
return None
placeholders = ','.join(['%s'] * len(paths))
if max_date:
cursor.execute(f"""
SELECT path, COUNT(*) as like_count
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
AND data_time < %s
GROUP BY path
""", (account_id, category) + tuple(paths) + (max_date,))
else:
cursor.execute(f"""
SELECT path, COUNT(*) as like_count
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
GROUP BY path
""", (account_id, category) + tuple(paths))
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
# 3. 批量获取向量
vectors_dict = query_vectors_by_paths(paths)
# 处理查询不到的 path用户图或异常情况
missing_paths = [p for p in paths if p not in vectors_dict]
if missing_paths:
logger.info(f"用户 {account_id} 类别 {category}{len(missing_paths)} 个 path 需要实时计算向量")
# 目前未有非系统图向量,跳过
# 这里可以实时计算并写入 Milvus但为了简化先跳过
# 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus
# 4. 计算权重并加权平均
vectors = []
weights = []
K_half = RECOMMENDATION_CONFIG["K_half"]
for k, (path, data_time) in enumerate(like_records, 1):
if path not in vectors_dict:
continue
vector_data = vectors_dict[path]
feature_vector = np.array(vector_data["feature_vector"])
# 时间衰减权重
d_k = 0.5 ** (k / K_half)
# 点赞次数权重
like_count = like_counts.get(path, 1)
p_i = 1 + math.log(1 + like_count)
# 综合权重
2026-01-12 09:49:07 +08:00
w_i = d_k * p_i
# w_i = p_i
2025-12-29 10:52:33 +08:00
vectors.append(feature_vector)
weights.append(w_i)
if not vectors:
return None
# 5. 计算加权平均并做 L2 归一化IP≈cosine
preference_vector = compute_weighted_average(vectors, weights)
preference_vector = normalize_vector(preference_vector)
return preference_vector
except Exception as e:
logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
return None
finally:
if should_close and conn:
conn.close()
def generate_initial_user_preference_vectors(batch_size: int = 100):
"""
初始用户偏好向量生成
Args:
batch_size: 每批处理用户数
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 扫描历史数据
logger.info("扫描用户和类别组合...")
cursor.execute(f"""
SELECT DISTINCT account_id, category
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE category IS NOT NULL
AND style IS NOT NULL
""")
user_categories = cursor.fetchall()
logger.info(f"找到 {len(user_categories)} 个用户-类别组合")
if not user_categories:
logger.warning("没有找到用户-类别组合")
return
# 2. 批量处理
processed = 0
failed = 0
for account_id, category in user_categories:
try:
# 计算偏好向量
preference_vector = compute_user_preference_vector(account_id, category, conn)
if preference_vector is not None:
# 写入 Redis
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
# 序列化向量(使用 JSON
vector_json = json.dumps(preference_vector.tolist())
Redis.write(
key=key,
value=vector_json,
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
)
processed += 1
else:
failed += 1
if (processed + failed) % batch_size == 0:
logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}")
except Exception as e:
logger.error(f"处理失败 [user={account_id}, category={category}]: {e}")
failed += 1
logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}")
except Exception as e:
logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True)
finally:
if conn:
conn.close()
def run_precompute():
"""
运行所有预计算任务
"""
logger.info("=" * 50)
logger.info("开始预计算任务")
logger.info("=" * 50)
# 1. 优化数据库表结构
2026-01-12 13:03:58 +08:00
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
2025-12-29 10:52:33 +08:00
# # 2. 创建 Milvus 集合
# logger.info("\n[2/5] 创建 Milvus 集合...")
# create_collection()
# 3. 历史数据迁移
2026-01-12 13:03:58 +08:00
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
2025-12-29 10:52:33 +08:00
# # 4. 系统图向量预计算
# logger.info("\n[4/5] 系统图向量预计算...")
# precompute_system_sketch_vectors()
# 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...")
generate_initial_user_preference_vectors()
logger.info("=" * 50)
logger.info("预计算任务完成")
logger.info("=" * 50)
if __name__ == "__main__":
2026-01-12 13:34:56 +08:00
# # 1. 优化数据库表结构
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
#
# # 3. 历史数据迁移
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
2025-12-29 10:52:33 +08:00
# 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...")
generate_initial_user_preference_vectors()