213 lines
6.8 KiB
Python
213 lines
6.8 KiB
Python
import io
|
||
import logging
|
||
import sys
|
||
import time
|
||
from typing import List
|
||
from collections import defaultdict
|
||
import numpy as np
|
||
from apscheduler.schedulers.background import BackgroundScheduler
|
||
from apscheduler.triggers.cron import CronTrigger
|
||
from fastapi import HTTPException, APIRouter
|
||
|
||
from app.service.recommend.service import load_resources, matrix_data
|
||
import pymysql
|
||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||
from minio import Minio
|
||
import torch
|
||
from torchvision import models, transforms
|
||
from PIL import Image
|
||
import os
|
||
from fastapi.responses import JSONResponse
|
||
|
||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||
logger = logging.getLogger()
|
||
router = APIRouter()
|
||
|
||
# MinIO 配置
|
||
minio_client = Minio(
|
||
"www.minio.aida.com.hk:12024",
|
||
access_key="admin",
|
||
secret_key="Aidlab123123!",
|
||
secure=True
|
||
)
|
||
|
||
transform = transforms.Compose([
|
||
transforms.Resize((224, 224)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||
std=[0.229, 0.224, 0.225]),
|
||
])
|
||
|
||
# ResNet50(去掉最后全连接层)
|
||
resnet_model = models.resnet50(pretrained=True)
|
||
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1])
|
||
resnet_model.eval()
|
||
|
||
|
||
def get_sketch_image_from_minio(sketch_path: str):
|
||
path_parts = sketch_path.split('/', 1)
|
||
if len(path_parts) != 2:
|
||
return None
|
||
bucket_name, file_name = path_parts
|
||
try:
|
||
obj = minio_client.get_object(bucket_name, file_name)
|
||
img = Image.open(io.BytesIO(obj.read()))
|
||
return transform(img).unsqueeze(0)
|
||
except Exception as e:
|
||
logger.warning(f"Fetch image failed [{sketch_path}]: {e}")
|
||
return None
|
||
|
||
|
||
def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
|
||
img_tensor = get_sketch_image_from_minio(sketch_path)
|
||
if img_tensor is None:
|
||
return np.zeros(2048, dtype=np.float32)
|
||
with torch.no_grad():
|
||
vec = resnet_model(img_tensor) # [1, 2048, 1, 1]
|
||
return vec.squeeze().cpu().numpy()
|
||
|
||
|
||
# 预加载
|
||
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
|
||
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||
|
||
|
||
def save_sketch_to_iid():
|
||
sketch_to_iid = {
|
||
sketch_path: iid
|
||
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
|
||
}
|
||
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
|
||
|
||
|
||
def load_sketch_to_iid():
|
||
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
|
||
if os.path.exists(path):
|
||
return np.load(path, allow_pickle=True).item()
|
||
save_sketch_to_iid()
|
||
return np.load(path, allow_pickle=True).item()
|
||
|
||
|
||
sketch_to_iid = load_sketch_to_iid()
|
||
|
||
|
||
def getNewCategory(gender: str, sketch_category: str) -> str:
|
||
return f"{gender.lower()}_{sketch_category.lower()}"
|
||
|
||
|
||
def get_category_from_path(path: str) -> str:
|
||
parts = path.split('/')
|
||
if len(parts) >= 4:
|
||
return f"{parts[2].lower()}_{parts[3].lower()}"
|
||
return "unknown_unknown"
|
||
|
||
|
||
def load_brand_matrix():
|
||
"""单独加载 brand_matrix 和 brand_index_map"""
|
||
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
|
||
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||
try:
|
||
matrix = np.load(mat_path)
|
||
index_map = np.load(idx_path, allow_pickle=True).item()
|
||
except FileNotFoundError:
|
||
matrix = np.zeros((0, len(sketch_to_iid)), dtype=np.float32)
|
||
index_map = {}
|
||
return matrix, index_map
|
||
|
||
def cosine_similarity(vec1, vec2):
|
||
"""计算余弦相似度(增加零值处理)"""
|
||
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
||
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
||
|
||
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
||
# 1. 收集品牌-分类-特征
|
||
brand_feature = defaultdict(lambda: defaultdict(list))
|
||
for _id, sketch_path, gender, sketch_category in sketch_data:
|
||
cat = getNewCategory(gender, sketch_category)
|
||
feat = BRAND_FEATURES.get(_id) or extract_feature_vector_from_resnet(sketch_path)
|
||
brand_feature[(brand_id, cat)][_id].append(feat)
|
||
|
||
# 2. 构建 sketch 索引
|
||
sketch_list = sorted(sketch_to_iid.values())
|
||
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
|
||
n_sketch = len(sketch_list)
|
||
|
||
# 3. 加载或初始化矩阵
|
||
brand_matrix, brand_index_map = load_brand_matrix()
|
||
|
||
# 4. 增加/更新 行
|
||
if brand_id in brand_index_map:
|
||
row_idx = brand_index_map[brand_id]
|
||
else:
|
||
row_idx = brand_matrix.shape[0]
|
||
brand_index_map[brand_id] = row_idx
|
||
brand_matrix = np.vstack([
|
||
brand_matrix,
|
||
np.zeros((1, n_sketch), dtype=np.float32)
|
||
])
|
||
|
||
# 5. 计算品牌-分类平均向量
|
||
brand_avg = {}
|
||
for key, id_dict in brand_feature.items():
|
||
all_feats = [v for feats in id_dict.values() for v in feats]
|
||
if all_feats:
|
||
brand_avg[key] = np.mean(all_feats, axis=0)
|
||
|
||
# 6. 填充相似度
|
||
for sketch_path, sys_vec in SYSTEM_FEATURES.items():
|
||
iid = sketch_to_iid.get(sketch_path)
|
||
if not iid or iid not in sketch_index:
|
||
continue
|
||
cat_key = (brand_id, get_category_from_path(sketch_path))
|
||
avg_vec = brand_avg.get(cat_key)
|
||
if avg_vec is not None:
|
||
cos_sim = cosine_similarity(avg_vec, sys_vec)
|
||
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
|
||
|
||
# 7. 持久化
|
||
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
|
||
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
|
||
|
||
# 返回该品牌对应行
|
||
return brand_matrix[row_idx:row_idx+1]
|
||
|
||
|
||
@router.get("/brand_dna_initialize/{brand_id}")
|
||
async def brand_dna_initialize(brand_id: int):
|
||
conn = None
|
||
try:
|
||
conn = pymysql.connect(**DB_CONFIG)
|
||
cursor = conn.cursor()
|
||
cursor.execute("""
|
||
SELECT id, img_url, gender, category
|
||
FROM product_image_attribute
|
||
WHERE library_id IN (
|
||
SELECT library_id
|
||
FROM brand_rel_library
|
||
WHERE brand_id = %s
|
||
)
|
||
""", (brand_id,))
|
||
sketch_data = cursor.fetchall()
|
||
|
||
# 触发计算并持久化,若内部出错会抛异常
|
||
_ = calculate_brand_matrix(sketch_data, brand_id)
|
||
|
||
# 返回成功
|
||
return {"success": True}
|
||
|
||
except HTTPException:
|
||
# 已经是明确的 HTTPException,直接抛出
|
||
raise
|
||
|
||
except Exception as e:
|
||
logger.error(f"品牌初始化失败 [{brand_id}]: {e}", exc_info=True)
|
||
# 返回失败的 JSON,同时设置 500 状态码
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"success": False, "message": "品牌初始化失败"}
|
||
)
|
||
|
||
finally:
|
||
if conn:
|
||
conn.close()
|