2025-12-29 10:52:33 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Milvus 客户端封装
|
|
|
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from typing import List, Dict, Optional, Any
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
|
|
|
|
|
|
|
2025-12-30 17:23:36 +08:00
|
|
|
|
from app.core.config import settings
|
2025-12-29 10:52:33 +08:00
|
|
|
|
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
# Milvus 客户端(单例)
|
|
|
|
|
|
_milvus_client = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_milvus_client() -> MilvusClient:
|
|
|
|
|
|
"""获取 Milvus 客户端(单例模式)"""
|
|
|
|
|
|
global _milvus_client
|
|
|
|
|
|
if _milvus_client is None:
|
|
|
|
|
|
try:
|
|
|
|
|
|
_milvus_client = MilvusClient(
|
2025-12-30 17:23:36 +08:00
|
|
|
|
uri=settings.MILVUS_URL,
|
|
|
|
|
|
token=settings.MILVUS_TOKEN,
|
2025-12-30 17:35:32 +08:00
|
|
|
|
db_name="",
|
2025-12-29 10:52:33 +08:00
|
|
|
|
)
|
|
|
|
|
|
logger.info("Milvus 客户端连接成功")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"Milvus 客户端连接失败: {e}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
return _milvus_client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_collection():
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建 Milvus 集合 sketch_vectors
|
|
|
|
|
|
|
|
|
|
|
|
集合结构:
|
|
|
|
|
|
- path (PK, varchar(512)) - 主键,MinIO 逻辑 URL
|
|
|
|
|
|
- sys_file_id (int64, 可为NULL) - 系统文件ID
|
|
|
|
|
|
- style (varchar(50), 可为NULL) - 风格样式
|
|
|
|
|
|
- category (varchar(100), 可为NULL) - 类别
|
|
|
|
|
|
- is_system_sketch (int8, 默认 1) - 标记字段:1-系统图,0-用户图
|
|
|
|
|
|
- deprecated (int8, 默认 0) - 是否废弃
|
|
|
|
|
|
- feature_vector (FloatVector(2048)) - 2048维特征向量
|
|
|
|
|
|
"""
|
|
|
|
|
|
client = get_milvus_client()
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 检查集合是否已存在
|
|
|
|
|
|
collections = client.list_collections()
|
|
|
|
|
|
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
|
|
|
|
|
|
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
|
|
|
|
|
|
return
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
try:
|
|
|
|
|
|
# 解析 Milvus URL
|
|
|
|
|
|
# 处理 http://host.docker.internal:19530 格式
|
2025-12-30 17:23:36 +08:00
|
|
|
|
url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "")
|
2025-12-29 10:52:33 +08:00
|
|
|
|
if ":" in url_clean:
|
|
|
|
|
|
host, port_str = url_clean.split(":", 1)
|
|
|
|
|
|
port = int(port_str)
|
|
|
|
|
|
else:
|
|
|
|
|
|
host = url_clean
|
|
|
|
|
|
port = 19530
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 使用传统 API 创建集合(更可靠)
|
|
|
|
|
|
# 连接到 Milvus(如果未连接)
|
|
|
|
|
|
try:
|
|
|
|
|
|
connections.connect(
|
2025-12-30 17:23:36 +08:00
|
|
|
|
alias=settings.MILVUS_ALIAS,
|
2025-12-29 10:52:33 +08:00
|
|
|
|
host=host,
|
|
|
|
|
|
port=port,
|
2025-12-30 17:23:36 +08:00
|
|
|
|
token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None
|
2025-12-29 10:52:33 +08:00
|
|
|
|
)
|
|
|
|
|
|
logger.info(f"已连接到 Milvus: {host}:{port}")
|
|
|
|
|
|
except Exception as conn_e:
|
|
|
|
|
|
# 如果连接已存在,忽略错误
|
|
|
|
|
|
if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e):
|
|
|
|
|
|
logger.info("Milvus 连接已存在")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 定义字段
|
|
|
|
|
|
fields = [
|
|
|
|
|
|
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
|
|
|
|
|
FieldSchema(name="sys_file_id", dtype=DataType.INT64),
|
|
|
|
|
|
FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50),
|
|
|
|
|
|
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50),
|
|
|
|
|
|
FieldSchema(name="is_system_sketch", dtype=DataType.INT8),
|
|
|
|
|
|
FieldSchema(name="deprecated", dtype=DataType.INT8),
|
|
|
|
|
|
FieldSchema(
|
|
|
|
|
|
name="feature_vector",
|
|
|
|
|
|
dtype=DataType.FLOAT_VECTOR,
|
|
|
|
|
|
dim=RECOMMENDATION_CONFIG["vector_dim"]
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 创建 schema
|
|
|
|
|
|
schema = CollectionSchema(
|
|
|
|
|
|
fields=fields,
|
|
|
|
|
|
description="Sketch vectors collection for recommendation system"
|
|
|
|
|
|
)
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 创建集合
|
|
|
|
|
|
collection = Collection(
|
|
|
|
|
|
name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
|
|
|
|
|
schema=schema,
|
2025-12-30 17:23:36 +08:00
|
|
|
|
using=settings.MILVUS_ALIAS
|
2025-12-29 10:52:33 +08:00
|
|
|
|
)
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 创建索引
|
|
|
|
|
|
# 注意:使用 IP(内积)作为度量类型,与搜索时保持一致
|
|
|
|
|
|
# 如果向量已归一化,IP 等价于 COSINE
|
|
|
|
|
|
index_params = {
|
|
|
|
|
|
"metric_type": "IP", # 内积(Inner Product)
|
|
|
|
|
|
"index_type": "IVF_FLAT",
|
|
|
|
|
|
"params": {"nlist": 1024}
|
|
|
|
|
|
}
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
collection.create_index(
|
|
|
|
|
|
field_name="feature_vector",
|
|
|
|
|
|
index_params=index_params
|
|
|
|
|
|
)
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"创建集合失败: {e}", exc_info=True)
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def insert_vectors(data: List[Dict[str, Any]]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
批量插入向量到 Milvus
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
data: 数据列表,每个元素包含:
|
|
|
|
|
|
- path: str
|
|
|
|
|
|
- sys_file_id: int (可选)
|
|
|
|
|
|
- style: str (可选)
|
|
|
|
|
|
- category: str (可选)
|
|
|
|
|
|
- is_system_sketch: int (默认 1)
|
|
|
|
|
|
- deprecated: int (默认 0)
|
|
|
|
|
|
- feature_vector: List[float] (2048维)
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
return
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
client = get_milvus_client()
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
try:
|
|
|
|
|
|
client.insert(
|
|
|
|
|
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
|
|
|
|
|
data=data
|
|
|
|
|
|
)
|
|
|
|
|
|
logger.info(f"成功插入 {len(data)} 条向量数据")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"插入向量失败: {e}", exc_info=True)
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据 path 列表批量查询向量
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
paths: path 列表
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
{path: {feature_vector: [...], ...}} 字典
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not paths:
|
|
|
|
|
|
return {}
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
client = get_milvus_client()
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
try:
|
|
|
|
|
|
# 构建查询表达式
|
|
|
|
|
|
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
|
|
|
|
|
# 对于字符串列表,使用单引号包裹每个值
|
|
|
|
|
|
path_list = ", ".join([f"'{p}'" for p in paths])
|
|
|
|
|
|
filter_expr = f"path in [{path_list}]"
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
results = client.query(
|
|
|
|
|
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
|
|
|
|
|
filter=filter_expr,
|
|
|
|
|
|
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
|
|
|
|
|
|
)
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 转换为字典
|
|
|
|
|
|
result_dict = {}
|
|
|
|
|
|
for r in results:
|
|
|
|
|
|
result_dict[r["path"]] = r
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
return result_dict
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"查询向量失败: {e}", exc_info=True)
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search_similar_vectors(
|
2025-12-30 17:23:36 +08:00
|
|
|
|
query_vector: np.ndarray,
|
|
|
|
|
|
category: str,
|
|
|
|
|
|
topk: int = 500,
|
2026-01-12 09:49:07 +08:00
|
|
|
|
style: Optional[str] = None,
|
|
|
|
|
|
style_boost_ratio: float = 0.2
|
2025-12-29 10:52:33 +08:00
|
|
|
|
) -> List[Dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
向量相似度检索
|
2026-01-12 09:49:07 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
Args:
|
|
|
|
|
|
query_vector: 查询向量(2048维)
|
|
|
|
|
|
category: 类别过滤
|
|
|
|
|
|
topk: 返回数量
|
2026-01-12 09:49:07 +08:00
|
|
|
|
style: 风格过滤(可选)- 当提供时,会给对应style的结果加分
|
|
|
|
|
|
style_boost_ratio: 风格加分比例(默认0.1,即10%)
|
|
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
Returns:
|
|
|
|
|
|
检索结果列表,每个元素包含 path, score, style, category 等字段
|
|
|
|
|
|
"""
|
|
|
|
|
|
client = get_milvus_client()
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
try:
|
2026-01-12 09:49:07 +08:00
|
|
|
|
# 如果没有指定style,使用原始逻辑
|
|
|
|
|
|
if not style:
|
|
|
|
|
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
|
|
|
|
|
results = client.search(
|
|
|
|
|
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
|
|
|
|
|
data=[query_vector.tolist()],
|
|
|
|
|
|
anns_field="feature_vector",
|
|
|
|
|
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
|
|
|
|
|
limit=topk,
|
|
|
|
|
|
filter=filter_expr,
|
|
|
|
|
|
output_fields=["path", "style", "category", "sys_file_id"]
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 有style参数时,使用两阶段搜索策略
|
|
|
|
|
|
|
|
|
|
|
|
# 第一阶段:搜索匹配style的向量,使用boosted query vector
|
|
|
|
|
|
filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
|
|
|
|
|
|
boosted_query = query_vector * (1 + style_boost_ratio)
|
|
|
|
|
|
results_style = client.search(
|
|
|
|
|
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
|
|
|
|
|
data=[boosted_query.tolist()],
|
|
|
|
|
|
anns_field="feature_vector",
|
|
|
|
|
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
|
|
|
|
|
limit=topk,
|
|
|
|
|
|
filter=filter_expr_style,
|
|
|
|
|
|
output_fields=["path", "style", "category", "sys_file_id"]
|
|
|
|
|
|
)
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2026-01-12 09:49:07 +08:00
|
|
|
|
# 第二阶段:搜索其他style的向量
|
|
|
|
|
|
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
|
|
|
|
|
|
results_others = client.search(
|
|
|
|
|
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
|
|
|
|
|
data=[query_vector.tolist()],
|
|
|
|
|
|
anns_field="feature_vector",
|
|
|
|
|
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
|
|
|
|
|
limit=topk,
|
|
|
|
|
|
filter=filter_expr_others,
|
|
|
|
|
|
output_fields=["path", "style", "category", "sys_file_id"]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 合并结果
|
|
|
|
|
|
results = []
|
|
|
|
|
|
if results_style and len(results_style) > 0:
|
|
|
|
|
|
results.extend(results_style[0])
|
|
|
|
|
|
if results_others and len(results_others) > 0:
|
|
|
|
|
|
results.extend(results_others[0])
|
|
|
|
|
|
|
|
|
|
|
|
# 转换为单个结果列表格式
|
|
|
|
|
|
results = [results] if results else []
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 格式化结果
|
|
|
|
|
|
formatted_results = []
|
|
|
|
|
|
if results and len(results) > 0:
|
|
|
|
|
|
for hit in results[0]:
|
|
|
|
|
|
formatted_results.append({
|
|
|
|
|
|
"path": hit.get("entity", {}).get("path", ""),
|
|
|
|
|
|
"score": hit.get("distance", 0.0),
|
|
|
|
|
|
"style": hit.get("entity", {}).get("style", ""),
|
|
|
|
|
|
"category": hit.get("entity", {}).get("category", ""),
|
|
|
|
|
|
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
|
|
|
|
|
})
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2026-01-12 09:49:07 +08:00
|
|
|
|
# 按分数排序并返回topk
|
|
|
|
|
|
formatted_results.sort(key=lambda x: x["score"], reverse=True)
|
|
|
|
|
|
return formatted_results[:topk]
|
|
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"向量检索失败: {e}", exc_info=True)
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
随机查询候选(用于探索分支)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
category: 类别
|
|
|
|
|
|
style: 风格(可选)
|
|
|
|
|
|
limit: 返回数量
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
候选列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
client = get_milvus_client()
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
try:
|
|
|
|
|
|
# 构建过滤表达式
|
|
|
|
|
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
|
|
|
|
|
if style:
|
|
|
|
|
|
filter_expr += f" && style == '{style}'"
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 查询所有符合条件的记录
|
|
|
|
|
|
results = client.query(
|
|
|
|
|
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
|
|
|
|
|
filter=filter_expr,
|
|
|
|
|
|
output_fields=["path", "style", "category"],
|
2026-01-12 09:49:07 +08:00
|
|
|
|
limit=10000
|
2025-12-29 10:52:33 +08:00
|
|
|
|
)
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
# 随机选择
|
|
|
|
|
|
if len(results) > limit:
|
|
|
|
|
|
import random
|
|
|
|
|
|
results = random.sample(results, limit)
|
2025-12-30 17:23:36 +08:00
|
|
|
|
|
2025-12-29 10:52:33 +08:00
|
|
|
|
return results
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"随机查询候选失败: {e}", exc_info=True)
|
|
|
|
|
|
return []
|