Merge remote-tracking branch 'origin/develop' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
This commit is contained in:
@@ -137,7 +137,7 @@ router = APIRouter()
|
|||||||
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||||
# raise HTTPException(status_code=500, detail=str(e))
|
# raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
# @router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
"""启动时初始化增量监听任务"""
|
"""启动时初始化增量监听任务"""
|
||||||
try:
|
try:
|
||||||
@@ -173,3 +173,31 @@ async def recommend(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
|
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/redis/user_pref")
|
||||||
|
async def get_all_user_preferences():
|
||||||
|
"""
|
||||||
|
获取所有以 user_pref 为前缀的 Redis key 信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.service.utils.redis_utils import Redis
|
||||||
|
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
|
||||||
|
|
||||||
|
# 扫描所有匹配 user_pref:* 的 key
|
||||||
|
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
|
||||||
|
keys = Redis.scan_keys(pattern)
|
||||||
|
|
||||||
|
# 直接返回所有 key 和原始 value
|
||||||
|
result = {}
|
||||||
|
for key in keys:
|
||||||
|
# 读取对应的值
|
||||||
|
value = Redis.read(key)
|
||||||
|
if value:
|
||||||
|
result[key] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -7,6 +7,7 @@ from app.api import api_design_pre_processing
|
|||||||
from app.api import api_generate_image
|
from app.api import api_generate_image
|
||||||
from app.api import api_mannequins_edit
|
from app.api import api_mannequins_edit
|
||||||
from app.api import api_pose_transform
|
from app.api import api_pose_transform
|
||||||
|
from app.api import api_precompute
|
||||||
from app.api import api_prompt_generation
|
from app.api import api_prompt_generation
|
||||||
from app.api import api_recommendation
|
from app.api import api_recommendation
|
||||||
from app.api import api_test
|
from app.api import api_test
|
||||||
@@ -21,6 +22,7 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
|
|||||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||||
|
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
|
||||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref"
|
|||||||
RECOMMENDATION_CONFIG = {
|
RECOMMENDATION_CONFIG = {
|
||||||
# 时间衰减半衰期(用于计算时间衰减权重)
|
# 时间衰减半衰期(用于计算时间衰减权重)
|
||||||
# 值越小,最近的行为权重越大
|
# 值越小,最近的行为权重越大
|
||||||
"K_half": 20,
|
"K_half": 10,
|
||||||
|
|
||||||
# 探索与利用的比例 (0.0-1.0)
|
# 探索与利用的比例 (0.0-1.0)
|
||||||
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
|
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
|
||||||
@@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = {
|
|||||||
# 向量检索返回的候选数量
|
# 向量检索返回的候选数量
|
||||||
# 值越大,候选池越大,但计算成本也越高
|
# 值越大,候选池越大,但计算成本也越高
|
||||||
# 建议范围: 100-1000
|
# 建议范围: 100-1000
|
||||||
"topk": 1000,
|
"topk": 200,
|
||||||
|
|
||||||
# Style 加分系数(同 style 的候选进行加分)
|
# Style 加分系数(同 style 的候选进行加分)
|
||||||
# 值越大,匹配 style 的候选被选中的概率越大
|
# 值越大,匹配 style 的候选被选中的概率越大
|
||||||
@@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 数据库表名
|
# 数据库表名
|
||||||
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test"
|
TABLE_USER_PREFERENCE_LOG = "user_preference"
|
||||||
TABLE_SYS_FILE = "t_sys_file"
|
TABLE_SYS_FILE = "t_sys_file"
|
||||||
|
|
||||||
# MySQL 连接配置(用于推荐系统)
|
# MySQL 连接配置(用于推荐系统)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
增量监听模块
|
增量监听模块
|
||||||
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量
|
实时监听 user_preference 表的新增记录,更新用户偏好向量
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -48,7 +48,7 @@ class IncrementalListener:
|
|||||||
if self.last_process_time is None:
|
if self.last_process_time is None:
|
||||||
# 第一次运行,查询最近30分钟的数据
|
# 第一次运行,查询最近30分钟的数据
|
||||||
cursor.execute(f"""
|
cursor.execute(f"""
|
||||||
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
|
SELECT id, account_id, path, category, style, data_time
|
||||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
|
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
|
||||||
ORDER BY data_time
|
ORDER BY data_time
|
||||||
@@ -56,7 +56,7 @@ class IncrementalListener:
|
|||||||
else:
|
else:
|
||||||
# 基于上次处理时间查询
|
# 基于上次处理时间查询
|
||||||
cursor.execute(f"""
|
cursor.execute(f"""
|
||||||
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
|
SELECT id, account_id, path, category, style, data_time
|
||||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
WHERE data_time > %s
|
WHERE data_time > %s
|
||||||
ORDER BY data_time
|
ORDER BY data_time
|
||||||
@@ -258,7 +258,7 @@ class IncrementalListener:
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# 用户图
|
# 用户图
|
||||||
# 从 user_preference_log_test 获取 category(如果有)
|
# 从 user_preference 获取 category(如果有)
|
||||||
cursor.execute(f"""
|
cursor.execute(f"""
|
||||||
SELECT category
|
SELECT category
|
||||||
FROM {TABLE_USER_PREFERENCE_LOG}
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
|||||||
@@ -203,7 +203,8 @@ def search_similar_vectors(
|
|||||||
query_vector: np.ndarray,
|
query_vector: np.ndarray,
|
||||||
category: str,
|
category: str,
|
||||||
topk: int = 500,
|
topk: int = 500,
|
||||||
style: Optional[str] = None
|
style: Optional[str] = None,
|
||||||
|
style_boost_ratio: float = 0.2
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
向量相似度检索
|
向量相似度检索
|
||||||
@@ -212,7 +213,8 @@ def search_similar_vectors(
|
|||||||
query_vector: 查询向量(2048维)
|
query_vector: 查询向量(2048维)
|
||||||
category: 类别过滤
|
category: 类别过滤
|
||||||
topk: 返回数量
|
topk: 返回数量
|
||||||
style: 风格过滤(可选)
|
style: 风格过滤(可选)- 当提供时,会给对应style的结果加分
|
||||||
|
style_boost_ratio: 风格加分比例(默认0.1,即10%)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
检索结果列表,每个元素包含 path, score, style, category 等字段
|
检索结果列表,每个元素包含 path, score, style, category 等字段
|
||||||
@@ -220,22 +222,55 @@ def search_similar_vectors(
|
|||||||
client = get_milvus_client()
|
client = get_milvus_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 构建过滤表达式
|
# 如果没有指定style,使用原始逻辑
|
||||||
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
if not style:
|
||||||
filter_expr = f"category == '{category}' && deprecated == 0"
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||||
if style:
|
results = client.search(
|
||||||
filter_expr += f" && style == '{style}'"
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
data=[query_vector.tolist()],
|
||||||
|
anns_field="feature_vector",
|
||||||
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||||
|
limit=topk,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["path", "style", "category", "sys_file_id"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 有style参数时,使用两阶段搜索策略
|
||||||
|
|
||||||
# 搜索
|
# 第一阶段:搜索匹配style的向量,使用boosted query vector
|
||||||
results = client.search(
|
filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
|
||||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
boosted_query = query_vector * (1 + style_boost_ratio)
|
||||||
data=[query_vector.tolist()],
|
results_style = client.search(
|
||||||
anns_field="feature_vector",
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
data=[boosted_query.tolist()],
|
||||||
limit=topk,
|
anns_field="feature_vector",
|
||||||
filter=filter_expr,
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||||
output_fields=["path", "style", "category", "sys_file_id"]
|
limit=topk,
|
||||||
)
|
filter=filter_expr_style,
|
||||||
|
output_fields=["path", "style", "category", "sys_file_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 第二阶段:搜索其他style的向量
|
||||||
|
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
|
||||||
|
results_others = client.search(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
data=[query_vector.tolist()],
|
||||||
|
anns_field="feature_vector",
|
||||||
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||||
|
limit=topk,
|
||||||
|
filter=filter_expr_others,
|
||||||
|
output_fields=["path", "style", "category", "sys_file_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 合并结果
|
||||||
|
results = []
|
||||||
|
if results_style and len(results_style) > 0:
|
||||||
|
results.extend(results_style[0])
|
||||||
|
if results_others and len(results_others) > 0:
|
||||||
|
results.extend(results_others[0])
|
||||||
|
|
||||||
|
# 转换为单个结果列表格式
|
||||||
|
results = [results] if results else []
|
||||||
|
|
||||||
# 格式化结果
|
# 格式化结果
|
||||||
formatted_results = []
|
formatted_results = []
|
||||||
@@ -249,7 +284,10 @@ def search_similar_vectors(
|
|||||||
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
||||||
})
|
})
|
||||||
|
|
||||||
return formatted_results
|
# 按分数排序并返回topk
|
||||||
|
formatted_results.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
return formatted_results[:topk]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"向量检索失败: {e}", exc_info=True)
|
logger.error(f"向量检索失败: {e}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
@@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
|
|||||||
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
filter=filter_expr,
|
filter=filter_expr,
|
||||||
output_fields=["path", "style", "category"],
|
output_fields=["path", "style", "category"],
|
||||||
limit=10000 # 先查询大量数据,然后随机选择
|
limit=10000
|
||||||
)
|
)
|
||||||
|
|
||||||
# 随机选择
|
# 随机选择
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import pymysql
|
import pymysql
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
from typing import List, Dict, Tuple, Optional
|
from typing import List, Dict, Tuple, Optional
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def optimize_database_table():
|
def optimize_database_table():
|
||||||
"""
|
"""
|
||||||
优化 user_preference_log_test 表结构
|
优化 user_preference 表结构
|
||||||
添加冗余字段和索引
|
添加冗余字段和索引
|
||||||
"""
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
@@ -317,8 +318,8 @@ def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int =
|
|||||||
def compute_user_preference_vector(
|
def compute_user_preference_vector(
|
||||||
account_id: int,
|
account_id: int,
|
||||||
category: str,
|
category: str,
|
||||||
conn: Optional[pymysql.connections.Connection] = None
|
conn: Optional[pymysql.connections.Connection] = None,
|
||||||
# max_date: Optional[datetime] = None
|
max_date: Optional[datetime] = None
|
||||||
) -> Optional[np.ndarray]:
|
) -> Optional[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
计算用户偏好向量
|
计算用户偏好向量
|
||||||
@@ -419,8 +420,8 @@ def compute_user_preference_vector(
|
|||||||
p_i = 1 + math.log(1 + like_count)
|
p_i = 1 + math.log(1 + like_count)
|
||||||
|
|
||||||
# 综合权重
|
# 综合权重
|
||||||
# w_i = d_k * p_i
|
w_i = d_k * p_i
|
||||||
w_i = p_i
|
# w_i = p_i
|
||||||
|
|
||||||
vectors.append(feature_vector)
|
vectors.append(feature_vector)
|
||||||
weights.append(w_i)
|
weights.append(w_i)
|
||||||
@@ -518,16 +519,16 @@ def run_precompute():
|
|||||||
logger.info("=" * 50)
|
logger.info("=" * 50)
|
||||||
|
|
||||||
# 1. 优化数据库表结构
|
# 1. 优化数据库表结构
|
||||||
logger.info("\n[1/5] 优化数据库表结构...")
|
# logger.info("\n[1/5] 优化数据库表结构...")
|
||||||
optimize_database_table()
|
# optimize_database_table()
|
||||||
|
|
||||||
# # 2. 创建 Milvus 集合
|
# # 2. 创建 Milvus 集合
|
||||||
# logger.info("\n[2/5] 创建 Milvus 集合...")
|
# logger.info("\n[2/5] 创建 Milvus 集合...")
|
||||||
# create_collection()
|
# create_collection()
|
||||||
|
|
||||||
# 3. 历史数据迁移
|
# 3. 历史数据迁移
|
||||||
logger.info("\n[3/5] 历史数据迁移...")
|
# logger.info("\n[3/5] 历史数据迁移...")
|
||||||
migrate_historical_data()
|
# migrate_historical_data()
|
||||||
|
|
||||||
# # 4. 系统图向量预计算
|
# # 4. 系统图向量预计算
|
||||||
# logger.info("\n[4/5] 系统图向量预计算...")
|
# logger.info("\n[4/5] 系统图向量预计算...")
|
||||||
@@ -543,13 +544,13 @@ def run_precompute():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 1. 优化数据库表结构
|
# # 1. 优化数据库表结构
|
||||||
logger.info("\n[1/5] 优化数据库表结构...")
|
# logger.info("\n[1/5] 优化数据库表结构...")
|
||||||
optimize_database_table()
|
# optimize_database_table()
|
||||||
|
#
|
||||||
# 3. 历史数据迁移
|
# # 3. 历史数据迁移
|
||||||
logger.info("\n[3/5] 历史数据迁移...")
|
# logger.info("\n[3/5] 历史数据迁移...")
|
||||||
migrate_historical_data()
|
# migrate_historical_data()
|
||||||
|
|
||||||
# 5. 初始用户偏好向量生成
|
# 5. 初始用户偏好向量生成
|
||||||
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||||||
|
|||||||
@@ -91,6 +91,21 @@ class Redis(object):
|
|||||||
r = cls._get_r()
|
r = cls._get_r()
|
||||||
r.expire(name, expire_in_seconds)
|
r.expire(name, expire_in_seconds)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def scan_keys(cls, pattern="*"):
|
||||||
|
"""
|
||||||
|
扫描匹配模式的key
|
||||||
|
"""
|
||||||
|
r = cls._get_r()
|
||||||
|
keys = []
|
||||||
|
cursor = 0
|
||||||
|
while True:
|
||||||
|
cursor, partial_keys = r.scan(cursor, match=pattern, count=1000)
|
||||||
|
keys.extend(partial_keys)
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
redis_client = Redis()
|
redis_client = Redis()
|
||||||
|
|||||||
Reference in New Issue
Block a user