Files
AiDA_Python/app/api/api_brand_dna_initialize.py
2025-06-10 10:54:20 +08:00

213 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import io
import logging
import sys
import time
from typing import List
from collections import defaultdict
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
import pymysql
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
from minio import Minio
import torch
from torchvision import models, transforms
from PIL import Image
import os
from fastapi.responses import JSONResponse
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
# MinIO 配置
minio_client = Minio(
"www.minio.aida.com.hk:12024",
access_key="admin",
secret_key="Aidlab123123!",
secure=True
)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ResNet50去掉最后全连接层
resnet_model = models.resnet50(pretrained=True)
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1])
resnet_model.eval()
def get_sketch_image_from_minio(sketch_path: str):
path_parts = sketch_path.split('/', 1)
if len(path_parts) != 2:
return None
bucket_name, file_name = path_parts
try:
obj = minio_client.get_object(bucket_name, file_name)
img = Image.open(io.BytesIO(obj.read()))
return transform(img).unsqueeze(0)
except Exception as e:
logger.warning(f"Fetch image failed [{sketch_path}]: {e}")
return None
def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
img_tensor = get_sketch_image_from_minio(sketch_path)
if img_tensor is None:
return np.zeros(2048, dtype=np.float32)
with torch.no_grad():
vec = resnet_model(img_tensor) # [1, 2048, 1, 1]
return vec.squeeze().cpu().numpy()
# 预加载
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
def save_sketch_to_iid():
sketch_to_iid = {
sketch_path: iid
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
}
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
def load_sketch_to_iid():
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
if os.path.exists(path):
return np.load(path, allow_pickle=True).item()
save_sketch_to_iid()
return np.load(path, allow_pickle=True).item()
sketch_to_iid = load_sketch_to_iid()
def getNewCategory(gender: str, sketch_category: str) -> str:
return f"{gender.lower()}_{sketch_category.lower()}"
def get_category_from_path(path: str) -> str:
parts = path.split('/')
if len(parts) >= 4:
return f"{parts[2].lower()}_{parts[3].lower()}"
return "unknown_unknown"
def load_brand_matrix():
"""单独加载 brand_matrix 和 brand_index_map"""
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
try:
matrix = np.load(mat_path)
index_map = np.load(idx_path, allow_pickle=True).item()
except FileNotFoundError:
matrix = np.zeros((0, len(sketch_to_iid)), dtype=np.float32)
index_map = {}
return matrix, index_map
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
# 1. 收集品牌-分类-特征
brand_feature = defaultdict(lambda: defaultdict(list))
for _id, sketch_path, gender, sketch_category in sketch_data:
cat = getNewCategory(gender, sketch_category)
feat = BRAND_FEATURES.get(_id) or extract_feature_vector_from_resnet(sketch_path)
brand_feature[(brand_id, cat)][_id].append(feat)
# 2. 构建 sketch 索引
sketch_list = sorted(sketch_to_iid.values())
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
n_sketch = len(sketch_list)
# 3. 加载或初始化矩阵
brand_matrix, brand_index_map = load_brand_matrix()
# 4. 增加/更新 行
if brand_id in brand_index_map:
row_idx = brand_index_map[brand_id]
else:
row_idx = brand_matrix.shape[0]
brand_index_map[brand_id] = row_idx
brand_matrix = np.vstack([
brand_matrix,
np.zeros((1, n_sketch), dtype=np.float32)
])
# 5. 计算品牌-分类平均向量
brand_avg = {}
for key, id_dict in brand_feature.items():
all_feats = [v for feats in id_dict.values() for v in feats]
if all_feats:
brand_avg[key] = np.mean(all_feats, axis=0)
# 6. 填充相似度
for sketch_path, sys_vec in SYSTEM_FEATURES.items():
iid = sketch_to_iid.get(sketch_path)
if not iid or iid not in sketch_index:
continue
cat_key = (brand_id, get_category_from_path(sketch_path))
avg_vec = brand_avg.get(cat_key)
if avg_vec is not None:
cos_sim = cosine_similarity(avg_vec, sys_vec)
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
# 7. 持久化
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
# 返回该品牌对应行
return brand_matrix[row_idx:row_idx+1]
@router.get("/brand_dna_initialize/{brand_id}")
async def brand_dna_initialize(brand_id: int):
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
cursor.execute("""
SELECT id, img_url, gender, category
FROM product_image_attribute
WHERE library_id IN (
SELECT library_id
FROM brand_rel_library
WHERE brand_id = %s
)
""", (brand_id,))
sketch_data = cursor.fetchall()
# 触发计算并持久化,若内部出错会抛异常
_ = calculate_brand_matrix(sketch_data, brand_id)
# 返回成功
return {"success": True}
except HTTPException:
# 已经是明确的 HTTPException直接抛出
raise
except Exception as e:
logger.error(f"品牌初始化失败 [{brand_id}]: {e}", exc_info=True)
# 返回失败的 JSON同时设置 500 状态码
return JSONResponse(
status_code=500,
content={"success": False, "message": "品牌初始化失败"}
)
finally:
if conn:
conn.close()