Merge remote-tracking branch 'origin/develop' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped

This commit is contained in:
zcr
2026-01-12 16:18:15 +08:00
7 changed files with 129 additions and 45 deletions

View File

@@ -137,7 +137,7 @@ router = APIRouter()
# logger.error(f"推荐失败: {str(e)}", exc_info=True) # logger.error(f"推荐失败: {str(e)}", exc_info=True)
# raise HTTPException(status_code=500, detail=str(e)) # raise HTTPException(status_code=500, detail=str(e))
# @router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
"""启动时初始化增量监听任务""" """启动时初始化增量监听任务"""
try: try:
@@ -173,3 +173,31 @@ async def recommend(
except Exception as e: except Exception as e:
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True) logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/redis/user_pref")
async def get_all_user_preferences():
"""
获取所有以 user_pref 为前缀的 Redis key 信息
"""
try:
from app.service.utils.redis_utils import Redis
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
# 扫描所有匹配 user_pref:* 的 key
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
keys = Redis.scan_keys(pattern)
# 直接返回所有 key 和原始 value
result = {}
for key in keys:
# 读取对应的值
value = Redis.read(key)
if value:
result[key] = value
return result
except Exception as e:
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -7,6 +7,7 @@ from app.api import api_design_pre_processing
from app.api import api_generate_image from app.api import api_generate_image
from app.api import api_mannequins_edit from app.api import api_mannequins_edit
from app.api import api_pose_transform from app.api import api_pose_transform
from app.api import api_precompute
from app.api import api_prompt_generation from app.api import api_prompt_generation
from app.api import api_recommendation from app.api import api_recommendation
from app.api import api_test from app.api import api_test
@@ -21,6 +22,7 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api") router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api") router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api") router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api") router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")

View File

@@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref"
RECOMMENDATION_CONFIG = { RECOMMENDATION_CONFIG = {
# 时间衰减半衰期(用于计算时间衰减权重) # 时间衰减半衰期(用于计算时间衰减权重)
# 值越小,最近的行为权重越大 # 值越小,最近的行为权重越大
"K_half": 20, "K_half": 10,
# 探索与利用的比例 (0.0-1.0) # 探索与利用的比例 (0.0-1.0)
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机 # - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
@@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = {
# 向量检索返回的候选数量 # 向量检索返回的候选数量
# 值越大,候选池越大,但计算成本也越高 # 值越大,候选池越大,但计算成本也越高
# 建议范围: 100-1000 # 建议范围: 100-1000
"topk": 1000, "topk": 200,
# Style 加分系数(同 style 的候选进行加分) # Style 加分系数(同 style 的候选进行加分)
# 值越大,匹配 style 的候选被选中的概率越大 # 值越大,匹配 style 的候选被选中的概率越大
@@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = {
} }
# 数据库表名 # 数据库表名
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test" TABLE_USER_PREFERENCE_LOG = "user_preference"
TABLE_SYS_FILE = "t_sys_file" TABLE_SYS_FILE = "t_sys_file"
# MySQL 连接配置(用于推荐系统) # MySQL 连接配置(用于推荐系统)

View File

@@ -1,6 +1,6 @@
""" """
增量监听模块 增量监听模块
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量 实时监听 user_preference 表的新增记录,更新用户偏好向量
""" """
import logging import logging
import math import math
@@ -48,7 +48,7 @@ class IncrementalListener:
if self.last_process_time is None: if self.last_process_time is None:
# 第一次运行查询最近30分钟的数据 # 第一次运行查询最近30分钟的数据
cursor.execute(f""" cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG} FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE) WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
ORDER BY data_time ORDER BY data_time
@@ -56,7 +56,7 @@ class IncrementalListener:
else: else:
# 基于上次处理时间查询 # 基于上次处理时间查询
cursor.execute(f""" cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG} FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > %s WHERE data_time > %s
ORDER BY data_time ORDER BY data_time
@@ -258,7 +258,7 @@ class IncrementalListener:
} }
else: else:
# 用户图 # 用户图
# 从 user_preference_log_test 获取 category如果有 # 从 user_preference 获取 category如果有
cursor.execute(f""" cursor.execute(f"""
SELECT category SELECT category
FROM {TABLE_USER_PREFERENCE_LOG} FROM {TABLE_USER_PREFERENCE_LOG}

View File

@@ -203,7 +203,8 @@ def search_similar_vectors(
query_vector: np.ndarray, query_vector: np.ndarray,
category: str, category: str,
topk: int = 500, topk: int = 500,
style: Optional[str] = None style: Optional[str] = None,
style_boost_ratio: float = 0.2
) -> List[Dict]: ) -> List[Dict]:
""" """
向量相似度检索 向量相似度检索
@@ -212,7 +213,8 @@ def search_similar_vectors(
query_vector: 查询向量2048维 query_vector: 查询向量2048维
category: 类别过滤 category: 类别过滤
topk: 返回数量 topk: 返回数量
style: 风格过滤(可选) style: 风格过滤(可选)- 当提供时会给对应style的结果加分
style_boost_ratio: 风格加分比例默认0.1即10%
Returns: Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段 检索结果列表,每个元素包含 path, score, style, category 等字段
@@ -220,22 +222,55 @@ def search_similar_vectors(
client = get_milvus_client() client = get_milvus_client()
try: try:
# 构建过滤表达式 # 如果没有指定style使用原始逻辑
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API if not style:
filter_expr = f"category == '{category}' && deprecated == 0" filter_expr = f"category == '{category}' && deprecated == 0"
if style: results = client.search(
filter_expr += f" && style == '{style}'" collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
else:
# 有style参数时使用两阶段搜索策略
# 搜索 # 第一阶段搜索匹配style的向量使用boosted query vector
results = client.search( filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, boosted_query = query_vector * (1 + style_boost_ratio)
data=[query_vector.tolist()], results_style = client.search(
anns_field="feature_vector", collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
search_params={"metric_type": "IP", "params": {"nprobe": 10}}, data=[boosted_query.tolist()],
limit=topk, anns_field="feature_vector",
filter=filter_expr, search_params={"metric_type": "IP", "params": {"nprobe": 10}},
output_fields=["path", "style", "category", "sys_file_id"] limit=topk,
) filter=filter_expr_style,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第二阶段搜索其他style的向量
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
results_others = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_others,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 合并结果
results = []
if results_style and len(results_style) > 0:
results.extend(results_style[0])
if results_others and len(results_others) > 0:
results.extend(results_others[0])
# 转换为单个结果列表格式
results = [results] if results else []
# 格式化结果 # 格式化结果
formatted_results = [] formatted_results = []
@@ -249,7 +284,10 @@ def search_similar_vectors(
"sys_file_id": hit.get("entity", {}).get("sys_file_id") "sys_file_id": hit.get("entity", {}).get("sys_file_id")
}) })
return formatted_results # 按分数排序并返回topk
formatted_results.sort(key=lambda x: x["score"], reverse=True)
return formatted_results[:topk]
except Exception as e: except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True) logger.error(f"向量检索失败: {e}", exc_info=True)
return [] return []
@@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr, filter=filter_expr,
output_fields=["path", "style", "category"], output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择 limit=10000
) )
# 随机选择 # 随机选择

View File

@@ -6,6 +6,7 @@ import logging
import math import math
import pymysql import pymysql
import numpy as np import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional from typing import List, Dict, Tuple, Optional
from collections import defaultdict from collections import defaultdict
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
def optimize_database_table(): def optimize_database_table():
""" """
优化 user_preference_log_test 表结构 优化 user_preference 表结构
添加冗余字段和索引 添加冗余字段和索引
""" """
conn = None conn = None
@@ -317,8 +318,8 @@ def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int =
def compute_user_preference_vector( def compute_user_preference_vector(
account_id: int, account_id: int,
category: str, category: str,
conn: Optional[pymysql.connections.Connection] = None conn: Optional[pymysql.connections.Connection] = None,
# max_date: Optional[datetime] = None max_date: Optional[datetime] = None
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
""" """
计算用户偏好向量 计算用户偏好向量
@@ -419,8 +420,8 @@ def compute_user_preference_vector(
p_i = 1 + math.log(1 + like_count) p_i = 1 + math.log(1 + like_count)
# 综合权重 # 综合权重
# w_i = d_k * p_i w_i = d_k * p_i
w_i = p_i # w_i = p_i
vectors.append(feature_vector) vectors.append(feature_vector)
weights.append(w_i) weights.append(w_i)
@@ -518,16 +519,16 @@ def run_precompute():
logger.info("=" * 50) logger.info("=" * 50)
# 1. 优化数据库表结构 # 1. 优化数据库表结构
logger.info("\n[1/5] 优化数据库表结构...") # logger.info("\n[1/5] 优化数据库表结构...")
optimize_database_table() # optimize_database_table()
# # 2. 创建 Milvus 集合 # # 2. 创建 Milvus 集合
# logger.info("\n[2/5] 创建 Milvus 集合...") # logger.info("\n[2/5] 创建 Milvus 集合...")
# create_collection() # create_collection()
# 3. 历史数据迁移 # 3. 历史数据迁移
logger.info("\n[3/5] 历史数据迁移...") # logger.info("\n[3/5] 历史数据迁移...")
migrate_historical_data() # migrate_historical_data()
# # 4. 系统图向量预计算 # # 4. 系统图向量预计算
# logger.info("\n[4/5] 系统图向量预计算...") # logger.info("\n[4/5] 系统图向量预计算...")
@@ -543,13 +544,13 @@ def run_precompute():
if __name__ == "__main__": if __name__ == "__main__":
# 1. 优化数据库表结构 # # 1. 优化数据库表结构
logger.info("\n[1/5] 优化数据库表结构...") # logger.info("\n[1/5] 优化数据库表结构...")
optimize_database_table() # optimize_database_table()
#
# 3. 历史数据迁移 # # 3. 历史数据迁移
logger.info("\n[3/5] 历史数据迁移...") # logger.info("\n[3/5] 历史数据迁移...")
migrate_historical_data() # migrate_historical_data()
# 5. 初始用户偏好向量生成 # 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...") logger.info("\n[5/5] 初始用户偏好向量生成...")

View File

@@ -91,6 +91,21 @@ class Redis(object):
r = cls._get_r() r = cls._get_r()
r.expire(name, expire_in_seconds) r.expire(name, expire_in_seconds)
@classmethod
def scan_keys(cls, pattern="*"):
"""
扫描匹配模式的key
"""
r = cls._get_r()
keys = []
cursor = 0
while True:
cursor, partial_keys = r.scan(cursor, match=pattern, count=1000)
keys.extend(partial_keys)
if cursor == 0:
break
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
if __name__ == '__main__': if __name__ == '__main__':
redis_client = Redis() redis_client = Redis()