新推荐接口first commit
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped

This commit is contained in:
zcr
2025-12-30 17:23:36 +08:00
parent fed3fcdf85
commit 2a6c48d937
2 changed files with 51 additions and 58 deletions

View File

@@ -2,12 +2,7 @@
推荐系统配置 推荐系统配置
""" """
import os import os
from app.core.config import ( from app.core.config import settings
DB_CONFIG, DB_HOST, DB_PORT, DB_USERNAME, DB_PASSWORD, DB_NAME,
REDIS_HOST, REDIS_PORT, REDIS_DB,
MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS,
MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
)
# Milvus 集合名称 # Milvus 集合名称
MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm" MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm"
@@ -20,39 +15,39 @@ RECOMMENDATION_CONFIG = {
# 时间衰减半衰期(用于计算时间衰减权重) # 时间衰减半衰期(用于计算时间衰减权重)
# 值越小,最近的行为权重越大 # 值越小,最近的行为权重越大
"K_half": 20, "K_half": 20,
# 探索与利用的比例 (0.0-1.0) # 探索与利用的比例 (0.0-1.0)
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机 # - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
# - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准 # - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准
# - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8 # - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8
"explore_ratio": 0.5, "explore_ratio": 0.5,
# 向量检索返回的候选数量 # 向量检索返回的候选数量
# 值越大,候选池越大,但计算成本也越高 # 值越大,候选池越大,但计算成本也越高
# 建议范围: 100-1000 # 建议范围: 100-1000
"topk": 1000, "topk": 1000,
# Style 加分系数(同 style 的候选进行加分) # Style 加分系数(同 style 的候选进行加分)
# 值越大,匹配 style 的候选被选中的概率越大 # 值越大,匹配 style 的候选被选中的概率越大
# 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05 # 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05
"style_bonus": 0.2, "style_bonus": 0.2,
# Softmax 抽样的温度参数 # Softmax 抽样的温度参数
# - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低 # - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低
# - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高 # - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高
# - 温度=1.0 为标准 Softmax # - 温度=1.0 为标准 Softmax
# - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0 # - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0
"softmax_temperature": 0.07, "softmax_temperature": 0.07,
# 监听间隔(秒) # 监听间隔(秒)
"listen_interval_sec": 30, "listen_interval_sec": 30,
# 批量处理大小 # 批量处理大小
"batch_size": 1000, "batch_size": 1000,
# Redis 过期时间30天 # Redis 过期时间30天
"redis_expire_seconds": 2592000, "redis_expire_seconds": 2592000,
# 向量维度 # 向量维度
"vector_dim": 2048, "vector_dim": 2048,
} }
@@ -63,11 +58,10 @@ TABLE_SYS_FILE = "t_sys_file"
# MySQL 连接配置(用于推荐系统) # MySQL 连接配置(用于推荐系统)
MYSQL_CONFIG = { MYSQL_CONFIG = {
"host": DB_HOST, "host": settings.MYSQL_HOST,
"port": DB_PORT, "port": settings.MYSQL_PORT,
"user": DB_USERNAME, "user": settings.MYSQL_USER,
"password": DB_PASSWORD, "password": settings.MYSQL_PASSWORD,
"database": DB_NAME, "database": settings.MYSQL_DB,
"charset": "utf8mb4" "charset": "utf8mb4"
} }

View File

@@ -6,7 +6,7 @@ from typing import List, Dict, Optional, Any
import numpy as np import numpy as np
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
from app.core.config import MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS from app.core.config import settings
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -21,9 +21,9 @@ def get_milvus_client() -> MilvusClient:
if _milvus_client is None: if _milvus_client is None:
try: try:
_milvus_client = MilvusClient( _milvus_client = MilvusClient(
uri=MILVUS_URL, uri=settings.MILVUS_URL,
token=MILVUS_TOKEN, token=settings.MILVUS_TOKEN,
db_name=MILVUS_ALIAS db_name=settings.MILVUS_DB,
) )
logger.info("Milvus 客户端连接成功") logger.info("Milvus 客户端连接成功")
except Exception as e: except Exception as e:
@@ -46,32 +46,32 @@ def create_collection():
- feature_vector (FloatVector(2048)) - 2048维特征向量 - feature_vector (FloatVector(2048)) - 2048维特征向量
""" """
client = get_milvus_client() client = get_milvus_client()
# 检查集合是否已存在 # 检查集合是否已存在
collections = client.list_collections() collections = client.list_collections()
if MILVUS_COLLECTION_SKETCH_VECTORS in collections: if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在") logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
return return
try: try:
# 解析 Milvus URL # 解析 Milvus URL
# 处理 http://host.docker.internal:19530 格式 # 处理 http://host.docker.internal:19530 格式
url_clean = MILVUS_URL.replace("http://", "").replace("https://", "") url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "")
if ":" in url_clean: if ":" in url_clean:
host, port_str = url_clean.split(":", 1) host, port_str = url_clean.split(":", 1)
port = int(port_str) port = int(port_str)
else: else:
host = url_clean host = url_clean
port = 19530 port = 19530
# 使用传统 API 创建集合(更可靠) # 使用传统 API 创建集合(更可靠)
# 连接到 Milvus如果未连接 # 连接到 Milvus如果未连接
try: try:
connections.connect( connections.connect(
alias=MILVUS_ALIAS, alias=settings.MILVUS_ALIAS,
host=host, host=host,
port=port, port=port,
token=MILVUS_TOKEN if MILVUS_TOKEN else None token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None
) )
logger.info(f"已连接到 Milvus: {host}:{port}") logger.info(f"已连接到 Milvus: {host}:{port}")
except Exception as conn_e: except Exception as conn_e:
@@ -80,7 +80,7 @@ def create_collection():
logger.info("Milvus 连接已存在") logger.info("Milvus 连接已存在")
else: else:
logger.warning(f"连接 Milvus 时出现警告: {conn_e}") logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
# 定义字段 # 定义字段
fields = [ fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512), FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
@@ -95,20 +95,20 @@ def create_collection():
dim=RECOMMENDATION_CONFIG["vector_dim"] dim=RECOMMENDATION_CONFIG["vector_dim"]
) )
] ]
# 创建 schema # 创建 schema
schema = CollectionSchema( schema = CollectionSchema(
fields=fields, fields=fields,
description="Sketch vectors collection for recommendation system" description="Sketch vectors collection for recommendation system"
) )
# 创建集合 # 创建集合
collection = Collection( collection = Collection(
name=MILVUS_COLLECTION_SKETCH_VECTORS, name=MILVUS_COLLECTION_SKETCH_VECTORS,
schema=schema, schema=schema,
using=MILVUS_ALIAS using=settings.MILVUS_ALIAS
) )
# 创建索引 # 创建索引
# 注意:使用 IP内积作为度量类型与搜索时保持一致 # 注意:使用 IP内积作为度量类型与搜索时保持一致
# 如果向量已归一化IP 等价于 COSINE # 如果向量已归一化IP 等价于 COSINE
@@ -117,14 +117,14 @@ def create_collection():
"index_type": "IVF_FLAT", "index_type": "IVF_FLAT",
"params": {"nlist": 1024} "params": {"nlist": 1024}
} }
collection.create_index( collection.create_index(
field_name="feature_vector", field_name="feature_vector",
index_params=index_params index_params=index_params
) )
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功") logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
except Exception as e: except Exception as e:
logger.error(f"创建集合失败: {e}", exc_info=True) logger.error(f"创建集合失败: {e}", exc_info=True)
raise raise
@@ -146,9 +146,9 @@ def insert_vectors(data: List[Dict[str, Any]]):
""" """
if not data: if not data:
return return
client = get_milvus_client() client = get_milvus_client()
try: try:
client.insert( client.insert(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
@@ -172,27 +172,27 @@ def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
""" """
if not paths: if not paths:
return {} return {}
client = get_milvus_client() client = get_milvus_client()
try: try:
# 构建查询表达式 # 构建查询表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API # 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
# 对于字符串列表,使用单引号包裹每个值 # 对于字符串列表,使用单引号包裹每个值
path_list = ", ".join([f"'{p}'" for p in paths]) path_list = ", ".join([f"'{p}'" for p in paths])
filter_expr = f"path in [{path_list}]" filter_expr = f"path in [{path_list}]"
results = client.query( results = client.query(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr, filter=filter_expr,
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"] output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
) )
# 转换为字典 # 转换为字典
result_dict = {} result_dict = {}
for r in results: for r in results:
result_dict[r["path"]] = r result_dict[r["path"]] = r
return result_dict return result_dict
except Exception as e: except Exception as e:
logger.error(f"查询向量失败: {e}", exc_info=True) logger.error(f"查询向量失败: {e}", exc_info=True)
@@ -200,10 +200,10 @@ def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
def search_similar_vectors( def search_similar_vectors(
query_vector: np.ndarray, query_vector: np.ndarray,
category: str, category: str,
topk: int = 500, topk: int = 500,
style: Optional[str] = None style: Optional[str] = None
) -> List[Dict]: ) -> List[Dict]:
""" """
向量相似度检索 向量相似度检索
@@ -218,14 +218,14 @@ def search_similar_vectors(
检索结果列表,每个元素包含 path, score, style, category 等字段 检索结果列表,每个元素包含 path, score, style, category 等字段
""" """
client = get_milvus_client() client = get_milvus_client()
try: try:
# 构建过滤表达式 # 构建过滤表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API # 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
filter_expr = f"category == '{category}' && deprecated == 0" filter_expr = f"category == '{category}' && deprecated == 0"
if style: if style:
filter_expr += f" && style == '{style}'" filter_expr += f" && style == '{style}'"
# 搜索 # 搜索
results = client.search( results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
@@ -236,7 +236,7 @@ def search_similar_vectors(
filter=filter_expr, filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"] output_fields=["path", "style", "category", "sys_file_id"]
) )
# 格式化结果 # 格式化结果
formatted_results = [] formatted_results = []
if results and len(results) > 0: if results and len(results) > 0:
@@ -248,7 +248,7 @@ def search_similar_vectors(
"category": hit.get("entity", {}).get("category", ""), "category": hit.get("entity", {}).get("category", ""),
"sys_file_id": hit.get("entity", {}).get("sys_file_id") "sys_file_id": hit.get("entity", {}).get("sys_file_id")
}) })
return formatted_results return formatted_results
except Exception as e: except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True) logger.error(f"向量检索失败: {e}", exc_info=True)
@@ -268,13 +268,13 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
候选列表 候选列表
""" """
client = get_milvus_client() client = get_milvus_client()
try: try:
# 构建过滤表达式 # 构建过滤表达式
filter_expr = f"category == '{category}' && deprecated == 0" filter_expr = f"category == '{category}' && deprecated == 0"
if style: if style:
filter_expr += f" && style == '{style}'" filter_expr += f" && style == '{style}'"
# 查询所有符合条件的记录 # 查询所有符合条件的记录
results = client.query( results = client.query(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
@@ -282,14 +282,13 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
output_fields=["path", "style", "category"], output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择 limit=10000 # 先查询大量数据,然后随机选择
) )
# 随机选择 # 随机选择
if len(results) > limit: if len(results) > limit:
import random import random
results = random.sample(results, limit) results = random.sample(results, limit)
return results return results
except Exception as e: except Exception as e:
logger.error(f"随机查询候选失败: {e}", exc_info=True) logger.error(f"随机查询候选失败: {e}", exc_info=True)
return [] return []