import io import logging import sys import time from typing import List from collections import defaultdict import numpy as np from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger from fastapi import HTTPException, APIRouter from app.service.recommend.service import load_resources, matrix_data import pymysql from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX from minio import Minio import torch from torchvision import models, transforms from PIL import Image import os from fastapi.responses import JSONResponse sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') logger = logging.getLogger() router = APIRouter() # MinIO 配置 minio_client = Minio( "www.minio.aida.com.hk:12024", access_key="admin", secret_key="Aidlab123123!", secure=True ) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ResNet50(去掉最后全连接层) resnet_model = models.resnet50(pretrained=True) resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1]) resnet_model.eval() def get_sketch_image_from_minio(sketch_path: str): path_parts = sketch_path.split('/', 1) if len(path_parts) != 2: return None bucket_name, file_name = path_parts try: obj = minio_client.get_object(bucket_name, file_name) img = Image.open(io.BytesIO(obj.read())) return transform(img).unsqueeze(0) except Exception as e: logger.warning(f"Fetch image failed [{sketch_path}]: {e}") return None def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray: img_tensor = get_sketch_image_from_minio(sketch_path) if img_tensor is None: return np.zeros(2048, dtype=np.float32) with torch.no_grad(): vec = resnet_model(img_tensor) # [1, 2048, 1, 1] return vec.squeeze().cpu().numpy() # 预加载 BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item() SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item() def save_sketch_to_iid(): sketch_to_iid = { sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1) } np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid) def load_sketch_to_iid(): path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy" if os.path.exists(path): return np.load(path, allow_pickle=True).item() save_sketch_to_iid() return np.load(path, allow_pickle=True).item() sketch_to_iid = load_sketch_to_iid() def getNewCategory(gender: str, sketch_category: str) -> str: return f"{gender.lower()}_{sketch_category.lower()}" def get_category_from_path(path: str) -> str: parts = path.split('/') if len(parts) >= 4: return f"{parts[2].lower()}_{parts[3].lower()}" return "unknown_unknown" def load_brand_matrix(): """单独加载 brand_matrix 和 brand_index_map""" mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy" idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy" try: matrix = np.load(mat_path) index_map = np.load(idx_path, allow_pickle=True).item() except FileNotFoundError: matrix = np.zeros((0, len(sketch_to_iid)), dtype=np.float32) index_map = {} return matrix, index_map def cosine_similarity(vec1, vec2): """计算余弦相似度(增加零值处理)""" norm = np.linalg.norm(vec1) * np.linalg.norm(vec2) return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0 def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray: # 1. 收集品牌-分类-特征 brand_feature = defaultdict(lambda: defaultdict(list)) for _id, sketch_path, gender, sketch_category in sketch_data: cat = getNewCategory(gender, sketch_category) feat = BRAND_FEATURES.get(_id) or extract_feature_vector_from_resnet(sketch_path) brand_feature[(brand_id, cat)][_id].append(feat) # 2. 构建 sketch 索引 sketch_list = sorted(sketch_to_iid.values()) sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)} n_sketch = len(sketch_list) # 3. 加载或初始化矩阵 brand_matrix, brand_index_map = load_brand_matrix() # 4. 增加/更新 行 if brand_id in brand_index_map: row_idx = brand_index_map[brand_id] else: row_idx = brand_matrix.shape[0] brand_index_map[brand_id] = row_idx brand_matrix = np.vstack([ brand_matrix, np.zeros((1, n_sketch), dtype=np.float32) ]) # 5. 计算品牌-分类平均向量 brand_avg = {} for key, id_dict in brand_feature.items(): all_feats = [v for feats in id_dict.values() for v in feats] if all_feats: brand_avg[key] = np.mean(all_feats, axis=0) # 6. 填充相似度 for sketch_path, sys_vec in SYSTEM_FEATURES.items(): iid = sketch_to_iid.get(sketch_path) if not iid or iid not in sketch_index: continue cat_key = (brand_id, get_category_from_path(sketch_path)) avg_vec = brand_avg.get(cat_key) if avg_vec is not None: cos_sim = cosine_similarity(avg_vec, sys_vec) brand_matrix[row_idx, sketch_index[iid]] = cos_sim # 7. 持久化 np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix) np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map) # 返回该品牌对应行 return brand_matrix[row_idx:row_idx+1] @router.get("/brand_dna_initialize/{brand_id}") async def brand_dna_initialize(brand_id: int): conn = None try: conn = pymysql.connect(**DB_CONFIG) cursor = conn.cursor() cursor.execute(""" SELECT id, img_url, gender, category FROM product_image_attribute WHERE library_id IN ( SELECT library_id FROM brand_rel_library WHERE brand_id = %s ) """, (brand_id,)) sketch_data = cursor.fetchall() # 触发计算并持久化,若内部出错会抛异常 _ = calculate_brand_matrix(sketch_data, brand_id) # 返回成功 return {"success": True} except HTTPException: # 已经是明确的 HTTPException,直接抛出 raise except Exception as e: logger.error(f"品牌初始化失败 [{brand_id}]: {e}", exc_info=True) # 返回失败的 JSON,同时设置 500 状态码 return JSONResponse( status_code=500, content={"success": False, "message": "品牌初始化失败"} ) finally: if conn: conn.close()