新推荐接口first commit
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
This commit is contained in:
@@ -2,12 +2,7 @@
|
|||||||
推荐系统配置
|
推荐系统配置
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from app.core.config import (
|
from app.core.config import settings
|
||||||
DB_CONFIG, DB_HOST, DB_PORT, DB_USERNAME, DB_PASSWORD, DB_NAME,
|
|
||||||
REDIS_HOST, REDIS_PORT, REDIS_DB,
|
|
||||||
MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS,
|
|
||||||
MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Milvus 集合名称
|
# Milvus 集合名称
|
||||||
MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm"
|
MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm"
|
||||||
@@ -20,39 +15,39 @@ RECOMMENDATION_CONFIG = {
|
|||||||
# 时间衰减半衰期(用于计算时间衰减权重)
|
# 时间衰减半衰期(用于计算时间衰减权重)
|
||||||
# 值越小,最近的行为权重越大
|
# 值越小,最近的行为权重越大
|
||||||
"K_half": 20,
|
"K_half": 20,
|
||||||
|
|
||||||
# 探索与利用的比例 (0.0-1.0)
|
# 探索与利用的比例 (0.0-1.0)
|
||||||
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
|
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
|
||||||
# - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准
|
# - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准
|
||||||
# - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8
|
# - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8
|
||||||
"explore_ratio": 0.5,
|
"explore_ratio": 0.5,
|
||||||
|
|
||||||
# 向量检索返回的候选数量
|
# 向量检索返回的候选数量
|
||||||
# 值越大,候选池越大,但计算成本也越高
|
# 值越大,候选池越大,但计算成本也越高
|
||||||
# 建议范围: 100-1000
|
# 建议范围: 100-1000
|
||||||
"topk": 1000,
|
"topk": 1000,
|
||||||
|
|
||||||
# Style 加分系数(同 style 的候选进行加分)
|
# Style 加分系数(同 style 的候选进行加分)
|
||||||
# 值越大,匹配 style 的候选被选中的概率越大
|
# 值越大,匹配 style 的候选被选中的概率越大
|
||||||
# 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05)
|
# 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05)
|
||||||
"style_bonus": 0.2,
|
"style_bonus": 0.2,
|
||||||
|
|
||||||
# Softmax 抽样的温度参数
|
# Softmax 抽样的温度参数
|
||||||
# - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低
|
# - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低
|
||||||
# - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高
|
# - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高
|
||||||
# - 温度=1.0 为标准 Softmax
|
# - 温度=1.0 为标准 Softmax
|
||||||
# - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0
|
# - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0
|
||||||
"softmax_temperature": 0.07,
|
"softmax_temperature": 0.07,
|
||||||
|
|
||||||
# 监听间隔(秒)
|
# 监听间隔(秒)
|
||||||
"listen_interval_sec": 30,
|
"listen_interval_sec": 30,
|
||||||
|
|
||||||
# 批量处理大小
|
# 批量处理大小
|
||||||
"batch_size": 1000,
|
"batch_size": 1000,
|
||||||
|
|
||||||
# Redis 过期时间(秒,30天)
|
# Redis 过期时间(秒,30天)
|
||||||
"redis_expire_seconds": 2592000,
|
"redis_expire_seconds": 2592000,
|
||||||
|
|
||||||
# 向量维度
|
# 向量维度
|
||||||
"vector_dim": 2048,
|
"vector_dim": 2048,
|
||||||
}
|
}
|
||||||
@@ -63,11 +58,10 @@ TABLE_SYS_FILE = "t_sys_file"
|
|||||||
|
|
||||||
# MySQL 连接配置(用于推荐系统)
|
# MySQL 连接配置(用于推荐系统)
|
||||||
MYSQL_CONFIG = {
|
MYSQL_CONFIG = {
|
||||||
"host": DB_HOST,
|
"host": settings.MYSQL_HOST,
|
||||||
"port": DB_PORT,
|
"port": settings.MYSQL_PORT,
|
||||||
"user": DB_USERNAME,
|
"user": settings.MYSQL_USER,
|
||||||
"password": DB_PASSWORD,
|
"password": settings.MYSQL_PASSWORD,
|
||||||
"database": DB_NAME,
|
"database": settings.MYSQL_DB,
|
||||||
"charset": "utf8mb4"
|
"charset": "utf8mb4"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import List, Dict, Optional, Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
|
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
|
||||||
|
|
||||||
from app.core.config import MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS
|
from app.core.config import settings
|
||||||
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
|
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -21,9 +21,9 @@ def get_milvus_client() -> MilvusClient:
|
|||||||
if _milvus_client is None:
|
if _milvus_client is None:
|
||||||
try:
|
try:
|
||||||
_milvus_client = MilvusClient(
|
_milvus_client = MilvusClient(
|
||||||
uri=MILVUS_URL,
|
uri=settings.MILVUS_URL,
|
||||||
token=MILVUS_TOKEN,
|
token=settings.MILVUS_TOKEN,
|
||||||
db_name=MILVUS_ALIAS
|
db_name=settings.MILVUS_DB,
|
||||||
)
|
)
|
||||||
logger.info("Milvus 客户端连接成功")
|
logger.info("Milvus 客户端连接成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -46,32 +46,32 @@ def create_collection():
|
|||||||
- feature_vector (FloatVector(2048)) - 2048维特征向量
|
- feature_vector (FloatVector(2048)) - 2048维特征向量
|
||||||
"""
|
"""
|
||||||
client = get_milvus_client()
|
client = get_milvus_client()
|
||||||
|
|
||||||
# 检查集合是否已存在
|
# 检查集合是否已存在
|
||||||
collections = client.list_collections()
|
collections = client.list_collections()
|
||||||
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
|
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
|
||||||
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
|
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析 Milvus URL
|
# 解析 Milvus URL
|
||||||
# 处理 http://host.docker.internal:19530 格式
|
# 处理 http://host.docker.internal:19530 格式
|
||||||
url_clean = MILVUS_URL.replace("http://", "").replace("https://", "")
|
url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "")
|
||||||
if ":" in url_clean:
|
if ":" in url_clean:
|
||||||
host, port_str = url_clean.split(":", 1)
|
host, port_str = url_clean.split(":", 1)
|
||||||
port = int(port_str)
|
port = int(port_str)
|
||||||
else:
|
else:
|
||||||
host = url_clean
|
host = url_clean
|
||||||
port = 19530
|
port = 19530
|
||||||
|
|
||||||
# 使用传统 API 创建集合(更可靠)
|
# 使用传统 API 创建集合(更可靠)
|
||||||
# 连接到 Milvus(如果未连接)
|
# 连接到 Milvus(如果未连接)
|
||||||
try:
|
try:
|
||||||
connections.connect(
|
connections.connect(
|
||||||
alias=MILVUS_ALIAS,
|
alias=settings.MILVUS_ALIAS,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
token=MILVUS_TOKEN if MILVUS_TOKEN else None
|
token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None
|
||||||
)
|
)
|
||||||
logger.info(f"已连接到 Milvus: {host}:{port}")
|
logger.info(f"已连接到 Milvus: {host}:{port}")
|
||||||
except Exception as conn_e:
|
except Exception as conn_e:
|
||||||
@@ -80,7 +80,7 @@ def create_collection():
|
|||||||
logger.info("Milvus 连接已存在")
|
logger.info("Milvus 连接已存在")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
|
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
|
||||||
|
|
||||||
# 定义字段
|
# 定义字段
|
||||||
fields = [
|
fields = [
|
||||||
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
||||||
@@ -95,20 +95,20 @@ def create_collection():
|
|||||||
dim=RECOMMENDATION_CONFIG["vector_dim"]
|
dim=RECOMMENDATION_CONFIG["vector_dim"]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 创建 schema
|
# 创建 schema
|
||||||
schema = CollectionSchema(
|
schema = CollectionSchema(
|
||||||
fields=fields,
|
fields=fields,
|
||||||
description="Sketch vectors collection for recommendation system"
|
description="Sketch vectors collection for recommendation system"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建集合
|
# 创建集合
|
||||||
collection = Collection(
|
collection = Collection(
|
||||||
name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
using=MILVUS_ALIAS
|
using=settings.MILVUS_ALIAS
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建索引
|
# 创建索引
|
||||||
# 注意:使用 IP(内积)作为度量类型,与搜索时保持一致
|
# 注意:使用 IP(内积)作为度量类型,与搜索时保持一致
|
||||||
# 如果向量已归一化,IP 等价于 COSINE
|
# 如果向量已归一化,IP 等价于 COSINE
|
||||||
@@ -117,14 +117,14 @@ def create_collection():
|
|||||||
"index_type": "IVF_FLAT",
|
"index_type": "IVF_FLAT",
|
||||||
"params": {"nlist": 1024}
|
"params": {"nlist": 1024}
|
||||||
}
|
}
|
||||||
|
|
||||||
collection.create_index(
|
collection.create_index(
|
||||||
field_name="feature_vector",
|
field_name="feature_vector",
|
||||||
index_params=index_params
|
index_params=index_params
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
|
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建集合失败: {e}", exc_info=True)
|
logger.error(f"创建集合失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -146,9 +146,9 @@ def insert_vectors(data: List[Dict[str, Any]]):
|
|||||||
"""
|
"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
client = get_milvus_client()
|
client = get_milvus_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client.insert(
|
client.insert(
|
||||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
@@ -172,27 +172,27 @@ def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
|
|||||||
"""
|
"""
|
||||||
if not paths:
|
if not paths:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
client = get_milvus_client()
|
client = get_milvus_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 构建查询表达式
|
# 构建查询表达式
|
||||||
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
||||||
# 对于字符串列表,使用单引号包裹每个值
|
# 对于字符串列表,使用单引号包裹每个值
|
||||||
path_list = ", ".join([f"'{p}'" for p in paths])
|
path_list = ", ".join([f"'{p}'" for p in paths])
|
||||||
filter_expr = f"path in [{path_list}]"
|
filter_expr = f"path in [{path_list}]"
|
||||||
|
|
||||||
results = client.query(
|
results = client.query(
|
||||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
filter=filter_expr,
|
filter=filter_expr,
|
||||||
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
|
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 转换为字典
|
# 转换为字典
|
||||||
result_dict = {}
|
result_dict = {}
|
||||||
for r in results:
|
for r in results:
|
||||||
result_dict[r["path"]] = r
|
result_dict[r["path"]] = r
|
||||||
|
|
||||||
return result_dict
|
return result_dict
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"查询向量失败: {e}", exc_info=True)
|
logger.error(f"查询向量失败: {e}", exc_info=True)
|
||||||
@@ -200,10 +200,10 @@ def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
|
|||||||
|
|
||||||
|
|
||||||
def search_similar_vectors(
|
def search_similar_vectors(
|
||||||
query_vector: np.ndarray,
|
query_vector: np.ndarray,
|
||||||
category: str,
|
category: str,
|
||||||
topk: int = 500,
|
topk: int = 500,
|
||||||
style: Optional[str] = None
|
style: Optional[str] = None
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
向量相似度检索
|
向量相似度检索
|
||||||
@@ -218,14 +218,14 @@ def search_similar_vectors(
|
|||||||
检索结果列表,每个元素包含 path, score, style, category 等字段
|
检索结果列表,每个元素包含 path, score, style, category 等字段
|
||||||
"""
|
"""
|
||||||
client = get_milvus_client()
|
client = get_milvus_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 构建过滤表达式
|
# 构建过滤表达式
|
||||||
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
||||||
filter_expr = f"category == '{category}' && deprecated == 0"
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||||
if style:
|
if style:
|
||||||
filter_expr += f" && style == '{style}'"
|
filter_expr += f" && style == '{style}'"
|
||||||
|
|
||||||
# 搜索
|
# 搜索
|
||||||
results = client.search(
|
results = client.search(
|
||||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
@@ -236,7 +236,7 @@ def search_similar_vectors(
|
|||||||
filter=filter_expr,
|
filter=filter_expr,
|
||||||
output_fields=["path", "style", "category", "sys_file_id"]
|
output_fields=["path", "style", "category", "sys_file_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 格式化结果
|
# 格式化结果
|
||||||
formatted_results = []
|
formatted_results = []
|
||||||
if results and len(results) > 0:
|
if results and len(results) > 0:
|
||||||
@@ -248,7 +248,7 @@ def search_similar_vectors(
|
|||||||
"category": hit.get("entity", {}).get("category", ""),
|
"category": hit.get("entity", {}).get("category", ""),
|
||||||
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
||||||
})
|
})
|
||||||
|
|
||||||
return formatted_results
|
return formatted_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"向量检索失败: {e}", exc_info=True)
|
logger.error(f"向量检索失败: {e}", exc_info=True)
|
||||||
@@ -268,13 +268,13 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
|
|||||||
候选列表
|
候选列表
|
||||||
"""
|
"""
|
||||||
client = get_milvus_client()
|
client = get_milvus_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 构建过滤表达式
|
# 构建过滤表达式
|
||||||
filter_expr = f"category == '{category}' && deprecated == 0"
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||||
if style:
|
if style:
|
||||||
filter_expr += f" && style == '{style}'"
|
filter_expr += f" && style == '{style}'"
|
||||||
|
|
||||||
# 查询所有符合条件的记录
|
# 查询所有符合条件的记录
|
||||||
results = client.query(
|
results = client.query(
|
||||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
@@ -282,14 +282,13 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
|
|||||||
output_fields=["path", "style", "category"],
|
output_fields=["path", "style", "category"],
|
||||||
limit=10000 # 先查询大量数据,然后随机选择
|
limit=10000 # 先查询大量数据,然后随机选择
|
||||||
)
|
)
|
||||||
|
|
||||||
# 随机选择
|
# 随机选择
|
||||||
if len(results) > limit:
|
if len(results) > limit:
|
||||||
import random
|
import random
|
||||||
results = random.sample(results, limit)
|
results = random.sample(results, limit)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"随机查询候选失败: {e}", exc_info=True)
|
logger.error(f"随机查询候选失败: {e}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user