From 2af9cbfe780b6ea13ca5791715ee58ae04c1bf08 Mon Sep 17 00:00:00 2001 From: litianxiang Date: Mon, 12 Jan 2026 09:49:07 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E6=8E=A8=E8=8D=90=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_recommendation.py | 2 +- app/service/recommendation_system/config.py | 6 +- .../incremental_listener.py | 4 +- .../recommendation_system/milvus_client.py | 80 ++++++++++++++----- .../recommendation_system/precompute.py | 7 +- 5 files changed, 69 insertions(+), 30 deletions(-) diff --git a/app/api/api_recommendation.py b/app/api/api_recommendation.py index e5b86b1..24b388e 100644 --- a/app/api/api_recommendation.py +++ b/app/api/api_recommendation.py @@ -137,7 +137,7 @@ router = APIRouter() # logger.error(f"推荐失败: {str(e)}", exc_info=True) # raise HTTPException(status_code=500, detail=str(e)) -# @router.on_event("startup") +@router.on_event("startup") async def startup_event(): """启动时初始化增量监听任务""" try: diff --git a/app/service/recommendation_system/config.py b/app/service/recommendation_system/config.py index 42221d1..bd48362 100644 --- a/app/service/recommendation_system/config.py +++ b/app/service/recommendation_system/config.py @@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref" RECOMMENDATION_CONFIG = { # 时间衰减半衰期(用于计算时间衰减权重) # 值越小,最近的行为权重越大 - "K_half": 20, + "K_half": 10, # 探索与利用的比例 (0.0-1.0) # - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机 @@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = { # 向量检索返回的候选数量 # 值越大,候选池越大,但计算成本也越高 # 建议范围: 100-1000 - "topk": 1000, + "topk": 200, # Style 加分系数(同 style 的候选进行加分) # 值越大,匹配 style 的候选被选中的概率越大 @@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = { } # 数据库表名 -TABLE_USER_PREFERENCE_LOG = "user_preference_log_test" +TABLE_USER_PREFERENCE_LOG = "user_preference" TABLE_SYS_FILE = "t_sys_file" # MySQL 连接配置(用于推荐系统) diff --git a/app/service/recommendation_system/incremental_listener.py b/app/service/recommendation_system/incremental_listener.py index 08c3b21..93ecf34 100644 --- a/app/service/recommendation_system/incremental_listener.py +++ b/app/service/recommendation_system/incremental_listener.py @@ -1,6 +1,6 @@ """ 增量监听模块 -实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量 +实时监听 user_preference 表的新增记录,更新用户偏好向量 """ import logging import math @@ -258,7 +258,7 @@ class IncrementalListener: } else: # 用户图 - # 从 user_preference_log_test 获取 category(如果有) + # 从 user_preference 获取 category(如果有) cursor.execute(f""" SELECT category FROM {TABLE_USER_PREFERENCE_LOG} diff --git a/app/service/recommendation_system/milvus_client.py b/app/service/recommendation_system/milvus_client.py index a027f99..5fefa71 100644 --- a/app/service/recommendation_system/milvus_client.py +++ b/app/service/recommendation_system/milvus_client.py @@ -203,39 +203,74 @@ def search_similar_vectors( query_vector: np.ndarray, category: str, topk: int = 500, - style: Optional[str] = None + style: Optional[str] = None, + style_boost_ratio: float = 0.2 ) -> List[Dict]: """ 向量相似度检索 - + Args: query_vector: 查询向量(2048维) category: 类别过滤 topk: 返回数量 - style: 风格过滤(可选) - + style: 风格过滤(可选)- 当提供时,会给对应style的结果加分 + style_boost_ratio: 风格加分比例(默认0.1,即10%) + Returns: 检索结果列表,每个元素包含 path, score, style, category 等字段 """ client = get_milvus_client() try: - # 构建过滤表达式 - # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) - filter_expr = f"category == '{category}' && deprecated == 0" - if style: - filter_expr += f" && style == '{style}'" + # 如果没有指定style,使用原始逻辑 + if not style: + filter_expr = f"category == '{category}' && deprecated == 0" + results = client.search( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + data=[query_vector.tolist()], + anns_field="feature_vector", + search_params={"metric_type": "IP", "params": {"nprobe": 10}}, + limit=topk, + filter=filter_expr, + output_fields=["path", "style", "category", "sys_file_id"] + ) + else: + # 有style参数时,使用两阶段搜索策略 - # 搜索 - results = client.search( - collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, - data=[query_vector.tolist()], - anns_field="feature_vector", - search_params={"metric_type": "IP", "params": {"nprobe": 10}}, - limit=topk, - filter=filter_expr, - output_fields=["path", "style", "category", "sys_file_id"] - ) + # 第一阶段:搜索匹配style的向量,使用boosted query vector + filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'" + boosted_query = query_vector * (1 + style_boost_ratio) + results_style = client.search( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + data=[boosted_query.tolist()], + anns_field="feature_vector", + search_params={"metric_type": "IP", "params": {"nprobe": 10}}, + limit=topk, + filter=filter_expr_style, + output_fields=["path", "style", "category", "sys_file_id"] + ) + + # 第二阶段:搜索其他style的向量 + filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'" + results_others = client.search( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + data=[query_vector.tolist()], + anns_field="feature_vector", + search_params={"metric_type": "IP", "params": {"nprobe": 10}}, + limit=topk, + filter=filter_expr_others, + output_fields=["path", "style", "category", "sys_file_id"] + ) + + # 合并结果 + results = [] + if results_style and len(results_style) > 0: + results.extend(results_style[0]) + if results_others and len(results_others) > 0: + results.extend(results_others[0]) + + # 转换为单个结果列表格式 + results = [results] if results else [] # 格式化结果 formatted_results = [] @@ -249,7 +284,10 @@ def search_similar_vectors( "sys_file_id": hit.get("entity", {}).get("sys_file_id") }) - return formatted_results + # 按分数排序并返回topk + formatted_results.sort(key=lambda x: x["score"], reverse=True) + return formatted_results[:topk] + except Exception as e: logger.error(f"向量检索失败: {e}", exc_info=True) return [] @@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "style", "category"], - limit=10000 # 先查询大量数据,然后随机选择 + limit=10000 ) # 随机选择 diff --git a/app/service/recommendation_system/precompute.py b/app/service/recommendation_system/precompute.py index c4797d1..30904fa 100644 --- a/app/service/recommendation_system/precompute.py +++ b/app/service/recommendation_system/precompute.py @@ -6,6 +6,7 @@ import logging import math import pymysql import numpy as np +from datetime import datetime from typing import List, Dict, Tuple, Optional from collections import defaultdict @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) def optimize_database_table(): """ - 优化 user_preference_log_test 表结构 + 优化 user_preference 表结构 添加冗余字段和索引 """ conn = None @@ -419,8 +420,8 @@ def compute_user_preference_vector( p_i = 1 + math.log(1 + like_count) # 综合权重 - # w_i = d_k * p_i - w_i = p_i + w_i = d_k * p_i + # w_i = p_i vectors.append(feature_vector) weights.append(w_i)