fix:推荐接口

This commit is contained in:
litianxiang
2026-01-12 09:49:07 +08:00
parent c792106f02
commit 2af9cbfe78
5 changed files with 69 additions and 30 deletions

View File

@@ -137,7 +137,7 @@ router = APIRouter()
# logger.error(f"推荐失败: {str(e)}", exc_info=True) # logger.error(f"推荐失败: {str(e)}", exc_info=True)
# raise HTTPException(status_code=500, detail=str(e)) # raise HTTPException(status_code=500, detail=str(e))
# @router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
"""启动时初始化增量监听任务""" """启动时初始化增量监听任务"""
try: try:

View File

@@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref"
RECOMMENDATION_CONFIG = { RECOMMENDATION_CONFIG = {
# 时间衰减半衰期(用于计算时间衰减权重) # 时间衰减半衰期(用于计算时间衰减权重)
# 值越小,最近的行为权重越大 # 值越小,最近的行为权重越大
"K_half": 20, "K_half": 10,
# 探索与利用的比例 (0.0-1.0) # 探索与利用的比例 (0.0-1.0)
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机 # - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
@@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = {
# 向量检索返回的候选数量 # 向量检索返回的候选数量
# 值越大,候选池越大,但计算成本也越高 # 值越大,候选池越大,但计算成本也越高
# 建议范围: 100-1000 # 建议范围: 100-1000
"topk": 1000, "topk": 200,
# Style 加分系数(同 style 的候选进行加分) # Style 加分系数(同 style 的候选进行加分)
# 值越大,匹配 style 的候选被选中的概率越大 # 值越大,匹配 style 的候选被选中的概率越大
@@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = {
} }
# 数据库表名 # 数据库表名
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test" TABLE_USER_PREFERENCE_LOG = "user_preference"
TABLE_SYS_FILE = "t_sys_file" TABLE_SYS_FILE = "t_sys_file"
# MySQL 连接配置(用于推荐系统) # MySQL 连接配置(用于推荐系统)

View File

@@ -1,6 +1,6 @@
""" """
增量监听模块 增量监听模块
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量 实时监听 user_preference 表的新增记录,更新用户偏好向量
""" """
import logging import logging
import math import math
@@ -258,7 +258,7 @@ class IncrementalListener:
} }
else: else:
# 用户图 # 用户图
# 从 user_preference_log_test 获取 category如果有 # 从 user_preference 获取 category如果有
cursor.execute(f""" cursor.execute(f"""
SELECT category SELECT category
FROM {TABLE_USER_PREFERENCE_LOG} FROM {TABLE_USER_PREFERENCE_LOG}

View File

@@ -203,39 +203,74 @@ def search_similar_vectors(
query_vector: np.ndarray, query_vector: np.ndarray,
category: str, category: str,
topk: int = 500, topk: int = 500,
style: Optional[str] = None style: Optional[str] = None,
style_boost_ratio: float = 0.2
) -> List[Dict]: ) -> List[Dict]:
""" """
向量相似度检索 向量相似度检索
Args: Args:
query_vector: 查询向量2048维 query_vector: 查询向量2048维
category: 类别过滤 category: 类别过滤
topk: 返回数量 topk: 返回数量
style: 风格过滤(可选) style: 风格过滤(可选)- 当提供时会给对应style的结果加分
style_boost_ratio: 风格加分比例默认0.1即10%
Returns: Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段 检索结果列表,每个元素包含 path, score, style, category 等字段
""" """
client = get_milvus_client() client = get_milvus_client()
try: try:
# 构建过滤表达式 # 如果没有指定style使用原始逻辑
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API if not style:
filter_expr = f"category == '{category}' && deprecated == 0" filter_expr = f"category == '{category}' && deprecated == 0"
if style: results = client.search(
filter_expr += f" && style == '{style}'" collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
else:
# 有style参数时使用两阶段搜索策略
# 搜索 # 第一阶段搜索匹配style的向量使用boosted query vector
results = client.search( filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, boosted_query = query_vector * (1 + style_boost_ratio)
data=[query_vector.tolist()], results_style = client.search(
anns_field="feature_vector", collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
search_params={"metric_type": "IP", "params": {"nprobe": 10}}, data=[boosted_query.tolist()],
limit=topk, anns_field="feature_vector",
filter=filter_expr, search_params={"metric_type": "IP", "params": {"nprobe": 10}},
output_fields=["path", "style", "category", "sys_file_id"] limit=topk,
) filter=filter_expr_style,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第二阶段搜索其他style的向量
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
results_others = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_others,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 合并结果
results = []
if results_style and len(results_style) > 0:
results.extend(results_style[0])
if results_others and len(results_others) > 0:
results.extend(results_others[0])
# 转换为单个结果列表格式
results = [results] if results else []
# 格式化结果 # 格式化结果
formatted_results = [] formatted_results = []
@@ -249,7 +284,10 @@ def search_similar_vectors(
"sys_file_id": hit.get("entity", {}).get("sys_file_id") "sys_file_id": hit.get("entity", {}).get("sys_file_id")
}) })
return formatted_results # 按分数排序并返回topk
formatted_results.sort(key=lambda x: x["score"], reverse=True)
return formatted_results[:topk]
except Exception as e: except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True) logger.error(f"向量检索失败: {e}", exc_info=True)
return [] return []
@@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr, filter=filter_expr,
output_fields=["path", "style", "category"], output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择 limit=10000
) )
# 随机选择 # 随机选择

View File

@@ -6,6 +6,7 @@ import logging
import math import math
import pymysql import pymysql
import numpy as np import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional from typing import List, Dict, Tuple, Optional
from collections import defaultdict from collections import defaultdict
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
def optimize_database_table(): def optimize_database_table():
""" """
优化 user_preference_log_test 表结构 优化 user_preference 表结构
添加冗余字段和索引 添加冗余字段和索引
""" """
conn = None conn = None
@@ -419,8 +420,8 @@ def compute_user_preference_vector(
p_i = 1 + math.log(1 + like_count) p_i = 1 + math.log(1 + like_count)
# 综合权重 # 综合权重
# w_i = d_k * p_i w_i = d_k * p_i
w_i = p_i # w_i = p_i
vectors.append(feature_vector) vectors.append(feature_vector)
weights.append(w_i) weights.append(w_i)