fix:推荐接口

This commit is contained in:
litianxiang
2026-01-12 09:49:07 +08:00
parent c792106f02
commit 2af9cbfe78
5 changed files with 69 additions and 30 deletions

View File

@@ -203,39 +203,74 @@ def search_similar_vectors(
query_vector: np.ndarray,
category: str,
topk: int = 500,
style: Optional[str] = None
style: Optional[str] = None,
style_boost_ratio: float = 0.2
) -> List[Dict]:
"""
向量相似度检索
Args:
query_vector: 查询向量2048维
category: 类别过滤
topk: 返回数量
style: 风格过滤(可选)
style: 风格过滤(可选)- 当提供时会给对应style的结果加分
style_boost_ratio: 风格加分比例默认0.1即10%
Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段
"""
client = get_milvus_client()
try:
# 构建过滤表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
filter_expr = f"category == '{category}' && deprecated == 0"
if style:
filter_expr += f" && style == '{style}'"
# 如果没有指定style使用原始逻辑
if not style:
filter_expr = f"category == '{category}' && deprecated == 0"
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
else:
# 有style参数时使用两阶段搜索策略
# 搜索
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第一阶段搜索匹配style的向量使用boosted query vector
filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
boosted_query = query_vector * (1 + style_boost_ratio)
results_style = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[boosted_query.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_style,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第二阶段搜索其他style的向量
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
results_others = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_others,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 合并结果
results = []
if results_style and len(results_style) > 0:
results.extend(results_style[0])
if results_others and len(results_others) > 0:
results.extend(results_others[0])
# 转换为单个结果列表格式
results = [results] if results else []
# 格式化结果
formatted_results = []
@@ -249,7 +284,10 @@ def search_similar_vectors(
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
})
return formatted_results
# 按分数排序并返回topk
formatted_results.sort(key=lambda x: x["score"], reverse=True)
return formatted_results[:topk]
except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True)
return []
@@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择
limit=10000
)
# 随机选择