Merge remote-tracking branch 'origin/develop' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped

This commit is contained in:
zcr
2026-01-12 16:18:15 +08:00
7 changed files with 129 additions and 45 deletions

View File

@@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref"
RECOMMENDATION_CONFIG = {
# 时间衰减半衰期(用于计算时间衰减权重)
# 值越小,最近的行为权重越大
"K_half": 20,
"K_half": 10,
# 探索与利用的比例 (0.0-1.0)
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
@@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = {
# 向量检索返回的候选数量
# 值越大,候选池越大,但计算成本也越高
# 建议范围: 100-1000
"topk": 1000,
"topk": 200,
# Style 加分系数(同 style 的候选进行加分)
# 值越大,匹配 style 的候选被选中的概率越大
@@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = {
}
# 数据库表名
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test"
TABLE_USER_PREFERENCE_LOG = "user_preference"
TABLE_SYS_FILE = "t_sys_file"
# MySQL 连接配置(用于推荐系统)

View File

@@ -1,6 +1,6 @@
"""
增量监听模块
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量
实时监听 user_preference 表的新增记录,更新用户偏好向量
"""
import logging
import math
@@ -48,7 +48,7 @@ class IncrementalListener:
if self.last_process_time is None:
# 第一次运行查询最近30分钟的数据
cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
ORDER BY data_time
@@ -56,7 +56,7 @@ class IncrementalListener:
else:
# 基于上次处理时间查询
cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > %s
ORDER BY data_time
@@ -258,7 +258,7 @@ class IncrementalListener:
}
else:
# 用户图
# 从 user_preference_log_test 获取 category如果有
# 从 user_preference 获取 category如果有
cursor.execute(f"""
SELECT category
FROM {TABLE_USER_PREFERENCE_LOG}

View File

@@ -203,39 +203,74 @@ def search_similar_vectors(
query_vector: np.ndarray,
category: str,
topk: int = 500,
style: Optional[str] = None
style: Optional[str] = None,
style_boost_ratio: float = 0.2
) -> List[Dict]:
"""
向量相似度检索
Args:
query_vector: 查询向量2048维
category: 类别过滤
topk: 返回数量
style: 风格过滤(可选)
style: 风格过滤(可选)- 当提供时会给对应style的结果加分
style_boost_ratio: 风格加分比例默认0.1即10%
Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段
"""
client = get_milvus_client()
try:
# 构建过滤表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
filter_expr = f"category == '{category}' && deprecated == 0"
if style:
filter_expr += f" && style == '{style}'"
# 如果没有指定style使用原始逻辑
if not style:
filter_expr = f"category == '{category}' && deprecated == 0"
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
else:
# 有style参数时使用两阶段搜索策略
# 搜索
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第一阶段搜索匹配style的向量使用boosted query vector
filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
boosted_query = query_vector * (1 + style_boost_ratio)
results_style = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[boosted_query.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_style,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第二阶段搜索其他style的向量
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
results_others = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_others,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 合并结果
results = []
if results_style and len(results_style) > 0:
results.extend(results_style[0])
if results_others and len(results_others) > 0:
results.extend(results_others[0])
# 转换为单个结果列表格式
results = [results] if results else []
# 格式化结果
formatted_results = []
@@ -249,7 +284,10 @@ def search_similar_vectors(
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
})
return formatted_results
# 按分数排序并返回topk
formatted_results.sort(key=lambda x: x["score"], reverse=True)
return formatted_results[:topk]
except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True)
return []
@@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择
limit=10000
)
# 随机选择

View File

@@ -6,6 +6,7 @@ import logging
import math
import pymysql
import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
def optimize_database_table():
"""
优化 user_preference_log_test 表结构
优化 user_preference 表结构
添加冗余字段和索引
"""
conn = None
@@ -317,8 +318,8 @@ def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int =
def compute_user_preference_vector(
account_id: int,
category: str,
conn: Optional[pymysql.connections.Connection] = None
# max_date: Optional[datetime] = None
conn: Optional[pymysql.connections.Connection] = None,
max_date: Optional[datetime] = None
) -> Optional[np.ndarray]:
"""
计算用户偏好向量
@@ -419,8 +420,8 @@ def compute_user_preference_vector(
p_i = 1 + math.log(1 + like_count)
# 综合权重
# w_i = d_k * p_i
w_i = p_i
w_i = d_k * p_i
# w_i = p_i
vectors.append(feature_vector)
weights.append(w_i)
@@ -518,16 +519,16 @@ def run_precompute():
logger.info("=" * 50)
# 1. 优化数据库表结构
logger.info("\n[1/5] 优化数据库表结构...")
optimize_database_table()
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
# # 2. 创建 Milvus 集合
# logger.info("\n[2/5] 创建 Milvus 集合...")
# create_collection()
# 3. 历史数据迁移
logger.info("\n[3/5] 历史数据迁移...")
migrate_historical_data()
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
# # 4. 系统图向量预计算
# logger.info("\n[4/5] 系统图向量预计算...")
@@ -543,13 +544,13 @@ def run_precompute():
if __name__ == "__main__":
# 1. 优化数据库表结构
logger.info("\n[1/5] 优化数据库表结构...")
optimize_database_table()
# 3. 历史数据迁移
logger.info("\n[3/5] 历史数据迁移...")
migrate_historical_data()
# # 1. 优化数据库表结构
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
#
# # 3. 历史数据迁移
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
# 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...")

View File

@@ -91,6 +91,21 @@ class Redis(object):
r = cls._get_r()
r.expire(name, expire_in_seconds)
@classmethod
def scan_keys(cls, pattern="*"):
"""
扫描匹配模式的key
"""
r = cls._get_r()
keys = []
cursor = 0
while True:
cursor, partial_keys = r.scan(cursor, match=pattern, count=1000)
keys.extend(partial_keys)
if cursor == 0:
break
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
if __name__ == '__main__':
redis_client = Redis()