fix:推荐接口
This commit is contained in:
@@ -137,7 +137,7 @@ router = APIRouter()
|
||||
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
# raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# @router.on_event("startup")
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
"""启动时初始化增量监听任务"""
|
||||
try:
|
||||
|
||||
@@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref"
|
||||
RECOMMENDATION_CONFIG = {
|
||||
# 时间衰减半衰期(用于计算时间衰减权重)
|
||||
# 值越小,最近的行为权重越大
|
||||
"K_half": 20,
|
||||
"K_half": 10,
|
||||
|
||||
# 探索与利用的比例 (0.0-1.0)
|
||||
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
|
||||
@@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = {
|
||||
# 向量检索返回的候选数量
|
||||
# 值越大,候选池越大,但计算成本也越高
|
||||
# 建议范围: 100-1000
|
||||
"topk": 1000,
|
||||
"topk": 200,
|
||||
|
||||
# Style 加分系数(同 style 的候选进行加分)
|
||||
# 值越大,匹配 style 的候选被选中的概率越大
|
||||
@@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = {
|
||||
}
|
||||
|
||||
# 数据库表名
|
||||
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test"
|
||||
TABLE_USER_PREFERENCE_LOG = "user_preference"
|
||||
TABLE_SYS_FILE = "t_sys_file"
|
||||
|
||||
# MySQL 连接配置(用于推荐系统)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
增量监听模块
|
||||
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量
|
||||
实时监听 user_preference 表的新增记录,更新用户偏好向量
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
@@ -258,7 +258,7 @@ class IncrementalListener:
|
||||
}
|
||||
else:
|
||||
# 用户图
|
||||
# 从 user_preference_log_test 获取 category(如果有)
|
||||
# 从 user_preference 获取 category(如果有)
|
||||
cursor.execute(f"""
|
||||
SELECT category
|
||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||
|
||||
@@ -203,39 +203,74 @@ def search_similar_vectors(
|
||||
query_vector: np.ndarray,
|
||||
category: str,
|
||||
topk: int = 500,
|
||||
style: Optional[str] = None
|
||||
style: Optional[str] = None,
|
||||
style_boost_ratio: float = 0.2
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
向量相似度检索
|
||||
|
||||
|
||||
Args:
|
||||
query_vector: 查询向量(2048维)
|
||||
category: 类别过滤
|
||||
topk: 返回数量
|
||||
style: 风格过滤(可选)
|
||||
|
||||
style: 风格过滤(可选)- 当提供时,会给对应style的结果加分
|
||||
style_boost_ratio: 风格加分比例(默认0.1,即10%)
|
||||
|
||||
Returns:
|
||||
检索结果列表,每个元素包含 path, score, style, category 等字段
|
||||
"""
|
||||
client = get_milvus_client()
|
||||
|
||||
try:
|
||||
# 构建过滤表达式
|
||||
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
||||
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||
if style:
|
||||
filter_expr += f" && style == '{style}'"
|
||||
# 如果没有指定style,使用原始逻辑
|
||||
if not style:
|
||||
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||
results = client.search(
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
data=[query_vector.tolist()],
|
||||
anns_field="feature_vector",
|
||||
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||
limit=topk,
|
||||
filter=filter_expr,
|
||||
output_fields=["path", "style", "category", "sys_file_id"]
|
||||
)
|
||||
else:
|
||||
# 有style参数时,使用两阶段搜索策略
|
||||
|
||||
# 搜索
|
||||
results = client.search(
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
data=[query_vector.tolist()],
|
||||
anns_field="feature_vector",
|
||||
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||
limit=topk,
|
||||
filter=filter_expr,
|
||||
output_fields=["path", "style", "category", "sys_file_id"]
|
||||
)
|
||||
# 第一阶段:搜索匹配style的向量,使用boosted query vector
|
||||
filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
|
||||
boosted_query = query_vector * (1 + style_boost_ratio)
|
||||
results_style = client.search(
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
data=[boosted_query.tolist()],
|
||||
anns_field="feature_vector",
|
||||
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||
limit=topk,
|
||||
filter=filter_expr_style,
|
||||
output_fields=["path", "style", "category", "sys_file_id"]
|
||||
)
|
||||
|
||||
# 第二阶段:搜索其他style的向量
|
||||
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
|
||||
results_others = client.search(
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
data=[query_vector.tolist()],
|
||||
anns_field="feature_vector",
|
||||
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||
limit=topk,
|
||||
filter=filter_expr_others,
|
||||
output_fields=["path", "style", "category", "sys_file_id"]
|
||||
)
|
||||
|
||||
# 合并结果
|
||||
results = []
|
||||
if results_style and len(results_style) > 0:
|
||||
results.extend(results_style[0])
|
||||
if results_others and len(results_others) > 0:
|
||||
results.extend(results_others[0])
|
||||
|
||||
# 转换为单个结果列表格式
|
||||
results = [results] if results else []
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = []
|
||||
@@ -249,7 +284,10 @@ def search_similar_vectors(
|
||||
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
||||
})
|
||||
|
||||
return formatted_results
|
||||
# 按分数排序并返回topk
|
||||
formatted_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return formatted_results[:topk]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"向量检索失败: {e}", exc_info=True)
|
||||
return []
|
||||
@@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
|
||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||
filter=filter_expr,
|
||||
output_fields=["path", "style", "category"],
|
||||
limit=10000 # 先查询大量数据,然后随机选择
|
||||
limit=10000
|
||||
)
|
||||
|
||||
# 随机选择
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
import math
|
||||
import pymysql
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def optimize_database_table():
|
||||
"""
|
||||
优化 user_preference_log_test 表结构
|
||||
优化 user_preference 表结构
|
||||
添加冗余字段和索引
|
||||
"""
|
||||
conn = None
|
||||
@@ -419,8 +420,8 @@ def compute_user_preference_vector(
|
||||
p_i = 1 + math.log(1 + like_count)
|
||||
|
||||
# 综合权重
|
||||
# w_i = d_k * p_i
|
||||
w_i = p_i
|
||||
w_i = d_k * p_i
|
||||
# w_i = p_i
|
||||
|
||||
vectors.append(feature_vector)
|
||||
weights.append(w_i)
|
||||
|
||||
Reference in New Issue
Block a user