""" Milvus 客户端封装 """ import logging from typing import List, Dict, Optional, Any import numpy as np from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection from app.core.config import MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG logger = logging.getLogger(__name__) # Milvus 客户端(单例) _milvus_client = None def get_milvus_client() -> MilvusClient: """获取 Milvus 客户端(单例模式)""" global _milvus_client if _milvus_client is None: try: _milvus_client = MilvusClient( uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS ) logger.info("Milvus 客户端连接成功") except Exception as e: logger.error(f"Milvus 客户端连接失败: {e}") raise return _milvus_client def create_collection(): """ 创建 Milvus 集合 sketch_vectors 集合结构: - path (PK, varchar(512)) - 主键,MinIO 逻辑 URL - sys_file_id (int64, 可为NULL) - 系统文件ID - style (varchar(50), 可为NULL) - 风格样式 - category (varchar(100), 可为NULL) - 类别 - is_system_sketch (int8, 默认 1) - 标记字段:1-系统图,0-用户图 - deprecated (int8, 默认 0) - 是否废弃 - feature_vector (FloatVector(2048)) - 2048维特征向量 """ client = get_milvus_client() # 检查集合是否已存在 collections = client.list_collections() if MILVUS_COLLECTION_SKETCH_VECTORS in collections: logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在") return try: # 解析 Milvus URL # 处理 http://host.docker.internal:19530 格式 url_clean = MILVUS_URL.replace("http://", "").replace("https://", "") if ":" in url_clean: host, port_str = url_clean.split(":", 1) port = int(port_str) else: host = url_clean port = 19530 # 使用传统 API 创建集合(更可靠) # 连接到 Milvus(如果未连接) try: connections.connect( alias=MILVUS_ALIAS, host=host, port=port, token=MILVUS_TOKEN if MILVUS_TOKEN else None ) logger.info(f"已连接到 Milvus: {host}:{port}") except Exception as conn_e: # 如果连接已存在,忽略错误 if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e): logger.info("Milvus 连接已存在") else: logger.warning(f"连接 Milvus 时出现警告: {conn_e}") # 定义字段 fields = [ FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512), FieldSchema(name="sys_file_id", dtype=DataType.INT64), FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50), FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50), FieldSchema(name="is_system_sketch", dtype=DataType.INT8), FieldSchema(name="deprecated", dtype=DataType.INT8), FieldSchema( name="feature_vector", dtype=DataType.FLOAT_VECTOR, dim=RECOMMENDATION_CONFIG["vector_dim"] ) ] # 创建 schema schema = CollectionSchema( fields=fields, description="Sketch vectors collection for recommendation system" ) # 创建集合 collection = Collection( name=MILVUS_COLLECTION_SKETCH_VECTORS, schema=schema, using=MILVUS_ALIAS ) # 创建索引 # 注意:使用 IP(内积)作为度量类型,与搜索时保持一致 # 如果向量已归一化,IP 等价于 COSINE index_params = { "metric_type": "IP", # 内积(Inner Product) "index_type": "IVF_FLAT", "params": {"nlist": 1024} } collection.create_index( field_name="feature_vector", index_params=index_params ) logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功") except Exception as e: logger.error(f"创建集合失败: {e}", exc_info=True) raise def insert_vectors(data: List[Dict[str, Any]]): """ 批量插入向量到 Milvus Args: data: 数据列表,每个元素包含: - path: str - sys_file_id: int (可选) - style: str (可选) - category: str (可选) - is_system_sketch: int (默认 1) - deprecated: int (默认 0) - feature_vector: List[float] (2048维) """ if not data: return client = get_milvus_client() try: client.insert( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, data=data ) logger.info(f"成功插入 {len(data)} 条向量数据") except Exception as e: logger.error(f"插入向量失败: {e}", exc_info=True) raise def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]: """ 根据 path 列表批量查询向量 Args: paths: path 列表 Returns: {path: {feature_vector: [...], ...}} 字典 """ if not paths: return {} client = get_milvus_client() try: # 构建查询表达式 # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) # 对于字符串列表,使用单引号包裹每个值 path_list = ", ".join([f"'{p}'" for p in paths]) filter_expr = f"path in [{path_list}]" results = client.query( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"] ) # 转换为字典 result_dict = {} for r in results: result_dict[r["path"]] = r return result_dict except Exception as e: logger.error(f"查询向量失败: {e}", exc_info=True) return {} def search_similar_vectors( query_vector: np.ndarray, category: str, topk: int = 500, style: Optional[str] = None ) -> List[Dict]: """ 向量相似度检索 Args: query_vector: 查询向量(2048维) category: 类别过滤 topk: 返回数量 style: 风格过滤(可选) Returns: 检索结果列表,每个元素包含 path, score, style, category 等字段 """ client = get_milvus_client() try: # 构建过滤表达式 # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) filter_expr = f"category == '{category}' && deprecated == 0" if style: filter_expr += f" && style == '{style}'" # 搜索 results = client.search( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, data=[query_vector.tolist()], anns_field="feature_vector", search_params={"metric_type": "IP", "params": {"nprobe": 10}}, limit=topk, filter=filter_expr, output_fields=["path", "style", "category", "sys_file_id"] ) # 格式化结果 formatted_results = [] if results and len(results) > 0: for hit in results[0]: formatted_results.append({ "path": hit.get("entity", {}).get("path", ""), "score": hit.get("distance", 0.0), "style": hit.get("entity", {}).get("style", ""), "category": hit.get("entity", {}).get("category", ""), "sys_file_id": hit.get("entity", {}).get("sys_file_id") }) return formatted_results except Exception as e: logger.error(f"向量检索失败: {e}", exc_info=True) return [] def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]: """ 随机查询候选(用于探索分支) Args: category: 类别 style: 风格(可选) limit: 返回数量 Returns: 候选列表 """ client = get_milvus_client() try: # 构建过滤表达式 filter_expr = f"category == '{category}' && deprecated == 0" if style: filter_expr += f" && style == '{style}'" # 查询所有符合条件的记录 results = client.query( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "style", "category"], limit=10000 # 先查询大量数据,然后随机选择 ) # 随机选择 if len(results) > limit: import random results = random.sample(results, limit) return results except Exception as e: logger.error(f"随机查询候选失败: {e}", exc_info=True) return []