diff --git a/app/api/api_recommendation.py b/app/api/api_recommendation.py index e5b86b1..24ab52c 100644 --- a/app/api/api_recommendation.py +++ b/app/api/api_recommendation.py @@ -137,7 +137,7 @@ router = APIRouter() # logger.error(f"推荐失败: {str(e)}", exc_info=True) # raise HTTPException(status_code=500, detail=str(e)) -# @router.on_event("startup") +@router.on_event("startup") async def startup_event(): """启动时初始化增量监听任务""" try: @@ -172,4 +172,32 @@ async def recommend( return [path] except Exception as e: logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/redis/user_pref") +async def get_all_user_preferences(): + """ + 获取所有以 user_pref 为前缀的 Redis key 信息 + """ + try: + from app.service.utils.redis_utils import Redis + from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX + + # 扫描所有匹配 user_pref:* 的 key + pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*" + keys = Redis.scan_keys(pattern) + + # 直接返回所有 key 和原始 value + result = {} + for key in keys: + # 读取对应的值 + value = Redis.read(key) + if value: + result[key] = value + + return result + + except Exception as e: + logger.error("获取用户偏好数据失败: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/app/api/api_route.py b/app/api/api_route.py index 25f314a..1af7b9f 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -7,6 +7,7 @@ from app.api import api_design_pre_processing from app.api import api_generate_image from app.api import api_mannequins_edit from app.api import api_pose_transform +from app.api import api_precompute from app.api import api_prompt_generation from app.api import api_recommendation from app.api import api_test @@ -21,6 +22,7 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'], router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api") +router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api") router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api") router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api") router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api") diff --git a/app/service/recommendation_system/config.py b/app/service/recommendation_system/config.py index 42221d1..bd48362 100644 --- a/app/service/recommendation_system/config.py +++ b/app/service/recommendation_system/config.py @@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref" RECOMMENDATION_CONFIG = { # 时间衰减半衰期(用于计算时间衰减权重) # 值越小,最近的行为权重越大 - "K_half": 20, + "K_half": 10, # 探索与利用的比例 (0.0-1.0) # - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机 @@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = { # 向量检索返回的候选数量 # 值越大,候选池越大,但计算成本也越高 # 建议范围: 100-1000 - "topk": 1000, + "topk": 200, # Style 加分系数(同 style 的候选进行加分) # 值越大,匹配 style 的候选被选中的概率越大 @@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = { } # 数据库表名 -TABLE_USER_PREFERENCE_LOG = "user_preference_log_test" +TABLE_USER_PREFERENCE_LOG = "user_preference" TABLE_SYS_FILE = "t_sys_file" # MySQL 连接配置(用于推荐系统) diff --git a/app/service/recommendation_system/incremental_listener.py b/app/service/recommendation_system/incremental_listener.py index 08c3b21..bc662ee 100644 --- a/app/service/recommendation_system/incremental_listener.py +++ b/app/service/recommendation_system/incremental_listener.py @@ -1,6 +1,6 @@ """ 增量监听模块 -实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量 +实时监听 user_preference 表的新增记录,更新用户偏好向量 """ import logging import math @@ -48,7 +48,7 @@ class IncrementalListener: if self.last_process_time is None: # 第一次运行,查询最近30分钟的数据 cursor.execute(f""" - SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id + SELECT id, account_id, path, category, style, data_time FROM {TABLE_USER_PREFERENCE_LOG} WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE) ORDER BY data_time @@ -56,7 +56,7 @@ class IncrementalListener: else: # 基于上次处理时间查询 cursor.execute(f""" - SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id + SELECT id, account_id, path, category, style, data_time FROM {TABLE_USER_PREFERENCE_LOG} WHERE data_time > %s ORDER BY data_time @@ -258,7 +258,7 @@ class IncrementalListener: } else: # 用户图 - # 从 user_preference_log_test 获取 category(如果有) + # 从 user_preference 获取 category(如果有) cursor.execute(f""" SELECT category FROM {TABLE_USER_PREFERENCE_LOG} diff --git a/app/service/recommendation_system/milvus_client.py b/app/service/recommendation_system/milvus_client.py index a027f99..5fefa71 100644 --- a/app/service/recommendation_system/milvus_client.py +++ b/app/service/recommendation_system/milvus_client.py @@ -203,39 +203,74 @@ def search_similar_vectors( query_vector: np.ndarray, category: str, topk: int = 500, - style: Optional[str] = None + style: Optional[str] = None, + style_boost_ratio: float = 0.2 ) -> List[Dict]: """ 向量相似度检索 - + Args: query_vector: 查询向量(2048维) category: 类别过滤 topk: 返回数量 - style: 风格过滤(可选) - + style: 风格过滤(可选)- 当提供时,会给对应style的结果加分 + style_boost_ratio: 风格加分比例(默认0.1,即10%) + Returns: 检索结果列表,每个元素包含 path, score, style, category 等字段 """ client = get_milvus_client() try: - # 构建过滤表达式 - # 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API) - filter_expr = f"category == '{category}' && deprecated == 0" - if style: - filter_expr += f" && style == '{style}'" + # 如果没有指定style,使用原始逻辑 + if not style: + filter_expr = f"category == '{category}' && deprecated == 0" + results = client.search( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + data=[query_vector.tolist()], + anns_field="feature_vector", + search_params={"metric_type": "IP", "params": {"nprobe": 10}}, + limit=topk, + filter=filter_expr, + output_fields=["path", "style", "category", "sys_file_id"] + ) + else: + # 有style参数时,使用两阶段搜索策略 - # 搜索 - results = client.search( - collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, - data=[query_vector.tolist()], - anns_field="feature_vector", - search_params={"metric_type": "IP", "params": {"nprobe": 10}}, - limit=topk, - filter=filter_expr, - output_fields=["path", "style", "category", "sys_file_id"] - ) + # 第一阶段:搜索匹配style的向量,使用boosted query vector + filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'" + boosted_query = query_vector * (1 + style_boost_ratio) + results_style = client.search( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + data=[boosted_query.tolist()], + anns_field="feature_vector", + search_params={"metric_type": "IP", "params": {"nprobe": 10}}, + limit=topk, + filter=filter_expr_style, + output_fields=["path", "style", "category", "sys_file_id"] + ) + + # 第二阶段:搜索其他style的向量 + filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'" + results_others = client.search( + collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, + data=[query_vector.tolist()], + anns_field="feature_vector", + search_params={"metric_type": "IP", "params": {"nprobe": 10}}, + limit=topk, + filter=filter_expr_others, + output_fields=["path", "style", "category", "sys_file_id"] + ) + + # 合并结果 + results = [] + if results_style and len(results_style) > 0: + results.extend(results_style[0]) + if results_others and len(results_others) > 0: + results.extend(results_others[0]) + + # 转换为单个结果列表格式 + results = [results] if results else [] # 格式化结果 formatted_results = [] @@ -249,7 +284,10 @@ def search_similar_vectors( "sys_file_id": hit.get("entity", {}).get("sys_file_id") }) - return formatted_results + # 按分数排序并返回topk + formatted_results.sort(key=lambda x: x["score"], reverse=True) + return formatted_results[:topk] + except Exception as e: logger.error(f"向量检索失败: {e}", exc_info=True) return [] @@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, filter=filter_expr, output_fields=["path", "style", "category"], - limit=10000 # 先查询大量数据,然后随机选择 + limit=10000 ) # 随机选择 diff --git a/app/service/recommendation_system/precompute.py b/app/service/recommendation_system/precompute.py index c4797d1..235e80f 100644 --- a/app/service/recommendation_system/precompute.py +++ b/app/service/recommendation_system/precompute.py @@ -6,6 +6,7 @@ import logging import math import pymysql import numpy as np +from datetime import datetime from typing import List, Dict, Tuple, Optional from collections import defaultdict @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) def optimize_database_table(): """ - 优化 user_preference_log_test 表结构 + 优化 user_preference 表结构 添加冗余字段和索引 """ conn = None @@ -317,8 +318,8 @@ def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = def compute_user_preference_vector( account_id: int, category: str, - conn: Optional[pymysql.connections.Connection] = None - # max_date: Optional[datetime] = None + conn: Optional[pymysql.connections.Connection] = None, + max_date: Optional[datetime] = None ) -> Optional[np.ndarray]: """ 计算用户偏好向量 @@ -419,8 +420,8 @@ def compute_user_preference_vector( p_i = 1 + math.log(1 + like_count) # 综合权重 - # w_i = d_k * p_i - w_i = p_i + w_i = d_k * p_i + # w_i = p_i vectors.append(feature_vector) weights.append(w_i) @@ -518,16 +519,16 @@ def run_precompute(): logger.info("=" * 50) # 1. 优化数据库表结构 - logger.info("\n[1/5] 优化数据库表结构...") - optimize_database_table() + # logger.info("\n[1/5] 优化数据库表结构...") + # optimize_database_table() # # 2. 创建 Milvus 集合 # logger.info("\n[2/5] 创建 Milvus 集合...") # create_collection() # 3. 历史数据迁移 - logger.info("\n[3/5] 历史数据迁移...") - migrate_historical_data() + # logger.info("\n[3/5] 历史数据迁移...") + # migrate_historical_data() # # 4. 系统图向量预计算 # logger.info("\n[4/5] 系统图向量预计算...") @@ -543,13 +544,13 @@ def run_precompute(): if __name__ == "__main__": - # 1. 优化数据库表结构 - logger.info("\n[1/5] 优化数据库表结构...") - optimize_database_table() - - # 3. 历史数据迁移 - logger.info("\n[3/5] 历史数据迁移...") - migrate_historical_data() + # # 1. 优化数据库表结构 + # logger.info("\n[1/5] 优化数据库表结构...") + # optimize_database_table() + # + # # 3. 历史数据迁移 + # logger.info("\n[3/5] 历史数据迁移...") + # migrate_historical_data() # 5. 初始用户偏好向量生成 logger.info("\n[5/5] 初始用户偏好向量生成...") diff --git a/app/service/utils/redis_utils.py b/app/service/utils/redis_utils.py index a2d446d..8761fde 100644 --- a/app/service/utils/redis_utils.py +++ b/app/service/utils/redis_utils.py @@ -91,6 +91,21 @@ class Redis(object): r = cls._get_r() r.expire(name, expire_in_seconds) + @classmethod + def scan_keys(cls, pattern="*"): + """ + 扫描匹配模式的key + """ + r = cls._get_r() + keys = [] + cursor = 0 + while True: + cursor, partial_keys = r.scan(cursor, match=pattern, count=1000) + keys.extend(partial_keys) + if cursor == 0: + break + return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys] + if __name__ == '__main__': redis_client = Redis()