""" Milvus 客户端封装 """ import logging from typing import List, Dict, Optional, Any import numpy as np from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection from app.core.config import settings from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG logger = logging.getLogger(__name__) # Milvus 客户端(单例) _milvus_client = None def get_milvus_client() -> MilvusClient: """获取 Milvus 客户端(单例模式)""" global _milvus_client if _milvus_client is None: try: _milvus_client = MilvusClient( uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name="", ) logger.info("Milvus 客户端连接成功") except Exception as e: logger.error(f"Milvus 客户端连接失败: {e}") raise return _milvus_client def create_collection(): """ 创建 Milvus 集合 sketch_vectors 集合结构: - path (PK, varchar(512)) - 主键,MinIO 逻辑 URL - sys_file_id (int64, 可为NULL) - 系统文件ID - style (varchar(50), 可为NULL) - 风格样式 - category (varchar(100), 可为NULL) - 类别 - is_system_sketch (int8, 默认 1) - 标记字段:1-系统图,0-用户图 - deprecated (int8, 默认 0) - 是否废弃 - feature_vector (FloatVector(2048)) - 2048维特征向量 """ client = get_milvus_client() # 检查集合是否已存在 collections = client.list_collections() if MILVUS_COLLECTION_SKETCH_VECTORS in collections: logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在") return try: # 解析 Milvus URL # 处理 http://host.docker.internal:19530 格式 url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "") if ":" in url_clean: host, port_str = url_clean.split(":", 1) port = int(port_str) else: host = url_clean port = 19530 # 使用传统 API 创建集合(更可靠) # 连接到 Milvus(如果未连接) try: connections.connect( alias=settings.MILVUS_ALIAS, host=host, port=port, token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None ) logger.info(f"已连接到 Milvus: {host}:{port}") except Exception as conn_e: # 如果连接已存在,忽略错误 if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e): logger.info("Milvus 连接已存在") else: logger.warning(f"连接 Milvus 时出现警告: {conn_e}") # 定义字段 fields = [ FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512), FieldSchema(name="sys_file_id", dtype=DataType.INT64), FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50), FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50), FieldSchema(name="is_system_sketch", dtype=DataType.INT8), FieldSchema(name="deprecated", dtype=DataType.INT8), FieldSchema( name="feature_vector", dtype=DataType.FLOAT_VECTOR, dim=RECOMMENDATION_CONFIG["vector_dim"] ) ] # 创建 schema schema = CollectionSchema( fields=fields, description="Sketch vectors collection for recommendation system" ) # 创建集合 collection = Collection( name=MILVUS_COLLECTION_SKETCH_VECTORS, schema=schema, using=settings.MILVUS_ALIAS ) # 创建索引 # 注意:使用 IP(内积)作为度量类型,与搜索时保持一致 # 如果向量已归一化,IP 等价于 COSINE index_params = { "metric_type": "IP", # 内积(Inner Product) "index_type": "IVF_FLAT", "params": {"nlist": 1024} } collection.create_index( field_name="feature_vector", index_params=index_params ) logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功") except Exception as e: logger.error(f"创建集合失败: {e}", exc_info=True) raise def insert_vectors(data: List[Dict[str, Any]]): """ 批量插入向量到 Milvus Args: data: 数据列表,每个元素包含: - path: str - sys_file_id: int (可选) - style: str (可选) - category: str (可选) - is_system_sketch: int (默认 1) - deprecated: int (默认 0) - feature_vector: List[float] (2048维) """ if not data: return client = get_milvus_client() try: client.insert( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, data=data ) logger.info(f"成功插入 {len(data)} 条向量数据") except Exception as e: logger.error(f"插入向量失败: {e}", exc_info=True) raise def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]: """ 根据 path 列表批量查询向量 Args: paths: path 列表 Returns: {path: {feature_vector: [...], ...}} 字典 """ if not paths: return {} client = get_milvus_client() try: # 构建查询表达式 # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) # 对于字符串列表,使用单引号包裹每个值 path_list = ", ".join([f"'{p}'" for p in paths]) filter_expr = f"path in [{path_list}]" results = client.query( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"] ) # 转换为字典 result_dict = {} for r in results: result_dict[r["path"]] = r return result_dict except Exception as e: logger.error(f"查询向量失败: {e}", exc_info=True) return {} def search_similar_vectors( query_vector: np.ndarray, category: str, topk: int = 500, style: Optional[str] = None, style_boost_ratio: float = 0.2 ) -> List[Dict]: """ 向量相似度检索 Args: query_vector: 查询向量(2048维) category: 类别过滤 topk: 返回数量 style: 风格过滤(可选)- 当提供时,会给对应style的结果加分 style_boost_ratio: 风格加分比例(默认0.1,即10%) Returns: 检索结果列表,每个元素包含 path, score, style, category 等字段 """ client = get_milvus_client() try: # 如果没有指定style,使用原始逻辑 if not style: filter_expr = f"category == '{category}' && deprecated == 0" results = client.search( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, data=[query_vector.tolist()], anns_field="feature_vector", search_params={"metric_type": "IP", "params": {"nprobe": 10}}, limit=topk, filter=filter_expr, output_fields=["path", "style", "category", "sys_file_id"] ) else: # 有style参数时,使用两阶段搜索策略 # 第一阶段:搜索匹配style的向量,使用boosted query vector filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'" boosted_query = query_vector * (1 + style_boost_ratio) results_style = client.search( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, data=[boosted_query.tolist()], anns_field="feature_vector", search_params={"metric_type": "IP", "params": {"nprobe": 10}}, limit=topk, filter=filter_expr_style, output_fields=["path", "style", "category", "sys_file_id"] ) # 第二阶段:搜索其他style的向量 filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'" results_others = client.search( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, data=[query_vector.tolist()], anns_field="feature_vector", search_params={"metric_type": "IP", "params": {"nprobe": 10}}, limit=topk, filter=filter_expr_others, output_fields=["path", "style", "category", "sys_file_id"] ) # 合并结果 results = [] if results_style and len(results_style) > 0: results.extend(results_style[0]) if results_others and len(results_others) > 0: results.extend(results_others[0]) # 转换为单个结果列表格式 results = [results] if results else [] # 格式化结果 formatted_results = [] if results and len(results) > 0: for hit in results[0]: formatted_results.append({ "path": hit.get("entity", {}).get("path", ""), "score": hit.get("distance", 0.0), "style": hit.get("entity", {}).get("style", ""), "category": hit.get("entity", {}).get("category", ""), "sys_file_id": hit.get("entity", {}).get("sys_file_id") }) # 按分数排序并返回topk formatted_results.sort(key=lambda x: x["score"], reverse=True) return formatted_results[:topk] except Exception as e: logger.error(f"向量检索失败: {e}", exc_info=True) return [] def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]: """ 随机查询候选(用于探索分支) Args: category: 类别 style: 风格(可选) limit: 返回数量 Returns: 候选列表 """ client = get_milvus_client() try: # 构建过滤表达式 filter_expr = f"category == '{category}' && deprecated == 0" if style: filter_expr += f" && style == '{style}'" # 查询所有符合条件的记录 results = client.query( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "style", "category"], limit=10000 ) # 随机选择 if len(results) > limit: import random results = random.sample(results, limit) return results except Exception as e: logger.error(f"随机查询候选失败: {e}", exc_info=True) return []