From 2a6c48d937644ff435f8706bb841dc4897ae3909 Mon Sep 17 00:00:00 2001 From: zcr Date: Tue, 30 Dec 2025 17:23:36 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E6=8E=A8=E8=8D=90=E6=8E=A5=E5=8F=A3fi?= =?UTF-8?q?rst=20commit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/recommendation_system/config.py | 34 ++++----- .../recommendation_system/milvus_client.py | 75 +++++++++---------- 2 files changed, 51 insertions(+), 58 deletions(-) diff --git a/app/service/recommendation_system/config.py b/app/service/recommendation_system/config.py index 9e6f40b..42221d1 100644 --- a/app/service/recommendation_system/config.py +++ b/app/service/recommendation_system/config.py @@ -2,12 +2,7 @@ 推荐系统配置 """ import os -from app.core.config import ( - DB_CONFIG, DB_HOST, DB_PORT, DB_USERNAME, DB_PASSWORD, DB_NAME, - REDIS_HOST, REDIS_PORT, REDIS_DB, - MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS, - MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE -) +from app.core.config import settings # Milvus 集合名称 MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm" @@ -20,39 +15,39 @@ RECOMMENDATION_CONFIG = { # 时间衰减半衰期(用于计算时间衰减权重) # 值越小,最近的行为权重越大 "K_half": 20, - + # 探索与利用的比例 (0.0-1.0) # - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机 # - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准 # - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8 "explore_ratio": 0.5, - + # 向量检索返回的候选数量 # 值越大,候选池越大,但计算成本也越高 # 建议范围: 100-1000 "topk": 1000, - + # Style 加分系数(同 style 的候选进行加分) # 值越大,匹配 style 的候选被选中的概率越大 # 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05) "style_bonus": 0.2, - + # Softmax 抽样的温度参数 # - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低 # - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高 # - 温度=1.0 为标准 Softmax # - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0 "softmax_temperature": 0.07, - + # 监听间隔(秒) "listen_interval_sec": 30, - + # 批量处理大小 "batch_size": 1000, - + # Redis 过期时间(秒,30天) "redis_expire_seconds": 2592000, - + # 向量维度 "vector_dim": 2048, } @@ -63,11 +58,10 @@ TABLE_SYS_FILE = "t_sys_file" # MySQL 连接配置(用于推荐系统) MYSQL_CONFIG = { - "host": DB_HOST, - "port": DB_PORT, - "user": DB_USERNAME, - "password": DB_PASSWORD, - "database": DB_NAME, + "host": settings.MYSQL_HOST, + "port": settings.MYSQL_PORT, + "user": settings.MYSQL_USER, + "password": settings.MYSQL_PASSWORD, + "database": settings.MYSQL_DB, "charset": "utf8mb4" } - diff --git a/app/service/recommendation_system/milvus_client.py b/app/service/recommendation_system/milvus_client.py index b17cf2c..0f4ef75 100644 --- a/app/service/recommendation_system/milvus_client.py +++ b/app/service/recommendation_system/milvus_client.py @@ -6,7 +6,7 @@ from typing import List, Dict, Optional, Any import numpy as np from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection -from app.core.config import MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS +from app.core.config import settings from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG logger = logging.getLogger(__name__) @@ -21,9 +21,9 @@ def get_milvus_client() -> MilvusClient: if _milvus_client is None: try: _milvus_client = MilvusClient( - uri=MILVUS_URL, - token=MILVUS_TOKEN, - db_name=MILVUS_ALIAS + uri=settings.MILVUS_URL, + token=settings.MILVUS_TOKEN, + db_name=settings.MILVUS_DB, ) logger.info("Milvus 客户端连接成功") except Exception as e: @@ -46,32 +46,32 @@ def create_collection(): - feature_vector (FloatVector(2048)) - 2048维特征向量 """ client = get_milvus_client() - + # 检查集合是否已存在 collections = client.list_collections() if MILVUS_COLLECTION_SKETCH_VECTORS in collections: logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在") return - + try: # 解析 Milvus URL # 处理 http://host.docker.internal:19530 格式 - url_clean = MILVUS_URL.replace("http://", "").replace("https://", "") + url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "") if ":" in url_clean: host, port_str = url_clean.split(":", 1) port = int(port_str) else: host = url_clean port = 19530 - + # 使用传统 API 创建集合(更可靠) # 连接到 Milvus(如果未连接) try: connections.connect( - alias=MILVUS_ALIAS, + alias=settings.MILVUS_ALIAS, host=host, port=port, - token=MILVUS_TOKEN if MILVUS_TOKEN else None + token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None ) logger.info(f"已连接到 Milvus: {host}:{port}") except Exception as conn_e: @@ -80,7 +80,7 @@ def create_collection(): logger.info("Milvus 连接已存在") else: logger.warning(f"连接 Milvus 时出现警告: {conn_e}") - + # 定义字段 fields = [ FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512), @@ -95,20 +95,20 @@ def create_collection(): dim=RECOMMENDATION_CONFIG["vector_dim"] ) ] - + # 创建 schema schema = CollectionSchema( fields=fields, description="Sketch vectors collection for recommendation system" ) - + # 创建集合 collection = Collection( name=MILVUS_COLLECTION_SKETCH_VECTORS, schema=schema, - using=MILVUS_ALIAS + using=settings.MILVUS_ALIAS ) - + # 创建索引 # 注意:使用 IP(内积)作为度量类型,与搜索时保持一致 # 如果向量已归一化,IP 等价于 COSINE @@ -117,14 +117,14 @@ def create_collection(): "index_type": "IVF_FLAT", "params": {"nlist": 1024} } - + collection.create_index( field_name="feature_vector", index_params=index_params ) - + logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功") - + except Exception as e: logger.error(f"创建集合失败: {e}", exc_info=True) raise @@ -146,9 +146,9 @@ def insert_vectors(data: List[Dict[str, Any]]): """ if not data: return - + client = get_milvus_client() - + try: client.insert( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, @@ -172,27 +172,27 @@ def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]: """ if not paths: return {} - + client = get_milvus_client() - + try: # 构建查询表达式 # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) # 对于字符串列表,使用单引号包裹每个值 path_list = ", ".join([f"'{p}'" for p in paths]) filter_expr = f"path in [{path_list}]" - + results = client.query( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"] ) - + # 转换为字典 result_dict = {} for r in results: result_dict[r["path"]] = r - + return result_dict except Exception as e: logger.error(f"查询向量失败: {e}", exc_info=True) @@ -200,10 +200,10 @@ def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]: def search_similar_vectors( - query_vector: np.ndarray, - category: str, - topk: int = 500, - style: Optional[str] = None + query_vector: np.ndarray, + category: str, + topk: int = 500, + style: Optional[str] = None ) -> List[Dict]: """ 向量相似度检索 @@ -218,14 +218,14 @@ def search_similar_vectors( 检索结果列表,每个元素包含 path, score, style, category 等字段 """ client = get_milvus_client() - + try: # 构建过滤表达式 # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) filter_expr = f"category == '{category}' && deprecated == 0" if style: filter_expr += f" && style == '{style}'" - + # 搜索 results = client.search( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, @@ -236,7 +236,7 @@ def search_similar_vectors( filter=filter_expr, output_fields=["path", "style", "category", "sys_file_id"] ) - + # 格式化结果 formatted_results = [] if results and len(results) > 0: @@ -248,7 +248,7 @@ def search_similar_vectors( "category": hit.get("entity", {}).get("category", ""), "sys_file_id": hit.get("entity", {}).get("sys_file_id") }) - + return formatted_results except Exception as e: logger.error(f"向量检索失败: {e}", exc_info=True) @@ -268,13 +268,13 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i 候选列表 """ client = get_milvus_client() - + try: # 构建过滤表达式 filter_expr = f"category == '{category}' && deprecated == 0" if style: filter_expr += f" && style == '{style}'" - + # 查询所有符合条件的记录 results = client.query( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, @@ -282,14 +282,13 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i output_fields=["path", "style", "category"], limit=10000 # 先查询大量数据,然后随机选择 ) - + # 随机选择 if len(results) > limit: import random results = random.sample(results, limit) - + return results except Exception as e: logger.error(f"随机查询候选失败: {e}", exc_info=True) return [] -