Files
AiDA_Python/app/service/recommendation_system/milvus_client.py
zcr 2a6c48d937
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
新推荐接口first commit
2025-12-30 17:23:36 +08:00

295 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Milvus 客户端封装
"""
import logging
from typing import List, Dict, Optional, Any
import numpy as np
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
from app.core.config import settings
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
logger = logging.getLogger(__name__)
# Milvus 客户端(单例)
_milvus_client = None
def get_milvus_client() -> MilvusClient:
"""获取 Milvus 客户端(单例模式)"""
global _milvus_client
if _milvus_client is None:
try:
_milvus_client = MilvusClient(
uri=settings.MILVUS_URL,
token=settings.MILVUS_TOKEN,
db_name=settings.MILVUS_DB,
)
logger.info("Milvus 客户端连接成功")
except Exception as e:
logger.error(f"Milvus 客户端连接失败: {e}")
raise
return _milvus_client
def create_collection():
"""
创建 Milvus 集合 sketch_vectors
集合结构:
- path (PK, varchar(512)) - 主键MinIO 逻辑 URL
- sys_file_id (int64, 可为NULL) - 系统文件ID
- style (varchar(50), 可为NULL) - 风格样式
- category (varchar(100), 可为NULL) - 类别
- is_system_sketch (int8, 默认 1) - 标记字段1-系统图0-用户图
- deprecated (int8, 默认 0) - 是否废弃
- feature_vector (FloatVector(2048)) - 2048维特征向量
"""
client = get_milvus_client()
# 检查集合是否已存在
collections = client.list_collections()
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
return
try:
# 解析 Milvus URL
# 处理 http://host.docker.internal:19530 格式
url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "")
if ":" in url_clean:
host, port_str = url_clean.split(":", 1)
port = int(port_str)
else:
host = url_clean
port = 19530
# 使用传统 API 创建集合(更可靠)
# 连接到 Milvus如果未连接
try:
connections.connect(
alias=settings.MILVUS_ALIAS,
host=host,
port=port,
token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None
)
logger.info(f"已连接到 Milvus: {host}:{port}")
except Exception as conn_e:
# 如果连接已存在,忽略错误
if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e):
logger.info("Milvus 连接已存在")
else:
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
# 定义字段
fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
FieldSchema(name="sys_file_id", dtype=DataType.INT64),
FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50),
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50),
FieldSchema(name="is_system_sketch", dtype=DataType.INT8),
FieldSchema(name="deprecated", dtype=DataType.INT8),
FieldSchema(
name="feature_vector",
dtype=DataType.FLOAT_VECTOR,
dim=RECOMMENDATION_CONFIG["vector_dim"]
)
]
# 创建 schema
schema = CollectionSchema(
fields=fields,
description="Sketch vectors collection for recommendation system"
)
# 创建集合
collection = Collection(
name=MILVUS_COLLECTION_SKETCH_VECTORS,
schema=schema,
using=settings.MILVUS_ALIAS
)
# 创建索引
# 注意:使用 IP内积作为度量类型与搜索时保持一致
# 如果向量已归一化IP 等价于 COSINE
index_params = {
"metric_type": "IP", # 内积Inner Product
"index_type": "IVF_FLAT",
"params": {"nlist": 1024}
}
collection.create_index(
field_name="feature_vector",
index_params=index_params
)
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
except Exception as e:
logger.error(f"创建集合失败: {e}", exc_info=True)
raise
def insert_vectors(data: List[Dict[str, Any]]):
"""
批量插入向量到 Milvus
Args:
data: 数据列表,每个元素包含:
- path: str
- sys_file_id: int (可选)
- style: str (可选)
- category: str (可选)
- is_system_sketch: int (默认 1)
- deprecated: int (默认 0)
- feature_vector: List[float] (2048维)
"""
if not data:
return
client = get_milvus_client()
try:
client.insert(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=data
)
logger.info(f"成功插入 {len(data)} 条向量数据")
except Exception as e:
logger.error(f"插入向量失败: {e}", exc_info=True)
raise
def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
"""
根据 path 列表批量查询向量
Args:
paths: path 列表
Returns:
{path: {feature_vector: [...], ...}} 字典
"""
if not paths:
return {}
client = get_milvus_client()
try:
# 构建查询表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
# 对于字符串列表,使用单引号包裹每个值
path_list = ", ".join([f"'{p}'" for p in paths])
filter_expr = f"path in [{path_list}]"
results = client.query(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
)
# 转换为字典
result_dict = {}
for r in results:
result_dict[r["path"]] = r
return result_dict
except Exception as e:
logger.error(f"查询向量失败: {e}", exc_info=True)
return {}
def search_similar_vectors(
query_vector: np.ndarray,
category: str,
topk: int = 500,
style: Optional[str] = None
) -> List[Dict]:
"""
向量相似度检索
Args:
query_vector: 查询向量2048维
category: 类别过滤
topk: 返回数量
style: 风格过滤(可选)
Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段
"""
client = get_milvus_client()
try:
# 构建过滤表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
filter_expr = f"category == '{category}' && deprecated == 0"
if style:
filter_expr += f" && style == '{style}'"
# 搜索
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 格式化结果
formatted_results = []
if results and len(results) > 0:
for hit in results[0]:
formatted_results.append({
"path": hit.get("entity", {}).get("path", ""),
"score": hit.get("distance", 0.0),
"style": hit.get("entity", {}).get("style", ""),
"category": hit.get("entity", {}).get("category", ""),
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
})
return formatted_results
except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True)
return []
def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]:
"""
随机查询候选(用于探索分支)
Args:
category: 类别
style: 风格(可选)
limit: 返回数量
Returns:
候选列表
"""
client = get_milvus_client()
try:
# 构建过滤表达式
filter_expr = f"category == '{category}' && deprecated == 0"
if style:
filter_expr += f" && style == '{style}'"
# 查询所有符合条件的记录
results = client.query(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择
)
# 随机选择
if len(results) > limit:
import random
results = random.sample(results, limit)
return results
except Exception as e:
logger.error(f"随机查询候选失败: {e}", exc_info=True)
return []