""" 预计算模块 包含:数据库表结构优化、Milvus集合创建、系统图向量预计算、初始用户偏好向量生成 """ import logging import math import pymysql import numpy as np from datetime import datetime from typing import List, Dict, Tuple, Optional from collections import defaultdict from app.service.recommendation_system.config import ( MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE, RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX ) from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average from app.service.recommendation_system.milvus_client import ( create_collection, insert_vectors, query_vectors_by_paths ) from app.service.utils.redis_utils import Redis import json logger = logging.getLogger(__name__) def optimize_database_table(): """ 优化 user_preference 表结构 添加冗余字段和索引 """ conn = None try: conn = pymysql.connect(**MYSQL_CONFIG) cursor = conn.cursor() # 1. 添加冗余字段 logger.info("添加冗余字段...") alter_sqls = [ f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别:lower(level3_type + \"_\" + level2_type)'", f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'", f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图(1-是,0-用户图)'", f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'", ] for sql in alter_sqls: try: cursor.execute(sql) logger.info(f"执行成功: {sql[:50]}...") except Exception as e: if "Duplicate column name" in str(e): logger.info(f"字段已存在,跳过: {sql[:50]}...") else: logger.warning(f"执行失败: {sql[:50]}... 错误: {e}") # 2. 创建索引(MySQL 不支持 IF NOT EXISTS,需要先检查) logger.info("创建索引...") index_definitions = [ ("idx_account_category_time", ["account_id", "category", "data_time"]), ("idx_account_path", ["account_id", "path"]), ] for index_name, columns in index_definitions: try: # 检查索引是否已存在 cursor.execute(f""" SELECT COUNT(*) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = '{TABLE_USER_PREFERENCE_LOG}' AND index_name = '{index_name}' """) exists = cursor.fetchone()[0] > 0 if exists: logger.info(f"索引已存在,跳过: {index_name}") else: # 创建索引 columns_str = ', '.join(columns) create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})" cursor.execute(create_sql) logger.info(f"索引创建成功: {index_name}") except Exception as e: logger.warning(f"索引创建失败: {index_name} 错误: {e}") conn.commit() logger.info("数据库表结构优化完成") except Exception as e: logger.error(f"数据库表结构优化失败: {e}", exc_info=True) if conn: conn.rollback() finally: if conn: conn.close() def migrate_historical_data(batch_size: int = 1000): """ 历史数据迁移:批量更新冗余字段 Args: batch_size: 每批处理数量 """ conn = None try: conn = pymysql.connect(**MYSQL_CONFIG) cursor = conn.cursor() # 查询需要更新的记录数 cursor.execute(f""" SELECT COUNT(*) FROM {TABLE_USER_PREFERENCE_LOG} u WHERE u.category IS NULL """) total_count = cursor.fetchone()[0] logger.info(f"需要迁移的记录数: {total_count}") if total_count == 0: logger.info("无需迁移数据") return # 分批处理 offset = 0 processed = 0 while offset < total_count: # 查询一批记录 cursor.execute(f""" SELECT u.id, u.path FROM {TABLE_USER_PREFERENCE_LOG} u WHERE u.category IS NULL LIMIT {batch_size} OFFSET {offset} """) records = cursor.fetchall() if not records: break # 批量更新 for record_id, path in records: # 查询 t_sys_file 表 cursor.execute(f""" SELECT id, url, style, level3_type, level2_type, deprecated FROM {TABLE_SYS_FILE} WHERE url = %s LIMIT 1 """, (path,)) sys_file = cursor.fetchone() if sys_file: # 系统图 sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file category = f"{level3_type.lower()}_{level2_type.lower()}" cursor.execute(f""" UPDATE {TABLE_USER_PREFERENCE_LOG} SET category = %s, style = %s, is_system_sketch = 1, sys_file_id = %s WHERE id = %s """, (category, style, sys_file_id, record_id)) else: # 用户图 cursor.execute(f""" UPDATE {TABLE_USER_PREFERENCE_LOG} SET is_system_sketch = 0, category = NULL, style = NULL, sys_file_id = NULL WHERE id = %s """, (record_id,)) conn.commit() processed += len(records) offset += batch_size logger.info(f"已迁移 {processed}/{total_count} 条记录") logger.info("历史数据迁移完成") except Exception as e: logger.error(f"历史数据迁移失败: {e}", exc_info=True) if conn: conn.rollback() finally: if conn: conn.close() def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3): """ 系统图向量预计算与导入 Args: batch_size: 每批处理数量 retry_times: 失败重试次数 """ conn = None try: conn = pymysql.connect(**MYSQL_CONFIG) cursor = conn.cursor() # 1. 数据筛选 logger.info("查询系统图数据...") cursor.execute(f""" SELECT id, url, style, level3_type, level2_type, deprecated FROM {TABLE_SYS_FILE} WHERE level1_type = 'Images' AND style IS NOT NULL AND style != '' AND deprecated != 1 """) records = cursor.fetchall() logger.info(f"找到 {len(records)} 条系统图记录") if not records: logger.warning("没有找到系统图数据") return # 2. 批量处理 failed_records = [] batch_data = [] for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1): try: # 计算 category category = f"{level3_type.lower()}_{level2_type.lower()}" # 提取特征向量 feature_vector = extract_feature_vector(url) # 检查向量是否有效 if np.all(feature_vector == 0): logger.warning(f"向量提取失败,跳过: {url}") failed_records.append((sys_file_id, url)) continue # 准备数据 data_item = { "path": url, "sys_file_id": sys_file_id, "style": style, "category": category, "is_system_sketch": 1, "deprecated": deprecated if deprecated else 0, "feature_vector": feature_vector.tolist() } batch_data.append(data_item) # 批量写入 if len(batch_data) >= batch_size: try: insert_vectors(batch_data) batch_data = [] logger.info(f"已处理 {idx}/{len(records)} 条记录") except Exception as e: logger.error(f"批量写入失败: {e}") failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data]) batch_data = [] except Exception as e: logger.error(f"处理记录失败 [{url}]: {e}") failed_records.append((sys_file_id, url)) # 写入剩余数据 if batch_data: try: insert_vectors(batch_data) except Exception as e: logger.error(f"写入剩余数据失败: {e}") failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data]) # 3. 重试失败记录 if failed_records and retry_times > 0: logger.info(f"重试 {len(failed_records)} 条失败记录...") for retry in range(retry_times): retry_failed = [] for sys_file_id, url in failed_records: try: category = f"{level3_type.lower()}_{level2_type.lower()}" feature_vector = extract_feature_vector(url) if not np.all(feature_vector == 0): data_item = { "path": url, "sys_file_id": sys_file_id, "style": style, "category": category, "is_system_sketch": 1, "deprecated": 0, "feature_vector": feature_vector.tolist() } insert_vectors([data_item]) else: retry_failed.append((sys_file_id, url)) except Exception as e: logger.error(f"重试失败 [{url}]: {e}") retry_failed.append((sys_file_id, url)) failed_records = retry_failed if not failed_records: break if failed_records: logger.warning(f"仍有 {len(failed_records)} 条记录处理失败") logger.info("系统图向量预计算完成") except Exception as e: logger.error(f"系统图向量预计算失败: {e}", exc_info=True) finally: if conn: conn.close() def compute_user_preference_vector( account_id: int, category: str, conn: Optional[pymysql.connections.Connection] = None, max_date: Optional[datetime] = None ) -> Optional[np.ndarray]: """ 计算用户偏好向量 Args: account_id: 用户ID category: 类别 conn: 数据库连接(可选) max_date: 最大日期(可选,用于评估时只使用训练集数据) Returns: 用户偏好向量(2048维),失败返回 None """ from datetime import datetime should_close = False if conn is None: conn = pymysql.connect(**MYSQL_CONFIG) should_close = True try: cursor = conn.cursor() # 1. 获取点赞记录(如果指定了max_date,只查询该日期之前的数据) if max_date: cursor.execute(f""" SELECT path, data_time FROM {TABLE_USER_PREFERENCE_LOG} WHERE account_id = %s AND category = %s AND style is not null AND data_time < %s ORDER BY data_time DESC """, (account_id, category, max_date)) else: cursor.execute(f""" SELECT path, data_time FROM {TABLE_USER_PREFERENCE_LOG} WHERE account_id = %s AND category = %s AND style is not null ORDER BY data_time DESC """, (account_id, category)) like_records = cursor.fetchall() if not like_records: return None # 2. 批量查询点赞次数(如果指定了max_date,只统计该日期之前的点赞) paths = [r[0] for r in like_records] if not paths: return None placeholders = ','.join(['%s'] * len(paths)) if max_date: cursor.execute(f""" SELECT path, COUNT(*) as like_count FROM {TABLE_USER_PREFERENCE_LOG} WHERE account_id = %s AND category = %s AND path IN ({placeholders}) AND data_time < %s GROUP BY path """, (account_id, category) + tuple(paths) + (max_date,)) else: cursor.execute(f""" SELECT path, COUNT(*) as like_count FROM {TABLE_USER_PREFERENCE_LOG} WHERE account_id = %s AND category = %s AND path IN ({placeholders}) GROUP BY path """, (account_id, category) + tuple(paths)) like_counts = {row[0]: row[1] for row in cursor.fetchall()} # 3. 批量获取向量 vectors_dict = query_vectors_by_paths(paths) # 处理查询不到的 path(用户图或异常情况) missing_paths = [p for p in paths if p not in vectors_dict] if missing_paths: logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量") # 目前未有非系统图向量,跳过 # 这里可以实时计算并写入 Milvus,但为了简化,先跳过 # 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus # 4. 计算权重并加权平均 vectors = [] weights = [] K_half = RECOMMENDATION_CONFIG["K_half"] for k, (path, data_time) in enumerate(like_records, 1): if path not in vectors_dict: continue vector_data = vectors_dict[path] feature_vector = np.array(vector_data["feature_vector"]) # 时间衰减权重 d_k = 0.5 ** (k / K_half) # 点赞次数权重 like_count = like_counts.get(path, 1) p_i = 1 + math.log(1 + like_count) # 综合权重 w_i = d_k * p_i # w_i = p_i vectors.append(feature_vector) weights.append(w_i) if not vectors: return None # 5. 计算加权平均并做 L2 归一化,IP≈cosine preference_vector = compute_weighted_average(vectors, weights) preference_vector = normalize_vector(preference_vector) return preference_vector except Exception as e: logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True) return None finally: if should_close and conn: conn.close() def generate_initial_user_preference_vectors(batch_size: int = 100): """ 初始用户偏好向量生成 Args: batch_size: 每批处理用户数 """ conn = None try: conn = pymysql.connect(**MYSQL_CONFIG) cursor = conn.cursor() # 1. 扫描历史数据 logger.info("扫描用户和类别组合...") cursor.execute(f""" SELECT DISTINCT account_id, category FROM {TABLE_USER_PREFERENCE_LOG} WHERE category IS NOT NULL AND style IS NOT NULL """) user_categories = cursor.fetchall() logger.info(f"找到 {len(user_categories)} 个用户-类别组合") if not user_categories: logger.warning("没有找到用户-类别组合") return # 2. 批量处理 processed = 0 failed = 0 for account_id, category in user_categories: try: # 计算偏好向量 preference_vector = compute_user_preference_vector(account_id, category, conn) if preference_vector is not None: # 写入 Redis key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}" # 序列化向量(使用 JSON) vector_json = json.dumps(preference_vector.tolist()) Redis.write( key=key, value=vector_json, expire=RECOMMENDATION_CONFIG["redis_expire_seconds"] ) processed += 1 else: failed += 1 if (processed + failed) % batch_size == 0: logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}") except Exception as e: logger.error(f"处理失败 [user={account_id}, category={category}]: {e}") failed += 1 logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}") except Exception as e: logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True) finally: if conn: conn.close() def run_precompute(): """ 运行所有预计算任务 """ logger.info("=" * 50) logger.info("开始预计算任务") logger.info("=" * 50) # 1. 优化数据库表结构 # logger.info("\n[1/5] 优化数据库表结构...") # optimize_database_table() # # 2. 创建 Milvus 集合 # logger.info("\n[2/5] 创建 Milvus 集合...") # create_collection() # 3. 历史数据迁移 # logger.info("\n[3/5] 历史数据迁移...") # migrate_historical_data() # # 4. 系统图向量预计算 # logger.info("\n[4/5] 系统图向量预计算...") # precompute_system_sketch_vectors() # 5. 初始用户偏好向量生成 logger.info("\n[5/5] 初始用户偏好向量生成...") generate_initial_user_preference_vectors() logger.info("=" * 50) logger.info("预计算任务完成") logger.info("=" * 50) if __name__ == "__main__": # # 1. 优化数据库表结构 # logger.info("\n[1/5] 优化数据库表结构...") # optimize_database_table() # # # 3. 历史数据迁移 # logger.info("\n[3/5] 历史数据迁移...") # migrate_historical_data() # 5. 初始用户偏好向量生成 logger.info("\n[5/5] 初始用户偏好向量生成...") generate_initial_user_preference_vectors()