Compare commits
33 Commits
417528f8cd
...
dev-ltx
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb46a9521d | ||
|
|
b90688f835 | ||
| 7e30779aec | |||
| f7294f5966 | |||
| 0ac5a4e0a8 | |||
| 40b57b749c | |||
|
|
b8a538a8a1 | ||
|
|
29b4f43a27 | ||
|
|
69dc20207d | ||
|
|
18979af604 | ||
|
|
74406f9be4 | ||
|
|
df99e3ac76 | ||
|
|
19346c2eb7 | ||
|
|
2af9cbfe78 | ||
| fe12b5697d | |||
| c04d4877b0 | |||
| 91016e6cae | |||
| 0f4bb260ad | |||
| c792106f02 | |||
| deac5a4cab | |||
| 15682036b3 | |||
| 9ba3a0ca49 | |||
| f6963070fb | |||
| 12f5ca3ca3 | |||
| 19110f51bf | |||
| e04636ce21 | |||
| 2a50e7040e | |||
| a6f3bda9f7 | |||
| c18f45e549 | |||
| 4951fab71a | |||
| aa57478852 | |||
| 2a6c48d937 | |||
|
|
fed3fcdf85 |
@@ -1,25 +1,34 @@
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pymysql
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
import torch
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from PIL import Image
|
|
||||||
from fastapi import HTTPException, APIRouter
|
from fastapi import HTTPException, APIRouter
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from minio import Minio
|
|
||||||
from torchvision import models, transforms
|
|
||||||
|
|
||||||
from app.core.mysql_config import DB_CONFIG
|
import pymysql
|
||||||
from app.core.new_config import settings
|
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||||
|
from minio import Minio
|
||||||
|
import torch
|
||||||
|
from torchvision import models, transforms
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
|
||||||
|
# MinIO 配置
|
||||||
|
minio_client = Minio(
|
||||||
|
"www.minio.aida.com.hk:12024",
|
||||||
|
access_key="admin",
|
||||||
|
secret_key="Aidlab123123!",
|
||||||
|
secure=True
|
||||||
|
)
|
||||||
|
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.Resize((224, 224)),
|
transforms.Resize((224, 224)),
|
||||||
@@ -58,8 +67,8 @@ def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
# 预加载
|
# 预加载
|
||||||
BRAND_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
|
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
|
||||||
SYSTEM_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||||
|
|
||||||
|
|
||||||
def save_sketch_to_iid():
|
def save_sketch_to_iid():
|
||||||
@@ -67,11 +76,11 @@ def save_sketch_to_iid():
|
|||||||
sketch_path: iid
|
sketch_path: iid
|
||||||
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
|
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
|
||||||
}
|
}
|
||||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
|
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
|
||||||
|
|
||||||
|
|
||||||
def load_sketch_to_iid():
|
def load_sketch_to_iid():
|
||||||
path = f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
|
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
return np.load(path, allow_pickle=True).item()
|
return np.load(path, allow_pickle=True).item()
|
||||||
save_sketch_to_iid()
|
save_sketch_to_iid()
|
||||||
@@ -81,7 +90,7 @@ def load_sketch_to_iid():
|
|||||||
sketch_to_iid = load_sketch_to_iid()
|
sketch_to_iid = load_sketch_to_iid()
|
||||||
|
|
||||||
|
|
||||||
def get_new_category(gender: str, sketch_category: str) -> str:
|
def getNewCategory(gender: str, sketch_category: str) -> str:
|
||||||
return f"{gender.lower()}_{sketch_category.lower()}"
|
return f"{gender.lower()}_{sketch_category.lower()}"
|
||||||
|
|
||||||
|
|
||||||
@@ -94,8 +103,8 @@ def get_category_from_path(path: str) -> str:
|
|||||||
|
|
||||||
def load_brand_matrix():
|
def load_brand_matrix():
|
||||||
"""单独加载 brand_matrix 和 brand_index_map"""
|
"""单独加载 brand_matrix 和 brand_index_map"""
|
||||||
mat_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_matrix.npy"
|
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
|
||||||
idx_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||||
try:
|
try:
|
||||||
matrix = np.load(mat_path)
|
matrix = np.load(mat_path)
|
||||||
index_map = np.load(idx_path, allow_pickle=True).item()
|
index_map = np.load(idx_path, allow_pickle=True).item()
|
||||||
@@ -104,19 +113,11 @@ def load_brand_matrix():
|
|||||||
index_map = {}
|
index_map = {}
|
||||||
return matrix, index_map
|
return matrix, index_map
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(vec1, vec2):
|
def cosine_similarity(vec1, vec2):
|
||||||
"""计算余弦相似度(增加零值处理)"""
|
"""计算余弦相似度(增加零值处理)"""
|
||||||
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
||||||
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
def getNewCategory(gender, sketch_category):
|
|
||||||
print(gender)
|
|
||||||
print(sketch_category)
|
|
||||||
return "None"
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
||||||
# 1. 收集品牌-分类-特征
|
# 1. 收集品牌-分类-特征
|
||||||
brand_feature = defaultdict(lambda: defaultdict(list))
|
brand_feature = defaultdict(lambda: defaultdict(list))
|
||||||
@@ -163,11 +164,11 @@ def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
|||||||
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
|
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
|
||||||
|
|
||||||
# 7. 持久化
|
# 7. 持久化
|
||||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
|
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
|
||||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
|
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
|
||||||
|
|
||||||
# 返回该品牌对应行
|
# 返回该品牌对应行
|
||||||
return brand_matrix[row_idx:row_idx + 1]
|
return brand_matrix[row_idx:row_idx+1]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/brand_dna_initialize/{brand_id}")
|
@router.get("/brand_dna_initialize/{brand_id}")
|
||||||
@@ -179,9 +180,11 @@ async def brand_dna_initialize(brand_id: int):
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT id, img_url, gender, category
|
SELECT id, img_url, gender, category
|
||||||
FROM product_image_attribute
|
FROM product_image_attribute
|
||||||
WHERE library_id IN (SELECT library_id
|
WHERE library_id IN (
|
||||||
|
SELECT library_id
|
||||||
FROM brand_rel_library
|
FROM brand_rel_library
|
||||||
WHERE brand_id = %s)
|
WHERE brand_id = %s
|
||||||
|
)
|
||||||
""", (brand_id,))
|
""", (brand_id,))
|
||||||
sketch_data = cursor.fetchall()
|
sketch_data = cursor.fetchall()
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import requests
|
||||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||||
|
|
||||||
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel
|
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel, SAMRequestModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.design_fast.design_generate import design_generate, design_generate_v2
|
from app.service.design_fast.design_generate import design_generate, design_generate_v2
|
||||||
from app.service.design_fast.model_process_service import model_transpose
|
from app.service.design_fast.model_process_service import model_transpose
|
||||||
@@ -15,16 +16,29 @@ logger = logging.getLogger()
|
|||||||
@router.post("/design")
|
@router.post("/design")
|
||||||
def design(request_data: DesignModel):
|
def design(request_data: DesignModel):
|
||||||
"""
|
"""
|
||||||
objects.items.transparent:
|
- **objects.items.transparent**:
|
||||||
|
```json
|
||||||
"transparent":{
|
"transparent":{
|
||||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||||
"scale":0.1
|
"scale":0.1
|
||||||
},
|
},
|
||||||
mask_url 为空"" -> 单件衣服透明
|
```
|
||||||
mask_url 非空"mask_url" -> 区域透明
|
- **mask_url** 为空"" -> 单件衣服透明
|
||||||
|
- **mask_url** 非空"mask_url" -> 区域透明
|
||||||
|
- **transpose** 镜像模式 ,:"top_bottom"或"left_right"
|
||||||
|
- **rotate** 45,
|
||||||
|
|
||||||
创建一个具有以下参数的请求体:
|
- ** design 参数变更:
|
||||||
|
design detail 请求参数中 basic -> preview_submit 替换为design_type 可选参数 default ,merge (移除preview和submit)
|
||||||
|
design_type 参数说明:
|
||||||
|
defuault模式下 请求参数不变
|
||||||
|
merge模式下 items -> 每个item需要新增 merge_image_path , merge_image_path为前端处理 print color等操作后的单件结果图
|
||||||
|
|
||||||
|
**
|
||||||
|
|
||||||
|
- 创建一个具有以下参数的请求体:
|
||||||
示例参数:
|
示例参数:
|
||||||
|
```json
|
||||||
{
|
{
|
||||||
"objects": [
|
"objects": [
|
||||||
{
|
{
|
||||||
@@ -56,7 +70,7 @@ def design(request_data: DesignModel):
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"layer_order": true,
|
"layer_order": true,
|
||||||
"preview_submit": "submit",
|
"design_type": "preview",
|
||||||
"scale_bag": 0.7,
|
"scale_bag": 0.7,
|
||||||
"scale_earrings": 0.16,
|
"scale_earrings": 0.16,
|
||||||
"self_template": true,
|
"self_template": true,
|
||||||
@@ -65,14 +79,19 @@ def design(request_data: DesignModel):
|
|||||||
},
|
},
|
||||||
"items": [
|
"items": [
|
||||||
{
|
{
|
||||||
"businessId": 2377945,
|
"businessId": 2115382,
|
||||||
"color": "209 196 171",
|
"color": "",
|
||||||
"image_id": 189410,
|
"image_id": 61686,
|
||||||
"offset": [
|
"offset": [
|
||||||
0,
|
0,
|
||||||
0
|
0
|
||||||
],
|
],
|
||||||
"path": "aida-collection-element/89/Sketchboard/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png",
|
"path": "aida-sys-image/images/female/dress/0628000564.jpg",
|
||||||
|
"transpose": [
|
||||||
|
1,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"rotate": 45,
|
||||||
"print": {
|
"print": {
|
||||||
"element": {
|
"element": {
|
||||||
"element_angle_list": [],
|
"element_angle_list": [],
|
||||||
@@ -81,85 +100,30 @@ def design(request_data: DesignModel):
|
|||||||
"location": []
|
"location": []
|
||||||
},
|
},
|
||||||
"overall": {
|
"overall": {
|
||||||
"location": [],
|
"location": [
|
||||||
"print_angle_list": [],
|
[
|
||||||
"print_path_list": [],
|
53.0,
|
||||||
"print_scale_list": []
|
118.5
|
||||||
},
|
]
|
||||||
"single": {
|
|
||||||
"location": [],
|
|
||||||
"print_angle_list": [],
|
|
||||||
"print_path_list": [],
|
|
||||||
"print_scale_list": []
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"priority": 12,
|
|
||||||
"resize_scale": [
|
|
||||||
1.0,
|
|
||||||
1.0
|
|
||||||
],
|
],
|
||||||
"seg_mask_url": "aida-clothing/mask/mask_8e96ddb0-e466-11f0-8de2-0242ac130002.png",
|
"print_angle_list": [
|
||||||
"type": "Outwear"
|
0.0
|
||||||
},
|
|
||||||
{
|
|
||||||
"businessId": 2377946,
|
|
||||||
"color": "122 152 139",
|
|
||||||
"image_id": 81868,
|
|
||||||
"offset": [
|
|
||||||
0,
|
|
||||||
0
|
|
||||||
],
|
],
|
||||||
"path": "aida-sys-image/images/female/blouse/0825001443.jpg",
|
"print_path_list": [
|
||||||
"print": {
|
"aida-users/89/print/02d57aa8-f342-4e1d-b02c-b278f94dcfe6-3-89.png"
|
||||||
"element": {
|
|
||||||
"element_angle_list": [],
|
|
||||||
"element_path_list": [],
|
|
||||||
"element_scale_list": [],
|
|
||||||
"location": []
|
|
||||||
},
|
|
||||||
"overall": {
|
|
||||||
"location": [],
|
|
||||||
"print_angle_list": [],
|
|
||||||
"print_path_list": [],
|
|
||||||
"print_scale_list": []
|
|
||||||
},
|
|
||||||
"single": {
|
|
||||||
"location": [],
|
|
||||||
"print_angle_list": [],
|
|
||||||
"print_path_list": [],
|
|
||||||
"print_scale_list": []
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"priority": 11,
|
|
||||||
"resize_scale": [
|
|
||||||
1.0,
|
|
||||||
1.0
|
|
||||||
],
|
],
|
||||||
"seg_mask_url": "aida-clothing/mask/mask_8f0fab78-e466-11f0-8de2-0242ac130002.png",
|
"print_scale_list": [
|
||||||
"type": "Blouse"
|
[
|
||||||
},
|
0.5,
|
||||||
{
|
0.5
|
||||||
"businessId": 2377947,
|
]
|
||||||
"color": "111 78 63",
|
|
||||||
"gradient": "aida-gradient/517c3a4d-aed7-4423-aa99-7b60d3577df1.png",
|
|
||||||
"image_id": 116494,
|
|
||||||
"offset": [
|
|
||||||
0,
|
|
||||||
0
|
|
||||||
],
|
],
|
||||||
"path": "aida-sys-image/images/female/skirt/0825000219.jpg",
|
"gap": [
|
||||||
"print": {
|
[
|
||||||
"element": {
|
10,
|
||||||
"element_angle_list": [],
|
10
|
||||||
"element_path_list": [],
|
]
|
||||||
"element_scale_list": [],
|
]
|
||||||
"location": []
|
|
||||||
},
|
|
||||||
"overall": {
|
|
||||||
"location": [],
|
|
||||||
"print_angle_list": [],
|
|
||||||
"print_path_list": [],
|
|
||||||
"print_scale_list": []
|
|
||||||
},
|
},
|
||||||
"single": {
|
"single": {
|
||||||
"location": [],
|
"location": [],
|
||||||
@@ -173,8 +137,8 @@ def design(request_data: DesignModel):
|
|||||||
1.0,
|
1.0,
|
||||||
1.0
|
1.0
|
||||||
],
|
],
|
||||||
"seg_mask_url": "aida-clothing/mask/mask_8f6191fe-e466-11f0-8de2-0242ac130002.png",
|
"seg_mask_url": "aida-clothing/mask/mask_9698b428-eb93-11f0-9327-0242c0a80003.png",
|
||||||
"type": "Skirt"
|
"type": "Dress"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||||
@@ -186,6 +150,7 @@ def design(request_data: DesignModel):
|
|||||||
],
|
],
|
||||||
"process_id": "89"
|
"process_id": "89"
|
||||||
}
|
}
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
|
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
|
||||||
# data = generate(request_data=request_data)
|
# data = generate(request_data=request_data)
|
||||||
@@ -421,6 +386,52 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
|||||||
return ResponseModel()
|
return ResponseModel()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/seg_anything")
|
||||||
|
async def seg_anything(request_data: SAMRequestModel):
|
||||||
|
"""
|
||||||
|
**Segment Anything 交互式分割接口**
|
||||||
|
|
||||||
|
通过传入图片路径和点击的点坐标,返回分割后的掩码数据。
|
||||||
|
|
||||||
|
### 参数说明:
|
||||||
|
- **user_id**:用户id 用于存储分割图
|
||||||
|
- **image_path**: 图片在服务器或云端的相对路径。
|
||||||
|
- **type**: 推理类型
|
||||||
|
- **box**: 框选矩形点位信息
|
||||||
|
- **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。
|
||||||
|
- **labels**: 坐标点的属性标签,必须与 points 长度一致:
|
||||||
|
- 1: **前景点** (代表想要分割出的区域)
|
||||||
|
- 0: **背景点** (代表想要排除的区域)
|
||||||
|
|
||||||
|
### 请求体示例:
|
||||||
|
```json
|
||||||
|
point
|
||||||
|
{
|
||||||
|
"user_id": 1,
|
||||||
|
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
|
||||||
|
"type":"point",
|
||||||
|
"points": [[310, 403], [493, 375], [261, 266], [404, 484]],
|
||||||
|
"labels": [1, 1, 0, 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
box
|
||||||
|
{
|
||||||
|
"user_id": 1,
|
||||||
|
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
|
||||||
|
"type":"box",
|
||||||
|
"box": [350, 286, 544, 520]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||||
|
data = requests.post("http://10.1.1.240:10075/predict", json=request_data.dict())
|
||||||
|
logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
|
||||||
|
return ResponseModel(data=json.loads(data.content))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"seg_anything Run Exception @@@@@@:{e}")
|
||||||
|
|
||||||
|
|
||||||
# @router.post('/get_progress')
|
# @router.post('/get_progress')
|
||||||
# def get_progress(request_data: DesignProgressModel):
|
# def get_progress(request_data: DesignProgressModel):
|
||||||
# """
|
# """
|
||||||
|
|||||||
116
app/api/api_import_sys_sketch.py
Normal file
116
app/api/api_import_sys_sketch.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 使用线程池执行器来运行长时间任务
|
||||||
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
# 用于跟踪任务状态
|
||||||
|
task_status = {"running": False}
|
||||||
|
|
||||||
|
|
||||||
|
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
|
||||||
|
"""在后台线程中运行导入任务"""
|
||||||
|
original_argv = None
|
||||||
|
try:
|
||||||
|
task_status["running"] = True
|
||||||
|
# 保存原始 sys.argv
|
||||||
|
original_argv = sys.argv.copy()
|
||||||
|
|
||||||
|
# 模拟命令行参数
|
||||||
|
sys.argv = [
|
||||||
|
"import_sys_sketch_to_milvus.py",
|
||||||
|
"--batch-size", str(batch_size),
|
||||||
|
"--retry-times", str(retry_times),
|
||||||
|
]
|
||||||
|
if limit is not None:
|
||||||
|
sys.argv.extend(["--limit", str(limit)])
|
||||||
|
if offset > 0:
|
||||||
|
sys.argv.extend(["--offset", str(offset)])
|
||||||
|
if skip_create_collection:
|
||||||
|
sys.argv.append("--skip-create-collection")
|
||||||
|
|
||||||
|
import_main()
|
||||||
|
task_status["running"] = False
|
||||||
|
logger.info("导入任务完成")
|
||||||
|
except Exception as e:
|
||||||
|
task_status["running"] = False
|
||||||
|
logger.error(f"导入任务失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 恢复原始 sys.argv
|
||||||
|
if original_argv is not None:
|
||||||
|
sys.argv = original_argv
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import-sys-sketch", response_model=ResponseModel)
|
||||||
|
async def import_sys_sketch(
|
||||||
|
batch_size: int = Query(1000, description="批量处理大小(默认:1000)"),
|
||||||
|
retry_times: int = Query(3, description="失败重试次数(默认:3)"),
|
||||||
|
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
|
||||||
|
offset: int = Query(0, description="起始偏移量(默认:0)"),
|
||||||
|
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
从 t_sys_file 导入系统图向量到 Milvus
|
||||||
|
|
||||||
|
该接口会异步执行导入任务,任务在后台运行。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查是否有任务正在运行
|
||||||
|
if task_status["running"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="已有导入任务正在运行,请等待完成后再试"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在后台线程中执行任务
|
||||||
|
executor.submit(
|
||||||
|
run_import_task,
|
||||||
|
batch_size,
|
||||||
|
retry_times,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
skip_create_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
code=200,
|
||||||
|
msg="导入任务已启动,正在后台执行",
|
||||||
|
data={
|
||||||
|
"status": "started",
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"retry_times": retry_times,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"skip_create_collection": skip_create_collection
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"启动导入任务失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
|
||||||
|
async def get_import_status():
|
||||||
|
"""
|
||||||
|
获取导入任务状态
|
||||||
|
"""
|
||||||
|
return ResponseModel(
|
||||||
|
code=200,
|
||||||
|
msg="OK",
|
||||||
|
data={
|
||||||
|
"running": task_status["running"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
85
app/api/api_precompute.py
Normal file
85
app/api/api_precompute.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.recommendation_system.precompute import run_precompute
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 使用线程池执行器来运行长时间任务
|
||||||
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
# 用于跟踪任务状态
|
||||||
|
task_status = {"running": False}
|
||||||
|
|
||||||
|
|
||||||
|
def run_precompute_task():
|
||||||
|
"""在后台线程中运行预计算任务"""
|
||||||
|
try:
|
||||||
|
task_status["running"] = True
|
||||||
|
logger.info("开始执行预计算任务...")
|
||||||
|
run_precompute()
|
||||||
|
task_status["running"] = False
|
||||||
|
logger.info("预计算任务完成")
|
||||||
|
except Exception as e:
|
||||||
|
task_status["running"] = False
|
||||||
|
logger.error(f"预计算任务失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/precompute", response_model=ResponseModel)
|
||||||
|
async def precompute():
|
||||||
|
"""
|
||||||
|
运行预计算任务
|
||||||
|
|
||||||
|
该接口会异步执行预计算任务,包括:
|
||||||
|
1. 优化数据库表结构
|
||||||
|
2. 历史数据迁移
|
||||||
|
3. 初始用户偏好向量生成
|
||||||
|
|
||||||
|
任务在后台运行。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查是否有任务正在运行
|
||||||
|
if task_status["running"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="已有预计算任务正在运行,请等待完成后再试"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在后台线程中执行任务
|
||||||
|
executor.submit(run_precompute_task)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
code=200,
|
||||||
|
msg="预计算任务已启动,正在后台执行",
|
||||||
|
data={
|
||||||
|
"status": "started",
|
||||||
|
"tasks": [
|
||||||
|
"优化数据库表结构",
|
||||||
|
"历史数据迁移",
|
||||||
|
"初始用户偏好向量生成"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/precompute/status", response_model=ResponseModel)
|
||||||
|
async def get_precompute_status():
|
||||||
|
"""
|
||||||
|
获取预计算任务状态
|
||||||
|
"""
|
||||||
|
return ResponseModel(
|
||||||
|
code=200,
|
||||||
|
msg="OK",
|
||||||
|
data={
|
||||||
|
"running": task_status["running"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@@ -1,206 +1,206 @@
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
from typing import List, Optional
|
||||||
from typing import List
|
from fastapi import HTTPException, APIRouter, Query
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from apscheduler.schedulers.background import BackgroundScheduler
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
|
||||||
from fastapi import HTTPException, APIRouter
|
|
||||||
|
|
||||||
from app.service.recommend.service import load_resources, matrix_data
|
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
|
||||||
|
from app.service.recommendation_system.incremental_listener import start_background_listener
|
||||||
|
from app.service.recommendation_system.milvus_client import create_collection
|
||||||
|
|
||||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
|
||||||
|
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||||
|
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||||
|
# """
|
||||||
|
# :param user_id: 4
|
||||||
|
# :param category: female_skirt
|
||||||
|
# :param num_recommendations: 1
|
||||||
|
# :return:
|
||||||
|
# [
|
||||||
|
# "aida-sys-image/images/female/skirt/903000017.jpg"
|
||||||
|
# ]
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# start_time = time.time()
|
||||||
|
# cache_key = (user_id, category)
|
||||||
|
# # === 新增:用户存在性检查 ===
|
||||||
|
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||||
|
# user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||||
|
#
|
||||||
|
# # 任一矩阵不存在用户则返回随机推荐
|
||||||
|
# if not (user_exists_inter and user_exists_feat):
|
||||||
|
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||||
|
# return get_random_recommendations(category, num_recommendations)
|
||||||
|
#
|
||||||
|
# # 检查缓存
|
||||||
|
# if cache_key in matrix_data["cached_scores"]:
|
||||||
|
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||||
|
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||||
|
# else:
|
||||||
|
# # 实时计算逻辑(同原代码)
|
||||||
|
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||||
|
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||||
|
#
|
||||||
|
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||||
|
# valid_sketch_idxs_inter = [
|
||||||
|
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||||
|
# if iid in category_iids
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# # 处理交互分数
|
||||||
|
# raw_inter_scores = []
|
||||||
|
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||||
|
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||||
|
# processed_inter = raw_inter_scores * 0.7
|
||||||
|
#
|
||||||
|
# # 处理特征分数
|
||||||
|
# valid_sketch_idxs_feature = [
|
||||||
|
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||||
|
# if iid in category_iids
|
||||||
|
# ]
|
||||||
|
# raw_feat_scores = []
|
||||||
|
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||||
|
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||||
|
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||||
|
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||||
|
# processed_feat = raw_feat_scores
|
||||||
|
# else:
|
||||||
|
# processed_feat = np.array([])
|
||||||
|
#
|
||||||
|
# # 更新缓存
|
||||||
|
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||||
|
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||||
|
#
|
||||||
|
# # 合并分数
|
||||||
|
# if brand_id is not None:
|
||||||
|
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||||
|
#
|
||||||
|
# brand_feat_valid = (
|
||||||
|
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||||
|
# brand_idx_feature is not None and
|
||||||
|
# valid_sketch_idxs_feature # 有可用索引
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# if brand_feat_valid:
|
||||||
|
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||||
|
# brand_idx_feature, valid_sketch_idxs_feature
|
||||||
|
# ]
|
||||||
|
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||||
|
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||||
|
# )
|
||||||
|
# processed_brand_feat = raw_brand_feat_scores
|
||||||
|
#
|
||||||
|
# # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||||
|
# if processed_feat.size == 0:
|
||||||
|
# processed_feat = np.zeros_like(processed_brand_feat)
|
||||||
|
#
|
||||||
|
# final_scores = processed_inter + 0.3 * (
|
||||||
|
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# # brand 信息不可用
|
||||||
|
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||||
|
# else:
|
||||||
|
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||||
|
#
|
||||||
|
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||||
|
#
|
||||||
|
# # 概率采样
|
||||||
|
# scores = np.array(final_scores)
|
||||||
|
#
|
||||||
|
# # 调整后的概率转换(带温度控制的softmax)
|
||||||
|
# def calibrated_softmax(scores, temperature=1.0):
|
||||||
|
# scores = scores / temperature
|
||||||
|
# scale = scores - max(scores)
|
||||||
|
# exps = np.exp(scale)
|
||||||
|
# return exps / np.sum(exps)
|
||||||
|
#
|
||||||
|
# probs = calibrated_softmax(scores, 0.09)
|
||||||
|
#
|
||||||
|
# chosen_indices = np.random.choice(
|
||||||
|
# len(valid_sketch_idxs),
|
||||||
|
# size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||||
|
# p=probs,
|
||||||
|
# replace=False
|
||||||
|
# )
|
||||||
|
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||||
|
#
|
||||||
|
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||||
|
# return recommendations
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||||
|
# raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
# 初始加载
|
"""启动时初始化增量监听任务"""
|
||||||
load_resources()
|
try:
|
||||||
|
# 屏蔽 apscheduler 的 INFO 日志
|
||||||
|
logging.getLogger("apscheduler").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# 确保 Milvus 集合已创建(若已存在则直接返回)
|
||||||
|
try:
|
||||||
|
create_collection()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
|
||||||
|
|
||||||
# 配置定时任务
|
# 配置定时任务
|
||||||
scheduler = BackgroundScheduler()
|
scheduler = BackgroundScheduler()
|
||||||
scheduler.add_job(
|
start_background_listener(scheduler)
|
||||||
load_resources,
|
|
||||||
trigger=CronTrigger(hour=0, minute=30),
|
|
||||||
name="每日资源刷新"
|
|
||||||
)
|
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
logger.info("定时任务已启动")
|
logger.info("增量监听定时任务已启动")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def softmax(scores):
|
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
|
||||||
max_score = max(scores)
|
async def recommend(
|
||||||
exp_scores = [math.exp(s - max_score) for s in scores]
|
user_id: int,
|
||||||
sum_exp = sum(exp_scores)
|
category: str,
|
||||||
return [s / sum_exp for s in exp_scores]
|
style: Optional[str] = Query(
|
||||||
|
None,
|
||||||
|
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""新版推荐接口(Milvus + Redis 偏好向量)。"""
|
||||||
|
try:
|
||||||
|
results = get_new_recommendations(user_id, category, style)
|
||||||
|
path = results[0] if results else ""
|
||||||
|
return [path]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
# def get_random_recommendations(category: str, num: int) -> List[str]:
|
@router.get("/redis/user_pref")
|
||||||
# """根据预加载热度向量推荐(冷启动)"""
|
async def get_all_user_preferences():
|
||||||
# try:
|
|
||||||
# heat_data = matrix_data.get("heat_data", {})
|
|
||||||
#
|
|
||||||
# if category not in heat_data:
|
|
||||||
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
|
|
||||||
#
|
|
||||||
# heat_dict = heat_data[category] # {url: score}
|
|
||||||
# urls = list(heat_dict.keys())
|
|
||||||
# scores = list(heat_dict.values())
|
|
||||||
#
|
|
||||||
# if not urls:
|
|
||||||
# raise ValueError("该类别下无热度记录,使用随机推荐")
|
|
||||||
#
|
|
||||||
# probs = softmax(scores)
|
|
||||||
# sample_size = min(num, len(urls))
|
|
||||||
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
|
|
||||||
#
|
|
||||||
# return sampled_urls
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# # 回退:完全随机推荐
|
|
||||||
# all_iids = list(matrix_data["iid_to_sketch"].keys())
|
|
||||||
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
|
||||||
# sample_size = min(num, len(category_iids))
|
|
||||||
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
|
||||||
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
|
||||||
|
|
||||||
def get_random_recommendations(category: str, num: int) -> List[str]:
|
|
||||||
"""全品类随机推荐"""
|
|
||||||
all_iids = list(matrix_data["iid_to_sketch"].keys())
|
|
||||||
# 优先从当前品类选择
|
|
||||||
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
|
||||||
# 确保不超出实际数量
|
|
||||||
sample_size = min(num, len(category_iids))
|
|
||||||
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
|
||||||
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
|
||||||
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
|
||||||
"""
|
"""
|
||||||
@param user_id: 4
|
获取所有以 user_pref 为前缀的 Redis key 信息
|
||||||
@param category: female_skirt
|
|
||||||
@param num_recommendations: 1
|
|
||||||
@return:
|
|
||||||
[
|
|
||||||
"aida-sys-image/images/female/skirt/903000017.jpg"
|
|
||||||
]
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"user_id:{user_id}-----category:{category}-----brand_id:{brand_id}-----brand_scale:{brand_scale}-----num_recommendations:{num_recommendations}")
|
from app.service.utils.redis_utils import Redis
|
||||||
start_time = time.time()
|
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
|
||||||
cache_key = (user_id, category)
|
|
||||||
# === 新增:用户存在性检查 ===
|
|
||||||
user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
|
||||||
user_exists_feat = user_id in matrix_data["user_index_feature"]
|
|
||||||
|
|
||||||
# 任一矩阵不存在用户则返回随机推荐
|
# 扫描所有匹配 user_pref:* 的 key
|
||||||
if not (user_exists_inter and user_exists_feat):
|
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
|
||||||
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
keys = Redis.scan_keys(pattern)
|
||||||
return get_random_recommendations(category, num_recommendations)
|
|
||||||
|
|
||||||
# 检查缓存
|
# 直接返回所有 key 和原始 value
|
||||||
if cache_key in matrix_data["cached_scores"]:
|
result = {}
|
||||||
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
for key in keys:
|
||||||
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
# 读取对应的值
|
||||||
else:
|
value = Redis.read(key)
|
||||||
# 实时计算逻辑(同原代码)
|
if value:
|
||||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
result[key] = value
|
||||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
|
||||||
|
|
||||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
return result
|
||||||
valid_sketch_idxs_inter = [
|
|
||||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
|
||||||
if iid in category_iids
|
|
||||||
]
|
|
||||||
|
|
||||||
# 处理交互分数
|
|
||||||
raw_inter_scores = []
|
|
||||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
|
||||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
|
||||||
processed_inter = raw_inter_scores * 0.7
|
|
||||||
|
|
||||||
# 处理特征分数
|
|
||||||
valid_sketch_idxs_feature = [
|
|
||||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
|
||||||
if iid in category_iids
|
|
||||||
]
|
|
||||||
raw_feat_scores = []
|
|
||||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
|
||||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
|
||||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
|
||||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
|
||||||
processed_feat = raw_feat_scores
|
|
||||||
else:
|
|
||||||
processed_feat = np.array([])
|
|
||||||
|
|
||||||
# 更新缓存
|
|
||||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
|
||||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
|
||||||
|
|
||||||
# 合并分数
|
|
||||||
if brand_id is not None:
|
|
||||||
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
|
||||||
|
|
||||||
brand_feat_valid = (
|
|
||||||
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
|
||||||
brand_idx_feature is not None and
|
|
||||||
valid_sketch_idxs_feature # 有可用索引
|
|
||||||
)
|
|
||||||
|
|
||||||
if brand_feat_valid:
|
|
||||||
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
|
||||||
brand_idx_feature, valid_sketch_idxs_feature
|
|
||||||
]
|
|
||||||
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
|
||||||
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
|
||||||
)
|
|
||||||
processed_brand_feat = raw_brand_feat_scores
|
|
||||||
|
|
||||||
# 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
|
||||||
if processed_feat.size == 0:
|
|
||||||
processed_feat = np.zeros_like(processed_brand_feat)
|
|
||||||
|
|
||||||
final_scores = processed_inter + 0.3 * (
|
|
||||||
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# brand 信息不可用
|
|
||||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
|
||||||
else:
|
|
||||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
|
||||||
|
|
||||||
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
|
||||||
|
|
||||||
# 概率采样
|
|
||||||
scores = np.array(final_scores)
|
|
||||||
|
|
||||||
# 调整后的概率转换(带温度控制的softmax)
|
|
||||||
def calibrated_softmax(scores, temperature=1.0):
|
|
||||||
scores = scores / temperature
|
|
||||||
scale = scores - max(scores)
|
|
||||||
exps = np.exp(scale)
|
|
||||||
return exps / np.sum(exps)
|
|
||||||
|
|
||||||
probs = calibrated_softmax(scores, 0.09)
|
|
||||||
|
|
||||||
chosen_indices = np.random.choice(
|
|
||||||
len(valid_sketch_idxs),
|
|
||||||
size=min(num_recommendations, len(valid_sketch_idxs)),
|
|
||||||
p=probs,
|
|
||||||
replace=False
|
|
||||||
)
|
|
||||||
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
|
||||||
|
|
||||||
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
|
||||||
return recommendations
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -7,6 +7,7 @@ from app.api import api_design_pre_processing
|
|||||||
from app.api import api_generate_image
|
from app.api import api_generate_image
|
||||||
from app.api import api_mannequins_edit
|
from app.api import api_mannequins_edit
|
||||||
from app.api import api_pose_transform
|
from app.api import api_pose_transform
|
||||||
|
from app.api import api_precompute
|
||||||
from app.api import api_prompt_generation
|
from app.api import api_prompt_generation
|
||||||
from app.api import api_recommendation
|
from app.api import api_recommendation
|
||||||
from app.api import api_test
|
from app.api import api_test
|
||||||
@@ -21,6 +22,7 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
|
|||||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||||
|
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
|
||||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# --- mysql 配置信息 ---
|
# --- mysql 配置信息 ---
|
||||||
MYSQL_HOST: str = Field(default='', description="")
|
MYSQL_HOST: str = Field(default='', description="")
|
||||||
MYSQL_PORT: str = Field(default='', description="")
|
MYSQL_PORT: int = Field(default='', description="")
|
||||||
MYSQL_USER: str = Field(default='', description="")
|
MYSQL_USER: str = Field(default='', description="")
|
||||||
MYSQL_PASSWORD: str = Field(default='', description="")
|
MYSQL_PASSWORD: str = Field(default='', description="")
|
||||||
MYSQL_DB: str = Field(default='', description="")
|
MYSQL_DB: str = Field(default='', description="")
|
||||||
|
|||||||
@@ -1,4 +1,15 @@
|
|||||||
from pydantic import BaseModel
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SAMRequestModel(BaseModel):
|
||||||
|
user_id: int = Field(..., description="用户id, 必填字段")
|
||||||
|
image_path: str = Field(..., description="图片路径,必填字段")
|
||||||
|
type: str = Field(..., description="推理类型,必填字段")
|
||||||
|
points: Optional[List[List[float]]] = None
|
||||||
|
labels: Optional[List[int]] = None
|
||||||
|
box: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
class DesignModel(BaseModel):
|
class DesignModel(BaseModel):
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ import requests
|
|||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem
|
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem, TopMergeItem, BottomMergeItem, OthersMergeItem
|
||||||
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
|
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
|
||||||
from app.service.design_fast.utils.progress import final_progress, update_progress
|
from app.service.design_fast.utils.progress import final_progress, update_progress
|
||||||
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
|
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority, merge
|
||||||
from app.service.utils.decorator import RunTime
|
from app.service.utils.decorator import RunTime
|
||||||
|
|
||||||
id_lock = threading.Lock()
|
id_lock = threading.Lock()
|
||||||
@@ -19,22 +19,46 @@ logger = logging.getLogger()
|
|||||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
def process_item(item, basic):
|
def process_item(item, basic, design_type):
|
||||||
# 处理project中单个item
|
# 1. 定义映射配置
|
||||||
if item['type'] == "Body":
|
# key 为 item_type 的小写,value 为对应的处理类
|
||||||
body_server = BodyItem(data=item, basic=basic, minio_client=minio_client)
|
DESIGN_MAP = {
|
||||||
item_data = body_server.process()
|
'body': BodyItem,
|
||||||
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
|
'blouse': TopItem, 'outwear': TopItem,
|
||||||
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
|
'dress': TopItem, 'tops': TopItem,
|
||||||
item_data = top_server.process()
|
'skirt': BottomItem, 'trousers': BottomItem,
|
||||||
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
|
'bottoms': BottomItem,
|
||||||
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
'others': OthersItem
|
||||||
item_data = bottom_server.process()
|
}
|
||||||
elif item['type'].lower() in ['others']:
|
|
||||||
bottom_server = OthersItem(data=item, basic=basic, minio_client=minio_client)
|
MERGE_MAP = {
|
||||||
item_data = bottom_server.process()
|
'body_merge': BodyItem,
|
||||||
|
'blouse_merge': TopMergeItem, 'outwear_merge': TopMergeItem,
|
||||||
|
'dress_merge': TopMergeItem, 'tops_merge': TopMergeItem,
|
||||||
|
'skirt_merge': BottomMergeItem, 'trousers_merge': BottomMergeItem,
|
||||||
|
'bottoms_merge': BottomMergeItem,
|
||||||
|
'others_merge': OthersMergeItem
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. 根据 design_type 选择映射表
|
||||||
|
mapping = MERGE_MAP if design_type == 'merge' else DESIGN_MAP
|
||||||
|
|
||||||
|
if design_type == 'merge':
|
||||||
|
item_type_key = f"{item['type'].lower()}_merge"
|
||||||
|
elif design_type == 'default':
|
||||||
|
item_type_key = item['type'].lower()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Item type {item['type']} not implemented")
|
item_type_key = item['type'].lower()
|
||||||
|
|
||||||
|
handler_class = mapping.get(item_type_key)
|
||||||
|
|
||||||
|
if not handler_class:
|
||||||
|
raise NotImplementedError(f"Item type {item['type']} not implemented for design_type={design_type}")
|
||||||
|
|
||||||
|
# 4. 统一实例化并执行
|
||||||
|
# 注意:这里假设所有 Item 类构造函数签名一致
|
||||||
|
server = handler_class(data=item, basic=basic, minio_client=minio_client)
|
||||||
|
item_data = server.process()
|
||||||
return item_data
|
return item_data
|
||||||
|
|
||||||
|
|
||||||
@@ -44,7 +68,7 @@ def process_layer(item, layers):
|
|||||||
body_layer = organize_body(item)
|
body_layer = organize_body(item)
|
||||||
layers.append(body_layer)
|
layers.append(body_layer)
|
||||||
return item['body_image'].size
|
return item['body_image'].size
|
||||||
elif item['name'] == 'others':
|
elif item['name'] in ['others', 'others_merge']:
|
||||||
front_layer, back_layer = organize_others(item)
|
front_layer, back_layer = organize_others(item)
|
||||||
layers.append(front_layer)
|
layers.append(front_layer)
|
||||||
layers.append(back_layer)
|
layers.append(back_layer)
|
||||||
@@ -70,10 +94,11 @@ def design_generate(request_data):
|
|||||||
nonlocal active_threads
|
nonlocal active_threads
|
||||||
basic = object['basic']
|
basic = object['basic']
|
||||||
items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""}
|
items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""}
|
||||||
|
design_type = basic.get('design_type', "default")
|
||||||
if basic['single_overall'] == "overall":
|
if basic['single_overall'] == "overall":
|
||||||
item_results = []
|
item_results = []
|
||||||
for item in object['items']:
|
for item in object['items']:
|
||||||
item_results.append(process_item(item, basic))
|
item_results.append(process_item(item, basic, design_type))
|
||||||
layers = []
|
layers = []
|
||||||
for item in item_results:
|
for item in item_results:
|
||||||
process_layer(item, layers)
|
process_layer(item, layers)
|
||||||
@@ -93,10 +118,17 @@ def design_generate(request_data):
|
|||||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||||
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
|
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
|
||||||
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
|
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
|
||||||
|
'transpose': lay.get('transpose', None),
|
||||||
|
'rotate': lay.get('rotate', None),
|
||||||
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
|
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
|
||||||
})
|
})
|
||||||
|
if basic.get('design_type') == 'default':
|
||||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||||
|
elif basic.get('design_type') == 'merge':
|
||||||
|
items_response['synthesis_url'] = merge(layers, new_size, basic)
|
||||||
|
else:
|
||||||
|
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
item_result = process_item(object['items'][0], basic)
|
item_result = process_item(object['items'][0], basic)
|
||||||
items_response['layers'].append({
|
items_response['layers'].append({
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ class BaseItem:
|
|||||||
self.result['name'] = data['type'].lower()
|
self.result['name'] = data['type'].lower()
|
||||||
self.result.pop("type")
|
self.result.pop("type")
|
||||||
self.result.update(basic)
|
self.result.update(basic)
|
||||||
|
self.result['design_type'] = basic.get('design_type', None)
|
||||||
|
|
||||||
|
|
||||||
class OthersItem(BaseItem):
|
class OthersItem(BaseItem):
|
||||||
@@ -14,13 +15,7 @@ class OthersItem(BaseItem):
|
|||||||
super().__init__(data, basic)
|
super().__init__(data, basic)
|
||||||
self.Others_pipeline = [
|
self.Others_pipeline = [
|
||||||
LoadImage(minio_client),
|
LoadImage(minio_client),
|
||||||
# KeyPoint(),
|
|
||||||
# ContourDetection(),
|
|
||||||
Segmentation(minio_client),
|
Segmentation(minio_client),
|
||||||
# BackPerspective(minio_client),
|
|
||||||
Color(minio_client),
|
|
||||||
NoSegPrintPainting(minio_client),
|
|
||||||
PrintPainting(minio_client),
|
|
||||||
Scaling(),
|
Scaling(),
|
||||||
Split(minio_client)
|
Split(minio_client)
|
||||||
]
|
]
|
||||||
@@ -74,6 +69,65 @@ class BottomItem(BaseItem):
|
|||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
|
|
||||||
|
"""merge"""
|
||||||
|
|
||||||
|
|
||||||
|
class OthersMergeItem(BaseItem):
|
||||||
|
def __init__(self, data, basic, minio_client):
|
||||||
|
super().__init__(data, basic)
|
||||||
|
self.Others_pipeline = [
|
||||||
|
LoadImage(minio_client),
|
||||||
|
# KeyPoint(),
|
||||||
|
# ContourDetection(),
|
||||||
|
Segmentation(minio_client),
|
||||||
|
# BackPerspective(minio_client),
|
||||||
|
Color(minio_client),
|
||||||
|
NoSegPrintPainting(minio_client),
|
||||||
|
PrintPainting(minio_client),
|
||||||
|
Scaling(),
|
||||||
|
Split(minio_client)
|
||||||
|
]
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
for item in self.Others_pipeline:
|
||||||
|
self.result = item(self.result)
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
|
||||||
|
class TopMergeItem(BaseItem):
|
||||||
|
def __init__(self, data, basic, minio_client):
|
||||||
|
super().__init__(data, basic)
|
||||||
|
self.top_pipeline = [
|
||||||
|
LoadImage(minio_client),
|
||||||
|
KeyPoint(),
|
||||||
|
Segmentation(minio_client),
|
||||||
|
Scaling(),
|
||||||
|
Split(minio_client)
|
||||||
|
]
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
for item in self.top_pipeline:
|
||||||
|
self.result = item(self.result)
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
|
||||||
|
class BottomMergeItem(BaseItem):
|
||||||
|
def __init__(self, data, basic, minio_client):
|
||||||
|
super().__init__(data, basic)
|
||||||
|
self.bottom_pipeline = [
|
||||||
|
LoadImage(minio_client),
|
||||||
|
KeyPoint(),
|
||||||
|
Segmentation(minio_client),
|
||||||
|
Scaling(),
|
||||||
|
Split(minio_client)
|
||||||
|
]
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
for item in self.bottom_pipeline:
|
||||||
|
self.result = item(self.result)
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
|
||||||
class BodyItem(BaseItem):
|
class BodyItem(BaseItem):
|
||||||
def __init__(self, data, basic, minio_client):
|
def __init__(self, data, basic, minio_client):
|
||||||
super().__init__(data, basic)
|
super().__init__(data, basic)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pymilvus import MilvusClient
|
# from pymilvus import MilvusClient
|
||||||
|
|
||||||
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
|
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
|
||||||
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
||||||
@@ -54,63 +54,64 @@ class KeyPoint:
|
|||||||
"keypoint_vector": result.tolist()
|
"keypoint_vector": result.tolist()
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
try:
|
|
||||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
|
||||||
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
|
||||||
client.close()
|
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"save keypoint cache milvus error : {e}")
|
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
|
|
||||||
@staticmethod
|
# try:
|
||||||
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||||
if site == "up":
|
# client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||||
# 需要的是up 即推理出来的是up 那么查询的就是down
|
# client.close()
|
||||||
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
# except Exception as e:
|
||||||
else:
|
# logger.info(f"save keypoint cache milvus error : {e}")
|
||||||
# 需要的是down 即推理出来的是down 那么查询的就是up
|
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
|
||||||
data = [
|
|
||||||
{"keypoint_id": keypoint_id,
|
|
||||||
"keypoint_site": "all",
|
|
||||||
"keypoint_vector": result.tolist()
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
# @staticmethod
|
||||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
# def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
||||||
client.upsert(
|
# if site == "up":
|
||||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
# # 需要的是up 即推理出来的是up 那么查询的就是down
|
||||||
data=data
|
# result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||||
)
|
# else:
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
# # 需要的是down 即推理出来的是down 那么查询的就是up
|
||||||
except Exception as e:
|
# result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||||
logger.info(f"save keypoint cache milvus error : {e}")
|
# data = [
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
# {"keypoint_id": keypoint_id,
|
||||||
|
# "keypoint_site": "all",
|
||||||
|
# "keypoint_vector": result.tolist()
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||||
|
# client.upsert(
|
||||||
|
# collection_name=MILVUS_TABLE_KEYPOINT,
|
||||||
|
# data=data
|
||||||
|
# )
|
||||||
|
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.info(f"save keypoint cache milvus error : {e}")
|
||||||
|
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
|
|
||||||
@RunTime
|
# @RunTime
|
||||||
def keypoint_cache(self, result, site):
|
# def keypoint_cache(self, result, site):
|
||||||
try:
|
# try:
|
||||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||||
keypoint_id = result['image_id']
|
# keypoint_id = result['image_id']
|
||||||
res = client.query(
|
# res = client.query(
|
||||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
# collection_name=MILVUS_TABLE_KEYPOINT,
|
||||||
# ids=[keypoint_id],
|
# # ids=[keypoint_id],
|
||||||
filter=f"keypoint_id == {keypoint_id}",
|
# filter=f"keypoint_id == {keypoint_id}",
|
||||||
output_fields=['keypoint_vector', 'keypoint_site']
|
# output_fields=['keypoint_vector', 'keypoint_site']
|
||||||
)
|
# )
|
||||||
if len(res) == 0:
|
# if len(res) == 0:
|
||||||
# 没有结果 直接推理拿结果 并保存
|
# # 没有结果 直接推理拿结果 并保存
|
||||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
# keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||||
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
# return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
||||||
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
# elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
||||||
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
# # 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||||
elif res[0]["keypoint_site"] != site:
|
# elif res[0]["keypoint_site"] != site:
|
||||||
# 需要的类型和查询到的不一致,则更新类型为all
|
# # 需要的类型和查询到的不一致,则更新类型为all
|
||||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
# keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||||
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
# return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.info(f"search keypoint cache milvus error {e}")
|
# logger.info(f"search keypoint cache milvus error {e}")
|
||||||
return False
|
# return False
|
||||||
|
|||||||
@@ -35,15 +35,9 @@ class LoadImage:
|
|||||||
return cls.name
|
return cls.name
|
||||||
|
|
||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
|
if result.get("merge_image_path"):
|
||||||
|
result['merge_image'], _ = self.read_image(result['merge_image_path'])
|
||||||
result['image'], result['pre_mask'] = self.read_image(result['path'])
|
result['image'], result['pre_mask'] = self.read_image(result['path'])
|
||||||
# if 'extract_lines' in result.keys():
|
|
||||||
# if result['extract_lines']:
|
|
||||||
# result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), result['path'])
|
|
||||||
# else:
|
|
||||||
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
|
||||||
# else:
|
|
||||||
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY))
|
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY))
|
||||||
result['keypoint'] = self.get_keypoint(result['name'])
|
result['keypoint'] = self.get_keypoint(result['name'])
|
||||||
result['img_shape'] = result['image'].shape
|
result['img_shape'] = result['image'].shape
|
||||||
@@ -61,21 +55,6 @@ class LoadImage:
|
|||||||
mask = skeleton
|
mask = skeleton
|
||||||
result = np.ones_like(img) * 255
|
result = np.ones_like(img) * 255
|
||||||
result[mask] = img[mask]
|
result[mask] = img[mask]
|
||||||
|
|
||||||
# 步骤2:细化边缘(可选,让线条更干净)
|
|
||||||
# kernel = np.ones((1, 1), np.uint8)
|
|
||||||
# clean = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
|
|
||||||
|
|
||||||
# thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
|
|
||||||
# mask = thinned > 0
|
|
||||||
# result = np.ones_like(img) * 255
|
|
||||||
# result[mask] = img[mask]
|
|
||||||
|
|
||||||
# 步骤3:反转回 白底黑线
|
|
||||||
# lines = cv2.bitwise_not(thinned)
|
|
||||||
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Original_{path.replace('/', '-')}.png"), img)
|
|
||||||
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Line_{path.replace('/', '-')}.png"), result)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def read_image(self, image_path):
|
def read_image(self, image_path):
|
||||||
@@ -96,19 +75,19 @@ class LoadImage:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_keypoint(name):
|
def get_keypoint(name):
|
||||||
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
|
if name in ['blouse', 'outwear', 'dress', 'tops', 'blouse_merge', 'outwear_merge', 'dress_merge', 'tops_merge']:
|
||||||
keypoint = 'shoulder'
|
keypoint = 'shoulder'
|
||||||
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
|
elif name in ['trousers', 'skirt', 'bottoms', 'trousers_merge', 'skirt_merge', 'bottoms_merge']:
|
||||||
keypoint = 'waistband'
|
keypoint = 'waistband'
|
||||||
elif name == 'bag':
|
elif name in ['bag', 'bag_merge']:
|
||||||
keypoint = 'hand_point'
|
keypoint = 'hand_point'
|
||||||
elif name == 'shoes':
|
elif name in ['shoes', 'shoes_merge']:
|
||||||
keypoint = 'toe'
|
keypoint = 'toe'
|
||||||
elif name == 'hairstyle':
|
elif name in ['hairstyle', 'hairstyle_merge']:
|
||||||
keypoint = 'head_point'
|
keypoint = 'head_point'
|
||||||
elif name == 'earring':
|
elif name in ['earring', 'earring_merge']:
|
||||||
keypoint = 'ear_point'
|
keypoint = 'ear_point'
|
||||||
elif name == 'others':
|
elif name in ['others', 'others_merge']:
|
||||||
keypoint = "others"
|
keypoint = "others"
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from app.service.utils.new_oss_client import oss_get_image
|
|||||||
|
|
||||||
class NoSegPrintPainting:
|
class NoSegPrintPainting:
|
||||||
def __init__(self, minio_client):
|
def __init__(self, minio_client):
|
||||||
self.random_seed = random.randint(0, 1000)
|
|
||||||
self.minio_client = minio_client
|
self.minio_client = minio_client
|
||||||
|
|
||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
@@ -21,16 +20,8 @@ class NoSegPrintPainting:
|
|||||||
|
|
||||||
if overall_print['print_path_list']:
|
if overall_print['print_path_list']:
|
||||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||||
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
|
# 获取平铺 + 旋转 的overall print
|
||||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
|
painting_dict = self.painting_collection(painting_dict, overall_print)
|
||||||
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
|
||||||
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
|
||||||
|
|
||||||
# resize 到sketch大小
|
|
||||||
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
|
||||||
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
|
||||||
else:
|
|
||||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
|
|
||||||
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = self.printpaint(result, painting_dict, print_=True)
|
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = self.printpaint(result, painting_dict, print_=True)
|
||||||
result['pattern_image'] = result['no_seg_sketch_overall']
|
result['pattern_image'] = result['no_seg_sketch_overall']
|
||||||
|
|
||||||
@@ -151,7 +142,6 @@ class NoSegPrintPainting:
|
|||||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||||
result['no_seg_sketch_print'] = cv2.add(tmp1, tmp2)
|
result['no_seg_sketch_print'] = cv2.add(tmp1, tmp2)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -166,26 +156,21 @@ class NoSegPrintPainting:
|
|||||||
print_background = img1_bg + img2_fg
|
print_background = img1_bg + img2_fg
|
||||||
return print_background
|
return print_background
|
||||||
|
|
||||||
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
|
def painting_collection(self, painting_dict, print_dict):
|
||||||
if print_trigger:
|
|
||||||
print_ = self.get_print(print_dict)
|
print_ = self.get_print(print_dict)
|
||||||
painting_dict['Trigger'] = not is_single
|
|
||||||
painting_dict['location'] = print_['location']
|
painting_dict['location'] = print_['location']
|
||||||
single_mask_inv_print = self.get_mask_inv(print_['image'])
|
|
||||||
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||||
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||||
if not is_single:
|
gap = print_dict.get('gap', [[0, 0]])[0]
|
||||||
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
painting_dict['tile_print'] = tile_image(pattern=print_['image'],
|
||||||
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
dim=dim_pattern,
|
||||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
gap_x=gap[0],
|
||||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
gap_y=gap[1],
|
||||||
else:
|
canvas_h=painting_dict['dim_image_h'],
|
||||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
canvas_w=painting_dict['dim_image_w'],
|
||||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
location=painting_dict['location'],
|
||||||
else:
|
angle=45)
|
||||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
painting_dict['mask_inv_print'] = np.zeros(painting_dict['tile_print'].shape[:2], dtype=np.uint8)
|
||||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
|
||||||
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
|
|
||||||
return painting_dict
|
return painting_dict
|
||||||
|
|
||||||
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
||||||
@@ -219,33 +204,32 @@ class NoSegPrintPainting:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def printpaint(result, painting_dict, print_=False):
|
def printpaint(result, painting_dict, print_=False):
|
||||||
|
if print_:
|
||||||
if print_ and painting_dict['Trigger']:
|
|
||||||
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
||||||
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
||||||
else:
|
else:
|
||||||
print_mask = result['mask']
|
print_mask = result['mask']
|
||||||
img_fg = result['final_image']
|
img_fg = result['final_image']
|
||||||
if print_ and not painting_dict['Trigger']:
|
# if print_ and not painting_dict['Trigger']:
|
||||||
index_ = None
|
# index_ = None
|
||||||
try:
|
# try:
|
||||||
index_ = len(painting_dict['location'])
|
# index_ = len(painting_dict['location'])
|
||||||
except:
|
# except:
|
||||||
assert f'there must be parameter of location if choose IfSingle'
|
# assert f'there must be parameter of location if choose IfSingle'
|
||||||
|
#
|
||||||
for i in range(index_):
|
# for i in range(index_):
|
||||||
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
# start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
||||||
|
#
|
||||||
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
# length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
||||||
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
# length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
||||||
|
#
|
||||||
change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
# change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
||||||
# problem in change_mask
|
# # problem in change_mask
|
||||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
# change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||||
# get real part into change mask
|
# # get real part into change mask
|
||||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
# _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||||
cv2.bitwise_not(painting_dict['mask_inv_print'])
|
# cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
# img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||||
|
|
||||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||||
|
|
||||||
@@ -277,8 +261,6 @@ class NoSegPrintPainting:
|
|||||||
print_w = print_shape[1]
|
print_w = print_shape[1]
|
||||||
print_h = print_shape[0]
|
print_h = print_shape[0]
|
||||||
|
|
||||||
random.seed(self.random_seed)
|
|
||||||
|
|
||||||
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||||
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
|
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
|
||||||
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
|
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
|
||||||
@@ -420,3 +402,96 @@ class NoSegPrintPainting:
|
|||||||
cropped_img = resized_img[start_y:start_y + target_height, :]
|
cropped_img = resized_img[start_y:start_y + target_height, :]
|
||||||
|
|
||||||
return cropped_img
|
return cropped_img
|
||||||
|
|
||||||
|
|
||||||
|
def tile_image(pattern, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
|
||||||
|
"""
|
||||||
|
按照指定的 X/Y 间距平铺印花,并支持旋转
|
||||||
|
:param angle: 旋转角度 (度数, 逆时针)
|
||||||
|
"""
|
||||||
|
# 1. 确保输入是 RGBA
|
||||||
|
if pattern.shape[2] == 3:
|
||||||
|
pattern = cv2.cvtColor(pattern, cv2.COLOR_BGR2BGRA)
|
||||||
|
|
||||||
|
# 2. 缩放与旋转印花
|
||||||
|
resized_p = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||||
|
rotated_p = rotate_image(resized_p, angle)
|
||||||
|
p_h, p_w = rotated_p.shape[:2]
|
||||||
|
|
||||||
|
# 3. 创建透明单元格
|
||||||
|
cell_h, cell_w = p_h + gap_y, p_w + gap_x
|
||||||
|
unit_cell = np.zeros((cell_h, cell_w, 4), dtype=np.uint8)
|
||||||
|
unit_cell[:p_h, :p_w, :] = rotated_p
|
||||||
|
|
||||||
|
# 4. 执行平铺
|
||||||
|
tiles_y = (canvas_h // cell_h) + 2
|
||||||
|
tiles_x = (canvas_w // cell_w) + 2
|
||||||
|
full_tiled = np.tile(unit_cell, (tiles_y, tiles_x, 1))
|
||||||
|
|
||||||
|
# 5. 裁剪平铺层
|
||||||
|
offset_x = int(location[0][1] % cell_w)
|
||||||
|
offset_y = int(location[0][0] % cell_h)
|
||||||
|
tiled_layer = full_tiled[offset_y: offset_y + canvas_h,
|
||||||
|
offset_x: offset_x + canvas_w]
|
||||||
|
|
||||||
|
# 6. 创建纯白色背景并合成
|
||||||
|
# 创建一个纯白色的 BGR 画布
|
||||||
|
white_background = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
|
||||||
|
|
||||||
|
# 分离平铺层的颜色通道和 Alpha 通道
|
||||||
|
tiled_bgr = tiled_layer[:, :, :3]
|
||||||
|
alpha_mask = tiled_layer[:, :, 3] / 255.0 # 归一化到 0-1
|
||||||
|
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask]) # 扩展到 3 通道
|
||||||
|
|
||||||
|
# 执行 Alpha 混合:结果 = 平铺层 * alpha + 背景 * (1 - alpha)
|
||||||
|
result = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_image(image, angle):
|
||||||
|
"""
|
||||||
|
旋转图片并保持完整内容(自动扩大画布)
|
||||||
|
"""
|
||||||
|
if angle == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
(h, w) = image.shape[:2]
|
||||||
|
(cX, cY) = (w // 2, h // 2)
|
||||||
|
|
||||||
|
# 获取旋转矩阵
|
||||||
|
M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
|
||||||
|
|
||||||
|
# 计算旋转后新边界的 sine 和 cosine
|
||||||
|
cos = np.abs(M[0, 0])
|
||||||
|
sin = np.abs(M[0, 1])
|
||||||
|
|
||||||
|
# 计算新的画布尺寸
|
||||||
|
nW = int((h * sin) + (w * cos))
|
||||||
|
nH = int((h * cos) + (w * sin))
|
||||||
|
|
||||||
|
# 调整旋转矩阵以考虑平移
|
||||||
|
M[0, 2] += (nW / 2) - cX
|
||||||
|
M[1, 2] += (nH / 2) - cY
|
||||||
|
|
||||||
|
# 执行旋转
|
||||||
|
return cv2.warpAffine(image, M, (nW, nH))
|
||||||
|
|
||||||
|
|
||||||
|
def crop_image(image, image_size_h, image_size_w, location, print_shape):
|
||||||
|
print_w = print_shape[1]
|
||||||
|
print_h = print_shape[0]
|
||||||
|
|
||||||
|
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||||
|
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
|
||||||
|
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
|
||||||
|
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
|
||||||
|
|
||||||
|
# y_offset = int(location[0][0])
|
||||||
|
# x_offset = int(location[0][1])
|
||||||
|
|
||||||
|
if len(image.shape) == 2:
|
||||||
|
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||||
|
elif len(image.shape) == 3:
|
||||||
|
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||||
|
return image
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from app.service.utils.new_oss_client import oss_get_image
|
|||||||
|
|
||||||
class PrintPainting:
|
class PrintPainting:
|
||||||
def __init__(self, minio_client):
|
def __init__(self, minio_client):
|
||||||
self.random_seed = None
|
|
||||||
self.minio_client = minio_client
|
self.minio_client = minio_client
|
||||||
|
|
||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
@@ -39,23 +38,14 @@ class PrintPainting:
|
|||||||
overall_print['location'][0] = [x * y for x, y in zip(overall_print['location'][0], result['resize_scale'])]
|
overall_print['location'][0] = [x * y for x, y in zip(overall_print['location'][0], result['resize_scale'])]
|
||||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||||
result['print_image'] = result['pattern_image']
|
result['print_image'] = result['pattern_image']
|
||||||
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
|
# 获取平铺 + 旋转 的overall print
|
||||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
|
painting_dict = self.painting_collection(painting_dict, overall_print)
|
||||||
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
|
||||||
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
|
||||||
|
|
||||||
# resize 到sketch大小
|
|
||||||
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
|
||||||
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
|
||||||
else:
|
|
||||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
|
|
||||||
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
|
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
|
||||||
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
|
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
|
||||||
|
|
||||||
if single_print['print_path_list']:
|
if single_print['print_path_list']:
|
||||||
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
|
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
|
||||||
sketch_resize_scale = result['resize_scale']
|
sketch_resize_scale = result['resize_scale']
|
||||||
|
|
||||||
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||||
for i in range(len(single_print['print_path_list'])):
|
for i in range(len(single_print['print_path_list'])):
|
||||||
@@ -78,75 +68,6 @@ class PrintPainting:
|
|||||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||||
# else:
|
|
||||||
# mask = self.get_mask_inv(image)
|
|
||||||
# mask = np.expand_dims(mask, axis=2)
|
|
||||||
# mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
|
||||||
# mask = cv2.bitwise_not(mask)
|
|
||||||
#
|
|
||||||
# mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
|
||||||
# image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
|
||||||
# # 旋转后的坐标需要重新算
|
|
||||||
# rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
|
|
||||||
# rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
|
|
||||||
# # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
|
||||||
# x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
|
||||||
#
|
|
||||||
# image_x = print_background.shape[1] # 底图宽
|
|
||||||
# image_y = print_background.shape[0] # 底图高
|
|
||||||
# print_x = rotate_image.shape[1] #印花宽
|
|
||||||
# print_y = rotate_image.shape[0] #印花高
|
|
||||||
#
|
|
||||||
# # 有bug
|
|
||||||
# # if x + print_x > image_x:
|
|
||||||
# # rotate_image = rotate_image[:, :x + print_x - image_x]
|
|
||||||
# # rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
|
||||||
# # #
|
|
||||||
# # if y + print_y > image_y:
|
|
||||||
# # rotate_image = rotate_image[:y + print_y - image_y]
|
|
||||||
# # rotate_mask = rotate_mask[:y + print_y - image_y]
|
|
||||||
#
|
|
||||||
# # 不能是并行
|
|
||||||
# # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
|
||||||
# # 先挪 再判断 最后裁剪
|
|
||||||
#
|
|
||||||
# # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
|
||||||
# if x <= 0: # 如果X轴偏移量小于0,说明印花需要被裁剪至合适大小 或当X轴偏移量大于印花宽度时,裁剪后的印花宽度为0
|
|
||||||
# rotate_image = rotate_image[:, abs(x):]
|
|
||||||
# rotate_mask = rotate_mask[:, abs(x):]
|
|
||||||
# start_x = x = 0
|
|
||||||
# else:
|
|
||||||
# start_x = x
|
|
||||||
#
|
|
||||||
# if y <= 0: # 如果X轴偏移量大于0,说明印花需要被裁剪至合适大小 或当Y轴偏移量大于印花宽度时,裁剪后的印花宽度为0
|
|
||||||
# rotate_image = rotate_image[abs(y):, :]
|
|
||||||
# rotate_mask = rotate_mask[abs(y):, :]
|
|
||||||
# start_y = y = 0
|
|
||||||
# else:
|
|
||||||
# start_y = y
|
|
||||||
#
|
|
||||||
# # ------------------
|
|
||||||
# # 如果print-size大于image-size 则需要裁剪print
|
|
||||||
#
|
|
||||||
# if x + print_x > image_x:
|
|
||||||
# rotate_image = rotate_image[:, :image_x - x]
|
|
||||||
# rotate_mask = rotate_mask[:, :image_x - x]
|
|
||||||
#
|
|
||||||
# if y + print_y > image_y:
|
|
||||||
# rotate_image = rotate_image[:image_y - y, :]
|
|
||||||
# rotate_mask = rotate_mask[:image_y - y, :]
|
|
||||||
#
|
|
||||||
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
|
||||||
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
|
||||||
#
|
|
||||||
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
|
||||||
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
|
||||||
# mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
|
||||||
# print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
|
||||||
|
|
||||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
|
||||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
|
||||||
|
|
||||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
|
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
|
||||||
@@ -166,7 +87,6 @@ class PrintPainting:
|
|||||||
if element_print['element_path_list']:
|
if element_print['element_path_list']:
|
||||||
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
|
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
|
||||||
sketch_resize_scale = result['resize_scale']
|
sketch_resize_scale = result['resize_scale']
|
||||||
|
|
||||||
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||||
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||||
for i in range(len(element_print['element_path_list'])):
|
for i in range(len(element_print['element_path_list'])):
|
||||||
@@ -207,20 +127,6 @@ class PrintPainting:
|
|||||||
print_x = rotate_image.shape[1]
|
print_x = rotate_image.shape[1]
|
||||||
print_y = rotate_image.shape[0]
|
print_y = rotate_image.shape[0]
|
||||||
|
|
||||||
# 有bug
|
|
||||||
# if x + print_x > image_x:
|
|
||||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
|
||||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
|
||||||
# #
|
|
||||||
# if y + print_y > image_y:
|
|
||||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
|
||||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
|
||||||
|
|
||||||
# 不能是并行
|
|
||||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
|
||||||
# 先挪 再判断 最后裁剪
|
|
||||||
|
|
||||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
|
||||||
if x <= 0:
|
if x <= 0:
|
||||||
rotate_image = rotate_image[:, -x:]
|
rotate_image = rotate_image[:, -x:]
|
||||||
rotate_mask = rotate_mask[:, -x:]
|
rotate_mask = rotate_mask[:, -x:]
|
||||||
@@ -235,9 +141,6 @@ class PrintPainting:
|
|||||||
else:
|
else:
|
||||||
start_y = y
|
start_y = y
|
||||||
|
|
||||||
# ------------------
|
|
||||||
# 如果print-size大于image-size 则需要裁剪print
|
|
||||||
|
|
||||||
if x + print_x > image_x:
|
if x + print_x > image_x:
|
||||||
rotate_image = rotate_image[:, :image_x - x]
|
rotate_image = rotate_image[:, :image_x - x]
|
||||||
rotate_mask = rotate_mask[:, :image_x - x]
|
rotate_mask = rotate_mask[:, :image_x - x]
|
||||||
@@ -246,11 +149,6 @@ class PrintPainting:
|
|||||||
rotate_image = rotate_image[:image_y - y, :]
|
rotate_image = rotate_image[:image_y - y, :]
|
||||||
rotate_mask = rotate_mask[:image_y - y, :]
|
rotate_mask = rotate_mask[:image_y - y, :]
|
||||||
|
|
||||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
|
||||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
|
||||||
|
|
||||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
|
||||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
|
||||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||||
|
|
||||||
@@ -298,12 +196,8 @@ class PrintPainting:
|
|||||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||||
# TODO element 丢失信息
|
|
||||||
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||||
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
||||||
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
|
||||||
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
|
||||||
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
|
||||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||||
canvas = np.full_like(result['final_image'], 255)
|
canvas = np.full_like(result['final_image'], 255)
|
||||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||||
@@ -325,27 +219,21 @@ class PrintPainting:
|
|||||||
print_background = img1_bg + img2_fg
|
print_background = img1_bg + img2_fg
|
||||||
return print_background
|
return print_background
|
||||||
|
|
||||||
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
|
def painting_collection(self, painting_dict, print_dict):
|
||||||
if print_trigger:
|
|
||||||
print_ = self.get_print(print_dict)
|
print_ = self.get_print(print_dict)
|
||||||
painting_dict['Trigger'] = not is_single
|
|
||||||
painting_dict['location'] = print_['location']
|
painting_dict['location'] = print_['location']
|
||||||
single_mask_inv_print = self.get_mask_inv(print_['image'])
|
|
||||||
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||||
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||||
if not is_single:
|
gap = print_dict.get('gap', [[0, 0]])[0]
|
||||||
self.random_seed = random.randint(0, 1000)
|
painting_dict['tile_print'] = tile_image(pattern=print_['image'],
|
||||||
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
dim=dim_pattern,
|
||||||
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
gap_x=gap[0],
|
||||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
gap_y=gap[1],
|
||||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
canvas_h=painting_dict['dim_image_h'],
|
||||||
else:
|
canvas_w=painting_dict['dim_image_w'],
|
||||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
location=painting_dict['location'],
|
||||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
angle=45)
|
||||||
else:
|
painting_dict['mask_inv_print'] = np.zeros(painting_dict['tile_print'].shape[:2], dtype=np.uint8)
|
||||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
|
||||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
|
||||||
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
|
|
||||||
return painting_dict
|
return painting_dict
|
||||||
|
|
||||||
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
||||||
@@ -374,51 +262,37 @@ class PrintPainting:
|
|||||||
mask_inv = cv2.inRange(print_tile, lower, upper)
|
mask_inv = cv2.inRange(print_tile, lower, upper)
|
||||||
return mask_inv
|
return mask_inv
|
||||||
else:
|
else:
|
||||||
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
|
||||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
|
||||||
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
|
||||||
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
|
||||||
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
|
||||||
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
|
||||||
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
|
||||||
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
|
||||||
|
|
||||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
|
||||||
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
|
|
||||||
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
|
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
|
||||||
return mask_inv
|
return mask_inv
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def printpaint(result, painting_dict, print_=False):
|
def printpaint(result, painting_dict, print_=False):
|
||||||
|
if print_:
|
||||||
if print_ and painting_dict['Trigger']:
|
|
||||||
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
||||||
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
||||||
else:
|
else:
|
||||||
print_mask = result['mask']
|
print_mask = result['mask']
|
||||||
img_fg = result['final_image']
|
img_fg = result['final_image']
|
||||||
if print_ and not painting_dict['Trigger']:
|
# if print_ and not painting_dict['Trigger']:
|
||||||
index_ = None
|
# index_ = None
|
||||||
try:
|
# try:
|
||||||
index_ = len(painting_dict['location'])
|
# index_ = len(painting_dict['location'])
|
||||||
except:
|
# except:
|
||||||
assert f'there must be parameter of location if choose IfSingle'
|
# assert f'there must be parameter of location if choose IfSingle'
|
||||||
|
#
|
||||||
for i in range(index_):
|
# for i in range(index_):
|
||||||
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
# start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
||||||
|
#
|
||||||
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
# length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
||||||
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
# length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
||||||
|
#
|
||||||
change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
# change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
||||||
# problem in change_mask
|
# # problem in change_mask
|
||||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
# change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||||
# get real part into change mask
|
# # get real part into change mask
|
||||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
# _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||||
cv2.bitwise_not(painting_dict['mask_inv_print'])
|
# cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
# img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||||
|
|
||||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||||
|
|
||||||
@@ -450,11 +324,6 @@ class PrintPainting:
|
|||||||
print_w = print_shape[1]
|
print_w = print_shape[1]
|
||||||
print_h = print_shape[0]
|
print_h = print_shape[0]
|
||||||
|
|
||||||
random.seed(self.random_seed)
|
|
||||||
# logging.info(f'overall print location : {location}')
|
|
||||||
# x_offset = random.randint(0, image.shape[0] - image_size_h)
|
|
||||||
# y_offset = random.randint(0, image.shape[1] - image_size_w)
|
|
||||||
|
|
||||||
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||||
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
|
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
|
||||||
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
|
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
|
||||||
@@ -596,3 +465,96 @@ class PrintPainting:
|
|||||||
cropped_img = resized_img[start_y:start_y + target_height, :]
|
cropped_img = resized_img[start_y:start_y + target_height, :]
|
||||||
|
|
||||||
return cropped_img
|
return cropped_img
|
||||||
|
|
||||||
|
|
||||||
|
def tile_image(pattern, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
|
||||||
|
"""
|
||||||
|
按照指定的 X/Y 间距平铺印花,并支持旋转
|
||||||
|
:param angle: 旋转角度 (度数, 逆时针)
|
||||||
|
"""
|
||||||
|
# 1. 确保输入是 RGBA
|
||||||
|
if pattern.shape[2] == 3:
|
||||||
|
pattern = cv2.cvtColor(pattern, cv2.COLOR_BGR2BGRA)
|
||||||
|
|
||||||
|
# 2. 缩放与旋转印花
|
||||||
|
resized_p = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||||
|
rotated_p = rotate_image(resized_p, angle)
|
||||||
|
p_h, p_w = rotated_p.shape[:2]
|
||||||
|
|
||||||
|
# 3. 创建透明单元格
|
||||||
|
cell_h, cell_w = p_h + gap_y, p_w + gap_x
|
||||||
|
unit_cell = np.zeros((cell_h, cell_w, 4), dtype=np.uint8)
|
||||||
|
unit_cell[:p_h, :p_w, :] = rotated_p
|
||||||
|
|
||||||
|
# 4. 执行平铺
|
||||||
|
tiles_y = (canvas_h // cell_h) + 2
|
||||||
|
tiles_x = (canvas_w // cell_w) + 2
|
||||||
|
full_tiled = np.tile(unit_cell, (tiles_y, tiles_x, 1))
|
||||||
|
|
||||||
|
# 5. 裁剪平铺层
|
||||||
|
offset_x = int(location[0][1] % cell_w)
|
||||||
|
offset_y = int(location[0][0] % cell_h)
|
||||||
|
tiled_layer = full_tiled[offset_y: offset_y + canvas_h,
|
||||||
|
offset_x: offset_x + canvas_w]
|
||||||
|
|
||||||
|
# 6. 创建纯白色背景并合成
|
||||||
|
# 创建一个纯白色的 BGR 画布
|
||||||
|
white_background = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
|
||||||
|
|
||||||
|
# 分离平铺层的颜色通道和 Alpha 通道
|
||||||
|
tiled_bgr = tiled_layer[:, :, :3]
|
||||||
|
alpha_mask = tiled_layer[:, :, 3] / 255.0 # 归一化到 0-1
|
||||||
|
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask]) # 扩展到 3 通道
|
||||||
|
|
||||||
|
# 执行 Alpha 混合:结果 = 平铺层 * alpha + 背景 * (1 - alpha)
|
||||||
|
result = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_image(image, angle):
|
||||||
|
"""
|
||||||
|
旋转图片并保持完整内容(自动扩大画布)
|
||||||
|
"""
|
||||||
|
if angle == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
(h, w) = image.shape[:2]
|
||||||
|
(cX, cY) = (w // 2, h // 2)
|
||||||
|
|
||||||
|
# 获取旋转矩阵
|
||||||
|
M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
|
||||||
|
|
||||||
|
# 计算旋转后新边界的 sine 和 cosine
|
||||||
|
cos = np.abs(M[0, 0])
|
||||||
|
sin = np.abs(M[0, 1])
|
||||||
|
|
||||||
|
# 计算新的画布尺寸
|
||||||
|
nW = int((h * sin) + (w * cos))
|
||||||
|
nH = int((h * cos) + (w * sin))
|
||||||
|
|
||||||
|
# 调整旋转矩阵以考虑平移
|
||||||
|
M[0, 2] += (nW / 2) - cX
|
||||||
|
M[1, 2] += (nH / 2) - cY
|
||||||
|
|
||||||
|
# 执行旋转
|
||||||
|
return cv2.warpAffine(image, M, (nW, nH))
|
||||||
|
|
||||||
|
|
||||||
|
def crop_image(image, image_size_h, image_size_w, location, print_shape):
|
||||||
|
print_w = print_shape[1]
|
||||||
|
print_h = print_shape[0]
|
||||||
|
|
||||||
|
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||||
|
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
|
||||||
|
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
|
||||||
|
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
|
||||||
|
|
||||||
|
# y_offset = int(location[0][0])
|
||||||
|
# x_offset = int(location[0][1])
|
||||||
|
|
||||||
|
if len(image.shape) == 2:
|
||||||
|
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||||
|
elif len(image.shape) == 3:
|
||||||
|
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||||
|
return image
|
||||||
|
|||||||
@@ -34,15 +34,15 @@ class Segmentation:
|
|||||||
result['mask'] = result['front_mask'] + result['back_mask']
|
result['mask'] = result['front_mask'] + result['back_mask']
|
||||||
else:
|
else:
|
||||||
# preview 过模型 不缓存
|
# preview 过模型 不缓存
|
||||||
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
|
if result.get("design_type", None) == "merge":
|
||||||
# 推理获得seg 结果
|
|
||||||
seg_result = get_seg_result(result['image'])
|
seg_result = get_seg_result(result['image'])
|
||||||
# submit 过模型 缓存
|
# 默认design 模式 - 过模型 缓存
|
||||||
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
|
# elif result.get("design_type", None) == "submit":
|
||||||
# 推理获得seg 结果
|
# 推理获得seg 结果
|
||||||
seg_result = get_seg_result(result['image'])
|
# seg_result = get_seg_result(result['image'])
|
||||||
self.save_seg_result(seg_result, result['image_id'])
|
# self.save_seg_result(seg_result, result['image_id'])
|
||||||
# null 正常流程 加载本地缓存 无缓存则过模型
|
|
||||||
|
# 默认模式- 加载模型,找不到则过模型推理,推理后保存到本地
|
||||||
else:
|
else:
|
||||||
# 本地查询seg 缓存是否存在
|
# 本地查询seg 缓存是否存在
|
||||||
_, seg_result = self.load_seg_result(result["image_id"])
|
_, seg_result = self.load_seg_result(result["image_id"])
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import logging
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from celery.bin.result import result
|
||||||
|
|
||||||
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
|
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
|
||||||
from app.service.design_fast.utils.transparent import sketch_to_transparent
|
from app.service.design_fast.utils.transparent import sketch_to_transparent
|
||||||
@@ -19,6 +20,36 @@ class Split(object):
|
|||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
try:
|
try:
|
||||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
|
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
|
||||||
|
if result.get('design_type', None) == 'merge':
|
||||||
|
# merge 不需要返回mask (红绿图)
|
||||||
|
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
|
||||||
|
front_mask = result['front_mask']
|
||||||
|
back_mask = result['back_mask']
|
||||||
|
else:
|
||||||
|
height, width = result['front_mask'].shape[:2]
|
||||||
|
new_width = int(width * result['resize_scale'][0])
|
||||||
|
new_height = int(height * result['resize_scale'][1])
|
||||||
|
|
||||||
|
front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||||
|
back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||||
|
result['merge_image'] = cv2.resize(result['merge_image'], (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
rgba_image = rgb_to_rgba(result['merge_image'], front_mask + back_mask)
|
||||||
|
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
|
||||||
|
rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
|
||||||
|
result_front_image = np.zeros_like(rgba_image)
|
||||||
|
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||||
|
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||||
|
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
|
||||||
|
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
|
||||||
|
|
||||||
|
result_back_image = np.zeros_like(rgba_image)
|
||||||
|
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||||
|
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||||
|
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
|
||||||
|
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
ori_front_mask = result['front_mask'].copy()
|
ori_front_mask = result['front_mask'].copy()
|
||||||
ori_back_mask = result['back_mask'].copy()
|
ori_back_mask = result['back_mask'].copy()
|
||||||
|
|
||||||
@@ -60,46 +91,9 @@ class Split(object):
|
|||||||
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
|
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
|
||||||
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
|
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
|
||||||
|
|
||||||
# 前片部分 (红图部分)
|
|
||||||
# height, width = front_mask.shape
|
|
||||||
# mask_image = np.zeros((height, width, 3))
|
|
||||||
# mask_image[front_mask != 0] = [0, 0, 255]
|
|
||||||
|
|
||||||
# 切换为原始图片尺寸-------------------------------
|
|
||||||
height, width = ori_front_mask.shape
|
height, width = ori_front_mask.shape
|
||||||
mask_image = np.zeros((height, width, 3))
|
mask_image = np.zeros((height, width, 3))
|
||||||
mask_image[ori_front_mask != 0] = [0, 0, 255]
|
mask_image[ori_front_mask != 0] = [0, 0, 255]
|
||||||
# -----------------------------------------------
|
|
||||||
|
|
||||||
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
|
||||||
# result_back_image = np.zeros_like(rgba_image)
|
|
||||||
# back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
|
|
||||||
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
|
||||||
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
|
||||||
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
|
||||||
# mask_image[back_mask != 0] = [0, 255, 0]
|
|
||||||
#
|
|
||||||
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
|
||||||
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
|
||||||
# image_data = io.BytesIO()
|
|
||||||
# mask_pil.save(image_data, format='PNG')
|
|
||||||
# image_data.seek(0)
|
|
||||||
# image_bytes = image_data.read()
|
|
||||||
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
|
||||||
# result['mask_url'] = req.bucket_name + "/" + req.object_name
|
|
||||||
# else:
|
|
||||||
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
|
||||||
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
|
||||||
# image_data = io.BytesIO()
|
|
||||||
# mask_pil.save(image_data, format='PNG')
|
|
||||||
# image_data.seek(0)
|
|
||||||
# image_bytes = image_data.read()
|
|
||||||
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
|
||||||
# result['mask_url'] = req.bucket_name + "/" + req.object_name
|
|
||||||
# result['back_image'] = None
|
|
||||||
# result["back_image_url"] = None
|
|
||||||
# # result["back_mask_url"] = None
|
|
||||||
# # result['back_mask_image'] = None
|
|
||||||
|
|
||||||
result_back_image = np.zeros_like(rgba_image)
|
result_back_image = np.zeros_like(rgba_image)
|
||||||
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
|
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||||
@@ -118,6 +112,14 @@ class Split(object):
|
|||||||
image_bytes = image_data.read()
|
image_bytes = image_data.read()
|
||||||
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||||
|
|
||||||
|
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
|
||||||
|
result_pattern_overall_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
|
||||||
|
result['pattern_overall_image'], result['pattern_overall_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_overall_image_pil, f'{generate_uuid()}')
|
||||||
|
|
||||||
|
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
|
||||||
|
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
|
||||||
|
return result
|
||||||
else:
|
else:
|
||||||
ori_front_mask, ori_back_mask = None, None
|
ori_front_mask, ori_back_mask = None, None
|
||||||
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
|
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
|
||||||
@@ -127,5 +129,6 @@ class Split(object):
|
|||||||
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
|
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
|
||||||
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
|
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
||||||
|
|||||||
@@ -23,20 +23,23 @@ def organize_clothing(layer):
|
|||||||
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
|
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
|
||||||
name=f'{layer["name"].lower()}_front',
|
name=f'{layer["name"].lower()}_front',
|
||||||
image=layer["front_image"],
|
image=layer["front_image"],
|
||||||
|
merge_image=layer["front_image"],
|
||||||
# mask_image=layer['front_mask_image'],
|
# mask_image=layer['front_mask_image'],
|
||||||
image_url=layer['front_image_url'],
|
image_url=layer['front_image_url'],
|
||||||
mask_url=layer['mask_url'],
|
mask_url=layer.get("mask_url", None),
|
||||||
sacle=layer['scale'],
|
sacle=layer['scale'],
|
||||||
clothes_keypoint=layer['clothes_keypoint'],
|
clothes_keypoint=layer['clothes_keypoint'],
|
||||||
position=start_point,
|
position=start_point,
|
||||||
resize_scale=layer["resize_scale"],
|
resize_scale=layer["resize_scale"],
|
||||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||||
pattern_overall_image_url=layer['pattern_overall_image_url'],
|
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
|
||||||
pattern_print_image_url=layer['pattern_print_image_url'],
|
pattern_print_image_url=layer.get('pattern_print_image_url', None),
|
||||||
|
|
||||||
pattern_image=layer['pattern_image'],
|
pattern_image=layer.get('pattern_image', None),
|
||||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||||
|
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
|
||||||
|
rotate=layer.get('rotate', 0),
|
||||||
)
|
)
|
||||||
# 后片数据
|
# 后片数据
|
||||||
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
|
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
|
||||||
@@ -44,16 +47,18 @@ def organize_clothing(layer):
|
|||||||
image=layer["back_image"],
|
image=layer["back_image"],
|
||||||
# mask_image=layer['back_mask_image'],
|
# mask_image=layer['back_mask_image'],
|
||||||
image_url=layer['back_image_url'],
|
image_url=layer['back_image_url'],
|
||||||
mask_url=layer['mask_url'],
|
mask_url=layer.get('mask_url', None),
|
||||||
sacle=layer['scale'],
|
sacle=layer['scale'],
|
||||||
clothes_keypoint=layer['clothes_keypoint'],
|
clothes_keypoint=layer['clothes_keypoint'],
|
||||||
position=start_point,
|
position=start_point,
|
||||||
resize_scale=layer["resize_scale"],
|
resize_scale=layer["resize_scale"],
|
||||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||||
pattern_overall_image_url=layer['pattern_overall_image_url'],
|
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
|
||||||
pattern_print_image_url=layer['pattern_print_image_url'],
|
pattern_print_image_url=layer.get('pattern_print_image_url', None),
|
||||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||||
|
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
|
||||||
|
rotate=layer.get('rotate', 0),
|
||||||
)
|
)
|
||||||
return front_layer, back_layer
|
return front_layer, back_layer
|
||||||
|
|
||||||
@@ -76,16 +81,16 @@ def organize_others(layer):
|
|||||||
image=layer["front_image"],
|
image=layer["front_image"],
|
||||||
# mask_image=layer['front_mask_image'],
|
# mask_image=layer['front_mask_image'],
|
||||||
image_url=layer['front_image_url'],
|
image_url=layer['front_image_url'],
|
||||||
mask_url=layer['mask_url'],
|
mask_url=layer.get('mask_url', None),
|
||||||
sacle=layer['scale'],
|
sacle=layer['scale'],
|
||||||
clothes_keypoint=(0, 0),
|
clothes_keypoint=(0, 0),
|
||||||
position=start_point,
|
position=start_point,
|
||||||
resize_scale=layer["resize_scale"],
|
resize_scale=layer["resize_scale"],
|
||||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||||
pattern_overall_image_url=layer['pattern_overall_image_url'],
|
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
|
||||||
pattern_print_image_url=layer['pattern_print_image_url'],
|
pattern_print_image_url=layer.get('pattern_print_image_url', None),
|
||||||
pattern_image=layer['pattern_image'],
|
pattern_image=layer.get('pattern_image', None),
|
||||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||||
)
|
)
|
||||||
# 后片数据
|
# 后片数据
|
||||||
@@ -94,15 +99,15 @@ def organize_others(layer):
|
|||||||
image=layer["back_image"],
|
image=layer["back_image"],
|
||||||
# mask_image=layer['back_mask_image'],
|
# mask_image=layer['back_mask_image'],
|
||||||
image_url=layer['back_image_url'],
|
image_url=layer['back_image_url'],
|
||||||
mask_url=layer['mask_url'],
|
mask_url=layer.get('mask_url', None),
|
||||||
sacle=layer['scale'],
|
sacle=layer['scale'],
|
||||||
clothes_keypoint=(0, 0),
|
clothes_keypoint=(0, 0),
|
||||||
position=start_point,
|
position=start_point,
|
||||||
resize_scale=layer["resize_scale"],
|
resize_scale=layer["resize_scale"],
|
||||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||||
pattern_overall_image_url=layer['pattern_overall_image_url'],
|
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
|
||||||
pattern_print_image_url=layer['pattern_print_image_url'],
|
pattern_print_image_url=layer.get('pattern_print_image_url', None),
|
||||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||||
)
|
)
|
||||||
return front_layer, back_layer
|
return front_layer, back_layer
|
||||||
|
|||||||
@@ -151,9 +151,11 @@ def synthesis(data, size, basic_info):
|
|||||||
if layer['image'] is not None:
|
if layer['image'] is not None:
|
||||||
if layer['name'] != "body":
|
if layer['name'] != "body":
|
||||||
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||||
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
paste_img, position = transpose_rotate(layer, layer['image'])
|
||||||
|
test_image.paste(paste_img, position, paste_img)
|
||||||
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||||
mask_alpha = Image.fromarray(mask_data)
|
mask_alpha = Image.fromarray(mask_data)
|
||||||
|
mask_alpha.paste(paste_img.getchannel('A'), position, paste_img.getchannel('A'))
|
||||||
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||||
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
|
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
|
||||||
else:
|
else:
|
||||||
@@ -185,6 +187,111 @@ def synthesis(data, size, basic_info):
|
|||||||
logging.warning(f"synthesis runtime exception : {e}")
|
logging.warning(f"synthesis runtime exception : {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def merge(data, size, basic_info):
|
||||||
|
# out_of_bounds_control: 是否允许服装越界 True 允许 False 不允许 默认情况允许
|
||||||
|
out_of_bounds_control = basic_info.get('out_of_bounds_control', True)
|
||||||
|
# 创建底图
|
||||||
|
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||||
|
try:
|
||||||
|
all_mask_shape = (size[1], size[0])
|
||||||
|
body_mask = None
|
||||||
|
for d in data:
|
||||||
|
if d['name'] == 'body' or d['name'] == 'mannequin':
|
||||||
|
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
|
||||||
|
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
|
||||||
|
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position
|
||||||
|
body_mask = np.array(transparent_image.split()[3])
|
||||||
|
|
||||||
|
# 根据新的坐标获取新的肩点
|
||||||
|
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||||
|
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||||
|
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
|
||||||
|
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
|
||||||
|
top_outer_mask = np.array(binary_body_mask)
|
||||||
|
bottom_outer_mask = np.array(binary_body_mask)
|
||||||
|
others_outer_mask = np.array(binary_body_mask)
|
||||||
|
|
||||||
|
top = True
|
||||||
|
bottom = True
|
||||||
|
others = True
|
||||||
|
i = len(data)
|
||||||
|
while i:
|
||||||
|
i -= 1
|
||||||
|
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
||||||
|
if out_of_bounds_control:
|
||||||
|
top = True
|
||||||
|
else:
|
||||||
|
top = False
|
||||||
|
mask_shape = data[i]['mask'].shape
|
||||||
|
y_offset, x_offset = data[i]['adaptive_position']
|
||||||
|
# 初始化叠加区域的起始和结束位置
|
||||||
|
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||||
|
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||||
|
# 将叠加区域赋值为相应的像素值
|
||||||
|
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||||
|
background = np.zeros_like(top_outer_mask)
|
||||||
|
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||||
|
top_outer_mask = background + top_outer_mask
|
||||||
|
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
|
||||||
|
# bottom = False
|
||||||
|
mask_shape = data[i]['mask'].shape
|
||||||
|
y_offset, x_offset = data[i]['adaptive_position']
|
||||||
|
# 初始化叠加区域的起始和结束位置
|
||||||
|
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||||
|
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||||
|
# 将叠加区域赋值为相应的像素值
|
||||||
|
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||||
|
background = np.zeros_like(top_outer_mask)
|
||||||
|
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||||
|
bottom_outer_mask = background + bottom_outer_mask
|
||||||
|
elif others and data[i]['name'] in ['others_front']:
|
||||||
|
mask_shape = data[i]['mask'].shape
|
||||||
|
y_offset, x_offset = data[i]['adaptive_position']
|
||||||
|
# 初始化叠加区域的起始和结束位置
|
||||||
|
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||||
|
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||||
|
# 将叠加区域赋值为相应的像素值
|
||||||
|
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||||
|
background = np.zeros_like(top_outer_mask)
|
||||||
|
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||||
|
others_outer_mask = background + others_outer_mask
|
||||||
|
pass
|
||||||
|
elif bottom is False and top is False:
|
||||||
|
break
|
||||||
|
|
||||||
|
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
|
||||||
|
all_mask = cv2.bitwise_or(all_mask, others_outer_mask)
|
||||||
|
|
||||||
|
for layer in data:
|
||||||
|
if layer['image'] is not None:
|
||||||
|
if layer['name'] != "body":
|
||||||
|
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||||
|
paste_img, position = transpose_rotate(layer, layer['image'])
|
||||||
|
test_image.paste(paste_img, position, paste_img)
|
||||||
|
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||||
|
mask_alpha = Image.fromarray(mask_data)
|
||||||
|
mask_alpha.paste(paste_img.getchannel('A'), position, paste_img.getchannel('A'))
|
||||||
|
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||||
|
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
|
||||||
|
else:
|
||||||
|
base_image.paste(layer['merge_image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['merge_image'])
|
||||||
|
|
||||||
|
result_image = base_image
|
||||||
|
|
||||||
|
image_data = io.BytesIO()
|
||||||
|
result_image.save(image_data, format='PNG')
|
||||||
|
image_data.seek(0)
|
||||||
|
|
||||||
|
# oss upload
|
||||||
|
image_bytes = image_data.read()
|
||||||
|
bucket_name = "aida-results"
|
||||||
|
object_name = f'result_{generate_uuid()}.png'
|
||||||
|
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
|
return f"{bucket_name}/{object_name}"
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"synthesis runtime exception : {e}")
|
||||||
|
|
||||||
|
|
||||||
def synthesis_single(front_image, back_image):
|
def synthesis_single(front_image, back_image):
|
||||||
result_image = None
|
result_image = None
|
||||||
if front_image:
|
if front_image:
|
||||||
@@ -232,3 +339,35 @@ def update_base_size_priority(layers):
|
|||||||
for info in layers:
|
for info in layers:
|
||||||
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
||||||
return layers, (new_width, new_height)
|
return layers, (new_width, new_height)
|
||||||
|
|
||||||
|
|
||||||
|
def transpose_rotate(layer, image):
|
||||||
|
# transpose[0]是左右 transpose[1]是上下
|
||||||
|
transpose = layer.get('transpose', [1, 1]) # 默认为1, 1代表不镜像
|
||||||
|
|
||||||
|
rotate = layer.get('rotate', 0)
|
||||||
|
paste_x, paste_y = layer['adaptive_position'][1], layer['adaptive_position'][0]
|
||||||
|
|
||||||
|
# transpose左右是1 上下是-1
|
||||||
|
if transpose[0] != 1:
|
||||||
|
# 左右
|
||||||
|
image = image.transpose(0)
|
||||||
|
|
||||||
|
if transpose[1] != 1:
|
||||||
|
# 上下
|
||||||
|
image = image.transpose(1)
|
||||||
|
|
||||||
|
if rotate:
|
||||||
|
image = image.rotate(-rotate, expand=True)
|
||||||
|
# 4. 计算粘贴位置以保持视觉中心一致
|
||||||
|
# 原本 (15, 36) 是 288*288 的左上角,我们计算其中心点
|
||||||
|
target_center_x = 15 + 288 // 2
|
||||||
|
target_center_y = 36 + 288 // 2
|
||||||
|
|
||||||
|
# 获取旋转后图像的新尺寸
|
||||||
|
new_w, new_h = image.size
|
||||||
|
|
||||||
|
# 计算新的左上角坐标,使得旋转后的图像中心依然在原定的中心位置
|
||||||
|
paste_x = target_center_x - new_w // 2
|
||||||
|
paste_y = target_center_y - new_h // 2
|
||||||
|
return image, (paste_x, paste_y)
|
||||||
|
|||||||
@@ -1,241 +1,240 @@
|
|||||||
# 预加载资源
|
# # 预加载资源
|
||||||
import logging
|
# import logging
|
||||||
import time
|
# import time
|
||||||
from collections import defaultdict
|
# from collections import defaultdict
|
||||||
import os
|
# import os
|
||||||
import json
|
# import json
|
||||||
import numpy as np
|
# import numpy as np
|
||||||
|
#
|
||||||
from app.core.config import settings
|
# from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
||||||
from app.core.mysql_config import DB_CONFIG
|
#
|
||||||
|
# logger = logging.getLogger()
|
||||||
logger = logging.getLogger()
|
# import pymysql
|
||||||
import pymysql
|
# from concurrent.futures import ThreadPoolExecutor
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
#
|
||||||
|
# HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
|
||||||
HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
|
#
|
||||||
|
# matrix_data = {
|
||||||
matrix_data = {
|
# "interaction_matrix": None,
|
||||||
"interaction_matrix": None,
|
# "feature_matrix": None,
|
||||||
"feature_matrix": None,
|
# "user_index_interaction": None,
|
||||||
"user_index_interaction": None,
|
# "sketch_index_interaction": None,
|
||||||
"sketch_index_interaction": None,
|
# "user_index_feature": None,
|
||||||
"user_index_feature": None,
|
# "sketch_index_feature": None,
|
||||||
"sketch_index_feature": None,
|
# "iid_to_sketch": None,
|
||||||
"iid_to_sketch": None,
|
# "category_to_iids": None,
|
||||||
"category_to_iids": None,
|
# "cached_scores": {},
|
||||||
"cached_scores": {},
|
# "cached_valid_idxs": {},
|
||||||
"cached_valid_idxs": {},
|
# "category_sketch_idxs_inter": None,
|
||||||
"category_sketch_idxs_inter": None,
|
# "category_sketch_idxs_feature": None,
|
||||||
"category_sketch_idxs_feature": None,
|
# "user_inter_full": dict(),
|
||||||
"user_inter_full": dict(),
|
# "user_feat_full": dict(),
|
||||||
"user_feat_full": dict(),
|
# "brand_feature_matrix": None,
|
||||||
"brand_feature_matrix": None,
|
# "brand_index_map": None,
|
||||||
"brand_index_map": None,
|
# "heat_data": {},
|
||||||
"heat_data": {},
|
# }
|
||||||
}
|
#
|
||||||
|
#
|
||||||
|
# def load_resources():
|
||||||
def load_resources():
|
# """加载所有矩阵和映射关系,并触发预缓存"""
|
||||||
"""加载所有矩阵和映射关系,并触发预缓存"""
|
# try:
|
||||||
try:
|
# start_time = time.time()
|
||||||
start_time = time.time()
|
#
|
||||||
|
# # 清空缓存
|
||||||
# 清空缓存
|
# matrix_data["cached_scores"].clear()
|
||||||
matrix_data["cached_scores"].clear()
|
# matrix_data["cached_valid_idxs"].clear()
|
||||||
matrix_data["cached_valid_idxs"].clear()
|
#
|
||||||
|
# # 加载数据
|
||||||
# 加载数据
|
# sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
||||||
sketch_to_iid = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
# matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
||||||
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
#
|
||||||
|
# matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
||||||
matrix_data["interaction_matrix"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
# matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
||||||
matrix_data["user_index_interaction"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
# matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
||||||
matrix_data["sketch_index_interaction"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
# allow_pickle=True).item()
|
||||||
allow_pickle=True).item()
|
#
|
||||||
|
# matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
||||||
matrix_data["feature_matrix"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
#
|
||||||
|
# brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
|
||||||
brand_feature_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
|
# if os.path.exists(brand_feature_path):
|
||||||
if os.path.exists(brand_feature_path):
|
# matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
|
||||||
matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
|
# else:
|
||||||
else:
|
# logger.warning("brand_feature_matrix 文件不存在,使用空数组")
|
||||||
logger.warning("brand_feature_matrix 文件不存在,使用空数组")
|
# matrix_data["brand_feature_matrix"] = np.array([])
|
||||||
matrix_data["brand_feature_matrix"] = np.array([])
|
#
|
||||||
|
# # brand_index_map
|
||||||
# brand_index_map
|
# brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||||
brand_index_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
# if os.path.exists(brand_index_path):
|
||||||
if os.path.exists(brand_index_path):
|
# matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
|
||||||
matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
|
# else:
|
||||||
else:
|
# logger.warning("brand_index_map 文件不存在,使用空字典")
|
||||||
logger.warning("brand_index_map 文件不存在,使用空字典")
|
# matrix_data["brand_index_map"] = {}
|
||||||
matrix_data["brand_index_map"] = {}
|
#
|
||||||
|
# matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
||||||
matrix_data["user_index_feature"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
#
|
||||||
|
# matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
||||||
matrix_data["sketch_index_feature"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
#
|
||||||
|
# category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
||||||
category_to_iid_map = np.load(f"{settings.RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
# matrix_data["category_to_iids"] = defaultdict(list)
|
||||||
matrix_data["category_to_iids"] = defaultdict(list)
|
# for iid, cat in category_to_iid_map.items():
|
||||||
for iid, cat in category_to_iid_map.items():
|
# matrix_data["category_to_iids"][cat].append(iid)
|
||||||
matrix_data["category_to_iids"][cat].append(iid)
|
#
|
||||||
|
# logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
||||||
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
#
|
||||||
|
# # 触发预缓存
|
||||||
# 触发预缓存
|
# precache_user_category()
|
||||||
precache_user_category()
|
#
|
||||||
|
# if os.path.exists(HEAT_VECTOR_FILE):
|
||||||
if os.path.exists(HEAT_VECTOR_FILE):
|
# with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
|
||||||
with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
|
# heat_json = json.load(f)
|
||||||
heat_json = json.load(f)
|
# matrix_data["heat_data"] = heat_json.get("data", {})
|
||||||
matrix_data["heat_data"] = heat_json.get("data", {})
|
# logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
|
||||||
logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
|
# else:
|
||||||
else:
|
# matrix_data["heat_data"] = {}
|
||||||
matrix_data["heat_data"] = {}
|
#
|
||||||
|
# except Exception as e:
|
||||||
except Exception as e:
|
# logger.error(f"资源加载失败: {str(e)}")
|
||||||
logger.error(f"资源加载失败: {str(e)}")
|
# raise RuntimeError("初始化失败")
|
||||||
raise RuntimeError("初始化失败")
|
#
|
||||||
|
#
|
||||||
|
# def precache_user_category():
|
||||||
def precache_user_category():
|
# """优化后的用户分类预缓存(添加耗时统计)"""
|
||||||
"""优化后的用户分类预缓存(添加耗时统计)"""
|
# if not all([
|
||||||
if not all([
|
# matrix_data["interaction_matrix"] is not None,
|
||||||
matrix_data["interaction_matrix"] is not None,
|
# matrix_data["feature_matrix"] is not None,
|
||||||
matrix_data["feature_matrix"] is not None,
|
# matrix_data["user_index_interaction"] is not None
|
||||||
matrix_data["user_index_interaction"] is not None
|
# ]):
|
||||||
]):
|
# logger.warning("资源未加载完成,跳过预缓存")
|
||||||
logger.warning("资源未加载完成,跳过预缓存")
|
# return
|
||||||
return
|
#
|
||||||
|
# start_time = time.perf_counter()
|
||||||
start_time = time.perf_counter()
|
# time_stats = {
|
||||||
time_stats = {
|
# "get_all_user_categories": 0,
|
||||||
"get_all_user_categories": 0,
|
# "process_user_category": 0,
|
||||||
"process_user_category": 0,
|
# "thread_execution": 0,
|
||||||
"thread_execution": 0,
|
# "cache_update": 0,
|
||||||
"cache_update": 0,
|
# "total": 0,
|
||||||
"total": 0,
|
# }
|
||||||
}
|
#
|
||||||
|
# # 统计用户类别获取时间
|
||||||
# 统计用户类别获取时间
|
# t1 = time.perf_counter()
|
||||||
t1 = time.perf_counter()
|
# user_categories = get_all_user_categories()
|
||||||
user_categories = get_all_user_categories()
|
# time_stats["get_all_user_categories"] = time.perf_counter() - t1
|
||||||
time_stats["get_all_user_categories"] = time.perf_counter() - t1
|
#
|
||||||
|
# precached_count = 0
|
||||||
precached_count = 0
|
#
|
||||||
|
# def process_user_category(user_id, categories):
|
||||||
def process_user_category(user_id, categories):
|
# """单用户类别缓存计算(统计耗时)"""
|
||||||
"""单用户类别缓存计算(统计耗时)"""
|
# local_cache = {}
|
||||||
local_cache = {}
|
# local_valid_idxs = {}
|
||||||
local_valid_idxs = {}
|
# t_start = time.perf_counter()
|
||||||
time.perf_counter()
|
#
|
||||||
|
# for category in categories:
|
||||||
for category in categories:
|
# cache_key = (user_id, category)
|
||||||
cache_key = (user_id, category)
|
# if cache_key in matrix_data["cached_scores"]:
|
||||||
if cache_key in matrix_data["cached_scores"]:
|
# continue
|
||||||
continue
|
#
|
||||||
|
# try:
|
||||||
try:
|
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
#
|
||||||
|
# # 统计获取类别 IID 耗时
|
||||||
# 统计获取类别 IID 耗时
|
# t_iid = time.perf_counter()
|
||||||
t_iid = time.perf_counter()
|
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
# valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
|
||||||
valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
|
# for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
|
||||||
for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
|
# valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
|
||||||
valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
|
# for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
|
||||||
for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
|
# time_stats["process_user_category"] += time.perf_counter() - t_iid
|
||||||
time_stats["process_user_category"] += time.perf_counter() - t_iid
|
#
|
||||||
|
# # 统计矩阵计算耗时
|
||||||
# 统计矩阵计算耗时
|
# t_matrix = time.perf_counter()
|
||||||
t_matrix = time.perf_counter()
|
# processed_inter = np.zeros(len(valid_sketch_idxs_inter))
|
||||||
processed_inter = np.zeros(len(valid_sketch_idxs_inter))
|
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
# processed_inter = raw_inter_scores * 0.7
|
||||||
processed_inter = raw_inter_scores * 0.7
|
#
|
||||||
|
# processed_feat = np.zeros(len(valid_sketch_idxs_feature))
|
||||||
processed_feat = np.zeros(len(valid_sketch_idxs_feature))
|
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
# processed_feat = raw_feat_scores * 0.3
|
||||||
processed_feat = raw_feat_scores * 0.3
|
# time_stats["process_user_category"] += time.perf_counter() - t_matrix
|
||||||
time_stats["process_user_category"] += time.perf_counter() - t_matrix
|
#
|
||||||
|
# if len(processed_inter) == len(processed_feat):
|
||||||
if len(processed_inter) == len(processed_feat):
|
# local_cache[cache_key] = (processed_inter, processed_feat)
|
||||||
local_cache[cache_key] = (processed_inter, processed_feat)
|
# local_valid_idxs[cache_key] = valid_sketch_idxs_inter
|
||||||
local_valid_idxs[cache_key] = valid_sketch_idxs_inter
|
#
|
||||||
|
# except Exception as e:
|
||||||
except Exception as e:
|
# logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
||||||
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
#
|
||||||
|
# return local_cache, local_valid_idxs
|
||||||
return local_cache, local_valid_idxs
|
#
|
||||||
|
# # 统计线程执行时间
|
||||||
# 统计线程执行时间
|
# t2 = time.perf_counter()
|
||||||
t2 = time.perf_counter()
|
# with ThreadPoolExecutor(max_workers=8) as executor:
|
||||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
# futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
|
||||||
futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
|
# for future in futures:
|
||||||
for future in futures:
|
# try:
|
||||||
try:
|
# t_cache = time.perf_counter()
|
||||||
t_cache = time.perf_counter()
|
# cache_part, valid_idxs_part = future.result()
|
||||||
cache_part, valid_idxs_part = future.result()
|
# matrix_data["cached_scores"].update(cache_part)
|
||||||
matrix_data["cached_scores"].update(cache_part)
|
# matrix_data["cached_valid_idxs"].update(valid_idxs_part)
|
||||||
matrix_data["cached_valid_idxs"].update(valid_idxs_part)
|
# time_stats["cache_update"] += time.perf_counter() - t_cache
|
||||||
time_stats["cache_update"] += time.perf_counter() - t_cache
|
# precached_count += len(cache_part)
|
||||||
precached_count += len(cache_part)
|
# except Exception as e:
|
||||||
except Exception as e:
|
# logger.error(f"线程执行错误: {str(e)}")
|
||||||
logger.error(f"线程执行错误: {str(e)}")
|
# time_stats["thread_execution"] = time.perf_counter() - t2
|
||||||
time_stats["thread_execution"] = time.perf_counter() - t2
|
#
|
||||||
|
# time_stats["total"] = time.perf_counter() - start_time
|
||||||
time_stats["total"] = time.perf_counter() - start_time
|
#
|
||||||
|
# # 输出统计信息
|
||||||
# 输出统计信息
|
# logger.info(f"""
|
||||||
logger.info(f"""
|
# 预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
|
||||||
预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
|
# - 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
|
||||||
- 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
|
# - 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
|
||||||
- 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
|
# - 线程任务执行: {time_stats["thread_execution"]:.2f}s
|
||||||
- 线程任务执行: {time_stats["thread_execution"]:.2f}s
|
# - 更新缓存数据: {time_stats["cache_update"]:.2f}s
|
||||||
- 更新缓存数据: {time_stats["cache_update"]:.2f}s
|
# - 总耗时: {time_stats["total"]:.2f}s
|
||||||
- 总耗时: {time_stats["total"]:.2f}s
|
# """)
|
||||||
""")
|
#
|
||||||
|
#
|
||||||
|
# def get_all_user_categories():
|
||||||
def get_all_user_categories():
|
# """获取所有用户及其对应的分类"""
|
||||||
"""获取所有用户及其对应的分类"""
|
# conn = None
|
||||||
conn = None
|
# try:
|
||||||
try:
|
# conn = pymysql.connect(**DB_CONFIG)
|
||||||
conn = pymysql.connect(**DB_CONFIG)
|
# cursor = conn.cursor()
|
||||||
cursor = conn.cursor()
|
#
|
||||||
|
# query = """
|
||||||
query = """
|
# SELECT DISTINCT account_id, path
|
||||||
SELECT DISTINCT account_id, path
|
# FROM user_preference_log_prediction
|
||||||
FROM user_preference_log_prediction \
|
# """
|
||||||
"""
|
# cursor.execute(query)
|
||||||
cursor.execute(query)
|
# results = cursor.fetchall()
|
||||||
results = cursor.fetchall()
|
#
|
||||||
|
# user_categories = defaultdict(set)
|
||||||
user_categories = defaultdict(set)
|
# for account_id, path in results:
|
||||||
for account_id, path in results:
|
# category = get_category_from_path(path)
|
||||||
category = get_category_from_path(path)
|
# user_categories[account_id].add(category)
|
||||||
user_categories[account_id].add(category)
|
#
|
||||||
|
# return dict(user_categories)
|
||||||
return dict(user_categories)
|
#
|
||||||
|
# except Exception as e:
|
||||||
except Exception as e:
|
# logger.error(f"数据库查询失败: {str(e)}")
|
||||||
logger.error(f"数据库查询失败: {str(e)}")
|
# return {}
|
||||||
return {}
|
# finally:
|
||||||
finally:
|
# if conn:
|
||||||
if conn:
|
# conn.close()
|
||||||
conn.close()
|
#
|
||||||
|
#
|
||||||
|
# def get_category_from_path(path: str) -> str:
|
||||||
def get_category_from_path(path: str) -> str:
|
# """从路径解析类别"""
|
||||||
"""从路径解析类别"""
|
# try:
|
||||||
try:
|
# parts = path.split('/')
|
||||||
parts = path.split('/')
|
# if len(parts) >= 4:
|
||||||
if len(parts) >= 4:
|
# return f"{parts[2]}_{parts[3]}"
|
||||||
return f"{parts[2]}_{parts[3]}"
|
# return "unknown"
|
||||||
return "unknown"
|
# except:
|
||||||
except:
|
# return "unknown"
|
||||||
return "unknown"
|
|
||||||
|
|||||||
1
app/service/recommendation_system/__init__.py
Normal file
1
app/service/recommendation_system/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
67
app/service/recommendation_system/config.py
Normal file
67
app/service/recommendation_system/config.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
推荐系统配置
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# Milvus 集合名称
|
||||||
|
MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm"
|
||||||
|
|
||||||
|
# Redis key 前缀
|
||||||
|
REDIS_KEY_USER_PREF_PREFIX = "user_pref"
|
||||||
|
|
||||||
|
# 推荐系统配置参数
|
||||||
|
RECOMMENDATION_CONFIG = {
|
||||||
|
# 时间衰减半衰期(用于计算时间衰减权重)
|
||||||
|
# 值越小,最近的行为权重越大
|
||||||
|
"K_half": 10,
|
||||||
|
|
||||||
|
# 探索与利用的比例 (0.0-1.0)
|
||||||
|
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
|
||||||
|
# - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准
|
||||||
|
# - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8
|
||||||
|
"explore_ratio": 0.5,
|
||||||
|
|
||||||
|
# 向量检索返回的候选数量
|
||||||
|
# 值越大,候选池越大,但计算成本也越高
|
||||||
|
# 建议范围: 100-1000
|
||||||
|
"topk": 200,
|
||||||
|
|
||||||
|
# Style 加分系数(同 style 的候选进行加分)
|
||||||
|
# 值越大,匹配 style 的候选被选中的概率越大
|
||||||
|
# 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05)
|
||||||
|
"style_bonus": 0.2,
|
||||||
|
|
||||||
|
# Softmax 抽样的温度参数
|
||||||
|
# - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低
|
||||||
|
# - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高
|
||||||
|
# - 温度=1.0 为标准 Softmax
|
||||||
|
# - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0
|
||||||
|
"softmax_temperature": 0.07,
|
||||||
|
|
||||||
|
# 监听间隔(秒)
|
||||||
|
"listen_interval_sec": 30,
|
||||||
|
|
||||||
|
# 批量处理大小
|
||||||
|
"batch_size": 1000,
|
||||||
|
|
||||||
|
# Redis 过期时间(秒,30天)
|
||||||
|
"redis_expire_seconds": 2592000,
|
||||||
|
|
||||||
|
# 向量维度
|
||||||
|
"vector_dim": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 数据库表名
|
||||||
|
TABLE_USER_PREFERENCE_LOG = "user_preference"
|
||||||
|
TABLE_SYS_FILE = "t_sys_file"
|
||||||
|
|
||||||
|
# MySQL 连接配置(用于推荐系统)
|
||||||
|
MYSQL_CONFIG = {
|
||||||
|
"host": settings.MYSQL_HOST,
|
||||||
|
"port": settings.MYSQL_PORT,
|
||||||
|
"user": settings.MYSQL_USER,
|
||||||
|
"password": settings.MYSQL_PASSWORD,
|
||||||
|
"database": settings.MYSQL_DB,
|
||||||
|
"charset": "utf8mb4"
|
||||||
|
}
|
||||||
331
app/service/recommendation_system/import_sys_sketch_to_milvus.py
Normal file
331
app/service/recommendation_system/import_sys_sketch_to_milvus.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
"""
|
||||||
|
独立脚本:从 t_sys_file 导入系统图向量到 Milvus
|
||||||
|
可以单独运行,不依赖整个项目启动
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
python -m app.service.recommendation_system.import_sys_sketch_to_milvus
|
||||||
|
或
|
||||||
|
python app/service/recommendation_system/import_sys_sketch_to_milvus.py
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到 Python 路径
|
||||||
|
project_root = Path(__file__).parent.parent.parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pymysql
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from app.service.recommendation_system.config import (
|
||||||
|
MYSQL_CONFIG, TABLE_SYS_FILE,
|
||||||
|
RECOMMENDATION_CONFIG, MILVUS_COLLECTION_SKETCH_VECTORS
|
||||||
|
)
|
||||||
|
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector
|
||||||
|
from app.service.recommendation_system.milvus_client import create_collection, insert_vectors
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler('import_sys_sketch.log', encoding='utf-8')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sys_file_records(conn, limit=None, offset=0):
|
||||||
|
"""
|
||||||
|
从 t_sys_file 表获取系统图记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conn: 数据库连接
|
||||||
|
limit: 限制数量(None 表示不限制)
|
||||||
|
offset: 偏移量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
记录列表,每个元素为 (id, url, style, level3_type, level2_type, deprecated)
|
||||||
|
"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE level1_type = 'Images'
|
||||||
|
AND style IS NOT NULL
|
||||||
|
AND style != ''
|
||||||
|
AND deprecated != 1
|
||||||
|
ORDER BY id
|
||||||
|
"""
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query += f" LIMIT {limit} OFFSET {offset}"
|
||||||
|
|
||||||
|
cursor.execute(query)
|
||||||
|
records = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
def get_total_count(conn):
|
||||||
|
"""获取总记录数"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE level1_type = 'Images'
|
||||||
|
AND style IS NOT NULL
|
||||||
|
AND style != ''
|
||||||
|
AND deprecated != 1
|
||||||
|
""")
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
cursor.close()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def process_and_insert_batch(records, batch_size=1000, retry_times=3):
|
||||||
|
"""
|
||||||
|
处理并批量插入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
records: 记录列表
|
||||||
|
batch_size: 批量大小
|
||||||
|
retry_times: 失败重试次数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(成功数量, 失败数量)
|
||||||
|
"""
|
||||||
|
success_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
failed_records = []
|
||||||
|
batch_data = []
|
||||||
|
|
||||||
|
# 使用 tqdm 显示进度
|
||||||
|
with tqdm(total=len(records), desc="处理记录", unit="条") as pbar:
|
||||||
|
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records):
|
||||||
|
try:
|
||||||
|
# 计算 category
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
# 提取特征向量
|
||||||
|
feature_vector = extract_feature_vector(url)
|
||||||
|
# 归一化,便于 IP≈cosine 度量
|
||||||
|
feature_vector = normalize_vector(feature_vector)
|
||||||
|
|
||||||
|
# 检查向量是否有效
|
||||||
|
if np.all(feature_vector == 0):
|
||||||
|
logger.warning(f"向量提取失败,跳过: {url} (id={sys_file_id})")
|
||||||
|
failed_count += 1
|
||||||
|
failed_records.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
data_item = {
|
||||||
|
"path": url,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": deprecated if deprecated else 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_data.append(data_item)
|
||||||
|
|
||||||
|
# 批量写入
|
||||||
|
if len(batch_data) >= batch_size:
|
||||||
|
try:
|
||||||
|
insert_vectors(batch_data)
|
||||||
|
success_count += len(batch_data)
|
||||||
|
batch_data = []
|
||||||
|
logger.info(f"已成功插入 {success_count} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量写入失败: {e}")
|
||||||
|
failed_count += len(batch_data)
|
||||||
|
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||||
|
batch_data = []
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理记录失败 [id={sys_file_id}, url={url}]: {e}")
|
||||||
|
failed_count += 1
|
||||||
|
failed_records.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 写入剩余数据
|
||||||
|
if batch_data:
|
||||||
|
try:
|
||||||
|
insert_vectors(batch_data)
|
||||||
|
success_count += len(batch_data)
|
||||||
|
logger.info(f"写入剩余 {len(batch_data)} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"写入剩余数据失败: {e}")
|
||||||
|
failed_count += len(batch_data)
|
||||||
|
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||||
|
|
||||||
|
# 重试失败记录
|
||||||
|
if failed_records and retry_times > 0:
|
||||||
|
logger.info(f"开始重试 {len(failed_records)} 条失败记录,最多重试 {retry_times} 次...")
|
||||||
|
|
||||||
|
for retry in range(retry_times):
|
||||||
|
if not failed_records:
|
||||||
|
break
|
||||||
|
|
||||||
|
retry_failed = []
|
||||||
|
with tqdm(total=len(failed_records), desc=f"重试第 {retry + 1} 次", unit="条") as pbar:
|
||||||
|
for sys_file_id, url in failed_records:
|
||||||
|
try:
|
||||||
|
# 重新查询记录信息
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE id = %s
|
||||||
|
""", (sys_file_id,))
|
||||||
|
record = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if not record:
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
sys_file_id, url, style, level3_type, level2_type, deprecated = record
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
feature_vector = extract_feature_vector(url)
|
||||||
|
feature_vector = normalize_vector(feature_vector)
|
||||||
|
if np.all(feature_vector == 0):
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_item = {
|
||||||
|
"path": url,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": deprecated if deprecated else 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
insert_vectors([data_item])
|
||||||
|
success_count += 1
|
||||||
|
failed_count -= 1
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重试失败 [id={sys_file_id}, url={url}]: {e}")
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
failed_records = retry_failed
|
||||||
|
if failed_records:
|
||||||
|
logger.warning(f"第 {retry + 1} 次重试后仍有 {len(failed_records)} 条记录失败")
|
||||||
|
|
||||||
|
return success_count, failed_count, failed_records
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
parser = argparse.ArgumentParser(description='从 t_sys_file 导入系统图向量到 Milvus')
|
||||||
|
parser.add_argument('--batch-size', type=int, default=1000, help='批量处理大小(默认:1000)')
|
||||||
|
parser.add_argument('--retry-times', type=int, default=3, help='失败重试次数(默认:3)')
|
||||||
|
parser.add_argument('--limit', type=int, default=None, help='限制处理数量(用于测试,默认:不限制)')
|
||||||
|
parser.add_argument('--offset', type=int, default=0, help='起始偏移量(默认:0)')
|
||||||
|
parser.add_argument('--skip-create-collection', action='store_true', help='跳过创建集合(如果集合已存在)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("开始从 t_sys_file 导入系统图向量到 Milvus")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"配置参数:")
|
||||||
|
logger.info(f" - 批量大小: {args.batch_size}")
|
||||||
|
logger.info(f" - 重试次数: {args.retry_times}")
|
||||||
|
logger.info(f" - 限制数量: {args.limit if args.limit else '不限制'}")
|
||||||
|
logger.info(f" - 起始偏移: {args.offset}")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# 1. 创建 Milvus 集合
|
||||||
|
if not args.skip_create_collection:
|
||||||
|
logger.info("创建 Milvus 集合...")
|
||||||
|
try:
|
||||||
|
create_collection()
|
||||||
|
logger.info("Milvus 集合创建成功(或已存在)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建 Milvus 集合失败: {e}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.info("跳过创建集合")
|
||||||
|
|
||||||
|
# 2. 连接数据库
|
||||||
|
logger.info("连接数据库...")
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
logger.info("数据库连接成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库连接失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 3. 获取总记录数
|
||||||
|
total_count = get_total_count(conn)
|
||||||
|
logger.info(f"找到 {total_count} 条系统图记录")
|
||||||
|
|
||||||
|
if total_count == 0:
|
||||||
|
logger.warning("没有找到系统图数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. 获取记录
|
||||||
|
logger.info("获取记录...")
|
||||||
|
records = get_sys_file_records(conn, limit=args.limit, offset=args.offset)
|
||||||
|
logger.info(f"获取到 {len(records)} 条记录")
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
logger.warning("没有获取到记录")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 5. 处理并插入
|
||||||
|
logger.info("开始处理记录...")
|
||||||
|
success_count, failed_count, failed_records = process_and_insert_batch(
|
||||||
|
records,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
retry_times=args.retry_times
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. 输出结果
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("导入完成!")
|
||||||
|
logger.info(f" - 成功: {success_count} 条")
|
||||||
|
logger.info(f" - 失败: {failed_count} 条")
|
||||||
|
if failed_records:
|
||||||
|
logger.warning(f" - 失败记录列表(前10条):")
|
||||||
|
for sys_file_id, url in failed_records[:10]:
|
||||||
|
logger.warning(f" ID={sys_file_id}, URL={url}")
|
||||||
|
if len(failed_records) > 10:
|
||||||
|
logger.warning(f" ... 还有 {len(failed_records) - 10} 条失败记录")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理过程中发生错误: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
logger.info("数据库连接已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
347
app/service/recommendation_system/incremental_listener.py
Normal file
347
app/service/recommendation_system/incremental_listener.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""
|
||||||
|
增量监听模块
|
||||||
|
实时监听 user_preference 表的新增记录,更新用户偏好向量
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import pymysql
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Set, Tuple, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||||
|
|
||||||
|
from app.service.recommendation_system.config import (
|
||||||
|
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
|
||||||
|
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||||
|
)
|
||||||
|
from app.service.recommendation_system.vector_utils import extract_feature_vector, compute_weighted_average, normalize_vector
|
||||||
|
from app.service.recommendation_system.milvus_client import query_vectors_by_paths, insert_vectors
|
||||||
|
from app.service.utils.redis_utils import Redis
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IncrementalListener:
|
||||||
|
"""增量监听器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.last_process_time = None
|
||||||
|
self.processed_combinations: Set[Tuple[int, str]] = set() # 已处理的 (account_id, category) 组合
|
||||||
|
self.listen_interval = RECOMMENDATION_CONFIG["listen_interval_sec"]
|
||||||
|
|
||||||
|
def get_new_like_records(self) -> List[Tuple]:
|
||||||
|
"""
|
||||||
|
获取新增点赞记录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
记录列表,每个元素为 (id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id)
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
if self.last_process_time is None:
|
||||||
|
# 第一次运行,查询最近30分钟的数据
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, account_id, path, category, style, data_time
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
|
||||||
|
ORDER BY data_time
|
||||||
|
""")
|
||||||
|
else:
|
||||||
|
# 基于上次处理时间查询
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, account_id, path, category, style, data_time
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE data_time > %s
|
||||||
|
ORDER BY data_time
|
||||||
|
""", (self.last_process_time,))
|
||||||
|
|
||||||
|
records = cursor.fetchall()
|
||||||
|
return records
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取新增点赞记录失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def process_new_records(self, records: List[Tuple]):
|
||||||
|
"""
|
||||||
|
处理新增记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
records: 记录列表
|
||||||
|
"""
|
||||||
|
if not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 按用户+类别分组
|
||||||
|
user_category_records = defaultdict(list)
|
||||||
|
for record in records:
|
||||||
|
account_id = record[1]
|
||||||
|
category = record[3]
|
||||||
|
if category: # 只处理有类别的记录
|
||||||
|
user_category_records[(account_id, category)].append(record)
|
||||||
|
|
||||||
|
# 去重:只处理一次每个 (account_id, category) 组合
|
||||||
|
to_process = []
|
||||||
|
for (account_id, category), recs in user_category_records.items():
|
||||||
|
if (account_id, category) not in self.processed_combinations:
|
||||||
|
to_process.append((account_id, category, recs))
|
||||||
|
self.processed_combinations.add((account_id, category))
|
||||||
|
|
||||||
|
logger.info(f"需要处理 {len(to_process)} 个用户-类别组合")
|
||||||
|
|
||||||
|
# 处理每个组合
|
||||||
|
for account_id, category, recs in to_process:
|
||||||
|
try:
|
||||||
|
self.update_user_preference_vector(account_id, category)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 更新最后处理时间
|
||||||
|
if records:
|
||||||
|
self.last_process_time = records[-1][5] # data_time
|
||||||
|
# 重置去重集合,确保下次周期不会跳过同一用户-类别
|
||||||
|
self.processed_combinations.clear()
|
||||||
|
|
||||||
|
def update_user_preference_vector(self, account_id: int, category: str):
|
||||||
|
"""
|
||||||
|
更新用户偏好向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
account_id: 用户ID
|
||||||
|
category: 类别
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 获取该用户该类别的所有点赞记录
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, data_time
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s
|
||||||
|
ORDER BY data_time DESC
|
||||||
|
""", (account_id, category))
|
||||||
|
|
||||||
|
like_records = cursor.fetchall()
|
||||||
|
|
||||||
|
if not like_records:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 批量查询点赞次数
|
||||||
|
paths = [r[0] for r in like_records]
|
||||||
|
placeholders = ','.join(['%s'] * len(paths))
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, COUNT(*) as like_count
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||||
|
GROUP BY path
|
||||||
|
""", (account_id, category) + tuple(paths))
|
||||||
|
|
||||||
|
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
# 3. 批量获取向量
|
||||||
|
vectors_dict = query_vectors_by_paths(paths)
|
||||||
|
|
||||||
|
# 处理查询不到的 path(新用户图或异常情况)
|
||||||
|
missing_paths = [p for p in paths if p not in vectors_dict]
|
||||||
|
if missing_paths:
|
||||||
|
logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量")
|
||||||
|
self._compute_and_insert_missing_vectors(missing_paths, conn)
|
||||||
|
# 重新查询
|
||||||
|
vectors_dict = query_vectors_by_paths(paths)
|
||||||
|
|
||||||
|
# 4. 计算权重并加权平均
|
||||||
|
vectors = []
|
||||||
|
weights = []
|
||||||
|
K_half = RECOMMENDATION_CONFIG["K_half"]
|
||||||
|
|
||||||
|
for k, (path, data_time) in enumerate(like_records, 1):
|
||||||
|
if path not in vectors_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vector_data = vectors_dict[path]
|
||||||
|
feature_vector = np.array(vector_data["feature_vector"])
|
||||||
|
|
||||||
|
# 时间衰减权重
|
||||||
|
d_k = 0.5 ** (k / K_half)
|
||||||
|
|
||||||
|
# 点赞次数权重
|
||||||
|
like_count = like_counts.get(path, 1)
|
||||||
|
p_i = 1 + math.log(1 + like_count)
|
||||||
|
|
||||||
|
# 综合权重
|
||||||
|
w_i = d_k * p_i
|
||||||
|
|
||||||
|
vectors.append(feature_vector)
|
||||||
|
weights.append(w_i)
|
||||||
|
|
||||||
|
if not vectors:
|
||||||
|
logger.warning(f"用户 {account_id} 类别 {category} 没有有效向量")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 5. 计算加权平均并做 L2 归一化,IP≈cosine
|
||||||
|
preference_vector = compute_weighted_average(vectors, weights)
|
||||||
|
preference_vector = normalize_vector(preference_vector)
|
||||||
|
|
||||||
|
# 6. 写入 Redis
|
||||||
|
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
|
||||||
|
vector_json = json.dumps(preference_vector.tolist())
|
||||||
|
Redis.write(
|
||||||
|
key=key,
|
||||||
|
value=vector_json,
|
||||||
|
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"用户偏好向量更新成功 [user={account_id}, category={category}]")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _compute_and_insert_missing_vectors(self, paths: List[str], conn: pymysql.connections.Connection):
|
||||||
|
"""
|
||||||
|
计算并插入缺失的向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: 缺失的 path 列表
|
||||||
|
conn: 数据库连接
|
||||||
|
"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
data_to_insert = []
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
try:
|
||||||
|
# 判断数据来源(查询 t_sys_file 表)
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE url = %s
|
||||||
|
LIMIT 1
|
||||||
|
""", (path,))
|
||||||
|
|
||||||
|
sys_file = cursor.fetchone()
|
||||||
|
|
||||||
|
# 提取特征向量
|
||||||
|
feature_vector = extract_feature_vector(path)
|
||||||
|
|
||||||
|
if np.all(feature_vector == 0):
|
||||||
|
logger.warning(f"向量提取失败,跳过: {path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if sys_file:
|
||||||
|
# 系统图
|
||||||
|
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
data_item = {
|
||||||
|
"path": path,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": deprecated if deprecated else 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 用户图
|
||||||
|
# 从 user_preference 获取 category(如果有)
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT category
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE path = %s AND category IS NOT NULL
|
||||||
|
LIMIT 1
|
||||||
|
""", (path,))
|
||||||
|
|
||||||
|
category_result = cursor.fetchone()
|
||||||
|
category = category_result[0] if category_result else None
|
||||||
|
|
||||||
|
data_item = {
|
||||||
|
"path": path,
|
||||||
|
"sys_file_id": None,
|
||||||
|
"style": None,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 0,
|
||||||
|
"deprecated": 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
data_to_insert.append(data_item)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理缺失向量失败 [{path}]: {e}")
|
||||||
|
|
||||||
|
# 批量插入
|
||||||
|
if data_to_insert:
|
||||||
|
try:
|
||||||
|
insert_vectors(data_to_insert)
|
||||||
|
logger.info(f"成功插入 {len(data_to_insert)} 个缺失向量")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"插入缺失向量失败: {e}")
|
||||||
|
|
||||||
|
def process_once(self):
|
||||||
|
"""单次轮询任务,供调度器调用"""
|
||||||
|
try:
|
||||||
|
records = self.get_new_like_records()
|
||||||
|
|
||||||
|
if records:
|
||||||
|
logger.info(f"发现 {len(records)} 条新增记录")
|
||||||
|
self.process_new_records(records)
|
||||||
|
else:
|
||||||
|
logger.debug("没有新增记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"监听轮询异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def start_background_listener(scheduler: BackgroundScheduler):
|
||||||
|
"""将增量监听任务注册到后台调度器"""
|
||||||
|
# 降低 apscheduler 的日志级别,避免大量刷屏
|
||||||
|
logging.getLogger('apscheduler.executors.default').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('apscheduler.scheduler').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
listener = IncrementalListener()
|
||||||
|
scheduler.add_job(
|
||||||
|
listener.process_once,
|
||||||
|
"interval",
|
||||||
|
seconds=listener.listen_interval,
|
||||||
|
max_instances=1,
|
||||||
|
coalesce=True,
|
||||||
|
id="recommendation_incremental_listener",
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
logger.info("增量监听任务已注册到调度器")
|
||||||
|
|
||||||
|
|
||||||
|
def start_blocking_listener():
|
||||||
|
"""以阻塞方式启动调度器(用于独立脚本运行)"""
|
||||||
|
listener = IncrementalListener()
|
||||||
|
scheduler = BlockingScheduler()
|
||||||
|
scheduler.add_job(
|
||||||
|
listener.process_once,
|
||||||
|
"interval",
|
||||||
|
seconds=listener.listen_interval,
|
||||||
|
max_instances=1,
|
||||||
|
coalesce=True,
|
||||||
|
id="recommendation_incremental_listener",
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
logger.info("增量监听调度器已启动(BlockingScheduler)")
|
||||||
|
scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_blocking_listener()
|
||||||
|
|
||||||
332
app/service/recommendation_system/milvus_client.py
Normal file
332
app/service/recommendation_system/milvus_client.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
"""
|
||||||
|
Milvus 客户端封装
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
import numpy as np
|
||||||
|
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Milvus 客户端(单例)
|
||||||
|
_milvus_client = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_milvus_client() -> MilvusClient:
|
||||||
|
"""获取 Milvus 客户端(单例模式)"""
|
||||||
|
global _milvus_client
|
||||||
|
if _milvus_client is None:
|
||||||
|
try:
|
||||||
|
_milvus_client = MilvusClient(
|
||||||
|
uri=settings.MILVUS_URL,
|
||||||
|
token=settings.MILVUS_TOKEN,
|
||||||
|
db_name="",
|
||||||
|
)
|
||||||
|
logger.info("Milvus 客户端连接成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Milvus 客户端连接失败: {e}")
|
||||||
|
raise
|
||||||
|
return _milvus_client
|
||||||
|
|
||||||
|
|
||||||
|
def create_collection():
|
||||||
|
"""
|
||||||
|
创建 Milvus 集合 sketch_vectors
|
||||||
|
|
||||||
|
集合结构:
|
||||||
|
- path (PK, varchar(512)) - 主键,MinIO 逻辑 URL
|
||||||
|
- sys_file_id (int64, 可为NULL) - 系统文件ID
|
||||||
|
- style (varchar(50), 可为NULL) - 风格样式
|
||||||
|
- category (varchar(100), 可为NULL) - 类别
|
||||||
|
- is_system_sketch (int8, 默认 1) - 标记字段:1-系统图,0-用户图
|
||||||
|
- deprecated (int8, 默认 0) - 是否废弃
|
||||||
|
- feature_vector (FloatVector(2048)) - 2048维特征向量
|
||||||
|
"""
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
# 检查集合是否已存在
|
||||||
|
collections = client.list_collections()
|
||||||
|
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
|
||||||
|
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析 Milvus URL
|
||||||
|
# 处理 http://host.docker.internal:19530 格式
|
||||||
|
url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "")
|
||||||
|
if ":" in url_clean:
|
||||||
|
host, port_str = url_clean.split(":", 1)
|
||||||
|
port = int(port_str)
|
||||||
|
else:
|
||||||
|
host = url_clean
|
||||||
|
port = 19530
|
||||||
|
|
||||||
|
# 使用传统 API 创建集合(更可靠)
|
||||||
|
# 连接到 Milvus(如果未连接)
|
||||||
|
try:
|
||||||
|
connections.connect(
|
||||||
|
alias=settings.MILVUS_ALIAS,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None
|
||||||
|
)
|
||||||
|
logger.info(f"已连接到 Milvus: {host}:{port}")
|
||||||
|
except Exception as conn_e:
|
||||||
|
# 如果连接已存在,忽略错误
|
||||||
|
if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e):
|
||||||
|
logger.info("Milvus 连接已存在")
|
||||||
|
else:
|
||||||
|
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
|
||||||
|
|
||||||
|
# 定义字段
|
||||||
|
fields = [
|
||||||
|
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
||||||
|
FieldSchema(name="sys_file_id", dtype=DataType.INT64),
|
||||||
|
FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50),
|
||||||
|
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50),
|
||||||
|
FieldSchema(name="is_system_sketch", dtype=DataType.INT8),
|
||||||
|
FieldSchema(name="deprecated", dtype=DataType.INT8),
|
||||||
|
FieldSchema(
|
||||||
|
name="feature_vector",
|
||||||
|
dtype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=RECOMMENDATION_CONFIG["vector_dim"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 创建 schema
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=fields,
|
||||||
|
description="Sketch vectors collection for recommendation system"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建集合
|
||||||
|
collection = Collection(
|
||||||
|
name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
schema=schema,
|
||||||
|
using=settings.MILVUS_ALIAS
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建索引
|
||||||
|
# 注意:使用 IP(内积)作为度量类型,与搜索时保持一致
|
||||||
|
# 如果向量已归一化,IP 等价于 COSINE
|
||||||
|
index_params = {
|
||||||
|
"metric_type": "IP", # 内积(Inner Product)
|
||||||
|
"index_type": "IVF_FLAT",
|
||||||
|
"params": {"nlist": 1024}
|
||||||
|
}
|
||||||
|
|
||||||
|
collection.create_index(
|
||||||
|
field_name="feature_vector",
|
||||||
|
index_params=index_params
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建集合失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def insert_vectors(data: List[Dict[str, Any]]):
|
||||||
|
"""
|
||||||
|
批量插入向量到 Milvus
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 数据列表,每个元素包含:
|
||||||
|
- path: str
|
||||||
|
- sys_file_id: int (可选)
|
||||||
|
- style: str (可选)
|
||||||
|
- category: str (可选)
|
||||||
|
- is_system_sketch: int (默认 1)
|
||||||
|
- deprecated: int (默认 0)
|
||||||
|
- feature_vector: List[float] (2048维)
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.insert(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
logger.info(f"成功插入 {len(data)} 条向量数据")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"插入向量失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
根据 path 列表批量查询向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: path 列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{path: {feature_vector: [...], ...}} 字典
|
||||||
|
"""
|
||||||
|
if not paths:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建查询表达式
|
||||||
|
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
||||||
|
# 对于字符串列表,使用单引号包裹每个值
|
||||||
|
path_list = ", ".join([f"'{p}'" for p in paths])
|
||||||
|
filter_expr = f"path in [{path_list}]"
|
||||||
|
|
||||||
|
results = client.query(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为字典
|
||||||
|
result_dict = {}
|
||||||
|
for r in results:
|
||||||
|
result_dict[r["path"]] = r
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询向量失败: {e}", exc_info=True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def search_similar_vectors(
|
||||||
|
query_vector: np.ndarray,
|
||||||
|
category: str,
|
||||||
|
topk: int = 500,
|
||||||
|
style: Optional[str] = None,
|
||||||
|
style_boost_ratio: float = 0.2
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
向量相似度检索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_vector: 查询向量(2048维)
|
||||||
|
category: 类别过滤
|
||||||
|
topk: 返回数量
|
||||||
|
style: 风格过滤(可选)- 当提供时,会给对应style的结果加分
|
||||||
|
style_boost_ratio: 风格加分比例(默认0.1,即10%)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索结果列表,每个元素包含 path, score, style, category 等字段
|
||||||
|
"""
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 如果没有指定style,使用原始逻辑
|
||||||
|
if not style:
|
||||||
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||||
|
results = client.search(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
data=[query_vector.tolist()],
|
||||||
|
anns_field="feature_vector",
|
||||||
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||||
|
limit=topk,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["path", "style", "category", "sys_file_id"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 有style参数时,使用两阶段搜索策略
|
||||||
|
|
||||||
|
# 第一阶段:搜索匹配style的向量,使用boosted query vector
|
||||||
|
filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
|
||||||
|
boosted_query = query_vector * (1 + style_boost_ratio)
|
||||||
|
results_style = client.search(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
data=[boosted_query.tolist()],
|
||||||
|
anns_field="feature_vector",
|
||||||
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||||
|
limit=topk,
|
||||||
|
filter=filter_expr_style,
|
||||||
|
output_fields=["path", "style", "category", "sys_file_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 第二阶段:搜索其他style的向量
|
||||||
|
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
|
||||||
|
results_others = client.search(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
data=[query_vector.tolist()],
|
||||||
|
anns_field="feature_vector",
|
||||||
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||||
|
limit=topk,
|
||||||
|
filter=filter_expr_others,
|
||||||
|
output_fields=["path", "style", "category", "sys_file_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 合并结果
|
||||||
|
results = []
|
||||||
|
if results_style and len(results_style) > 0:
|
||||||
|
results.extend(results_style[0])
|
||||||
|
if results_others and len(results_others) > 0:
|
||||||
|
results.extend(results_others[0])
|
||||||
|
|
||||||
|
# 转换为单个结果列表格式
|
||||||
|
results = [results] if results else []
|
||||||
|
|
||||||
|
# 格式化结果
|
||||||
|
formatted_results = []
|
||||||
|
if results and len(results) > 0:
|
||||||
|
for hit in results[0]:
|
||||||
|
formatted_results.append({
|
||||||
|
"path": hit.get("entity", {}).get("path", ""),
|
||||||
|
"score": hit.get("distance", 0.0),
|
||||||
|
"style": hit.get("entity", {}).get("style", ""),
|
||||||
|
"category": hit.get("entity", {}).get("category", ""),
|
||||||
|
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
||||||
|
})
|
||||||
|
|
||||||
|
# 按分数排序并返回topk
|
||||||
|
formatted_results.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
return formatted_results[:topk]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"向量检索失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
随机查询候选(用于探索分支)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: 类别
|
||||||
|
style: 风格(可选)
|
||||||
|
limit: 返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
候选列表
|
||||||
|
"""
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建过滤表达式
|
||||||
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||||
|
if style:
|
||||||
|
filter_expr += f" && style == '{style}'"
|
||||||
|
|
||||||
|
# 查询所有符合条件的记录
|
||||||
|
results = client.query(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["path", "style", "category"],
|
||||||
|
limit=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
# 随机选择
|
||||||
|
if len(results) > limit:
|
||||||
|
import random
|
||||||
|
results = random.sample(results, limit)
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"随机查询候选失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
557
app/service/recommendation_system/precompute.py
Normal file
557
app/service/recommendation_system/precompute.py
Normal file
@@ -0,0 +1,557 @@
|
|||||||
|
"""
|
||||||
|
预计算模块
|
||||||
|
包含:数据库表结构优化、Milvus集合创建、系统图向量预计算、初始用户偏好向量生成
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import pymysql
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from app.service.recommendation_system.config import (
|
||||||
|
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
|
||||||
|
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||||
|
)
|
||||||
|
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average
|
||||||
|
from app.service.recommendation_system.milvus_client import (
|
||||||
|
create_collection, insert_vectors, query_vectors_by_paths
|
||||||
|
)
|
||||||
|
from app.service.utils.redis_utils import Redis
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_database_table():
|
||||||
|
"""
|
||||||
|
优化 user_preference 表结构
|
||||||
|
添加冗余字段和索引
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 添加冗余字段
|
||||||
|
logger.info("添加冗余字段...")
|
||||||
|
alter_sqls = [
|
||||||
|
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别:lower(level3_type + \"_\" + level2_type)'",
|
||||||
|
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'",
|
||||||
|
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图(1-是,0-用户图)'",
|
||||||
|
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'",
|
||||||
|
]
|
||||||
|
|
||||||
|
for sql in alter_sqls:
|
||||||
|
try:
|
||||||
|
cursor.execute(sql)
|
||||||
|
logger.info(f"执行成功: {sql[:50]}...")
|
||||||
|
except Exception as e:
|
||||||
|
if "Duplicate column name" in str(e):
|
||||||
|
logger.info(f"字段已存在,跳过: {sql[:50]}...")
|
||||||
|
else:
|
||||||
|
logger.warning(f"执行失败: {sql[:50]}... 错误: {e}")
|
||||||
|
|
||||||
|
# 2. 创建索引(MySQL 不支持 IF NOT EXISTS,需要先检查)
|
||||||
|
logger.info("创建索引...")
|
||||||
|
index_definitions = [
|
||||||
|
("idx_account_category_time", ["account_id", "category", "data_time"]),
|
||||||
|
("idx_account_path", ["account_id", "path"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
for index_name, columns in index_definitions:
|
||||||
|
try:
|
||||||
|
# 检查索引是否已存在
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM information_schema.statistics
|
||||||
|
WHERE table_schema = DATABASE()
|
||||||
|
AND table_name = '{TABLE_USER_PREFERENCE_LOG}'
|
||||||
|
AND index_name = '{index_name}'
|
||||||
|
""")
|
||||||
|
exists = cursor.fetchone()[0] > 0
|
||||||
|
|
||||||
|
if exists:
|
||||||
|
logger.info(f"索引已存在,跳过: {index_name}")
|
||||||
|
else:
|
||||||
|
# 创建索引
|
||||||
|
columns_str = ', '.join(columns)
|
||||||
|
create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})"
|
||||||
|
cursor.execute(create_sql)
|
||||||
|
logger.info(f"索引创建成功: {index_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"索引创建失败: {index_name} 错误: {e}")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
logger.info("数据库表结构优化完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库表结构优化失败: {e}", exc_info=True)
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_historical_data(batch_size: int = 1000):
|
||||||
|
"""
|
||||||
|
历史数据迁移:批量更新冗余字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: 每批处理数量
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 查询需要更新的记录数
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG} u
|
||||||
|
WHERE u.category IS NULL
|
||||||
|
""")
|
||||||
|
total_count = cursor.fetchone()[0]
|
||||||
|
logger.info(f"需要迁移的记录数: {total_count}")
|
||||||
|
|
||||||
|
if total_count == 0:
|
||||||
|
logger.info("无需迁移数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 分批处理
|
||||||
|
offset = 0
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
while offset < total_count:
|
||||||
|
# 查询一批记录
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT u.id, u.path
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG} u
|
||||||
|
WHERE u.category IS NULL
|
||||||
|
LIMIT {batch_size} OFFSET {offset}
|
||||||
|
""")
|
||||||
|
records = cursor.fetchall()
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 批量更新
|
||||||
|
for record_id, path in records:
|
||||||
|
# 查询 t_sys_file 表
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE url = %s
|
||||||
|
LIMIT 1
|
||||||
|
""", (path,))
|
||||||
|
|
||||||
|
sys_file = cursor.fetchone()
|
||||||
|
|
||||||
|
if sys_file:
|
||||||
|
# 系统图
|
||||||
|
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
cursor.execute(f"""
|
||||||
|
UPDATE {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
SET category = %s,
|
||||||
|
style = %s,
|
||||||
|
is_system_sketch = 1,
|
||||||
|
sys_file_id = %s
|
||||||
|
WHERE id = %s
|
||||||
|
""", (category, style, sys_file_id, record_id))
|
||||||
|
else:
|
||||||
|
# 用户图
|
||||||
|
cursor.execute(f"""
|
||||||
|
UPDATE {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
SET is_system_sketch = 0,
|
||||||
|
category = NULL,
|
||||||
|
style = NULL,
|
||||||
|
sys_file_id = NULL
|
||||||
|
WHERE id = %s
|
||||||
|
""", (record_id,))
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
processed += len(records)
|
||||||
|
offset += batch_size
|
||||||
|
logger.info(f"已迁移 {processed}/{total_count} 条记录")
|
||||||
|
|
||||||
|
logger.info("历史数据迁移完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"历史数据迁移失败: {e}", exc_info=True)
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3):
|
||||||
|
"""
|
||||||
|
系统图向量预计算与导入
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: 每批处理数量
|
||||||
|
retry_times: 失败重试次数
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 数据筛选
|
||||||
|
logger.info("查询系统图数据...")
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE level1_type = 'Images'
|
||||||
|
AND style IS NOT NULL
|
||||||
|
AND style != ''
|
||||||
|
AND deprecated != 1
|
||||||
|
""")
|
||||||
|
records = cursor.fetchall()
|
||||||
|
logger.info(f"找到 {len(records)} 条系统图记录")
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
logger.warning("没有找到系统图数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 批量处理
|
||||||
|
failed_records = []
|
||||||
|
batch_data = []
|
||||||
|
|
||||||
|
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1):
|
||||||
|
try:
|
||||||
|
# 计算 category
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
# 提取特征向量
|
||||||
|
feature_vector = extract_feature_vector(url)
|
||||||
|
|
||||||
|
# 检查向量是否有效
|
||||||
|
if np.all(feature_vector == 0):
|
||||||
|
logger.warning(f"向量提取失败,跳过: {url}")
|
||||||
|
failed_records.append((sys_file_id, url))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
data_item = {
|
||||||
|
"path": url,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": deprecated if deprecated else 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_data.append(data_item)
|
||||||
|
|
||||||
|
# 批量写入
|
||||||
|
if len(batch_data) >= batch_size:
|
||||||
|
try:
|
||||||
|
insert_vectors(batch_data)
|
||||||
|
batch_data = []
|
||||||
|
logger.info(f"已处理 {idx}/{len(records)} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量写入失败: {e}")
|
||||||
|
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||||
|
batch_data = []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理记录失败 [{url}]: {e}")
|
||||||
|
failed_records.append((sys_file_id, url))
|
||||||
|
|
||||||
|
# 写入剩余数据
|
||||||
|
if batch_data:
|
||||||
|
try:
|
||||||
|
insert_vectors(batch_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"写入剩余数据失败: {e}")
|
||||||
|
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||||
|
|
||||||
|
# 3. 重试失败记录
|
||||||
|
if failed_records and retry_times > 0:
|
||||||
|
logger.info(f"重试 {len(failed_records)} 条失败记录...")
|
||||||
|
for retry in range(retry_times):
|
||||||
|
retry_failed = []
|
||||||
|
for sys_file_id, url in failed_records:
|
||||||
|
try:
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
feature_vector = extract_feature_vector(url)
|
||||||
|
if not np.all(feature_vector == 0):
|
||||||
|
data_item = {
|
||||||
|
"path": url,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
insert_vectors([data_item])
|
||||||
|
else:
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重试失败 [{url}]: {e}")
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
|
||||||
|
failed_records = retry_failed
|
||||||
|
if not failed_records:
|
||||||
|
break
|
||||||
|
|
||||||
|
if failed_records:
|
||||||
|
logger.warning(f"仍有 {len(failed_records)} 条记录处理失败")
|
||||||
|
|
||||||
|
logger.info("系统图向量预计算完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"系统图向量预计算失败: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_user_preference_vector(
|
||||||
|
account_id: int,
|
||||||
|
category: str,
|
||||||
|
conn: Optional[pymysql.connections.Connection] = None,
|
||||||
|
max_date: Optional[datetime] = None
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
计算用户偏好向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
account_id: 用户ID
|
||||||
|
category: 类别
|
||||||
|
conn: 数据库连接(可选)
|
||||||
|
max_date: 最大日期(可选,用于评估时只使用训练集数据)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户偏好向量(2048维),失败返回 None
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
should_close = False
|
||||||
|
if conn is None:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
should_close = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 获取点赞记录(如果指定了max_date,只查询该日期之前的数据)
|
||||||
|
if max_date:
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, data_time
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND style is not null
|
||||||
|
AND data_time < %s
|
||||||
|
ORDER BY data_time DESC
|
||||||
|
""", (account_id, category, max_date))
|
||||||
|
else:
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, data_time
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND style is not null
|
||||||
|
ORDER BY data_time DESC
|
||||||
|
""", (account_id, category))
|
||||||
|
|
||||||
|
like_records = cursor.fetchall()
|
||||||
|
|
||||||
|
if not like_records:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2. 批量查询点赞次数(如果指定了max_date,只统计该日期之前的点赞)
|
||||||
|
paths = [r[0] for r in like_records]
|
||||||
|
if not paths:
|
||||||
|
return None
|
||||||
|
|
||||||
|
placeholders = ','.join(['%s'] * len(paths))
|
||||||
|
if max_date:
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, COUNT(*) as like_count
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||||
|
AND data_time < %s
|
||||||
|
GROUP BY path
|
||||||
|
""", (account_id, category) + tuple(paths) + (max_date,))
|
||||||
|
else:
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, COUNT(*) as like_count
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||||
|
GROUP BY path
|
||||||
|
""", (account_id, category) + tuple(paths))
|
||||||
|
|
||||||
|
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
# 3. 批量获取向量
|
||||||
|
vectors_dict = query_vectors_by_paths(paths)
|
||||||
|
|
||||||
|
# 处理查询不到的 path(用户图或异常情况)
|
||||||
|
missing_paths = [p for p in paths if p not in vectors_dict]
|
||||||
|
if missing_paths:
|
||||||
|
logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量")
|
||||||
|
# 目前未有非系统图向量,跳过
|
||||||
|
# 这里可以实时计算并写入 Milvus,但为了简化,先跳过
|
||||||
|
# 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus
|
||||||
|
|
||||||
|
# 4. 计算权重并加权平均
|
||||||
|
vectors = []
|
||||||
|
weights = []
|
||||||
|
K_half = RECOMMENDATION_CONFIG["K_half"]
|
||||||
|
|
||||||
|
for k, (path, data_time) in enumerate(like_records, 1):
|
||||||
|
if path not in vectors_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vector_data = vectors_dict[path]
|
||||||
|
feature_vector = np.array(vector_data["feature_vector"])
|
||||||
|
|
||||||
|
# 时间衰减权重
|
||||||
|
d_k = 0.5 ** (k / K_half)
|
||||||
|
|
||||||
|
# 点赞次数权重
|
||||||
|
like_count = like_counts.get(path, 1)
|
||||||
|
p_i = 1 + math.log(1 + like_count)
|
||||||
|
|
||||||
|
# 综合权重
|
||||||
|
w_i = d_k * p_i
|
||||||
|
# w_i = p_i
|
||||||
|
|
||||||
|
vectors.append(feature_vector)
|
||||||
|
weights.append(w_i)
|
||||||
|
|
||||||
|
if not vectors:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 5. 计算加权平均并做 L2 归一化,IP≈cosine
|
||||||
|
preference_vector = compute_weighted_average(vectors, weights)
|
||||||
|
preference_vector = normalize_vector(preference_vector)
|
||||||
|
|
||||||
|
return preference_vector
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
if should_close and conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_initial_user_preference_vectors(batch_size: int = 100):
|
||||||
|
"""
|
||||||
|
初始用户偏好向量生成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: 每批处理用户数
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 扫描历史数据
|
||||||
|
logger.info("扫描用户和类别组合...")
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT DISTINCT account_id, category
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE category IS NOT NULL
|
||||||
|
AND style IS NOT NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
user_categories = cursor.fetchall()
|
||||||
|
logger.info(f"找到 {len(user_categories)} 个用户-类别组合")
|
||||||
|
|
||||||
|
if not user_categories:
|
||||||
|
logger.warning("没有找到用户-类别组合")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 批量处理
|
||||||
|
processed = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
for account_id, category in user_categories:
|
||||||
|
try:
|
||||||
|
# 计算偏好向量
|
||||||
|
preference_vector = compute_user_preference_vector(account_id, category, conn)
|
||||||
|
|
||||||
|
if preference_vector is not None:
|
||||||
|
# 写入 Redis
|
||||||
|
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
|
||||||
|
# 序列化向量(使用 JSON)
|
||||||
|
vector_json = json.dumps(preference_vector.tolist())
|
||||||
|
Redis.write(
|
||||||
|
key=key,
|
||||||
|
value=vector_json,
|
||||||
|
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||||
|
)
|
||||||
|
processed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
if (processed + failed) % batch_size == 0:
|
||||||
|
logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理失败 [user={account_id}, category={category}]: {e}")
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_precompute():
|
||||||
|
"""
|
||||||
|
运行所有预计算任务
|
||||||
|
"""
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("开始预计算任务")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
# 1. 优化数据库表结构
|
||||||
|
# logger.info("\n[1/5] 优化数据库表结构...")
|
||||||
|
# optimize_database_table()
|
||||||
|
|
||||||
|
# # 2. 创建 Milvus 集合
|
||||||
|
# logger.info("\n[2/5] 创建 Milvus 集合...")
|
||||||
|
# create_collection()
|
||||||
|
|
||||||
|
# 3. 历史数据迁移
|
||||||
|
# logger.info("\n[3/5] 历史数据迁移...")
|
||||||
|
# migrate_historical_data()
|
||||||
|
|
||||||
|
# # 4. 系统图向量预计算
|
||||||
|
# logger.info("\n[4/5] 系统图向量预计算...")
|
||||||
|
# precompute_system_sketch_vectors()
|
||||||
|
|
||||||
|
# 5. 初始用户偏好向量生成
|
||||||
|
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||||||
|
generate_initial_user_preference_vectors()
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("预计算任务完成")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# # 1. 优化数据库表结构
|
||||||
|
# logger.info("\n[1/5] 优化数据库表结构...")
|
||||||
|
# optimize_database_table()
|
||||||
|
#
|
||||||
|
# # 3. 历史数据迁移
|
||||||
|
# logger.info("\n[3/5] 历史数据迁移...")
|
||||||
|
# migrate_historical_data()
|
||||||
|
|
||||||
|
# 5. 初始用户偏好向量生成
|
||||||
|
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||||||
|
generate_initial_user_preference_vectors()
|
||||||
214
app/service/recommendation_system/recommendation_api.py
Normal file
214
app/service/recommendation_system/recommendation_api.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""
|
||||||
|
推荐接口实现
|
||||||
|
实现探索/利用分支、向量检索、Softmax抽样等功能
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||||
|
from app.service.recommendation_system.milvus_client import search_similar_vectors, query_random_candidates
|
||||||
|
from app.service.recommendation_system.precompute import compute_user_preference_vector
|
||||||
|
from app.service.recommendation_system.vector_utils import normalize_vector
|
||||||
|
from app.service.utils.redis_utils import Redis
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_preference_vector(user_id: int, category: str) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
获取用户偏好向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
category: 类别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户偏好向量(2048维),失败返回 None
|
||||||
|
"""
|
||||||
|
# 1. 从 Redis 获取
|
||||||
|
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{user_id}:{category}"
|
||||||
|
vector_json = Redis.read(key)
|
||||||
|
|
||||||
|
if vector_json:
|
||||||
|
try:
|
||||||
|
vector_list = json.loads(vector_json)
|
||||||
|
return np.array(vector_list, dtype=np.float32)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"解析 Redis 向量失败 [user={user_id}, category={category}]: {e}")
|
||||||
|
|
||||||
|
# 2. 如果不存在,实时计算
|
||||||
|
logger.info(f"Redis 中不存在用户偏好向量,实时计算 [user={user_id}, category={category}]")
|
||||||
|
preference_vector = compute_user_preference_vector(user_id, category)
|
||||||
|
|
||||||
|
if preference_vector is not None:
|
||||||
|
# 写入 Redis
|
||||||
|
vector_json = json.dumps(preference_vector.tolist())
|
||||||
|
Redis.write(
|
||||||
|
key=key,
|
||||||
|
value=vector_json,
|
||||||
|
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return preference_vector
|
||||||
|
|
||||||
|
|
||||||
|
def explore_branch(category: str, style: Optional[str] = None) -> List[str]:
|
||||||
|
"""
|
||||||
|
探索分支(随机推荐)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: 类别
|
||||||
|
style: 风格(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐结果列表,每个元素包含 path, style, category 等字段
|
||||||
|
"""
|
||||||
|
# 查询候选(随机池)
|
||||||
|
pool_size = 10 # 固定查询10个,然后随机选择
|
||||||
|
|
||||||
|
candidates = query_random_candidates(category, style, limit=pool_size)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.warning(f"探索分支:类别 {category} 没有候选数据")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 随机选择
|
||||||
|
if len(candidates) > 1:
|
||||||
|
import random
|
||||||
|
candidates = random.sample(candidates, 1)
|
||||||
|
|
||||||
|
# 格式化返回结果
|
||||||
|
return [candidate.get("path", "") for candidate in candidates[:1]]
|
||||||
|
|
||||||
|
|
||||||
|
def exploit_branch(
|
||||||
|
user_id: int,
|
||||||
|
category: str,
|
||||||
|
style: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
利用分支(基于向量相似度推荐)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
category: 类别
|
||||||
|
num_recommendations: 返回数量
|
||||||
|
style: 风格(可选,用于加分)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐结果列表,每个元素包含 path, style, category, similarity, sample_score 等字段
|
||||||
|
"""
|
||||||
|
# 1. 获取用户偏好向量
|
||||||
|
embedding = get_user_preference_vector(user_id, category)
|
||||||
|
|
||||||
|
if embedding is None:
|
||||||
|
logger.warning(f"利用分支:无法获取用户偏好向量,回退到探索分支 [user={user_id}, category={category}]")
|
||||||
|
return explore_branch(category, style)
|
||||||
|
|
||||||
|
# 2. Milvus 相似度检索(内积 IP)
|
||||||
|
topk = RECOMMENDATION_CONFIG["topk"]
|
||||||
|
results = search_similar_vectors(embedding, category, topk)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
logger.warning(f"利用分支:向量检索无结果,回退到探索分支 [user={user_id}, category={category}]")
|
||||||
|
return explore_branch(category, style)
|
||||||
|
|
||||||
|
# 3. Style 加分(可选,需传入 style 参数)
|
||||||
|
style_bonus = RECOMMENDATION_CONFIG["style_bonus"]
|
||||||
|
if style:
|
||||||
|
for result in results:
|
||||||
|
similarity = result["score"]
|
||||||
|
if result.get("style") == style:
|
||||||
|
# 加分:相似度 * (1 + style_bonus)
|
||||||
|
similarity = similarity * (1 + style_bonus)
|
||||||
|
result["final_score"] = similarity
|
||||||
|
else:
|
||||||
|
for result in results:
|
||||||
|
result["final_score"] = result["score"]
|
||||||
|
|
||||||
|
# 4. Softmax 抽样
|
||||||
|
scores = [r["final_score"] for r in results]
|
||||||
|
probabilities = softmax_with_temperature(scores, RECOMMENDATION_CONFIG["softmax_temperature"])
|
||||||
|
|
||||||
|
# 根据概率抽样
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
selected_index = np.random.choice(len(results), size=1, p=probabilities, replace=False)
|
||||||
|
selected_results = [results[int(selected_index[0])]]
|
||||||
|
|
||||||
|
# 5. 返回结果
|
||||||
|
return [result.get("path", "") for result in selected_results]
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_with_temperature(scores: List[float], temperature: float = 1.0) -> List[float]:
|
||||||
|
"""
|
||||||
|
Softmax 函数(带温度参数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scores: 分数列表
|
||||||
|
temperature: 温度参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
概率列表
|
||||||
|
"""
|
||||||
|
if not scores:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 除以温度
|
||||||
|
scaled_scores = [s / temperature for s in scores]
|
||||||
|
|
||||||
|
# 减去最大值(数值稳定性)
|
||||||
|
max_score = max(scaled_scores)
|
||||||
|
exp_scores = [math.exp(s - max_score) for s in scaled_scores]
|
||||||
|
|
||||||
|
# 归一化
|
||||||
|
sum_exp = sum(exp_scores)
|
||||||
|
if sum_exp == 0:
|
||||||
|
# 如果所有分数都是负无穷或非常小,返回均匀分布
|
||||||
|
return [1.0 / len(scores)] * len(scores)
|
||||||
|
|
||||||
|
probabilities = [exp_s / sum_exp for exp_s in exp_scores]
|
||||||
|
return probabilities
|
||||||
|
|
||||||
|
|
||||||
|
def get_recommendations(
|
||||||
|
user_id: int,
|
||||||
|
category: str,
|
||||||
|
style: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取推荐结果(主函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
category: 类别(如 female_skirt)
|
||||||
|
num_recommendations: 返回推荐数量(默认 1)
|
||||||
|
style: 风格(可选):若传入,则在利用分支对同 style 的候选进行加分
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐结果列表,每个元素包含 path 等字段
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. 读取配置参数
|
||||||
|
explore_ratio = RECOMMENDATION_CONFIG["explore_ratio"]
|
||||||
|
|
||||||
|
# 2. 探索/利用决策
|
||||||
|
r = random.random() # 生成随机数 (0-1)
|
||||||
|
|
||||||
|
if r < explore_ratio:
|
||||||
|
logger.debug(f"探索分支 [user={user_id}, category={category}]")
|
||||||
|
return explore_branch(category, style)
|
||||||
|
|
||||||
|
logger.debug(f"利用分支 [user={user_id}, category={category}]")
|
||||||
|
return exploit_branch(user_id, category, style)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取推荐结果失败 [user={user_id}, category={category}]: {e}", exc_info=True)
|
||||||
|
# 容错:回退到探索分支
|
||||||
|
return explore_branch(category, style)
|
||||||
|
|
||||||
189
app/service/recommendation_system/vector_utils.py
Normal file
189
app/service/recommendation_system/vector_utils.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
向量计算工具类
|
||||||
|
包含 ResNet50 特征提取、向量归一化等功能
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision import models, transforms
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 图像预处理(与ResNet训练时的预处理一致)
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
|
||||||
|
transforms.ToTensor(), # 转换为 Tensor
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
|
||||||
|
])
|
||||||
|
|
||||||
|
# 加载预训练的 ResNet50 模型(去掉最后全连接层)
|
||||||
|
_resnet_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_resnet_model():
|
||||||
|
"""获取 ResNet50 模型(单例模式)"""
|
||||||
|
global _resnet_model
|
||||||
|
if _resnet_model is None:
|
||||||
|
logger.info("加载 ResNet50 模型...")
|
||||||
|
_resnet_model = models.resnet50(pretrained=True)
|
||||||
|
modules = list(_resnet_model.children())[:-1] # 移除最后的全连接层
|
||||||
|
_resnet_model = torch.nn.Sequential(*modules)
|
||||||
|
_resnet_model.eval() # 设置为评估模式
|
||||||
|
logger.info("ResNet50 模型加载完成")
|
||||||
|
return _resnet_model
|
||||||
|
|
||||||
|
|
||||||
|
# MinIO 客户端(单例)
|
||||||
|
_minio_client = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_minio_client():
|
||||||
|
"""获取 MinIO 客户端(单例模式)"""
|
||||||
|
global _minio_client
|
||||||
|
if _minio_client is None:
|
||||||
|
_minio_client = Minio(
|
||||||
|
settings.MINIO_URL,
|
||||||
|
access_key=settings.MINIO_ACCESS,
|
||||||
|
secret_key=settings.MINIO_SECRET,
|
||||||
|
secure=settings.MINIO_SECURE
|
||||||
|
)
|
||||||
|
return _minio_client
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_minio(path: str) -> Image.Image:
|
||||||
|
"""
|
||||||
|
从 MinIO 获取图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: MinIO 逻辑 URL,格式如 "bucket_name/object_name"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image 对象,失败返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 分割路径,获取桶名和文件路径
|
||||||
|
path_parts = path.split('/', 1)
|
||||||
|
if len(path_parts) != 2:
|
||||||
|
logger.error(f"路径格式错误: {path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
bucket_name, file_name = path_parts
|
||||||
|
minio_client = get_minio_client()
|
||||||
|
|
||||||
|
# 获取文件
|
||||||
|
obj = minio_client.get_object(bucket_name, file_name)
|
||||||
|
img_data = obj.read() # 读取图像数据
|
||||||
|
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
|
||||||
|
|
||||||
|
return img
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从 MinIO 获取图片失败 [{path}]: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_feature_vector(path: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
使用 ResNet50 提取图片特征向量(2048维)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: MinIO 逻辑 URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
2048维特征向量(numpy array),失败返回零向量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从 MinIO 获取图像
|
||||||
|
img = get_image_from_minio(path)
|
||||||
|
if img is None:
|
||||||
|
logger.warning(f"无法获取图片,返回零向量: {path}")
|
||||||
|
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
|
||||||
|
# 预处理
|
||||||
|
# 部分 MinIO 图片可能是 RGBA/CMYK,转换成 RGB 以匹配 3 通道标准化参数
|
||||||
|
if img.mode != "RGB":
|
||||||
|
try:
|
||||||
|
img = img.convert("RGB")
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"无法转换图片为RGB,返回零向量: {path}")
|
||||||
|
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
|
||||||
|
img_tensor = transform(img).unsqueeze(0) # 扩展维度以适应批量处理
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
resnet_model = get_resnet_model()
|
||||||
|
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
|
||||||
|
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
|
||||||
|
feature_vector = feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
|
||||||
|
|
||||||
|
# 确保是 2048 维
|
||||||
|
if feature_vector.ndim > 1:
|
||||||
|
feature_vector = feature_vector.flatten()
|
||||||
|
|
||||||
|
# 确保维度正确
|
||||||
|
if len(feature_vector) != RECOMMENDATION_CONFIG["vector_dim"]:
|
||||||
|
logger.warning(f"向量维度不正确: {len(feature_vector)}, 期望: {RECOMMENDATION_CONFIG['vector_dim']}")
|
||||||
|
# 如果维度不对,尝试调整
|
||||||
|
if len(feature_vector) > RECOMMENDATION_CONFIG["vector_dim"]:
|
||||||
|
feature_vector = feature_vector[:RECOMMENDATION_CONFIG["vector_dim"]]
|
||||||
|
else:
|
||||||
|
padded = np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
padded[:len(feature_vector)] = feature_vector
|
||||||
|
feature_vector = padded
|
||||||
|
|
||||||
|
return feature_vector.astype(np.float32)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取特征向量失败 [{path}]: {e}", exc_info=True)
|
||||||
|
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_vector(vector: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
L2 归一化向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector: 输入向量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
归一化后的向量
|
||||||
|
"""
|
||||||
|
norm = np.linalg.norm(vector)
|
||||||
|
if norm == 0:
|
||||||
|
return vector
|
||||||
|
return vector / norm
|
||||||
|
|
||||||
|
|
||||||
|
def compute_weighted_average(vectors: list, weights: list) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
计算加权平均向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors: 向量列表
|
||||||
|
weights: 权重列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加权平均向量(不做归一化,模长为加权平均后的尺度)
|
||||||
|
"""
|
||||||
|
if not vectors or not weights:
|
||||||
|
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
|
||||||
|
# 确保所有向量都是 numpy array
|
||||||
|
vectors = [np.array(v) for v in vectors]
|
||||||
|
weights = np.array(weights)
|
||||||
|
|
||||||
|
# 计算加权和
|
||||||
|
weighted_sum = np.zeros_like(vectors[0])
|
||||||
|
for v, w in zip(vectors, weights):
|
||||||
|
weighted_sum += v * w
|
||||||
|
|
||||||
|
# 返回加权平均(除以权重和,不做 L2 归一化,模长不会随条数线性暴涨)
|
||||||
|
weight_total = weights.sum()
|
||||||
|
if weight_total == 0:
|
||||||
|
return weighted_sum
|
||||||
|
return weighted_sum / weight_total
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ if __name__ == '__main__':
|
|||||||
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||||
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
||||||
# url = "aida-users/89/single_logo/123-89.png"
|
# url = "aida-users/89/single_logo/123-89.png"
|
||||||
url = "lanecarford/lc_stylist_agent_outfit_items/141/ee25ec85-d504-4b42-9a18-db6682fe9e3b-6.jpg"
|
url = "aida-results/result_a7adcbd8-ef8d-11f0-8c92-0966ede33ab5.png"
|
||||||
|
|
||||||
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
||||||
read_type = "2"
|
read_type = "2"
|
||||||
|
|||||||
@@ -91,6 +91,21 @@ class Redis(object):
|
|||||||
r = cls._get_r()
|
r = cls._get_r()
|
||||||
r.expire(name, expire_in_seconds)
|
r.expire(name, expire_in_seconds)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def scan_keys(cls, pattern="*"):
|
||||||
|
"""
|
||||||
|
扫描匹配模式的key
|
||||||
|
"""
|
||||||
|
r = cls._get_r()
|
||||||
|
keys = []
|
||||||
|
cursor = 0
|
||||||
|
while True:
|
||||||
|
cursor, partial_keys = r.scan(cursor, match=pattern, count=1000)
|
||||||
|
keys.extend(partial_keys)
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
redis_client = Redis()
|
redis_client = Redis()
|
||||||
|
|||||||
@@ -11,3 +11,15 @@ services:
|
|||||||
- ./seg_cache:/seg_cache
|
- ./seg_cache:/seg_cache
|
||||||
ports:
|
ports:
|
||||||
- "10200:80"
|
- "10200:80"
|
||||||
|
depends_on:
|
||||||
|
- redis
|
||||||
|
redis:
|
||||||
|
image: redis
|
||||||
|
container_name: aida_redis
|
||||||
|
restart: always
|
||||||
|
ports:
|
||||||
|
- "6400:6379"
|
||||||
|
volumes:
|
||||||
|
- ./redis/data:/data
|
||||||
|
- ./redis/conf/redis.conf:/etc/redis/redis.conf
|
||||||
|
command: redis-server /etc/redis/redis.conf --appendonly yes
|
||||||
@@ -1,10 +1,15 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
LOGGER_CONFIG_DICT = {
|
LOGGER_CONFIG_DICT = {
|
||||||
'version': 1,
|
'version': 1,
|
||||||
'disable_existing_loggers': False,
|
'disable_existing_loggers': False,
|
||||||
'formatters': {
|
'formatters': {
|
||||||
'simple': {'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'}
|
'simple': {
|
||||||
|
'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s',
|
||||||
|
'datefmt': '%Y-%m-%d %H:%M:%S' # 补充日期格式,日志更易读
|
||||||
|
}
|
||||||
},
|
},
|
||||||
'handlers': {
|
'handlers': {
|
||||||
'console': {
|
'console': {
|
||||||
@@ -17,7 +22,7 @@ LOGGER_CONFIG_DICT = {
|
|||||||
'class': 'logging.handlers.RotatingFileHandler',
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
'level': 'INFO',
|
'level': 'INFO',
|
||||||
'formatter': 'simple',
|
'formatter': 'simple',
|
||||||
'filename': f'{settings.LOGS_PATH}info.log',
|
'filename': os.path.join(settings.LOGS_PATH, 'info.log'),
|
||||||
'maxBytes': 10485760,
|
'maxBytes': 10485760,
|
||||||
'backupCount': 50,
|
'backupCount': 50,
|
||||||
'encoding': 'utf8',
|
'encoding': 'utf8',
|
||||||
@@ -26,7 +31,7 @@ LOGGER_CONFIG_DICT = {
|
|||||||
'class': 'logging.handlers.RotatingFileHandler',
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
'level': 'ERROR',
|
'level': 'ERROR',
|
||||||
'formatter': 'simple',
|
'formatter': 'simple',
|
||||||
'filename': f'{settings.LOGS_PATH}error.log',
|
'filename': os.path.join(settings.LOGS_PATH, 'error.log'),
|
||||||
'maxBytes': 10485760,
|
'maxBytes': 10485760,
|
||||||
'backupCount': 20,
|
'backupCount': 20,
|
||||||
'encoding': 'utf8',
|
'encoding': 'utf8',
|
||||||
@@ -35,7 +40,7 @@ LOGGER_CONFIG_DICT = {
|
|||||||
'class': 'logging.handlers.RotatingFileHandler',
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
'level': 'DEBUG',
|
'level': 'DEBUG',
|
||||||
'formatter': 'simple',
|
'formatter': 'simple',
|
||||||
'filename': f'{settings.LOGS_PATH}debug.log',
|
'filename': os.path.join(settings.LOGS_PATH, 'debug.log'),
|
||||||
'maxBytes': 10485760,
|
'maxBytes': 10485760,
|
||||||
'backupCount': 50,
|
'backupCount': 50,
|
||||||
'encoding': 'utf8',
|
'encoding': 'utf8',
|
||||||
@@ -45,7 +50,7 @@ LOGGER_CONFIG_DICT = {
|
|||||||
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
|
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
|
||||||
},
|
},
|
||||||
'root': {
|
'root': {
|
||||||
'level': 'INFO',
|
'level': 'DEBUG',
|
||||||
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
|
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user