新推荐接口first commit

This commit is contained in:
litianxiang
2025-12-29 10:52:33 +08:00
committed by zcr
parent 417528f8cd
commit fed3fcdf85
13 changed files with 2634 additions and 460 deletions

View File

@@ -0,0 +1,295 @@
"""
Milvus 客户端封装
"""
import logging
from typing import List, Dict, Optional, Any
import numpy as np
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
from app.core.config import MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
logger = logging.getLogger(__name__)
# Milvus 客户端(单例)
_milvus_client = None
def get_milvus_client() -> MilvusClient:
"""获取 Milvus 客户端(单例模式)"""
global _milvus_client
if _milvus_client is None:
try:
_milvus_client = MilvusClient(
uri=MILVUS_URL,
token=MILVUS_TOKEN,
db_name=MILVUS_ALIAS
)
logger.info("Milvus 客户端连接成功")
except Exception as e:
logger.error(f"Milvus 客户端连接失败: {e}")
raise
return _milvus_client
def create_collection():
"""
创建 Milvus 集合 sketch_vectors
集合结构:
- path (PK, varchar(512)) - 主键MinIO 逻辑 URL
- sys_file_id (int64, 可为NULL) - 系统文件ID
- style (varchar(50), 可为NULL) - 风格样式
- category (varchar(100), 可为NULL) - 类别
- is_system_sketch (int8, 默认 1) - 标记字段1-系统图0-用户图
- deprecated (int8, 默认 0) - 是否废弃
- feature_vector (FloatVector(2048)) - 2048维特征向量
"""
client = get_milvus_client()
# 检查集合是否已存在
collections = client.list_collections()
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
return
try:
# 解析 Milvus URL
# 处理 http://host.docker.internal:19530 格式
url_clean = MILVUS_URL.replace("http://", "").replace("https://", "")
if ":" in url_clean:
host, port_str = url_clean.split(":", 1)
port = int(port_str)
else:
host = url_clean
port = 19530
# 使用传统 API 创建集合(更可靠)
# 连接到 Milvus如果未连接
try:
connections.connect(
alias=MILVUS_ALIAS,
host=host,
port=port,
token=MILVUS_TOKEN if MILVUS_TOKEN else None
)
logger.info(f"已连接到 Milvus: {host}:{port}")
except Exception as conn_e:
# 如果连接已存在,忽略错误
if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e):
logger.info("Milvus 连接已存在")
else:
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
# 定义字段
fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
FieldSchema(name="sys_file_id", dtype=DataType.INT64),
FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50),
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50),
FieldSchema(name="is_system_sketch", dtype=DataType.INT8),
FieldSchema(name="deprecated", dtype=DataType.INT8),
FieldSchema(
name="feature_vector",
dtype=DataType.FLOAT_VECTOR,
dim=RECOMMENDATION_CONFIG["vector_dim"]
)
]
# 创建 schema
schema = CollectionSchema(
fields=fields,
description="Sketch vectors collection for recommendation system"
)
# 创建集合
collection = Collection(
name=MILVUS_COLLECTION_SKETCH_VECTORS,
schema=schema,
using=MILVUS_ALIAS
)
# 创建索引
# 注意:使用 IP内积作为度量类型与搜索时保持一致
# 如果向量已归一化IP 等价于 COSINE
index_params = {
"metric_type": "IP", # 内积Inner Product
"index_type": "IVF_FLAT",
"params": {"nlist": 1024}
}
collection.create_index(
field_name="feature_vector",
index_params=index_params
)
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
except Exception as e:
logger.error(f"创建集合失败: {e}", exc_info=True)
raise
def insert_vectors(data: List[Dict[str, Any]]):
"""
批量插入向量到 Milvus
Args:
data: 数据列表,每个元素包含:
- path: str
- sys_file_id: int (可选)
- style: str (可选)
- category: str (可选)
- is_system_sketch: int (默认 1)
- deprecated: int (默认 0)
- feature_vector: List[float] (2048维)
"""
if not data:
return
client = get_milvus_client()
try:
client.insert(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=data
)
logger.info(f"成功插入 {len(data)} 条向量数据")
except Exception as e:
logger.error(f"插入向量失败: {e}", exc_info=True)
raise
def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
"""
根据 path 列表批量查询向量
Args:
paths: path 列表
Returns:
{path: {feature_vector: [...], ...}} 字典
"""
if not paths:
return {}
client = get_milvus_client()
try:
# 构建查询表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
# 对于字符串列表,使用单引号包裹每个值
path_list = ", ".join([f"'{p}'" for p in paths])
filter_expr = f"path in [{path_list}]"
results = client.query(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
)
# 转换为字典
result_dict = {}
for r in results:
result_dict[r["path"]] = r
return result_dict
except Exception as e:
logger.error(f"查询向量失败: {e}", exc_info=True)
return {}
def search_similar_vectors(
query_vector: np.ndarray,
category: str,
topk: int = 500,
style: Optional[str] = None
) -> List[Dict]:
"""
向量相似度检索
Args:
query_vector: 查询向量2048维
category: 类别过滤
topk: 返回数量
style: 风格过滤(可选)
Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段
"""
client = get_milvus_client()
try:
# 构建过滤表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
filter_expr = f"category == '{category}' && deprecated == 0"
if style:
filter_expr += f" && style == '{style}'"
# 搜索
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 格式化结果
formatted_results = []
if results and len(results) > 0:
for hit in results[0]:
formatted_results.append({
"path": hit.get("entity", {}).get("path", ""),
"score": hit.get("distance", 0.0),
"style": hit.get("entity", {}).get("style", ""),
"category": hit.get("entity", {}).get("category", ""),
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
})
return formatted_results
except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True)
return []
def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]:
"""
随机查询候选(用于探索分支)
Args:
category: 类别
style: 风格(可选)
limit: 返回数量
Returns:
候选列表
"""
client = get_milvus_client()
try:
# 构建过滤表达式
filter_expr = f"category == '{category}' && deprecated == 0"
if style:
filter_expr += f" && style == '{style}'"
# 查询所有符合条件的记录
results = client.query(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择
)
# 随机选择
if len(results) > limit:
import random
results = random.sample(results, limit)
return results
except Exception as e:
logger.error(f"随机查询候选失败: {e}", exc_info=True)
return []