import io import logging import os import sys from collections import defaultdict import numpy as np import pymysql import torch from PIL import Image from fastapi import HTTPException, APIRouter from fastapi.responses import JSONResponse from minio import Minio from torchvision import models, transforms from app.core.mysql_config import DB_CONFIG from app.core.new_config import settings sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') logger = logging.getLogger() router = APIRouter() minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ResNet50(去掉最后全连接层) resnet_model = models.resnet50(pretrained=True) resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1]) resnet_model.eval() def get_sketch_image_from_minio(sketch_path: str): path_parts = sketch_path.split('/', 1) if len(path_parts) != 2: return None bucket_name, file_name = path_parts try: obj = minio_client.get_object(bucket_name, file_name) img = Image.open(io.BytesIO(obj.read())) return transform(img).unsqueeze(0) except Exception as e: logger.warning(f"Fetch image failed [{sketch_path}]: {e}") return None def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray: img_tensor = get_sketch_image_from_minio(sketch_path) if img_tensor is None: return np.zeros(2048, dtype=np.float32) with torch.no_grad(): vec = resnet_model(img_tensor) # [1, 2048, 1, 1] return vec.squeeze().cpu().numpy() # 预加载 BRAND_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item() SYSTEM_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item() def save_sketch_to_iid(): sketch_to_iid = { sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1) } np.save(f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid) def load_sketch_to_iid(): path = f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy" if os.path.exists(path): return np.load(path, allow_pickle=True).item() save_sketch_to_iid() return np.load(path, allow_pickle=True).item() sketch_to_iid = load_sketch_to_iid() def get_new_category(gender: str, sketch_category: str) -> str: return f"{gender.lower()}_{sketch_category.lower()}" def get_category_from_path(path: str) -> str: parts = path.split('/') if len(parts) >= 4: return f"{parts[2].lower()}_{parts[3].lower()}" return "unknown_unknown" def load_brand_matrix(): """单独加载 brand_matrix 和 brand_index_map""" mat_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_matrix.npy" idx_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy" try: matrix = np.load(mat_path) index_map = np.load(idx_path, allow_pickle=True).item() except FileNotFoundError: matrix = np.zeros((0, len(sketch_to_iid)), dtype=np.float32) index_map = {} return matrix, index_map def cosine_similarity(vec1, vec2): """计算余弦相似度(增加零值处理)""" norm = np.linalg.norm(vec1) * np.linalg.norm(vec2) return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0 def getNewCategory(gender, sketch_category): print(gender) print(sketch_category) return "None" def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray: # 1. 收集品牌-分类-特征 brand_feature = defaultdict(lambda: defaultdict(list)) for _id, sketch_path, gender, sketch_category in sketch_data: cat = getNewCategory(gender, sketch_category) feat = BRAND_FEATURES.get(_id) or extract_feature_vector_from_resnet(sketch_path) brand_feature[(brand_id, cat)][_id].append(feat) # 2. 构建 sketch 索引 sketch_list = sorted(sketch_to_iid.values()) sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)} n_sketch = len(sketch_list) # 3. 加载或初始化矩阵 brand_matrix, brand_index_map = load_brand_matrix() # 4. 增加/更新 行 if brand_id in brand_index_map: row_idx = brand_index_map[brand_id] else: row_idx = brand_matrix.shape[0] brand_index_map[brand_id] = row_idx brand_matrix = np.vstack([ brand_matrix, np.zeros((1, n_sketch), dtype=np.float32) ]) # 5. 计算品牌-分类平均向量 brand_avg = {} for key, id_dict in brand_feature.items(): all_feats = [v for feats in id_dict.values() for v in feats] if all_feats: brand_avg[key] = np.mean(all_feats, axis=0) # 6. 填充相似度 for sketch_path, sys_vec in SYSTEM_FEATURES.items(): iid = sketch_to_iid.get(sketch_path) if not iid or iid not in sketch_index: continue cat_key = (brand_id, get_category_from_path(sketch_path)) avg_vec = brand_avg.get(cat_key) if avg_vec is not None: cos_sim = cosine_similarity(avg_vec, sys_vec) brand_matrix[row_idx, sketch_index[iid]] = cos_sim # 7. 持久化 np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix) np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map) # 返回该品牌对应行 return brand_matrix[row_idx:row_idx + 1] @router.get("/brand_dna_initialize/{brand_id}") async def brand_dna_initialize(brand_id: int): conn = None try: conn = pymysql.connect(**DB_CONFIG) cursor = conn.cursor() cursor.execute(""" SELECT id, img_url, gender, category FROM product_image_attribute WHERE library_id IN (SELECT library_id FROM brand_rel_library WHERE brand_id = %s) """, (brand_id,)) sketch_data = cursor.fetchall() # 触发计算并持久化,若内部出错会抛异常 _ = calculate_brand_matrix(sketch_data, brand_id) # 返回成功 return {"success": True} except HTTPException: # 已经是明确的 HTTPException,直接抛出 raise except Exception as e: logger.error(f"品牌初始化失败 [{brand_id}]: {e}", exc_info=True) # 返回失败的 JSON,同时设置 500 状态码 return JSONResponse( status_code=500, content={"success": False, "message": "品牌初始化失败"} ) finally: if conn: conn.close()