All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
558 lines
20 KiB
Python
558 lines
20 KiB
Python
"""
|
||
预计算模块
|
||
包含:数据库表结构优化、Milvus集合创建、系统图向量预计算、初始用户偏好向量生成
|
||
"""
|
||
import logging
|
||
import math
|
||
import pymysql
|
||
import numpy as np
|
||
from datetime import datetime
|
||
from typing import List, Dict, Tuple, Optional
|
||
from collections import defaultdict
|
||
|
||
from app.service.recommendation_system.config import (
|
||
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
|
||
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||
)
|
||
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average
|
||
from app.service.recommendation_system.milvus_client import (
|
||
create_collection, insert_vectors, query_vectors_by_paths
|
||
)
|
||
from app.service.utils.redis_utils import Redis
|
||
import json
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def optimize_database_table():
|
||
"""
|
||
优化 user_preference 表结构
|
||
添加冗余字段和索引
|
||
"""
|
||
conn = None
|
||
try:
|
||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||
cursor = conn.cursor()
|
||
|
||
# 1. 添加冗余字段
|
||
logger.info("添加冗余字段...")
|
||
alter_sqls = [
|
||
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别:lower(level3_type + \"_\" + level2_type)'",
|
||
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'",
|
||
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图(1-是,0-用户图)'",
|
||
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'",
|
||
]
|
||
|
||
for sql in alter_sqls:
|
||
try:
|
||
cursor.execute(sql)
|
||
logger.info(f"执行成功: {sql[:50]}...")
|
||
except Exception as e:
|
||
if "Duplicate column name" in str(e):
|
||
logger.info(f"字段已存在,跳过: {sql[:50]}...")
|
||
else:
|
||
logger.warning(f"执行失败: {sql[:50]}... 错误: {e}")
|
||
|
||
# 2. 创建索引(MySQL 不支持 IF NOT EXISTS,需要先检查)
|
||
logger.info("创建索引...")
|
||
index_definitions = [
|
||
("idx_account_category_time", ["account_id", "category", "data_time"]),
|
||
("idx_account_path", ["account_id", "path"]),
|
||
]
|
||
|
||
for index_name, columns in index_definitions:
|
||
try:
|
||
# 检查索引是否已存在
|
||
cursor.execute(f"""
|
||
SELECT COUNT(*)
|
||
FROM information_schema.statistics
|
||
WHERE table_schema = DATABASE()
|
||
AND table_name = '{TABLE_USER_PREFERENCE_LOG}'
|
||
AND index_name = '{index_name}'
|
||
""")
|
||
exists = cursor.fetchone()[0] > 0
|
||
|
||
if exists:
|
||
logger.info(f"索引已存在,跳过: {index_name}")
|
||
else:
|
||
# 创建索引
|
||
columns_str = ', '.join(columns)
|
||
create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})"
|
||
cursor.execute(create_sql)
|
||
logger.info(f"索引创建成功: {index_name}")
|
||
except Exception as e:
|
||
logger.warning(f"索引创建失败: {index_name} 错误: {e}")
|
||
|
||
conn.commit()
|
||
logger.info("数据库表结构优化完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"数据库表结构优化失败: {e}", exc_info=True)
|
||
if conn:
|
||
conn.rollback()
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
|
||
def migrate_historical_data(batch_size: int = 1000):
|
||
"""
|
||
历史数据迁移:批量更新冗余字段
|
||
|
||
Args:
|
||
batch_size: 每批处理数量
|
||
"""
|
||
conn = None
|
||
try:
|
||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||
cursor = conn.cursor()
|
||
|
||
# 查询需要更新的记录数
|
||
cursor.execute(f"""
|
||
SELECT COUNT(*)
|
||
FROM {TABLE_USER_PREFERENCE_LOG} u
|
||
WHERE u.category IS NULL
|
||
""")
|
||
total_count = cursor.fetchone()[0]
|
||
logger.info(f"需要迁移的记录数: {total_count}")
|
||
|
||
if total_count == 0:
|
||
logger.info("无需迁移数据")
|
||
return
|
||
|
||
# 分批处理
|
||
offset = 0
|
||
processed = 0
|
||
|
||
while offset < total_count:
|
||
# 查询一批记录
|
||
cursor.execute(f"""
|
||
SELECT u.id, u.path
|
||
FROM {TABLE_USER_PREFERENCE_LOG} u
|
||
WHERE u.category IS NULL
|
||
LIMIT {batch_size} OFFSET {offset}
|
||
""")
|
||
records = cursor.fetchall()
|
||
|
||
if not records:
|
||
break
|
||
|
||
# 批量更新
|
||
for record_id, path in records:
|
||
# 查询 t_sys_file 表
|
||
cursor.execute(f"""
|
||
SELECT id, url, style, level3_type, level2_type, deprecated
|
||
FROM {TABLE_SYS_FILE}
|
||
WHERE url = %s
|
||
LIMIT 1
|
||
""", (path,))
|
||
|
||
sys_file = cursor.fetchone()
|
||
|
||
if sys_file:
|
||
# 系统图
|
||
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
|
||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||
|
||
cursor.execute(f"""
|
||
UPDATE {TABLE_USER_PREFERENCE_LOG}
|
||
SET category = %s,
|
||
style = %s,
|
||
is_system_sketch = 1,
|
||
sys_file_id = %s
|
||
WHERE id = %s
|
||
""", (category, style, sys_file_id, record_id))
|
||
else:
|
||
# 用户图
|
||
cursor.execute(f"""
|
||
UPDATE {TABLE_USER_PREFERENCE_LOG}
|
||
SET is_system_sketch = 0,
|
||
category = NULL,
|
||
style = NULL,
|
||
sys_file_id = NULL
|
||
WHERE id = %s
|
||
""", (record_id,))
|
||
|
||
conn.commit()
|
||
processed += len(records)
|
||
offset += batch_size
|
||
logger.info(f"已迁移 {processed}/{total_count} 条记录")
|
||
|
||
logger.info("历史数据迁移完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"历史数据迁移失败: {e}", exc_info=True)
|
||
if conn:
|
||
conn.rollback()
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
|
||
def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3):
|
||
"""
|
||
系统图向量预计算与导入
|
||
|
||
Args:
|
||
batch_size: 每批处理数量
|
||
retry_times: 失败重试次数
|
||
"""
|
||
conn = None
|
||
try:
|
||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||
cursor = conn.cursor()
|
||
|
||
# 1. 数据筛选
|
||
logger.info("查询系统图数据...")
|
||
cursor.execute(f"""
|
||
SELECT id, url, style, level3_type, level2_type, deprecated
|
||
FROM {TABLE_SYS_FILE}
|
||
WHERE level1_type = 'Images'
|
||
AND style IS NOT NULL
|
||
AND style != ''
|
||
AND deprecated != 1
|
||
""")
|
||
records = cursor.fetchall()
|
||
logger.info(f"找到 {len(records)} 条系统图记录")
|
||
|
||
if not records:
|
||
logger.warning("没有找到系统图数据")
|
||
return
|
||
|
||
# 2. 批量处理
|
||
failed_records = []
|
||
batch_data = []
|
||
|
||
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1):
|
||
try:
|
||
# 计算 category
|
||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||
|
||
# 提取特征向量
|
||
feature_vector = extract_feature_vector(url)
|
||
|
||
# 检查向量是否有效
|
||
if np.all(feature_vector == 0):
|
||
logger.warning(f"向量提取失败,跳过: {url}")
|
||
failed_records.append((sys_file_id, url))
|
||
continue
|
||
|
||
# 准备数据
|
||
data_item = {
|
||
"path": url,
|
||
"sys_file_id": sys_file_id,
|
||
"style": style,
|
||
"category": category,
|
||
"is_system_sketch": 1,
|
||
"deprecated": deprecated if deprecated else 0,
|
||
"feature_vector": feature_vector.tolist()
|
||
}
|
||
|
||
batch_data.append(data_item)
|
||
|
||
# 批量写入
|
||
if len(batch_data) >= batch_size:
|
||
try:
|
||
insert_vectors(batch_data)
|
||
batch_data = []
|
||
logger.info(f"已处理 {idx}/{len(records)} 条记录")
|
||
except Exception as e:
|
||
logger.error(f"批量写入失败: {e}")
|
||
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||
batch_data = []
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理记录失败 [{url}]: {e}")
|
||
failed_records.append((sys_file_id, url))
|
||
|
||
# 写入剩余数据
|
||
if batch_data:
|
||
try:
|
||
insert_vectors(batch_data)
|
||
except Exception as e:
|
||
logger.error(f"写入剩余数据失败: {e}")
|
||
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||
|
||
# 3. 重试失败记录
|
||
if failed_records and retry_times > 0:
|
||
logger.info(f"重试 {len(failed_records)} 条失败记录...")
|
||
for retry in range(retry_times):
|
||
retry_failed = []
|
||
for sys_file_id, url in failed_records:
|
||
try:
|
||
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||
feature_vector = extract_feature_vector(url)
|
||
if not np.all(feature_vector == 0):
|
||
data_item = {
|
||
"path": url,
|
||
"sys_file_id": sys_file_id,
|
||
"style": style,
|
||
"category": category,
|
||
"is_system_sketch": 1,
|
||
"deprecated": 0,
|
||
"feature_vector": feature_vector.tolist()
|
||
}
|
||
insert_vectors([data_item])
|
||
else:
|
||
retry_failed.append((sys_file_id, url))
|
||
except Exception as e:
|
||
logger.error(f"重试失败 [{url}]: {e}")
|
||
retry_failed.append((sys_file_id, url))
|
||
|
||
failed_records = retry_failed
|
||
if not failed_records:
|
||
break
|
||
|
||
if failed_records:
|
||
logger.warning(f"仍有 {len(failed_records)} 条记录处理失败")
|
||
|
||
logger.info("系统图向量预计算完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"系统图向量预计算失败: {e}", exc_info=True)
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
|
||
def compute_user_preference_vector(
|
||
account_id: int,
|
||
category: str,
|
||
conn: Optional[pymysql.connections.Connection] = None
|
||
# max_date: Optional[datetime] = None
|
||
) -> Optional[np.ndarray]:
|
||
"""
|
||
计算用户偏好向量
|
||
|
||
Args:
|
||
account_id: 用户ID
|
||
category: 类别
|
||
conn: 数据库连接(可选)
|
||
max_date: 最大日期(可选,用于评估时只使用训练集数据)
|
||
|
||
Returns:
|
||
用户偏好向量(2048维),失败返回 None
|
||
"""
|
||
from datetime import datetime
|
||
|
||
should_close = False
|
||
if conn is None:
|
||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||
should_close = True
|
||
|
||
try:
|
||
cursor = conn.cursor()
|
||
|
||
# 1. 获取点赞记录(如果指定了max_date,只查询该日期之前的数据)
|
||
if max_date:
|
||
cursor.execute(f"""
|
||
SELECT path, data_time
|
||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||
WHERE account_id = %s AND category = %s AND style is not null
|
||
AND data_time < %s
|
||
ORDER BY data_time DESC
|
||
""", (account_id, category, max_date))
|
||
else:
|
||
cursor.execute(f"""
|
||
SELECT path, data_time
|
||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||
WHERE account_id = %s AND category = %s AND style is not null
|
||
ORDER BY data_time DESC
|
||
""", (account_id, category))
|
||
|
||
like_records = cursor.fetchall()
|
||
|
||
if not like_records:
|
||
return None
|
||
|
||
# 2. 批量查询点赞次数(如果指定了max_date,只统计该日期之前的点赞)
|
||
paths = [r[0] for r in like_records]
|
||
if not paths:
|
||
return None
|
||
|
||
placeholders = ','.join(['%s'] * len(paths))
|
||
if max_date:
|
||
cursor.execute(f"""
|
||
SELECT path, COUNT(*) as like_count
|
||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||
AND data_time < %s
|
||
GROUP BY path
|
||
""", (account_id, category) + tuple(paths) + (max_date,))
|
||
else:
|
||
cursor.execute(f"""
|
||
SELECT path, COUNT(*) as like_count
|
||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||
GROUP BY path
|
||
""", (account_id, category) + tuple(paths))
|
||
|
||
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||
|
||
# 3. 批量获取向量
|
||
vectors_dict = query_vectors_by_paths(paths)
|
||
|
||
# 处理查询不到的 path(用户图或异常情况)
|
||
missing_paths = [p for p in paths if p not in vectors_dict]
|
||
if missing_paths:
|
||
logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量")
|
||
# 目前未有非系统图向量,跳过
|
||
# 这里可以实时计算并写入 Milvus,但为了简化,先跳过
|
||
# 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus
|
||
|
||
# 4. 计算权重并加权平均
|
||
vectors = []
|
||
weights = []
|
||
K_half = RECOMMENDATION_CONFIG["K_half"]
|
||
|
||
for k, (path, data_time) in enumerate(like_records, 1):
|
||
if path not in vectors_dict:
|
||
continue
|
||
|
||
vector_data = vectors_dict[path]
|
||
feature_vector = np.array(vector_data["feature_vector"])
|
||
|
||
# 时间衰减权重
|
||
d_k = 0.5 ** (k / K_half)
|
||
|
||
# 点赞次数权重
|
||
like_count = like_counts.get(path, 1)
|
||
p_i = 1 + math.log(1 + like_count)
|
||
|
||
# 综合权重
|
||
w_i = d_k * p_i
|
||
# w_i = p_i
|
||
|
||
vectors.append(feature_vector)
|
||
weights.append(w_i)
|
||
|
||
if not vectors:
|
||
return None
|
||
|
||
# 5. 计算加权平均并做 L2 归一化,IP≈cosine
|
||
preference_vector = compute_weighted_average(vectors, weights)
|
||
preference_vector = normalize_vector(preference_vector)
|
||
|
||
return preference_vector
|
||
|
||
except Exception as e:
|
||
logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||
return None
|
||
finally:
|
||
if should_close and conn:
|
||
conn.close()
|
||
|
||
|
||
def generate_initial_user_preference_vectors(batch_size: int = 100):
|
||
"""
|
||
初始用户偏好向量生成
|
||
|
||
Args:
|
||
batch_size: 每批处理用户数
|
||
"""
|
||
conn = None
|
||
try:
|
||
conn = pymysql.connect(**MYSQL_CONFIG)
|
||
cursor = conn.cursor()
|
||
|
||
# 1. 扫描历史数据
|
||
logger.info("扫描用户和类别组合...")
|
||
cursor.execute(f"""
|
||
SELECT DISTINCT account_id, category
|
||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||
WHERE category IS NOT NULL
|
||
AND style IS NOT NULL
|
||
""")
|
||
|
||
user_categories = cursor.fetchall()
|
||
logger.info(f"找到 {len(user_categories)} 个用户-类别组合")
|
||
|
||
if not user_categories:
|
||
logger.warning("没有找到用户-类别组合")
|
||
return
|
||
|
||
# 2. 批量处理
|
||
processed = 0
|
||
failed = 0
|
||
|
||
for account_id, category in user_categories:
|
||
try:
|
||
# 计算偏好向量
|
||
preference_vector = compute_user_preference_vector(account_id, category, conn)
|
||
|
||
if preference_vector is not None:
|
||
# 写入 Redis
|
||
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
|
||
# 序列化向量(使用 JSON)
|
||
vector_json = json.dumps(preference_vector.tolist())
|
||
Redis.write(
|
||
key=key,
|
||
value=vector_json,
|
||
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||
)
|
||
processed += 1
|
||
else:
|
||
failed += 1
|
||
|
||
if (processed + failed) % batch_size == 0:
|
||
logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理失败 [user={account_id}, category={category}]: {e}")
|
||
failed += 1
|
||
|
||
logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True)
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
|
||
def run_precompute():
|
||
"""
|
||
运行所有预计算任务
|
||
"""
|
||
logger.info("=" * 50)
|
||
logger.info("开始预计算任务")
|
||
logger.info("=" * 50)
|
||
|
||
# 1. 优化数据库表结构
|
||
# logger.info("\n[1/5] 优化数据库表结构...")
|
||
# optimize_database_table()
|
||
|
||
# # 2. 创建 Milvus 集合
|
||
# logger.info("\n[2/5] 创建 Milvus 集合...")
|
||
# create_collection()
|
||
|
||
# 3. 历史数据迁移
|
||
# logger.info("\n[3/5] 历史数据迁移...")
|
||
# migrate_historical_data()
|
||
|
||
# # 4. 系统图向量预计算
|
||
# logger.info("\n[4/5] 系统图向量预计算...")
|
||
# precompute_system_sketch_vectors()
|
||
|
||
# 5. 初始用户偏好向量生成
|
||
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||
generate_initial_user_preference_vectors()
|
||
|
||
logger.info("=" * 50)
|
||
logger.info("预计算任务完成")
|
||
logger.info("=" * 50)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 1. 优化数据库表结构
|
||
logger.info("\n[1/5] 优化数据库表结构...")
|
||
optimize_database_table()
|
||
|
||
# 3. 历史数据迁移
|
||
logger.info("\n[3/5] 历史数据迁移...")
|
||
migrate_historical_data()
|
||
|
||
# 5. 初始用户偏好向量生成
|
||
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||
generate_initial_user_preference_vectors()
|