新推荐接口first commit

This commit is contained in:
litianxiang
2025-12-29 10:52:33 +08:00
committed by zcr
parent 417528f8cd
commit fed3fcdf85
13 changed files with 2634 additions and 460 deletions

View File

@@ -1,25 +1,34 @@
import io
import logging
import os
import sys
import time
from typing import List
from collections import defaultdict
import numpy as np
import pymysql
import torch
from PIL import Image
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from fastapi.responses import JSONResponse
from minio import Minio
from torchvision import models, transforms
from app.core.mysql_config import DB_CONFIG
from app.core.new_config import settings
import pymysql
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
from minio import Minio
import torch
from torchvision import models, transforms
from PIL import Image
import os
from fastapi.responses import JSONResponse
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
# MinIO 配置
minio_client = Minio(
"www.minio.aida.com.hk:12024",
access_key="admin",
secret_key="Aidlab123123!",
secure=True
)
transform = transforms.Compose([
transforms.Resize((224, 224)),
@@ -58,8 +67,8 @@ def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
# 预加载
BRAND_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
def save_sketch_to_iid():
@@ -67,11 +76,11 @@ def save_sketch_to_iid():
sketch_path: iid
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
}
np.save(f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
def load_sketch_to_iid():
path = f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
if os.path.exists(path):
return np.load(path, allow_pickle=True).item()
save_sketch_to_iid()
@@ -81,7 +90,7 @@ def load_sketch_to_iid():
sketch_to_iid = load_sketch_to_iid()
def get_new_category(gender: str, sketch_category: str) -> str:
def getNewCategory(gender: str, sketch_category: str) -> str:
return f"{gender.lower()}_{sketch_category.lower()}"
@@ -94,8 +103,8 @@ def get_category_from_path(path: str) -> str:
def load_brand_matrix():
"""单独加载 brand_matrix 和 brand_index_map"""
mat_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy"
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
try:
matrix = np.load(mat_path)
index_map = np.load(idx_path, allow_pickle=True).item()
@@ -104,19 +113,11 @@ def load_brand_matrix():
index_map = {}
return matrix, index_map
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
def getNewCategory(gender, sketch_category):
print(gender)
print(sketch_category)
return "None"
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
# 1. 收集品牌-分类-特征
brand_feature = defaultdict(lambda: defaultdict(list))
@@ -163,11 +164,11 @@ def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
# 7. 持久化
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
# 返回该品牌对应行
return brand_matrix[row_idx:row_idx + 1]
return brand_matrix[row_idx:row_idx+1]
@router.get("/brand_dna_initialize/{brand_id}")
@@ -177,12 +178,14 @@ async def brand_dna_initialize(brand_id: int):
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
cursor.execute("""
SELECT id, img_url, gender, category
FROM product_image_attribute
WHERE library_id IN (SELECT library_id
FROM brand_rel_library
WHERE brand_id = %s)
""", (brand_id,))
SELECT id, img_url, gender, category
FROM product_image_attribute
WHERE library_id IN (
SELECT library_id
FROM brand_rel_library
WHERE brand_id = %s
)
""", (brand_id,))
sketch_data = cursor.fetchall()
# 触发计算并持久化,若内部出错会抛异常