新推荐接口first commit
This commit is contained in:
@@ -1,25 +1,34 @@
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pymysql
|
||||
import torch
|
||||
from PIL import Image
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from minio import Minio
|
||||
from torchvision import models, transforms
|
||||
|
||||
from app.core.mysql_config import DB_CONFIG
|
||||
from app.core.new_config import settings
|
||||
import pymysql
|
||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||
from minio import Minio
|
||||
import torch
|
||||
from torchvision import models, transforms
|
||||
from PIL import Image
|
||||
import os
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
# MinIO 配置
|
||||
minio_client = Minio(
|
||||
"www.minio.aida.com.hk:12024",
|
||||
access_key="admin",
|
||||
secret_key="Aidlab123123!",
|
||||
secure=True
|
||||
)
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
@@ -58,8 +67,8 @@ def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
|
||||
|
||||
|
||||
# 预加载
|
||||
BRAND_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
|
||||
SYSTEM_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
|
||||
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||
|
||||
|
||||
def save_sketch_to_iid():
|
||||
@@ -67,11 +76,11 @@ def save_sketch_to_iid():
|
||||
sketch_path: iid
|
||||
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
|
||||
}
|
||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
|
||||
|
||||
|
||||
def load_sketch_to_iid():
|
||||
path = f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
|
||||
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
|
||||
if os.path.exists(path):
|
||||
return np.load(path, allow_pickle=True).item()
|
||||
save_sketch_to_iid()
|
||||
@@ -81,7 +90,7 @@ def load_sketch_to_iid():
|
||||
sketch_to_iid = load_sketch_to_iid()
|
||||
|
||||
|
||||
def get_new_category(gender: str, sketch_category: str) -> str:
|
||||
def getNewCategory(gender: str, sketch_category: str) -> str:
|
||||
return f"{gender.lower()}_{sketch_category.lower()}"
|
||||
|
||||
|
||||
@@ -94,8 +103,8 @@ def get_category_from_path(path: str) -> str:
|
||||
|
||||
def load_brand_matrix():
|
||||
"""单独加载 brand_matrix 和 brand_index_map"""
|
||||
mat_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_matrix.npy"
|
||||
idx_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
|
||||
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||
try:
|
||||
matrix = np.load(mat_path)
|
||||
index_map = np.load(idx_path, allow_pickle=True).item()
|
||||
@@ -104,19 +113,11 @@ def load_brand_matrix():
|
||||
index_map = {}
|
||||
return matrix, index_map
|
||||
|
||||
|
||||
def cosine_similarity(vec1, vec2):
|
||||
"""计算余弦相似度(增加零值处理)"""
|
||||
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
||||
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
||||
|
||||
|
||||
def getNewCategory(gender, sketch_category):
|
||||
print(gender)
|
||||
print(sketch_category)
|
||||
return "None"
|
||||
|
||||
|
||||
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
||||
# 1. 收集品牌-分类-特征
|
||||
brand_feature = defaultdict(lambda: defaultdict(list))
|
||||
@@ -163,11 +164,11 @@ def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
||||
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
|
||||
|
||||
# 7. 持久化
|
||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
|
||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
|
||||
|
||||
# 返回该品牌对应行
|
||||
return brand_matrix[row_idx:row_idx + 1]
|
||||
return brand_matrix[row_idx:row_idx+1]
|
||||
|
||||
|
||||
@router.get("/brand_dna_initialize/{brand_id}")
|
||||
@@ -177,12 +178,14 @@ async def brand_dna_initialize(brand_id: int):
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, img_url, gender, category
|
||||
FROM product_image_attribute
|
||||
WHERE library_id IN (SELECT library_id
|
||||
FROM brand_rel_library
|
||||
WHERE brand_id = %s)
|
||||
""", (brand_id,))
|
||||
SELECT id, img_url, gender, category
|
||||
FROM product_image_attribute
|
||||
WHERE library_id IN (
|
||||
SELECT library_id
|
||||
FROM brand_rel_library
|
||||
WHERE brand_id = %s
|
||||
)
|
||||
""", (brand_id,))
|
||||
sketch_data = cursor.fetchall()
|
||||
|
||||
# 触发计算并持久化,若内部出错会抛异常
|
||||
|
||||
Reference in New Issue
Block a user