""" Milvus 客户端封装 """ import logging from typing import List, Dict, Optional, Any import numpy as np from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection from app.core.config import settings from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG logger = logging.getLogger(__name__) # Milvus 客户端(单例) _milvus_client = None def get_milvus_client() -> MilvusClient: """获取 Milvus 客户端(单例模式)""" global _milvus_client if _milvus_client is None: try: _milvus_client = MilvusClient( uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_DB, ) logger.info("Milvus 客户端连接成功") except Exception as e: logger.error(f"Milvus 客户端连接失败: {e}") raise return _milvus_client def create_collection(): """ 创建 Milvus 集合 sketch_vectors 集合结构: - path (PK, varchar(512)) - 主键,MinIO 逻辑 URL - sys_file_id (int64, 可为NULL) - 系统文件ID - style (varchar(50), 可为NULL) - 风格样式 - category (varchar(100), 可为NULL) - 类别 - is_system_sketch (int8, 默认 1) - 标记字段:1-系统图,0-用户图 - deprecated (int8, 默认 0) - 是否废弃 - feature_vector (FloatVector(2048)) - 2048维特征向量 """ client = get_milvus_client() # 检查集合是否已存在 collections = client.list_collections() if MILVUS_COLLECTION_SKETCH_VECTORS in collections: logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在") return try: # 解析 Milvus URL # 处理 http://host.docker.internal:19530 格式 url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "") if ":" in url_clean: host, port_str = url_clean.split(":", 1) port = int(port_str) else: host = url_clean port = 19530 # 使用传统 API 创建集合(更可靠) # 连接到 Milvus(如果未连接) try: connections.connect( alias=settings.MILVUS_ALIAS, host=host, port=port, token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None ) logger.info(f"已连接到 Milvus: {host}:{port}") except Exception as conn_e: # 如果连接已存在,忽略错误 if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e): logger.info("Milvus 连接已存在") else: logger.warning(f"连接 Milvus 时出现警告: {conn_e}") # 定义字段 fields = [ FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512), FieldSchema(name="sys_file_id", dtype=DataType.INT64), FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50), FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50), FieldSchema(name="is_system_sketch", dtype=DataType.INT8), FieldSchema(name="deprecated", dtype=DataType.INT8), FieldSchema( name="feature_vector", dtype=DataType.FLOAT_VECTOR, dim=RECOMMENDATION_CONFIG["vector_dim"] ) ] # 创建 schema schema = CollectionSchema( fields=fields, description="Sketch vectors collection for recommendation system" ) # 创建集合 collection = Collection( name=MILVUS_COLLECTION_SKETCH_VECTORS, schema=schema, using=settings.MILVUS_ALIAS ) # 创建索引 # 注意:使用 IP(内积)作为度量类型,与搜索时保持一致 # 如果向量已归一化,IP 等价于 COSINE index_params = { "metric_type": "IP", # 内积(Inner Product) "index_type": "IVF_FLAT", "params": {"nlist": 1024} } collection.create_index( field_name="feature_vector", index_params=index_params ) logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功") except Exception as e: logger.error(f"创建集合失败: {e}", exc_info=True) raise def insert_vectors(data: List[Dict[str, Any]]): """ 批量插入向量到 Milvus Args: data: 数据列表,每个元素包含: - path: str - sys_file_id: int (可选) - style: str (可选) - category: str (可选) - is_system_sketch: int (默认 1) - deprecated: int (默认 0) - feature_vector: List[float] (2048维) """ if not data: return client = get_milvus_client() try: client.insert( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, data=data ) logger.info(f"成功插入 {len(data)} 条向量数据") except Exception as e: logger.error(f"插入向量失败: {e}", exc_info=True) raise def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]: """ 根据 path 列表批量查询向量 Args: paths: path 列表 Returns: {path: {feature_vector: [...], ...}} 字典 """ if not paths: return {} client = get_milvus_client() try: # 构建查询表达式 # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) # 对于字符串列表,使用单引号包裹每个值 path_list = ", ".join([f"'{p}'" for p in paths]) filter_expr = f"path in [{path_list}]" results = client.query( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"] ) # 转换为字典 result_dict = {} for r in results: result_dict[r["path"]] = r return result_dict except Exception as e: logger.error(f"查询向量失败: {e}", exc_info=True) return {} def search_similar_vectors( query_vector: np.ndarray, category: str, topk: int = 500, style: Optional[str] = None ) -> List[Dict]: """ 向量相似度检索 Args: query_vector: 查询向量(2048维) category: 类别过滤 topk: 返回数量 style: 风格过滤(可选) Returns: 检索结果列表,每个元素包含 path, score, style, category 等字段 """ client = get_milvus_client() try: # 构建过滤表达式 # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) filter_expr = f"category == '{category}' && deprecated == 0" if style: filter_expr += f" && style == '{style}'" # 搜索 results = client.search( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, data=[query_vector.tolist()], anns_field="feature_vector", search_params={"metric_type": "IP", "params": {"nprobe": 10}}, limit=topk, filter=filter_expr, output_fields=["path", "style", "category", "sys_file_id"] ) # 格式化结果 formatted_results = [] if results and len(results) > 0: for hit in results[0]: formatted_results.append({ "path": hit.get("entity", {}).get("path", ""), "score": hit.get("distance", 0.0), "style": hit.get("entity", {}).get("style", ""), "category": hit.get("entity", {}).get("category", ""), "sys_file_id": hit.get("entity", {}).get("sys_file_id") }) return formatted_results except Exception as e: logger.error(f"向量检索失败: {e}", exc_info=True) return [] def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]: """ 随机查询候选(用于探索分支) Args: category: 类别 style: 风格(可选) limit: 返回数量 Returns: 候选列表 """ client = get_milvus_client() try: # 构建过滤表达式 filter_expr = f"category == '{category}' && deprecated == 0" if style: filter_expr += f" && style == '{style}'" # 查询所有符合条件的记录 results = client.query( collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "style", "category"], limit=10000 # 先查询大量数据,然后随机选择 ) # 随机选择 if len(results) > limit: import random results = random.sample(results, limit) return results except Exception as e: logger.error(f"随机查询候选失败: {e}", exc_info=True) return []