Merge remote-tracking branch 'refs/remotes/origin/develop'
# Conflicts: # app/core/config.py # app/service/chat_robot/script/service/CallQWen.py
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@@ -142,3 +142,10 @@ app/logs/*
|
||||
*.npy
|
||||
*.pytorch
|
||||
*.jpg
|
||||
*.mp4
|
||||
*.sqlite3
|
||||
*.bin
|
||||
*.pickle
|
||||
*.csv
|
||||
*.avi
|
||||
*.json
|
||||
@@ -3,16 +3,17 @@ import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.brand_dna import BrandDnaModel
|
||||
from app.schemas.brand_dna import BrandDnaModel, GenerateBrandModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.brand_dna.service import BrandDna
|
||||
from app.service.brand_dna.service_generate_brand_info import GenerateBrandInfo
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/seg_product")
|
||||
def image2sketch(request_item: BrandDnaModel):
|
||||
def seg_product(request_item: BrandDnaModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **image_url**: 提取图片url
|
||||
@@ -20,8 +21,8 @@ def image2sketch(request_item: BrandDnaModel):
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
|
||||
"is_brand_dna": False
|
||||
"image_url": "aida-results/result_00006a48-e315-11ee-b7c8-b48351119060.png",
|
||||
"is_brand_dna": false
|
||||
}
|
||||
"""
|
||||
try:
|
||||
@@ -32,3 +33,27 @@ def image2sketch(request_item: BrandDnaModel):
|
||||
logger.warning(f"brand dna Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=result_url)
|
||||
|
||||
|
||||
@router.post("/GenerateBrand")
|
||||
def GenerateBrand(request_data: GenerateBrandModel):
|
||||
"""
|
||||
通过prompt 生成 brand name ,brand slogan , brand logo。
|
||||
创建一个具有以下参数的请求体:
|
||||
- **prompt**:
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"prompt": "xiaomi",
|
||||
"user_id": "89"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"GenerateBrand request item is : @@@@@@:{request_data}")
|
||||
service = GenerateBrandInfo(request_data)
|
||||
data = service.get_result()
|
||||
logger.info(f"GenerateBrand response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"GenerateBrand Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
212
app/api/api_brand_dna_initialize.py
Normal file
212
app/api/api_brand_dna_initialize.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
import pymysql
|
||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||
from minio import Minio
|
||||
import torch
|
||||
from torchvision import models, transforms
|
||||
from PIL import Image
|
||||
import os
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# MinIO 配置
|
||||
minio_client = Minio(
|
||||
"www.minio.aida.com.hk:12024",
|
||||
access_key="admin",
|
||||
secret_key="Aidlab123123!",
|
||||
secure=True
|
||||
)
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
# ResNet50(去掉最后全连接层)
|
||||
resnet_model = models.resnet50(pretrained=True)
|
||||
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1])
|
||||
resnet_model.eval()
|
||||
|
||||
|
||||
def get_sketch_image_from_minio(sketch_path: str):
|
||||
path_parts = sketch_path.split('/', 1)
|
||||
if len(path_parts) != 2:
|
||||
return None
|
||||
bucket_name, file_name = path_parts
|
||||
try:
|
||||
obj = minio_client.get_object(bucket_name, file_name)
|
||||
img = Image.open(io.BytesIO(obj.read()))
|
||||
return transform(img).unsqueeze(0)
|
||||
except Exception as e:
|
||||
logger.warning(f"Fetch image failed [{sketch_path}]: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
|
||||
img_tensor = get_sketch_image_from_minio(sketch_path)
|
||||
if img_tensor is None:
|
||||
return np.zeros(2048, dtype=np.float32)
|
||||
with torch.no_grad():
|
||||
vec = resnet_model(img_tensor) # [1, 2048, 1, 1]
|
||||
return vec.squeeze().cpu().numpy()
|
||||
|
||||
|
||||
# 预加载
|
||||
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
|
||||
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||
|
||||
|
||||
def save_sketch_to_iid():
|
||||
sketch_to_iid = {
|
||||
sketch_path: iid
|
||||
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
|
||||
}
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
|
||||
|
||||
|
||||
def load_sketch_to_iid():
|
||||
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
|
||||
if os.path.exists(path):
|
||||
return np.load(path, allow_pickle=True).item()
|
||||
save_sketch_to_iid()
|
||||
return np.load(path, allow_pickle=True).item()
|
||||
|
||||
|
||||
sketch_to_iid = load_sketch_to_iid()
|
||||
|
||||
|
||||
def getNewCategory(gender: str, sketch_category: str) -> str:
|
||||
return f"{gender.lower()}_{sketch_category.lower()}"
|
||||
|
||||
|
||||
def get_category_from_path(path: str) -> str:
|
||||
parts = path.split('/')
|
||||
if len(parts) >= 4:
|
||||
return f"{parts[2].lower()}_{parts[3].lower()}"
|
||||
return "unknown_unknown"
|
||||
|
||||
|
||||
def load_brand_matrix():
|
||||
"""单独加载 brand_matrix 和 brand_index_map"""
|
||||
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
|
||||
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||
try:
|
||||
matrix = np.load(mat_path)
|
||||
index_map = np.load(idx_path, allow_pickle=True).item()
|
||||
except FileNotFoundError:
|
||||
matrix = np.zeros((0, len(sketch_to_iid)), dtype=np.float32)
|
||||
index_map = {}
|
||||
return matrix, index_map
|
||||
|
||||
def cosine_similarity(vec1, vec2):
|
||||
"""计算余弦相似度(增加零值处理)"""
|
||||
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
||||
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
||||
|
||||
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
||||
# 1. 收集品牌-分类-特征
|
||||
brand_feature = defaultdict(lambda: defaultdict(list))
|
||||
for _id, sketch_path, gender, sketch_category in sketch_data:
|
||||
cat = getNewCategory(gender, sketch_category)
|
||||
feat = BRAND_FEATURES.get(_id) or extract_feature_vector_from_resnet(sketch_path)
|
||||
brand_feature[(brand_id, cat)][_id].append(feat)
|
||||
|
||||
# 2. 构建 sketch 索引
|
||||
sketch_list = sorted(sketch_to_iid.values())
|
||||
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
|
||||
n_sketch = len(sketch_list)
|
||||
|
||||
# 3. 加载或初始化矩阵
|
||||
brand_matrix, brand_index_map = load_brand_matrix()
|
||||
|
||||
# 4. 增加/更新 行
|
||||
if brand_id in brand_index_map:
|
||||
row_idx = brand_index_map[brand_id]
|
||||
else:
|
||||
row_idx = brand_matrix.shape[0]
|
||||
brand_index_map[brand_id] = row_idx
|
||||
brand_matrix = np.vstack([
|
||||
brand_matrix,
|
||||
np.zeros((1, n_sketch), dtype=np.float32)
|
||||
])
|
||||
|
||||
# 5. 计算品牌-分类平均向量
|
||||
brand_avg = {}
|
||||
for key, id_dict in brand_feature.items():
|
||||
all_feats = [v for feats in id_dict.values() for v in feats]
|
||||
if all_feats:
|
||||
brand_avg[key] = np.mean(all_feats, axis=0)
|
||||
|
||||
# 6. 填充相似度
|
||||
for sketch_path, sys_vec in SYSTEM_FEATURES.items():
|
||||
iid = sketch_to_iid.get(sketch_path)
|
||||
if not iid or iid not in sketch_index:
|
||||
continue
|
||||
cat_key = (brand_id, get_category_from_path(sketch_path))
|
||||
avg_vec = brand_avg.get(cat_key)
|
||||
if avg_vec is not None:
|
||||
cos_sim = cosine_similarity(avg_vec, sys_vec)
|
||||
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
|
||||
|
||||
# 7. 持久化
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
|
||||
|
||||
# 返回该品牌对应行
|
||||
return brand_matrix[row_idx:row_idx+1]
|
||||
|
||||
|
||||
@router.get("/brand_dna_initialize/{brand_id}")
|
||||
async def brand_dna_initialize(brand_id: int):
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, img_url, gender, category
|
||||
FROM product_image_attribute
|
||||
WHERE library_id IN (
|
||||
SELECT library_id
|
||||
FROM brand_rel_library
|
||||
WHERE brand_id = %s
|
||||
)
|
||||
""", (brand_id,))
|
||||
sketch_data = cursor.fetchall()
|
||||
|
||||
# 触发计算并持久化,若内部出错会抛异常
|
||||
_ = calculate_brand_matrix(sketch_data, brand_id)
|
||||
|
||||
# 返回成功
|
||||
return {"success": True}
|
||||
|
||||
except HTTPException:
|
||||
# 已经是明确的 HTTPException,直接抛出
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"品牌初始化失败 [{brand_id}]: {e}", exc_info=True)
|
||||
# 返回失败的 JSON,同时设置 500 状态码
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"success": False, "message": "品牌初始化失败"}
|
||||
)
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
51
app/api/api_clothing_seg.py
Normal file
51
app/api/api_clothing_seg.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.schemas.clothing_seg import ClothingSegModel
|
||||
from app.service.clothing_seg.service import ClothingSeg
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/clothing_seg")
|
||||
def clothing_seg(request_item: ClothingSegModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **user_id**: 用户id
|
||||
- **image_data**: 图片数据
|
||||
{
|
||||
"image_url": "test/clothing_seg/dress.jpg",
|
||||
"image_type": "product"
|
||||
}
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"user_id": 89,
|
||||
"image_data": [
|
||||
{
|
||||
"image_url": "test/clothing_seg/dress.jpg",
|
||||
"image_type": "sketch"
|
||||
},
|
||||
{
|
||||
"image_url": "test/clothing_seg/skirt_559.jpg",
|
||||
"image_type": "sketch"
|
||||
},
|
||||
{
|
||||
"image_url": "test/clothing_seg/10144613.jpg",
|
||||
"image_type": "product"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
server = ClothingSeg(request_item)
|
||||
result_url = server.get_result()
|
||||
except Exception as e:
|
||||
logger.warning(f"clothing_seg Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=result_url)
|
||||
@@ -433,12 +433,12 @@ def model_process(request_data: ModelProgressModel):
|
||||
|
||||
|
||||
@router.post("/design_batch_generate")
|
||||
async def design(file: UploadFile = File(...),
|
||||
tasks_id: str = Form(...),
|
||||
user_id: str = Form(...),
|
||||
file_name: str = Form(...),
|
||||
total: int = Form(...)
|
||||
):
|
||||
async def design_batch(file: UploadFile = File(...),
|
||||
tasks_id: str = Form(...),
|
||||
user_id: str = Form(...),
|
||||
file_name: str = Form(...),
|
||||
total: int = Form(...)
|
||||
):
|
||||
dbg_config = DBGConfigModel(
|
||||
tasks_id=tasks_id,
|
||||
user_id=user_id,
|
||||
|
||||
39
app/api/api_extraction_project_info.py
Normal file
39
app/api/api_extraction_project_info.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.project_info_extraction.service import ProjectInfoExtraction
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/extraction_project_info")
|
||||
def extraction_project_info(request_data: ProjectInfoExtractionModel):
|
||||
"""
|
||||
通过prompt 提取project_name,role ,gender ,style。
|
||||
创建一个具有以下参数的请求体:
|
||||
- **prompt**:
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"prompt": "海边派对主题的系列设计",
|
||||
"image_list": [
|
||||
"https://www.minio-api.aida.com.hk/test/test123.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250519%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250519T050808Z&X-Amz-Expires=7200&X-Amz-SignedHeaders=host&X-Amz-Signature=296ff07cc4692d0a26ddffac582064f036494af343389fe60193dc2c5dc883ff"
|
||||
],
|
||||
"file_list": [
|
||||
""
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"extraction_project_info request item is : @@@@@@:{request_data}")
|
||||
service = ProjectInfoExtraction(request_data)
|
||||
data = service.get_result()
|
||||
logger.info(f"extraction_project_info response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"extraction_project_info Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
@@ -3,8 +3,11 @@ import logging
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel
|
||||
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel
|
||||
from app.schemas.pose_transform import BatchPoseTransformModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
|
||||
from app.service.generate_image.service_agent_tool_generate_image import AgentToolGenerateImage
|
||||
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
||||
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
|
||||
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
||||
@@ -228,3 +231,123 @@ def generate_relight_image(tasks_id: str):
|
||||
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
"""batch generate img"""
|
||||
|
||||
|
||||
@router.post("/batch_generate_product_image")
|
||||
async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **image_strength**: 生成强度,越低越接近原图
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||
"image_strength": 0.8,
|
||||
"product_type": "overall",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_product_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
@router.post("/batch_generate_relight_image")
|
||||
async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"direction": "Right Light",
|
||||
"product_type": "overall",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_relight_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
@router.post("/batch_generate_pose_transform_image")
|
||||
async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **pose_id**: 1
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"pose_id": "1",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_pose_transform_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
"""agent tool"""
|
||||
|
||||
|
||||
@router.post("/agent_tool_generate_image")
|
||||
def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **category**: 生成图片的类别,sketch print 等等
|
||||
- **gender**: 生成sketch专用,服装类别
|
||||
- **version**: 使用模型版本 fast 或者 high
|
||||
- **size**: 生成数量
|
||||
- **version**: 使用模型版本 fast 或者 high
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||
"category": "sketch",
|
||||
"gender": "male",
|
||||
"size":2,
|
||||
"version":"high"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
|
||||
request_data = request_item.dict()
|
||||
service = AgentToolGenerateImage(request_data['version'])
|
||||
image_url_list, clothing_category_list = service.get_result(
|
||||
prompt=request_data['prompt'],
|
||||
size=request_data['size'],
|
||||
version=request_data['version'],
|
||||
category=request_data['category'],
|
||||
gender=request_data['gender']
|
||||
)
|
||||
data = {
|
||||
"image_url_list": image_url_list,
|
||||
"clothing_category_list": clothing_category_list
|
||||
}
|
||||
logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
45
app/api/api_mannequins_edit.py
Normal file
45
app/api/api_mannequins_edit.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.mannequin_edit import MannequinModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.mannequins_edit.service import MannequinEditService
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/mannequins_edit")
|
||||
def mannequins_edit(request_data: MannequinModel):
|
||||
"""
|
||||
模特腿长调整
|
||||
创建一个具有以下参数的请求体:
|
||||
- **mannequins**: mannequins url等信息
|
||||
- **resize_pixel**: 拉伸像素量
|
||||
- **bucket_name**: bucket name
|
||||
- **mannequin_name**: 模特名称
|
||||
- **top**: 拉伸y轴点位
|
||||
- **bottom**: 拉伸y轴点位
|
||||
|
||||
|
||||
示例参数:
|
||||
- **{
|
||||
"mannequins": "aida-sys-image/models/male/dc36ce58-46c3-4b6f-8787-5ca7d6fc26e6.png",
|
||||
"resize_pixel": -50,
|
||||
"bucket_name": "test",
|
||||
"mannequin_name": "mannequin_name",
|
||||
"top" : 270,
|
||||
"bottom" : 432
|
||||
}**
|
||||
"""
|
||||
try:
|
||||
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
service = MannequinEditService(request_data)
|
||||
data = service()
|
||||
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"mannequins_edit Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
49
app/api/api_pose_transform.py
Normal file
49
app/api/api_pose_transform.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from app.schemas.pose_transform import PoseTransformModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/pose_transform")
|
||||
def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **pose_id**: 1
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"pose_id": "1"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = PoseTransformService(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/pose_transform_cancel/{tasks_id}")
|
||||
def pose_transform_cancel(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = pose_transform_infer_cancel(tasks_id)
|
||||
logger.info(f"pose_transform_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
@@ -4,9 +4,10 @@ import time
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.prompt_generation import PromptGenerationImageModel
|
||||
from app.schemas.prompt_generation import PromptGenerationImageModel, ImageRequest
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.prompt_generation.chatgpt_for_translation import translate_to_en, get_translation_from_llama3
|
||||
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, \
|
||||
get_prompt_from_image
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
@@ -32,3 +33,20 @@ def prompt_generation(request_data: PromptGenerationImageModel):
|
||||
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
@router.post("/img2prompt")
|
||||
def get_prompt_from_img(img: ImageRequest):
|
||||
"""
|
||||
自动识别图片并输出为prompt
|
||||
|
||||
:param img: 图片的minio地址
|
||||
:return: 图片的文字描述
|
||||
"""
|
||||
text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
|
||||
"The description should allow for the reconstruction of the corresponding line art based on the details "
|
||||
"given.")
|
||||
logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
|
||||
description = get_prompt_from_image(img, text)
|
||||
logger.info(f"生成的图片描述 response @@@@@@:{description}")
|
||||
return description
|
||||
|
||||
204
app/api/api_recommendation.py
Normal file
204
app/api/api_recommendation.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
# 初始加载
|
||||
load_resources()
|
||||
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_job(
|
||||
load_resources,
|
||||
trigger=CronTrigger(hour=0, minute=30),
|
||||
name="每日资源刷新"
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("定时任务已启动")
|
||||
|
||||
def softmax(scores):
|
||||
max_score = max(scores)
|
||||
exp_scores = [math.exp(s - max_score) for s in scores]
|
||||
sum_exp = sum(exp_scores)
|
||||
return [s / sum_exp for s in exp_scores]
|
||||
|
||||
# def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
# """根据预加载热度向量推荐(冷启动)"""
|
||||
# try:
|
||||
# heat_data = matrix_data.get("heat_data", {})
|
||||
#
|
||||
# if category not in heat_data:
|
||||
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
|
||||
#
|
||||
# heat_dict = heat_data[category] # {url: score}
|
||||
# urls = list(heat_dict.keys())
|
||||
# scores = list(heat_dict.values())
|
||||
#
|
||||
# if not urls:
|
||||
# raise ValueError("该类别下无热度记录,使用随机推荐")
|
||||
#
|
||||
# probs = softmax(scores)
|
||||
# sample_size = min(num, len(urls))
|
||||
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
|
||||
#
|
||||
# return sampled_urls
|
||||
#
|
||||
# except Exception as e:
|
||||
# # 回退:完全随机推荐
|
||||
# all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# sample_size = min(num, len(category_iids))
|
||||
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
|
||||
def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
"""全品类随机推荐"""
|
||||
all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# 优先从当前品类选择
|
||||
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# 确保不超出实际数量
|
||||
sample_size = min(num, len(category_iids))
|
||||
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
"""
|
||||
:param user_id: 4
|
||||
:param category: female_skirt
|
||||
:param num_recommendations: 1
|
||||
:return:
|
||||
[
|
||||
"aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
]
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
cache_key = (user_id, category)
|
||||
# === 新增:用户存在性检查 ===
|
||||
user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
|
||||
# 任一矩阵不存在用户则返回随机推荐
|
||||
if not (user_exists_inter and user_exists_feat):
|
||||
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
return get_random_recommendations(category, num_recommendations)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
else:
|
||||
# 实时计算逻辑(同原代码)
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
valid_sketch_idxs_inter = [
|
||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
|
||||
# 处理交互分数
|
||||
raw_inter_scores = []
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
|
||||
# 处理特征分数
|
||||
valid_sketch_idxs_feature = [
|
||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
raw_feat_scores = []
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores
|
||||
else:
|
||||
processed_feat = np.array([])
|
||||
|
||||
# 更新缓存
|
||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
|
||||
# 合并分数
|
||||
if brand_id is not None:
|
||||
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
|
||||
brand_feat_valid = (
|
||||
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
brand_idx_feature is not None and
|
||||
valid_sketch_idxs_feature # 有可用索引
|
||||
)
|
||||
|
||||
if brand_feat_valid:
|
||||
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
brand_idx_feature, valid_sketch_idxs_feature
|
||||
]
|
||||
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
)
|
||||
processed_brand_feat = raw_brand_feat_scores
|
||||
|
||||
# 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
if processed_feat.size == 0:
|
||||
processed_feat = np.zeros_like(processed_brand_feat)
|
||||
|
||||
final_scores = processed_inter + 0.3 * (
|
||||
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
)
|
||||
else:
|
||||
# brand 信息不可用
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
else:
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
|
||||
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
|
||||
# 概率采样
|
||||
scores = np.array(final_scores)
|
||||
|
||||
# 调整后的概率转换(带温度控制的softmax)
|
||||
def calibrated_softmax(scores, temperature=1.0):
|
||||
scores = scores / temperature
|
||||
scale = scores - max(scores)
|
||||
exps = np.exp(scale)
|
||||
return exps / np.sum(exps)
|
||||
|
||||
probs = calibrated_softmax(scores, 0.09)
|
||||
|
||||
chosen_indices = np.random.choice(
|
||||
len(valid_sketch_idxs),
|
||||
size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
p=probs,
|
||||
replace=False
|
||||
)
|
||||
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
|
||||
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -4,11 +4,16 @@ from app.api import api_attribute_retrieve, api_query_image
|
||||
from app.api import api_brand_dna
|
||||
from app.api import api_brighten
|
||||
from app.api import api_chat_robot
|
||||
from app.api import api_clothing_seg
|
||||
from app.api import api_design
|
||||
from app.api import api_design_pre_processing
|
||||
from app.api import api_extraction_project_info
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_image2sketch
|
||||
from app.api import api_mannequins_edit
|
||||
from app.api import api_pose_transform
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_recommendation
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_test
|
||||
|
||||
@@ -26,3 +31,8 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix
|
||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, BATCH_PS_RABBITMQ_QUEUES
|
||||
from app.schemas.response_template import ResponseModel
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -14,10 +14,19 @@ router = APIRouter()
|
||||
@router.get("{id}")
|
||||
def test(id: int):
|
||||
data = {
|
||||
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
|
||||
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
||||
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
||||
"pose transform PS_RABBITMQ_QUEUES": PS_RABBITMQ_QUEUES,
|
||||
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
|
||||
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||
"to product image GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||
"relight GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||
|
||||
# batch
|
||||
"batch product BATCH_GPI_RABBITMQ_QUEUES": BATCH_GPI_RABBITMQ_QUEUES,
|
||||
"batch relight BATCH_GRI_RABBITMQ_QUEUES": BATCH_GRI_RABBITMQ_QUEUES,
|
||||
"batch pose transform BATCH_PS_RABBITMQ_QUEUES": BATCH_PS_RABBITMQ_QUEUES,
|
||||
|
||||
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
||||
"local_oss_server": OSS
|
||||
}
|
||||
|
||||
@@ -9,14 +9,14 @@ load_dotenv(os.path.join(BASE_DIR, '.env'))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
|
||||
SECRET_KEY = os.getenv('SECRET_KEY', '')
|
||||
API_PREFIX = ''
|
||||
BACKEND_CORS_ORIGINS = ['*']
|
||||
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '')
|
||||
PROJECT_NAME: str = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
|
||||
SECRET_KEY: str = os.getenv('SECRET_KEY', '')
|
||||
API_PREFIX: str = ''
|
||||
BACKEND_CORS_ORIGINS: list[str] = ['*']
|
||||
DATABASE_URL: str = os.getenv('SQL_DATABASE_URL', '')
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||
SECURITY_ALGORITHM = 'HS256'
|
||||
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
SECURITY_ALGORITHM: str = 'HS256'
|
||||
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
|
||||
|
||||
OSS = "minio"
|
||||
@@ -25,13 +25,20 @@ if DEBUG:
|
||||
LOGS_PATH = "logs/"
|
||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "../seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "service/recommend/"
|
||||
CHROMADB_PATH = "./chromadb/"
|
||||
else:
|
||||
LOGS_PATH = "app/logs/"
|
||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "/seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
||||
CHROMADB_PATH = "/chromadb/"
|
||||
|
||||
RABBITMQ_ENV = "-prod" # 生产环境
|
||||
# RABBITMQ_ENV = "-dev" # 开发环境
|
||||
|
||||
# RABBITMQ_ENV = "" # 生产环境
|
||||
RABBITMQ_ENV = "-dev" # 开发环境
|
||||
# RABBITMQ_ENV = "-local" # 本地测试环境
|
||||
|
||||
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
|
||||
@@ -99,7 +106,7 @@ OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
|
||||
SR_MODEL_NAME = "super_resolution"
|
||||
SR_TRITON_URL = "10.1.1.240:10031"
|
||||
SR_MINIO_BUCKET = "aida-users"
|
||||
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
|
||||
SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ENV}")
|
||||
|
||||
# GenerateImage service config
|
||||
FAST_GI_MODEL_URL = '10.1.1.243:10011'
|
||||
@@ -111,20 +118,20 @@ GI_MODEL_NAME = 'flux'
|
||||
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||
GMV_MODEL_NAME = 'multi_view'
|
||||
|
||||
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
|
||||
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
|
||||
|
||||
GI_MINIO_BUCKET = "aida-users"
|
||||
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
|
||||
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||
|
||||
# SLOGAN service config
|
||||
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
|
||||
SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}")
|
||||
|
||||
# Generate Single Logo service config
|
||||
GSL_MODEL_URL = '10.1.1.243:10041'
|
||||
GSL_MINIO_BUCKET = "aida-users"
|
||||
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
|
||||
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
|
||||
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}")
|
||||
|
||||
# Generate Product service config
|
||||
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
|
||||
@@ -132,17 +139,25 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
|
||||
# GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Product service config 旧版product img 模型
|
||||
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
|
||||
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
|
||||
BATCH_GPI_RABBITMQ_QUEUES = os.getenv("BATCH_GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"BatchToProductImage{RABBITMQ_ENV}")
|
||||
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
|
||||
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
|
||||
GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Single Logo service config
|
||||
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
|
||||
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
|
||||
BATCH_GRI_RABBITMQ_QUEUES = os.getenv("BATCH_GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"BatchRelight{RABBITMQ_ENV}")
|
||||
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||
|
||||
# Pose Transform service config
|
||||
|
||||
PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}")
|
||||
BATCH_PS_RABBITMQ_QUEUES = os.getenv("BATCH_PS_RABBITMQ_QUEUES", f"BatchPoseTransform{RABBITMQ_ENV}")
|
||||
PT_MODEL_URL = '10.1.1.243:10061'
|
||||
|
||||
# SEG service config
|
||||
SEGMENTATION = {
|
||||
"new_model_name": "seg_knet",
|
||||
@@ -152,6 +167,10 @@ SEGMENTATION = {
|
||||
}
|
||||
# ollama config
|
||||
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
|
||||
|
||||
# design batch
|
||||
BATCH_DESIGN_RABBITMQ_QUEUES = os.getenv("BATCH_DESIGN_RABBITMQ_QUEUES", f"DesignBatch{RABBITMQ_ENV}")
|
||||
|
||||
# DESIGN config
|
||||
DESIGN_MODEL_URL = '10.1.1.240:10000'
|
||||
AIDA_CLOTHING = "aida-clothing"
|
||||
@@ -189,3 +208,23 @@ PRIORITY_DICT = {
|
||||
}
|
||||
|
||||
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
|
||||
|
||||
DB_CONFIG = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": "QWa998345",
|
||||
"database": "aida",
|
||||
"charset": "utf8mb4"
|
||||
}
|
||||
|
||||
TABLE_CATEGORIES = {
|
||||
"female_dress": "female/dress",
|
||||
"female_outwear": "female/outwear",
|
||||
"female_trousers": "female/trousers",
|
||||
"female_skirt": "female/skirt",
|
||||
"female_blouse": "female/blouse",
|
||||
"male_tops": "male/tops",
|
||||
"male_bottoms": "male/bottoms",
|
||||
"male_outwear": "male/outwear"
|
||||
}
|
||||
|
||||
12
app/main.py
12
app/main.py
@@ -1,15 +1,17 @@
|
||||
import logging.config
|
||||
from http.client import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
import uvicorn
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api.api_route import router
|
||||
from app.core.config import settings
|
||||
from app.core.record_api_count import count_api_calls
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommend.service import load_resources
|
||||
from logging_env import LOGGER_CONFIG_DICT
|
||||
|
||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||
@@ -17,6 +19,8 @@ logging.getLogger("pika").setLevel(logging.WARNING)
|
||||
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(
|
||||
@@ -51,5 +55,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
@@ -4,3 +4,8 @@ from pydantic import BaseModel
|
||||
class BrandDnaModel(BaseModel):
|
||||
image_url: str
|
||||
is_brand_dna: bool
|
||||
|
||||
|
||||
class GenerateBrandModel(BaseModel):
|
||||
user_id: str
|
||||
prompt: str
|
||||
|
||||
@@ -2,7 +2,6 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatRobotModel(BaseModel):
|
||||
gender: str
|
||||
message: str
|
||||
session_id: str
|
||||
user_id: int
|
||||
|
||||
6
app/schemas/clothing_seg.py
Normal file
6
app/schemas/clothing_seg.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ClothingSegModel(BaseModel):
|
||||
user_id: str
|
||||
image_data: list[dict]
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -36,3 +38,53 @@ class GenerateRelightImageModel(BaseModel):
|
||||
image_url: str
|
||||
direction: str
|
||||
product_type: str
|
||||
|
||||
|
||||
"""
|
||||
batch generate image
|
||||
"""
|
||||
|
||||
|
||||
# product任务子项
|
||||
class ProductItemModel(BaseModel):
|
||||
tasks_id: str
|
||||
image_strength: float
|
||||
prompt: str
|
||||
image_url: str
|
||||
product_type: str
|
||||
|
||||
|
||||
# product批处理 集合
|
||||
class BatchGenerateProductImageModel(BaseModel):
|
||||
batch_tasks_id: str
|
||||
user_id: str
|
||||
batch_data_list: List[ProductItemModel]
|
||||
|
||||
|
||||
# relight任务子项
|
||||
class RelightItemModel(BaseModel):
|
||||
tasks_id: str
|
||||
prompt: str
|
||||
image_url: str
|
||||
direction: str
|
||||
product_type: str
|
||||
|
||||
|
||||
# relight批处理集合
|
||||
class BatchGenerateRelightImageModel(BaseModel):
|
||||
batch_tasks_id: str
|
||||
user_id: str
|
||||
batch_data_list: List[RelightItemModel]
|
||||
|
||||
|
||||
"""
|
||||
agent tool generate image
|
||||
"""
|
||||
|
||||
|
||||
class AgentTollGenerateImageModel(BaseModel):
|
||||
prompt: str
|
||||
category: str
|
||||
gender: str
|
||||
version: str
|
||||
size: int
|
||||
|
||||
10
app/schemas/mannequin_edit.py
Normal file
10
app/schemas/mannequin_edit.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MannequinModel(BaseModel):
|
||||
mannequins: str
|
||||
resize_pixel: float
|
||||
bucket_name: str
|
||||
mannequin_name: str
|
||||
top: int
|
||||
bottom: int
|
||||
14
app/schemas/pose_transform.py
Normal file
14
app/schemas/pose_transform.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PoseTransformModel(BaseModel):
|
||||
image_url: str
|
||||
tasks_id: str
|
||||
pose_id: str
|
||||
|
||||
|
||||
class BatchPoseTransformModel(BaseModel):
|
||||
image_url: str
|
||||
tasks_id: str
|
||||
pose_id: str
|
||||
batch_size: int
|
||||
7
app/schemas/project_info_extraction.py
Normal file
7
app/schemas/project_info_extraction.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProjectInfoExtractionModel(BaseModel):
|
||||
prompt: str
|
||||
image_list: list
|
||||
file_list: list
|
||||
@@ -3,3 +3,7 @@ from pydantic import BaseModel
|
||||
|
||||
class PromptGenerationImageModel(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
img: str
|
||||
|
||||
@@ -9,9 +9,9 @@ import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL, CATEGORY_PATH
|
||||
from app.schemas.brand_dna import BrandDnaModel
|
||||
from app.service.attribute.config import local_debug_const
|
||||
from app.service.attribute.config import local_debug_const, const
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
|
||||
|
||||
@@ -25,18 +25,18 @@ class BrandDna:
|
||||
self.sketch_bucket = "test"
|
||||
self.image_url = request_item.image_url
|
||||
self.is_brand_dna = request_item.is_brand_dna
|
||||
# self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
|
||||
self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
|
||||
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
|
||||
# self.const = const
|
||||
self.const = local_debug_const
|
||||
self.const = const
|
||||
# self.const = local_debug_const
|
||||
|
||||
# 获取结果
|
||||
def get_result(self):
|
||||
mask, image = self.get_seg_mask()
|
||||
cv2.imshow("", image)
|
||||
cv2.waitKey(0)
|
||||
# cv2.imshow("", image)
|
||||
# cv2.waitKey(0)
|
||||
|
||||
height, width, channels = image.shape
|
||||
result_dict = []
|
||||
@@ -50,8 +50,8 @@ class BrandDna:
|
||||
outwear_img[mask == value] = image[mask == value]
|
||||
outwear_mask_img[mask == value] = [0, 0, 255]
|
||||
|
||||
cv2.imshow("", outwear_img)
|
||||
cv2.waitKey(0)
|
||||
# cv2.imshow("", outwear_img)
|
||||
# cv2.waitKey(0)
|
||||
|
||||
# 预处理之后的input img
|
||||
preprocess_img = self.category_preprocess(outwear_img)
|
||||
@@ -89,8 +89,8 @@ class BrandDna:
|
||||
tops_img[mask == value] = image[mask == value]
|
||||
tops_mask_img[mask == value] = [0, 0, 255]
|
||||
|
||||
cv2.imshow("", tops_img)
|
||||
cv2.waitKey(0)
|
||||
# cv2.imshow("", tops_img)
|
||||
# cv2.waitKey(0)
|
||||
|
||||
# 预处理之后的input img
|
||||
preprocess_img = self.category_preprocess(tops_img)
|
||||
@@ -129,8 +129,8 @@ class BrandDna:
|
||||
bottoms_img[mask == value] = image[mask == value]
|
||||
bottoms_mask_img[mask == value] = [0, 0, 255]
|
||||
|
||||
cv2.imshow("", bottoms_img)
|
||||
cv2.waitKey(0)
|
||||
# cv2.imshow("", bottoms_img)
|
||||
# cv2.waitKey(0)
|
||||
|
||||
# 预处理之后的input img
|
||||
preprocess_img = self.category_preprocess(bottoms_img)
|
||||
@@ -327,7 +327,7 @@ if __name__ == '__main__':
|
||||
# result_url = service.get_result()
|
||||
# print(result_url)
|
||||
request_item = BrandDnaModel(
|
||||
image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png",
|
||||
image_url="aida-results/result_00006a48-e315-11ee-b7c8-b48351119060.png",
|
||||
is_brand_dna=True
|
||||
)
|
||||
service = BrandDna(request_item)
|
||||
|
||||
104
app/service/brand_dna/service_generate_brand_info.py
Normal file
104
app/service/brand_dna/service_generate_brand_info.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
# from langchain_openai import ChatOpenAI
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
|
||||
from app.schemas.brand_dna import GenerateBrandModel
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
|
||||
class GenerateBrandInfo:
|
||||
def __init__(self, request_data):
|
||||
# minio client init
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# user info init
|
||||
self.user_id = request_data.user_id
|
||||
self.category = "brand_logo"
|
||||
# generate logo init
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||
self.batch_size = 1
|
||||
self.mode = 'txt2img'
|
||||
|
||||
# llm generate brand info init
|
||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
|
||||
self.response_schemas = [
|
||||
ResponseSchema(name="brand_name", description="Brand name."),
|
||||
ResponseSchema(name="brand_slogan", description="Brand slogan."),
|
||||
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
|
||||
]
|
||||
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
|
||||
self.format_instructions = self.output_parser.get_format_instructions()
|
||||
self.prompt = PromptTemplate(
|
||||
template="你是一个时装品牌的设计师。根据用户输入提取出brand name,brand slogan,brand logo 描述。如果没有以上内容,需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt,这个prompt用于生成模型,prompt需要完全表达用户的想法并使用英文,使用简洁明了的单词不要过长。.\n{format_instructions}\n{question}",
|
||||
input_variables=["question"],
|
||||
partial_variables={"format_instructions": self.format_instructions}
|
||||
)
|
||||
self._input = self.prompt.format_prompt(question=request_data.prompt)
|
||||
|
||||
self.result_data = {}
|
||||
|
||||
def get_result(self):
|
||||
self.llm_generate_brand_info()
|
||||
self.generate_brand_logo()
|
||||
return self.result_data
|
||||
|
||||
def llm_generate_brand_info(self):
|
||||
output = self.model(self._input.to_messages())
|
||||
brand_data = self.output_parser.parse(output.content)
|
||||
self.result_data = brand_data
|
||||
self.generate_logo_prompt = brand_data['brand_logo_prompt']
|
||||
|
||||
def generate_brand_logo(self):
|
||||
prompts = [self.generate_logo_prompt] * self.batch_size
|
||||
modes = [self.mode] * self.batch_size
|
||||
images = [self.image.astype(np.float16)] * self.batch_size
|
||||
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
||||
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
||||
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_mode.set_data_from_numpy(mode_obj)
|
||||
|
||||
inputs = [input_text, input_image, input_mode]
|
||||
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
|
||||
image = result.as_numpy("generated_image")
|
||||
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
||||
logo_url = self.upload_logo_image(image_result, generate_uuid())
|
||||
self.result_data['brand_logo'] = logo_url
|
||||
|
||||
def upload_logo_image(self, image, object_name):
|
||||
try:
|
||||
_, img_byte_array = cv2.imencode('.jpg', image)
|
||||
object_name = f'{self.user_id}/{self.category}/{object_name}'
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = GenerateBrandModel(
|
||||
user_id="89",
|
||||
prompt="华为"
|
||||
)
|
||||
service = GenerateBrandInfo(request_data)
|
||||
print(service.get_result())
|
||||
32
app/service/brand_dna/test.py
Normal file
32
app/service/brand_dna/test.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from dotenv import load_dotenv
|
||||
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# 加载.env文件的环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 创建一个大语言模型,model指定了大语言模型的种类
|
||||
model = ChatOpenAI(model="qwen2.5-14b-instruct")
|
||||
|
||||
# 想要接收的响应模式
|
||||
response_schemas = [
|
||||
ResponseSchema(name="brand_name", description="Brand name."),
|
||||
ResponseSchema(name="brand_slogan", description="Brand slogan."),
|
||||
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
|
||||
]
|
||||
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
prompt = PromptTemplate(
|
||||
template="你是一个时装品牌的设计师。根据用户输入提取出brand name,brand slogan,brand logo 描述。如果没有以上内容,需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt,这个prompt用于生成模型.\n{format_instructions}\n{question}",
|
||||
input_variables=["question"],
|
||||
partial_variables={"format_instructions": format_instructions}
|
||||
)
|
||||
_input = prompt.format_prompt(question="brand name: cat home")
|
||||
|
||||
output = model(_input.to_messages())
|
||||
brand_data = output_parser.parse(output.content)
|
||||
|
||||
|
||||
def generate_logo(bucket_name, object_name, prompt):
|
||||
pass
|
||||
@@ -90,7 +90,6 @@ def chat(post_data):
|
||||
user_id = post_data.user_id
|
||||
session_id = post_data.session_id
|
||||
input_message = post_data.message
|
||||
gender = post_data.gender
|
||||
|
||||
# final_outputs = agent_executor(
|
||||
# {"input": input_message, "gender": gender},
|
||||
@@ -98,7 +97,7 @@ def chat(post_data):
|
||||
# session_key=f"buffer:{user_id}:{session_id}",
|
||||
# )
|
||||
|
||||
final_outputs = CallQWen.call_with_messages(input_message, gender)
|
||||
final_outputs = CallQWen.call_with_messages(input_message)
|
||||
# api_response = {
|
||||
# 'user_id': user_id,
|
||||
# 'session_id': session_id,
|
||||
|
||||
@@ -34,6 +34,39 @@ You may encounter the following types of questions:
|
||||
Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
|
||||
"""
|
||||
|
||||
FASHION_CHAT_BOT_PREFIX_TEMP = """
|
||||
You are a fashion design assistant with the following capabilities:
|
||||
1. Direct conversation: Answer general questions (e.g., greetings, opinions).
|
||||
2. Tool usage:
|
||||
- `get_image_from_vector_db`: Retrieve clothing items (requires gender parameter).
|
||||
- `internet_search`: Fetch real-time fashion trends.
|
||||
- `tutorial_tool`: Provide styling guides.
|
||||
|
||||
Key Rules:
|
||||
1. Tool Selection:
|
||||
- Use `get_image_from_vector_db` for clothing queries (e.g., "show men's jackets").
|
||||
- Use `internet_search` for time-sensitive queries (e.g., "2024 Paris Fashion Week trends").
|
||||
- Use `tutorial_tool` for educational requests (e.g., "how to layer outfits").
|
||||
|
||||
2. Gender Handling (for `get_image_from_vector_db` only):
|
||||
- Step 1: Check the **current user input** for gender keywords (e.g., "women/men/she/he"). If found, extract and pass as `gender`.
|
||||
- Step 2: If no gender in current input, scan the **chat history** for the most recent gender reference.
|
||||
- Step 3: If undetermined, default to `"unisex"`.
|
||||
|
||||
3. Output Format:
|
||||
- Direct replies: Keep responses under 20 words.
|
||||
- Tool calls:
|
||||
- Always include required parameters (e.g., `gender` for `get_image_from_vector_db`).
|
||||
- Auto-fill `gender` using the above rules if unspecified.
|
||||
|
||||
Examples:
|
||||
1. User: "Find red dresses for women"
|
||||
→ `get_image_from_vector_db(gender="female", query="dress")`
|
||||
2. User: "show men's jackets"
|
||||
→ `get_image_from_vector_db(gender="male", query="outwear")`
|
||||
3. User: "Show casual outfits"
|
||||
→ `get_image_from_vector_db(gender="unisex", query="casual outfits")`"""
|
||||
|
||||
TOOL_SELECT_SUFFIX = """
|
||||
Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly.
|
||||
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
|
||||
|
||||
@@ -9,7 +9,7 @@ from app.core.config import *
|
||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||
GET_LANGUAGE_PREFIX
|
||||
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
|
||||
from app.service.search_image_with_text.service import query
|
||||
|
||||
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
@@ -212,14 +212,15 @@ def get_assistant_response(messages):
|
||||
return response
|
||||
|
||||
|
||||
def call_with_messages(message, gender):
|
||||
def call_with_messages(message):
|
||||
global tool_info
|
||||
user_input = message
|
||||
print('\n')
|
||||
|
||||
messages = [
|
||||
{
|
||||
"content": FASHION_CHAT_BOT_PREFIX, # 系统message
|
||||
# "content": FASHION_CHAT_BOT_PREFIX, # 系统message
|
||||
"content": FASHION_CHAT_BOT_PREFIX_TEMP, # 修改后的系统message
|
||||
"role": "system"
|
||||
},
|
||||
{
|
||||
@@ -255,7 +256,7 @@ def call_with_messages(message, gender):
|
||||
tool_info = {"name": "search_from_internet", "role": "tool"}
|
||||
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
|
||||
message = [
|
||||
{'role': 'assistant', 'content': content['query']}
|
||||
{'role': 'assistant', 'content': content['query'] if "query" in content.keys() else user_input}
|
||||
]
|
||||
tool_info['content'] = search_from_internet(message)
|
||||
flag = False
|
||||
@@ -282,6 +283,8 @@ def call_with_messages(message, gender):
|
||||
result_content = tool_info['content']
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
|
||||
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
|
||||
# todo 从历史对话中获取性别,目前无法获得性别时,默认使用female
|
||||
gender = content['gender'] if "gender" in content.keys() and content['gender'] != 'unisex' else 'female'
|
||||
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
|
||||
'content': get_image_from_vector_db(gender, content['parameters']['content'] if "parameters" in content.keys() else content['content'])}
|
||||
flag = False
|
||||
|
||||
161
app/service/clothing_seg/service.py
Normal file
161
app/service/clothing_seg/service.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import io
|
||||
import time
|
||||
from pprint import pprint
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.clothing_seg import ClothingSegModel
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
class ClothingSeg:
|
||||
def __init__(self, request_data):
|
||||
self.image_data = request_data.image_data
|
||||
self.user_id = request_data.user_id
|
||||
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
|
||||
|
||||
@RunTime
|
||||
def get_result(self):
|
||||
self.read_image()
|
||||
self.clothing_seg()
|
||||
self.upload_image()
|
||||
for data in self.image_data:
|
||||
del data["image"]
|
||||
del data["clothing"]
|
||||
return self.image_data
|
||||
|
||||
@RunTime
|
||||
def upload_image(self):
|
||||
for data in self.image_data:
|
||||
data["clothing_url"] = []
|
||||
for clothing in data["clothing"]:
|
||||
object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png"
|
||||
image_data = io.BytesIO()
|
||||
clothing.save(image_data, format="PNG")
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
|
||||
data["clothing_url"].append(f"aida-users/{object_name}")
|
||||
|
||||
@RunTime
|
||||
def read_image(self):
|
||||
for data in self.image_data:
|
||||
url = data["image_url"]
|
||||
image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
data["image"] = image
|
||||
|
||||
@RunTime
|
||||
def clothing_seg(self):
|
||||
for data in self.image_data:
|
||||
image_type = data["image_type"]
|
||||
image = data["image"]
|
||||
clothing_result = []
|
||||
if image_type == "sketch":
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
seg_mask = get_seg_result(1, image[:, :, :3])
|
||||
else:
|
||||
seg_mask = get_seg_result(1, image[:, :, :3])
|
||||
temp = seg_mask != 0.0
|
||||
mask = (255 * (temp + 0).astype(np.uint8))
|
||||
x_min, y_min, x_max, y_max = get_bounding_box(mask)
|
||||
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
h, w = cropped_image.shape[:2]
|
||||
mask_pil = Image.fromarray(cropped_mask).convert("L")
|
||||
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
|
||||
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
|
||||
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
|
||||
clothing_result.append(transparent_image)
|
||||
else:
|
||||
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
input0_data = [input_image.astype(np.float32)] * 1
|
||||
input0_data = np.array(input0_data, dtype=np.float32)
|
||||
inputs = [
|
||||
grpcclient.InferInput(
|
||||
"INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
|
||||
),
|
||||
]
|
||||
|
||||
inputs[0].set_data_from_numpy(input0_data)
|
||||
|
||||
outputs = [
|
||||
# grpcclient.InferRequestedOutput("OUTPUT0"),
|
||||
grpcclient.InferRequestedOutput("OUTPUT1"),
|
||||
]
|
||||
|
||||
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
|
||||
# output0_data = response.as_numpy("OUTPUT0")
|
||||
# cv2.imwrite("output02.png", output0_data * 100)
|
||||
output1_data = response.as_numpy("OUTPUT1")
|
||||
for alpha in output1_data:
|
||||
alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
|
||||
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
|
||||
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
h, w = cropped_image.shape[:2]
|
||||
mask_pil = Image.fromarray(cropped_mask).convert("L")
|
||||
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
|
||||
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
|
||||
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
|
||||
clothing_result.append(transparent_image)
|
||||
data["clothing"] = clothing_result
|
||||
|
||||
|
||||
@RunTime
|
||||
def get_bounding_box(mask):
|
||||
"""
|
||||
从仅包含 0 和 1 的掩码图像中获取边界框。
|
||||
|
||||
:param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1
|
||||
:return: 边界框坐标 (x_min, y_min, x_max, y_max)
|
||||
"""
|
||||
# 找到所有值不为 0 的像素的坐标
|
||||
rows, cols = np.where(mask != 0)
|
||||
|
||||
if len(rows) == 0 or len(cols) == 0:
|
||||
# 如果没有找到不为 0 的像素,返回全 0 的边界框
|
||||
return 0, 0, 0, 0
|
||||
|
||||
# 计算边界框的坐标
|
||||
x_min = np.min(cols)
|
||||
y_min = np.min(rows)
|
||||
x_max = np.max(cols)
|
||||
y_max = np.max(rows)
|
||||
|
||||
return x_min, y_min, x_max, y_max
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_data = ClothingSegModel(
|
||||
user_id=89,
|
||||
image_data=[
|
||||
# {
|
||||
# "image_url": "test/clothing_seg/dress.jpg",
|
||||
# "image_type": "sketch"
|
||||
# },
|
||||
# {
|
||||
# "image_url": "test/clothing_seg/skirt_559.jpg",
|
||||
# "image_type": "sketch"
|
||||
# },
|
||||
{
|
||||
"image_url": "aida-collection-element/87/Sketchboard/ab40e035-547a-48c5-9f97-1db7bf56ad77.jpg",
|
||||
"image_type": "sketch"
|
||||
}
|
||||
]
|
||||
)
|
||||
start_time = time.time()
|
||||
server = ClothingSeg(test_data)
|
||||
pprint(server.get_result())
|
||||
print(time.time() - start_time)
|
||||
@@ -5,9 +5,9 @@ from celery import Celery
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.design_batch.item import BodyItem, TopItem, BottomItem
|
||||
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, AccessoriesItem
|
||||
from app.service.design_batch.utils.MQ import publish_status
|
||||
from app.service.design_batch.utils.organize import organize_body, organize_clothing
|
||||
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_accessories
|
||||
from app.service.design_batch.utils.save_json import oss_upload_json
|
||||
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
||||
|
||||
@@ -19,6 +19,8 @@ logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
logger = logging.getLogger()
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
print("start")
|
||||
|
||||
|
||||
def process_item(item, basic):
|
||||
# 处理project中单个item
|
||||
@@ -28,9 +30,14 @@ def process_item(item, basic):
|
||||
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
|
||||
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = top_server.process()
|
||||
else:
|
||||
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
|
||||
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = bottom_server.process()
|
||||
elif item['type'].lower() in ['accessories']:
|
||||
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = bottom_server.process()
|
||||
else:
|
||||
raise NotImplementedError(f"Item type {item['type']} not implemented")
|
||||
return item_data
|
||||
|
||||
|
||||
@@ -40,6 +47,10 @@ def process_layer(item, layers):
|
||||
body_layer = organize_body(item)
|
||||
layers.append(body_layer)
|
||||
return item['body_image'].size
|
||||
elif item['name'] == 'accessories':
|
||||
front_layer, back_layer = organize_accessories(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
else:
|
||||
front_layer, back_layer = organize_clothing(item)
|
||||
layers.append(front_layer)
|
||||
@@ -48,6 +59,9 @@ def process_layer(item, layers):
|
||||
|
||||
@celery_app.task
|
||||
def batch_design(objects_data, tasks_id, json_name):
|
||||
print(objects_data)
|
||||
print(tasks_id)
|
||||
print(json_name)
|
||||
object_response = []
|
||||
threads = []
|
||||
active_threads = 0
|
||||
@@ -71,7 +85,7 @@ def batch_design(objects_data, tasks_id, json_name):
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
'image_category': lay['name'],
|
||||
'image_category': "body" if lay['name'] == 'mannequin' else lay['name'],
|
||||
'position': lay['position'],
|
||||
'priority': lay.get("priority", None),
|
||||
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
|
||||
@@ -121,6 +135,7 @@ def batch_design(objects_data, tasks_id, json_name):
|
||||
for t in threads:
|
||||
t.join()
|
||||
logger.debug(object_response)
|
||||
print(object_response)
|
||||
oss_upload_json(minio_client, object_response, json_name)
|
||||
publish_status(tasks_id, "ok", json_name)
|
||||
return object_response
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from app.service.design_batch.pipeline import *
|
||||
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
|
||||
|
||||
|
||||
class BaseItem:
|
||||
@@ -9,6 +9,27 @@ class BaseItem:
|
||||
self.result.update(basic)
|
||||
|
||||
|
||||
class AccessoriesItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.Accessories_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
# KeyPoint(),
|
||||
ContourDetection(),
|
||||
# Segmentation(minio_client),
|
||||
# BackPerspective(minio_client),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.Accessories_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
class TopItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
@@ -16,6 +37,7 @@ class TopItem(BaseItem):
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
Segmentation(minio_client),
|
||||
# BackPerspective(minio_client),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
@@ -35,7 +57,8 @@ class BottomItem(BaseItem):
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
ContourDetection(),
|
||||
# Segmentation(),
|
||||
Segmentation(minio_client),
|
||||
# BackPerspective(minio_client),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .back_perspective import BackPerspective
|
||||
from .color import Color
|
||||
from .contour_detection import ContourDetection
|
||||
from .keypoint import KeyPoint
|
||||
@@ -13,6 +14,7 @@ __all__ = [
|
||||
'KeyPoint',
|
||||
'ContourDetection',
|
||||
'Segmentation',
|
||||
'BackPerspective',
|
||||
'Color',
|
||||
'PrintPainting',
|
||||
'Scaling',
|
||||
|
||||
79
app/service/design_batch/pipeline/back_perspective.py
Normal file
79
app/service/design_batch/pipeline/back_perspective.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
|
||||
class BackPerspective:
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
|
||||
# 如果sketch为系统图 查看是否有对应的 背后视角图
|
||||
if result['path'].split('/')[0] == 'aida-sys-image':
|
||||
file_path = result['path'].replace("images", 'images_back', 1)
|
||||
if self.is_file_exists(bucket_name='aida-sys-image', file_name=file_path[file_path.find('/') + 1:]):
|
||||
result['back_perspective_url'] = file_path
|
||||
return result
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
|
||||
seg_result = result['seg_result']
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
|
||||
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
|
||||
back_sketch = result['image'].copy()
|
||||
back_sketch[m > 100] = 255
|
||||
# 上传背后视角图
|
||||
_, img_encoded = cv2.imencode(".jpg", back_sketch)
|
||||
|
||||
resp = oss_upload_image(self.minio_client, bucket='test', object_name=result['path'], image_bytes=img_encoded.tobytes())
|
||||
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
|
||||
return result
|
||||
|
||||
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
|
||||
mask = mask.astype(np.uint8) * 255
|
||||
# 查找轮廓
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# 创建一个彩色副本用于绘制轮廓
|
||||
mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
def thicken_contour_inward(contour, thick):
|
||||
# 创建一个空白的黑色图像与原始掩码大小相同
|
||||
blank = np.zeros_like(mask)
|
||||
# 在空白图像上绘制白色的轮廓
|
||||
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
|
||||
# 找到轮廓的中心(可以用重心等方法近似)
|
||||
M = cv2.moments(contour)
|
||||
cx = int(M['m10'] / M['m00'])
|
||||
cy = int(M['m01'] / M['m00'])
|
||||
# 进行距离变换,离中心越近的值越小
|
||||
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
|
||||
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留
|
||||
result = np.zeros_like(mask)
|
||||
for i in range(dist_transform.shape[0]):
|
||||
for j in range(dist_transform.shape[1]):
|
||||
if dist_transform[i, j] < thick:
|
||||
result[i, j] = 255
|
||||
return result
|
||||
|
||||
for contour in contours:
|
||||
thickened_contour = thicken_contour_inward(contour, thickness)
|
||||
mask_color[thickened_contour > 0] = color
|
||||
|
||||
_, binary_result = cv2.threshold(mask_color, 127, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# 转换为掩码形式
|
||||
mask_result = cv2.cvtColor(binary_result, cv2.COLOR_BGR2GRAY)
|
||||
return mask_result
|
||||
|
||||
def is_file_exists(self, bucket_name, file_name):
|
||||
try:
|
||||
self.minio_client.stat_object(bucket_name, file_name)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -14,14 +14,39 @@ class Color:
|
||||
|
||||
def __call__(self, result):
|
||||
dim_image_h, dim_image_w = result['image'].shape[0:2]
|
||||
# 渐变色
|
||||
if "gradient" in result.keys() and result['gradient'] != "":
|
||||
bucket_name = result['gradient'].split('/')[0]
|
||||
object_name = result['gradient'][result['gradient'].find('/') + 1:]
|
||||
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
# 无色
|
||||
elif "color" not in result.keys() or result['color'] == "":
|
||||
result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
|
||||
result['alpha'] = 100 / 255.0
|
||||
return result
|
||||
# 正常颜色
|
||||
else:
|
||||
pattern = self.get_pattern(result['color'])
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
if "partial_color" in result.keys() and result['partial_color'] != "":
|
||||
bucket_name = result['partial_color'].split('/')[0]
|
||||
object_name = result['partial_color'][result['partial_color'].find('/') + 1:]
|
||||
partial_color = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||
h, w = partial_color.shape[0:2]
|
||||
resize_pattern = cv2.resize(resize_pattern, (w, h), interpolation=cv2.INTER_AREA)
|
||||
# 分离出 png 图的 alpha 通道
|
||||
alpha_channel = partial_color[:, :, 3]
|
||||
# 提取 png 图的 RGB 通道
|
||||
png_rgb = partial_color[:, :, :3]
|
||||
# 创建一个与 cv 图大小相同的掩码,用于指示哪些像素需要替换
|
||||
mask = alpha_channel > 0
|
||||
# 将掩码扩展为 3 通道,以便与 cv 图进行逐元素操作
|
||||
mask_3ch = np.stack([mask] * 3, axis=-1)
|
||||
# 根据掩码将 png 图的颜色覆盖到 cv 图上
|
||||
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
|
||||
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||
|
||||
@@ -4,7 +4,8 @@ import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.design_batch.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.utils.decorator import ClassCallRunTime, RunTime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,14 +17,15 @@ class KeyPoint:
|
||||
def get_name(cls):
|
||||
return cls.name
|
||||
|
||||
@ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
keypoint_cache = self.keypoint_cache(result, site)
|
||||
# keypoint_cache = self.keypoint_cache(result, site)
|
||||
keypoint_cache = False
|
||||
# 取消向量查询 直接过模型推理
|
||||
# keypoint_cache = False
|
||||
if keypoint_cache is False:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
@@ -87,7 +89,7 @@ class KeyPoint:
|
||||
logger.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
# @ RunTime
|
||||
@RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
@@ -71,6 +74,8 @@ class LoadImage:
|
||||
keypoint = 'head_point'
|
||||
elif name == 'earring':
|
||||
keypoint = 'ear_point'
|
||||
elif name == 'accessories':
|
||||
keypoint = "accessories"
|
||||
else:
|
||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||
f"bag, shoes, hairstyle, earring.")
|
||||
|
||||
@@ -15,8 +15,25 @@ class PrintPainting:
|
||||
single_print = result['print']['single']
|
||||
overall_print = result['print']['overall']
|
||||
element_print = result['print']['element']
|
||||
partial_path = result['print']['partial'] if 'partial' in result['print'] else None
|
||||
result['single_image'] = None
|
||||
result['print_image'] = None
|
||||
# TODO 给result['pattern_image'] resize 到resize_scale的大小
|
||||
# TODO 给result['mask'] resize 到resize_scale的大小
|
||||
|
||||
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
|
||||
pass
|
||||
else:
|
||||
height, width = result['pattern_image'].shape[:2]
|
||||
new_width = int(width * result['resize_scale'][0])
|
||||
new_height = int(height * result['resize_scale'][1])
|
||||
|
||||
result['pattern_image'] = cv2.resize(result['pattern_image'], (new_width, new_height))
|
||||
result['final_image'] = cv2.resize(result['final_image'], (new_width, new_height))
|
||||
result['mask'] = cv2.resize(result['mask'], (new_width, new_height))
|
||||
result['gray'] = cv2.resize(result['gray'], (new_width, new_height))
|
||||
|
||||
print(1)
|
||||
if overall_print['print_path_list']:
|
||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||
result['print_image'] = result['pattern_image']
|
||||
@@ -39,7 +56,7 @@ class PrintPainting:
|
||||
for i in range(len(single_print['print_path_list'])):
|
||||
image, image_mode = self.read_image(single_print['print_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
|
||||
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
@@ -62,9 +79,12 @@ class PrintPainting:
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
|
||||
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
@@ -143,7 +163,7 @@ class PrintPainting:
|
||||
for i in range(len(element_print['element_path_list'])):
|
||||
image, image_mode = self.read_image(element_print['element_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
|
||||
new_size = (int(result['final_image'].shape[1] * element_print['element_scale_list'][i][0]), int(result['final_image'].shape[0] * element_print['element_scale_list'][i][1]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
@@ -165,9 +185,11 @@ class PrintPainting:
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
@@ -241,6 +263,45 @@ class PrintPainting:
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
|
||||
if partial_path:
|
||||
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
image, image_mode = self.read_image(partial_path)
|
||||
if image_mode == "RGBA":
|
||||
new_size = (result['pattern_image'].shape[1], result['pattern_image'].shape[0])
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
# rotated_resized_source = resized_source.rotate(-partial_print['print_angle_list'][i])
|
||||
# rotated_resized_source_mask = resized_source_mask.rotate(-partial_print['print_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(resized_source, (0, 0), resized_source)
|
||||
source_image_pil_mask.paste(resized_source_mask, (0, 0), resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
# TODO element 丢失信息
|
||||
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
||||
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@@ -360,10 +421,10 @@ class PrintPainting:
|
||||
return print_image
|
||||
|
||||
def get_print(self, print_dict):
|
||||
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
|
||||
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0][0] < 0.3:
|
||||
print_dict['scale'] = 0.3
|
||||
else:
|
||||
print_dict['scale'] = print_dict['print_scale_list'][0]
|
||||
print_dict['scale'] = print_dict['print_scale_list'][0][0]
|
||||
|
||||
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
|
||||
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
|
||||
@@ -386,8 +447,9 @@ class PrintPainting:
|
||||
# y_offset = random.randint(0, image.shape[1] - image_size_w)
|
||||
|
||||
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||
x_offset = print_w - int(location[0][1] % print_w)
|
||||
y_offset = print_w - int(location[0][0] % print_h)
|
||||
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
|
||||
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
|
||||
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
|
||||
|
||||
# y_offset = int(location[0][0])
|
||||
# x_offset = int(location[0][1])
|
||||
@@ -409,7 +471,7 @@ class PrintPainting:
|
||||
return high, low
|
||||
|
||||
@staticmethod
|
||||
def img_rotate(image, angel, scale):
|
||||
def img_rotate(image, angel):
|
||||
"""顺时针旋转图像任意角度
|
||||
|
||||
Args:
|
||||
@@ -424,7 +486,7 @@ class PrintPainting:
|
||||
center = (w // 2, h // 2)
|
||||
# if type(angel) is not int:
|
||||
# angel = 0
|
||||
M = cv2.getRotationMatrix2D(center, -angel, scale)
|
||||
M = cv2.getRotationMatrix2D(center, -angel, 1)
|
||||
# 调整旋转后的图像长宽
|
||||
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
|
||||
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
|
||||
@@ -433,7 +495,7 @@ class PrintPainting:
|
||||
# 旋转图像
|
||||
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
|
||||
|
||||
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
|
||||
return rotated_img, ((rotated_img.shape[1] - image.shape[1]) // 2, (rotated_img.shape[0] - image.shape[0]) // 2)
|
||||
# return rotated_img, (0, 0)
|
||||
|
||||
@staticmethod
|
||||
@@ -442,8 +504,11 @@ class PrintPainting:
|
||||
angle: 旋转的角度
|
||||
crop: 是否需要进行裁剪,布尔向量
|
||||
"""
|
||||
if not isinstance(crop, bool):
|
||||
raise ValueError("The 'crop' parameter must be a boolean.")
|
||||
|
||||
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
|
||||
w, h = img.shape[:2]
|
||||
h, w = img.shape[:2]
|
||||
# 旋转角度的周期是360°
|
||||
angle %= 360
|
||||
# 计算仿射变换矩阵
|
||||
@@ -455,7 +520,7 @@ class PrintPainting:
|
||||
if crop:
|
||||
# 裁剪角度的等效周期是180°
|
||||
angle_crop = angle % 180
|
||||
if angle > 90:
|
||||
if angle_crop > 90:
|
||||
angle_crop = 180 - angle_crop
|
||||
# 转化角度为弧度
|
||||
theta = angle_crop * np.pi / 180
|
||||
|
||||
@@ -46,4 +46,16 @@ class Scaling:
|
||||
result['scale'] = result['scale_bag']
|
||||
elif result['keypoint'] == 'ear_point':
|
||||
result['scale'] = result['scale_earrings']
|
||||
elif result['keypoint'] == 'accessories':
|
||||
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
|
||||
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
|
||||
distance_clo = result['img_shape'][1]
|
||||
distance_bdy = 320 / 2
|
||||
|
||||
if distance_clo == 0:
|
||||
result['scale'] = 1
|
||||
else:
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
else:
|
||||
result['scale'] = 1
|
||||
return result
|
||||
|
||||
@@ -5,7 +5,8 @@ import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.service.design_batch.utils.design_ensemble import get_seg_result
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -15,6 +16,7 @@ class Segmentation:
|
||||
def __init__(self, minio_client):
|
||||
self.minio_client = minio_client
|
||||
|
||||
@ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
|
||||
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
|
||||
@@ -31,24 +33,37 @@ class Segmentation:
|
||||
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
else:
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
result['seg_result'] = seg_result
|
||||
if not _:
|
||||
# preview 过模型 不缓存
|
||||
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
# submit 过模型 缓存
|
||||
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# null 正常流程 加载本地缓存 无缓存则过模型
|
||||
else:
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
# 判断缓存和实际图片size是否相同
|
||||
if not _ or result["image"].shape[:2] != seg_result.shape:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
result['seg_result'] = seg_result
|
||||
|
||||
# 处理前片后片
|
||||
temp_front = seg_result == 1.0
|
||||
temp_front = seg_result == 1
|
||||
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = seg_result == 2.0
|
||||
temp_back = seg_result == 2
|
||||
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"seg_cache/{image_id}.npy"
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
@@ -57,7 +72,7 @@ class Segmentation:
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"seg_cache/{image_id}.npy"
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
|
||||
@@ -7,10 +7,11 @@ from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.design_batch.utils.conversion_image import rgb_to_rgba
|
||||
from app.service.design_batch.utils.upload_image import upload_png_mask
|
||||
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
|
||||
from app.service.design_fast.utils.transparent import sketch_to_transparent
|
||||
from app.service.design_fast.utils.upload_image import upload_png_mask
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
|
||||
|
||||
|
||||
class Split(object):
|
||||
@@ -20,51 +21,95 @@ class Split(object):
|
||||
def __call__(self, result):
|
||||
try:
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
front_mask = result['front_mask']
|
||||
back_mask = result['back_mask']
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'):
|
||||
|
||||
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
|
||||
front_mask = result['front_mask']
|
||||
back_mask = result['back_mask']
|
||||
else:
|
||||
height, width = result['front_mask'].shape[:2]
|
||||
new_width = int(width * result['resize_scale'][0])
|
||||
new_height = int(height * result['resize_scale'][1])
|
||||
|
||||
front_mask = cv2.resize(result['front_mask'], (new_width, new_height))
|
||||
back_mask = cv2.resize(result['back_mask'], (new_width, new_height))
|
||||
|
||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
||||
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
|
||||
rgba_image = cv2.resize(rgba_image, new_size)
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
front_mask = cv2.resize(front_mask, new_size)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
if 'transparent' in result.keys():
|
||||
# 用户自选区域transparent
|
||||
transparent = result['transparent']
|
||||
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
|
||||
# 预处理用户自选区mask
|
||||
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
|
||||
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST)
|
||||
# 转换颜色空间为 RGB(OpenCV 默认是 BGR)
|
||||
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
|
||||
|
||||
r, g, b = cv2.split(image_rgb)
|
||||
blue_mask = b > r
|
||||
|
||||
# 创建红色和绿色掩码
|
||||
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
|
||||
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
|
||||
else:
|
||||
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
|
||||
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
|
||||
|
||||
height, width = front_mask.shape
|
||||
mask_image = np.zeros((height, width, 3))
|
||||
mask_image[front_mask != 0] = [0, 0, 255]
|
||||
|
||||
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
# result_back_image = np.zeros_like(rgba_image)
|
||||
# back_mask = cv2.resize(back_mask, new_size)
|
||||
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
# mask_image[back_mask != 0] = [0, 255, 0]
|
||||
#
|
||||
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
# image_data = io.BytesIO()
|
||||
# mask_pil.save(image_data, format='PNG')
|
||||
# image_data.seek(0)
|
||||
# image_bytes = image_data.read()
|
||||
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
# result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
# else:
|
||||
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
# image_data = io.BytesIO()
|
||||
# mask_pil.save(image_data, format='PNG')
|
||||
# image_data.seek(0)
|
||||
# image_bytes = image_data.read()
|
||||
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
# result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
# result['back_image'] = None
|
||||
# result["back_image_url"] = None
|
||||
# # result["back_mask_url"] = None
|
||||
# # result['back_mask_image'] = None
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
else:
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
result['back_image'] = None
|
||||
result["back_image_url"] = None
|
||||
# result["back_mask_url"] = None
|
||||
# result['back_mask_image'] = None
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
|
||||
@@ -2,16 +2,16 @@ import json
|
||||
|
||||
import pika
|
||||
|
||||
from app.core.config import RABBITMQ_PARAMS
|
||||
from app.core.config import RABBITMQ_PARAMS, BATCH_DESIGN_RABBITMQ_QUEUES
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue='DesignBatch', durable=True)
|
||||
channel.queue_declare(queue=BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key='DesignBatch',
|
||||
routing_key=BATCH_DESIGN_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
|
||||
@@ -33,8 +33,8 @@ def organize_clothing(layer):
|
||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||
pattern_image_url=layer['pattern_image_url'],
|
||||
pattern_image=layer['pattern_image']
|
||||
|
||||
pattern_image=layer['pattern_image'],
|
||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||
)
|
||||
# 后片数据
|
||||
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
|
||||
@@ -50,6 +50,46 @@ def organize_clothing(layer):
|
||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||
pattern_image_url=layer['pattern_image_url'],
|
||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||
)
|
||||
return front_layer, back_layer
|
||||
|
||||
|
||||
def organize_accessories(layer):
|
||||
# 起始坐标
|
||||
start_point = (0, 0)
|
||||
# 前片数据
|
||||
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
|
||||
name=f'{layer["name"].lower()}_front',
|
||||
image=layer["front_image"],
|
||||
# mask_image=layer['front_mask_image'],
|
||||
image_url=layer['front_image_url'],
|
||||
mask_url=layer['mask_url'],
|
||||
sacle=layer['scale'],
|
||||
clothes_keypoint=(0, 0),
|
||||
position=start_point,
|
||||
resize_scale=layer["resize_scale"],
|
||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||
pattern_image_url=layer['pattern_image_url'],
|
||||
pattern_image=layer['pattern_image'],
|
||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||
)
|
||||
# 后片数据
|
||||
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
|
||||
name=f'{layer["name"].lower()}_back',
|
||||
image=layer["back_image"],
|
||||
# mask_image=layer['back_mask_image'],
|
||||
image_url=layer['back_image_url'],
|
||||
mask_url=layer['mask_url'],
|
||||
sacle=layer['scale'],
|
||||
clothes_keypoint=(0, 0),
|
||||
position=start_point,
|
||||
resize_scale=layer["resize_scale"],
|
||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||
pattern_image_url=layer['pattern_image_url'],
|
||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||
)
|
||||
return front_layer, back_layer
|
||||
|
||||
|
||||
@@ -200,6 +200,11 @@ def design_generate_v2(request_data):
|
||||
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
||||
# 发送结果给java端
|
||||
url = JAVA_STREAM_API_URL
|
||||
# xu_pei_test_url = "https://cd21b9110505.ngrok-free.app/api/third/party/receiveDesignResults"
|
||||
|
||||
logger.info(f"java 回调 -> {url}")
|
||||
# logger.info(f"xupei java 回调 -> {xu_pei_test_url}")
|
||||
|
||||
headers = {
|
||||
'Accept': "*/*",
|
||||
'Accept-Encoding': "gzip, deflate, br",
|
||||
@@ -213,6 +218,11 @@ def design_generate_v2(request_data):
|
||||
# 打印结果
|
||||
logger.info(response.text)
|
||||
|
||||
# response = post_request(xu_pei_test_url, json_data=items_response, headers=headers)
|
||||
# if response:
|
||||
# 打印结果
|
||||
# logger.info(f"xupei test response : {response.text}")
|
||||
|
||||
for step, object in enumerate(objects_data):
|
||||
t = threading.Thread(target=process_object, args=(step, object))
|
||||
threads.append(t)
|
||||
|
||||
@@ -57,7 +57,7 @@ class BottomItem(BaseItem):
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
ContourDetection(),
|
||||
# Segmentation(),
|
||||
Segmentation(minio_client),
|
||||
# BackPerspective(minio_client),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
|
||||
@@ -29,6 +29,24 @@ class Color:
|
||||
else:
|
||||
pattern = self.get_pattern(result['color'])
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
if "partial_color" in result.keys() and result['partial_color'] != "":
|
||||
bucket_name = result['partial_color'].split('/')[0]
|
||||
object_name = result['partial_color'][result['partial_color'].find('/') + 1:]
|
||||
partial_color = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||
h, w = partial_color.shape[0:2]
|
||||
resize_pattern = cv2.resize(resize_pattern, (w, h), interpolation=cv2.INTER_AREA)
|
||||
# 分离出 png 图的 alpha 通道
|
||||
alpha_channel = partial_color[:, :, 3]
|
||||
# 提取 png 图的 RGB 通道
|
||||
png_rgb = partial_color[:, :, :3]
|
||||
# 创建一个与 cv 图大小相同的掩码,用于指示哪些像素需要替换
|
||||
mask = alpha_channel > 0
|
||||
# 将掩码扩展为 3 通道,以便与 cv 图进行逐元素操作
|
||||
mask_3ch = np.stack([mask] * 3, axis=-1)
|
||||
# 根据掩码将 png 图的颜色覆盖到 cv 图上
|
||||
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
|
||||
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||
|
||||
@@ -15,6 +15,7 @@ class PrintPainting:
|
||||
single_print = result['print']['single']
|
||||
overall_print = result['print']['overall']
|
||||
element_print = result['print']['element']
|
||||
partial_path = result['print']['partial'] if 'partial' in result['print'] else None
|
||||
result['single_image'] = None
|
||||
result['print_image'] = None
|
||||
# TODO 给result['pattern_image'] resize 到resize_scale的大小
|
||||
@@ -32,7 +33,6 @@ class PrintPainting:
|
||||
result['mask'] = cv2.resize(result['mask'], (new_width, new_height))
|
||||
result['gray'] = cv2.resize(result['gray'], (new_width, new_height))
|
||||
|
||||
print(1)
|
||||
if overall_print['print_path_list']:
|
||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||
result['print_image'] = result['pattern_image']
|
||||
@@ -54,90 +54,89 @@ class PrintPainting:
|
||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(single_print['print_path_list'])):
|
||||
image, image_mode = self.read_image(single_print['print_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
if image_mode == "RGB":
|
||||
image_rgba = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
|
||||
image = Image.fromarray(image_rgba)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
|
||||
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
# else:
|
||||
# mask = self.get_mask_inv(image)
|
||||
# mask = np.expand_dims(mask, axis=2)
|
||||
# mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
# mask = cv2.bitwise_not(mask)
|
||||
#
|
||||
# mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||
# image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||
# # 旋转后的坐标需要重新算
|
||||
# rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
|
||||
# rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
|
||||
# # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
# x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
||||
#
|
||||
# image_x = print_background.shape[1] # 底图宽
|
||||
# image_y = print_background.shape[0] # 底图高
|
||||
# print_x = rotate_image.shape[1] #印花宽
|
||||
# print_y = rotate_image.shape[0] #印花高
|
||||
#
|
||||
# # 有bug
|
||||
# # if x + print_x > image_x:
|
||||
# # rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# # rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# # #
|
||||
# # if y + print_y > image_y:
|
||||
# # rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# # rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
#
|
||||
# # 不能是并行
|
||||
# # 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# # 先挪 再判断 最后裁剪
|
||||
#
|
||||
# # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
# if x <= 0: # 如果X轴偏移量小于0,说明印花需要被裁剪至合适大小 或当X轴偏移量大于印花宽度时,裁剪后的印花宽度为0
|
||||
# rotate_image = rotate_image[:, abs(x):]
|
||||
# rotate_mask = rotate_mask[:, abs(x):]
|
||||
# start_x = x = 0
|
||||
# else:
|
||||
# start_x = x
|
||||
#
|
||||
# if y <= 0: # 如果X轴偏移量大于0,说明印花需要被裁剪至合适大小 或当Y轴偏移量大于印花宽度时,裁剪后的印花宽度为0
|
||||
# rotate_image = rotate_image[abs(y):, :]
|
||||
# rotate_mask = rotate_mask[abs(y):, :]
|
||||
# start_y = y = 0
|
||||
# else:
|
||||
# start_y = y
|
||||
#
|
||||
# # ------------------
|
||||
# # 如果print-size大于image-size 则需要裁剪print
|
||||
#
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :image_x - x]
|
||||
# rotate_mask = rotate_mask[:, :image_x - x]
|
||||
#
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:image_y - y, :]
|
||||
# rotate_mask = rotate_mask[:image_y - y, :]
|
||||
#
|
||||
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
#
|
||||
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
# mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
# print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
@@ -262,6 +261,45 @@ class PrintPainting:
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
|
||||
if partial_path:
|
||||
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
image, image_mode = self.read_image(partial_path)
|
||||
if image_mode == "RGBA":
|
||||
new_size = (result['pattern_image'].shape[1], result['pattern_image'].shape[0])
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
# rotated_resized_source = resized_source.rotate(-partial_print['print_angle_list'][i])
|
||||
# rotated_resized_source_mask = resized_source_mask.rotate(-partial_print['print_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(resized_source, (0, 0), resized_source)
|
||||
source_image_pil_mask.paste(resized_source_mask, (0, 0), resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
# TODO element 丢失信息
|
||||
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
||||
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@@ -414,7 +452,6 @@ class PrintPainting:
|
||||
# y_offset = int(location[0][0])
|
||||
# x_offset = int(location[0][1])
|
||||
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||
elif len(image.shape) == 3:
|
||||
|
||||
@@ -65,35 +65,51 @@ class Split(object):
|
||||
mask_image = np.zeros((height, width, 3))
|
||||
mask_image[front_mask != 0] = [0, 0, 255]
|
||||
|
||||
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
# result_back_image = np.zeros_like(rgba_image)
|
||||
# back_mask = cv2.resize(back_mask, new_size)
|
||||
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
# mask_image[back_mask != 0] = [0, 255, 0]
|
||||
#
|
||||
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
# image_data = io.BytesIO()
|
||||
# mask_pil.save(image_data, format='PNG')
|
||||
# image_data.seek(0)
|
||||
# image_bytes = image_data.read()
|
||||
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
# result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
# else:
|
||||
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
# image_data = io.BytesIO()
|
||||
# mask_pil.save(image_data, format='PNG')
|
||||
# image_data.seek(0)
|
||||
# image_bytes = image_data.read()
|
||||
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
# result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
# result['back_image'] = None
|
||||
# result["back_image_url"] = None
|
||||
# # result["back_mask_url"] = None
|
||||
# # result['back_mask_image'] = None
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
else:
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
result['back_image'] = None
|
||||
result["back_image_url"] = None
|
||||
# result["back_mask_url"] = None
|
||||
# result['back_mask_image'] = None
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
|
||||
24
app/service/generate_batch_image/service.py
Normal file
24
app/service/generate_batch_image/service.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product, publish_status as product_publish_status
|
||||
from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight, publish_status as relight_publish_status
|
||||
from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform, publish_status as pose_transform_publish_status
|
||||
|
||||
|
||||
async def start_product_batch_generate(data):
|
||||
generate_clothes_task = batch_generate_product.delay(data.dict())
|
||||
print(generate_clothes_task)
|
||||
product_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "")
|
||||
return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state}
|
||||
|
||||
|
||||
async def start_relight_batch_generate(data):
|
||||
generate_clothes_task = batch_generate_relight.delay(data.dict())
|
||||
print(generate_clothes_task)
|
||||
relight_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "")
|
||||
return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state}
|
||||
|
||||
|
||||
async def start_pose_transform_batch_generate(data):
|
||||
generate_clothes_task = batch_generate_pose_transform.delay(data.dict())
|
||||
print(generate_clothes_task)
|
||||
pose_transform_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
|
||||
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}
|
||||
@@ -0,0 +1,242 @@
|
||||
# 旧版product
|
||||
# !/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from celery import Celery
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import BatchGenerateProductImageModel, ProductItemModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
celery_app = Celery('product_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app.conf.task_default_queue = 'queue_product'
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logger = logging.getLogger()
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||
category = "product_image"
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def batch_generate_product(batch_request_data):
|
||||
batch_size = len(batch_request_data['batch_data_list'])
|
||||
logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}")
|
||||
batch_tasks_id = batch_request_data['batch_tasks_id']
|
||||
user_id = batch_request_data['user_id']
|
||||
result_data_list = []
|
||||
|
||||
for i, data in enumerate(batch_request_data['batch_data_list']):
|
||||
tasks_id = data['tasks_id']
|
||||
image = pre_processing_image(data['image_url'])
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
images = [image.astype(np.uint8)] * 1
|
||||
prompts = [data['prompt']] * 1
|
||||
if data['product_type'] == "single":
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_image_strength.set_data_from_numpy(image_strength_obj)
|
||||
|
||||
inputs = [input_text, input_image, input_image_strength]
|
||||
|
||||
try:
|
||||
if data['product_type'] == "single":
|
||||
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_cnet_image")
|
||||
else:
|
||||
result = grpc_client.infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_inpaint_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
except Exception as e:
|
||||
if 'mask_list' in str(e):
|
||||
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
e_image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||
|
||||
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
|
||||
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
|
||||
e_input_image_strength = grpcclient.InferInput("image_strength", e_image_strength_obj.shape, np_to_triton_dtype(e_image_strength_obj.dtype))
|
||||
|
||||
e_input_text.set_data_from_numpy(e_text_obj)
|
||||
e_input_image.set_data_from_numpy(e_image_obj)
|
||||
e_input_image_strength.set_data_from_numpy(e_image_strength_obj)
|
||||
|
||||
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=[e_input_text, e_input_image, e_input_image_strength], priority=100)
|
||||
image = result.as_numpy("generated_cnet_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
else:
|
||||
image_result = str(e)
|
||||
logger.error(image_result)
|
||||
|
||||
if isinstance(image_result, Image.Image):
|
||||
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
|
||||
data['product_img'] = image_url
|
||||
result_data_list.append(data)
|
||||
else:
|
||||
image_url = image_result
|
||||
data['product_img'] = image_url
|
||||
result_data_list.append(data)
|
||||
|
||||
# 发送每条结果
|
||||
if DEBUG:
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
else:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
|
||||
# 任务完成,发送所有数据结果
|
||||
if DEBUG:
|
||||
print(result_data_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
else:
|
||||
publish_status(batch_tasks_id, f"OK", result_data_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# 目标图片的尺寸
|
||||
target_width = 512
|
||||
target_height = 768
|
||||
|
||||
# 原始图片的尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 计算宽度和高度的缩放比例
|
||||
width_ratio = target_width / original_width
|
||||
height_ratio = target_height / original_height
|
||||
|
||||
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||
scale_ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# 计算调整后的尺寸
|
||||
new_width = int(original_width * scale_ratio)
|
||||
new_height = int(original_height * scale_ratio)
|
||||
|
||||
# 调整图片大小
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
# 创建一个 512x768 的透明图片
|
||||
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
|
||||
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (target_width - new_width) // 2
|
||||
y_offset = (target_height - new_height) // 2
|
||||
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
if resized_image.mode == "RGBA":
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
else:
|
||||
result_image.paste(resized_image, (x_offset, y_offset))
|
||||
|
||||
image = np.array(result_image)
|
||||
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
return image
|
||||
|
||||
|
||||
def post_processing_image(image, left, top):
|
||||
resized_image = image.resize((int(image.width * (768 / image.height)), 768))
|
||||
# 计算裁剪的坐标
|
||||
left = (resized_image.width - 512) // 2
|
||||
upper = 0
|
||||
right = left + 512
|
||||
lower = 768
|
||||
|
||||
# 进行裁剪
|
||||
cropped_image = resized_image.crop((left, upper, right, lower))
|
||||
return cropped_image
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=BATCH_GPI_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=BATCH_GPI_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# rd = BatchGenerateProductImageModel(
|
||||
# tasks_id="123-15-51-89",
|
||||
# image_strength=0.7,
|
||||
# prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||
# image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
# product_type="overall",
|
||||
# batch_size=20
|
||||
# )
|
||||
# batch_generate_product(rd.dict())
|
||||
# rd = {
|
||||
# "user_id": "89",
|
||||
# "batch_data_list": [
|
||||
# {
|
||||
# "tasks_id": "A-123-15-51-89",
|
||||
# "image_strength": 0.7,
|
||||
# "prompt": " The best quality, ma123sterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
# "product_type": "overall",
|
||||
# },
|
||||
# {
|
||||
# "tasks_id": "B-123-15-51-89",
|
||||
# "image_strength": 0.7,
|
||||
# "prompt": " The best quality, masterpiece, real image.Outwear123,high quality clothing details,8K realistic,HDR",
|
||||
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
# "product_type": "overall",
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
rd = BatchGenerateProductImageModel(
|
||||
batch_tasks_id="abcd",
|
||||
user_id="89",
|
||||
batch_data_list=[
|
||||
ProductItemModel(
|
||||
tasks_id="123-5464",
|
||||
image_strength=0.7,
|
||||
product_type="overall",
|
||||
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
prompt="123"
|
||||
),
|
||||
ProductItemModel(
|
||||
tasks_id="123-5464123",
|
||||
image_strength=0.7,
|
||||
product_type="overall",
|
||||
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
prompt="123"
|
||||
)
|
||||
]
|
||||
)
|
||||
batch_generate_product(rd.dict())
|
||||
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from celery import Celery
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import BatchGenerateRelightImageModel, RelightItemModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
celery_app = Celery('relight_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app.conf.task_default_queue = 'queue_relight'
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
||||
category = "relight_image"
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def batch_generate_relight(batch_request_data):
|
||||
batch_size = len(batch_request_data['batch_data_list'])
|
||||
logger.info(f"batch_generate_relight batch_request_data: {json.dumps(batch_request_data, indent=4)}")
|
||||
batch_tasks_id = batch_request_data['batch_tasks_id']
|
||||
user_id = batch_request_data['user_id']
|
||||
result_data_list = []
|
||||
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
||||
seed = "1"
|
||||
|
||||
for i, data in enumerate(batch_request_data['batch_data_list']):
|
||||
direction = data['direction']
|
||||
|
||||
prompt = data['prompt']
|
||||
product_type = data['product_type']
|
||||
image_url = data['image_url']
|
||||
image = pre_processing_image(image_url)
|
||||
tasks_id = data['tasks_id']
|
||||
|
||||
prompts = [prompt] * 1
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (512, 768))
|
||||
images = [image.astype(np.uint8)] * 1
|
||||
seeds = [seed] * 1
|
||||
nagetive_prompts = [negative_prompt] * 1
|
||||
directions = [direction] * 1
|
||||
|
||||
if product_type == 'single':
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
|
||||
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
|
||||
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_natext.set_data_from_numpy(na_text_obj)
|
||||
input_seed.set_data_from_numpy(seed_obj)
|
||||
input_direction.set_data_from_numpy(direction_obj)
|
||||
|
||||
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
|
||||
try:
|
||||
if data['product_type'] == "single":
|
||||
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_relight_image")
|
||||
else:
|
||||
result = grpc_client.infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_inpaint_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if 'mask_list' in str(e):
|
||||
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
e_na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
|
||||
e_seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
|
||||
e_direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
|
||||
|
||||
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
|
||||
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
|
||||
e_input_natext = grpcclient.InferInput("negative_prompt", e_na_text_obj.shape, np_to_triton_dtype(e_na_text_obj.dtype))
|
||||
e_input_seed = grpcclient.InferInput("seed", e_seed_obj.shape, np_to_triton_dtype(e_seed_obj.dtype))
|
||||
e_input_direction = grpcclient.InferInput("direction", e_direction_obj.shape, np_to_triton_dtype(e_direction_obj.dtype))
|
||||
|
||||
e_input_text.set_data_from_numpy(e_text_obj)
|
||||
e_input_image.set_data_from_numpy(e_image_obj)
|
||||
e_input_natext.set_data_from_numpy(e_na_text_obj)
|
||||
e_input_seed.set_data_from_numpy(e_seed_obj)
|
||||
e_input_direction.set_data_from_numpy(e_direction_obj)
|
||||
|
||||
e_inputs = [e_input_text, e_input_natext, e_input_image, e_input_seed, e_input_direction]
|
||||
|
||||
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=e_inputs, priority=100)
|
||||
image = result.as_numpy("generated_relight_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
else:
|
||||
image_result = str(e)
|
||||
logger.error(e)
|
||||
if isinstance(image_result, Image.Image):
|
||||
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
|
||||
data['relight_img'] = image_url
|
||||
|
||||
result_data_list.append(data)
|
||||
else:
|
||||
image_url = image_result
|
||||
data['relight_img'] = image_url
|
||||
result_data_list.append(data)
|
||||
|
||||
# 发送每条结果
|
||||
if DEBUG:
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
else:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
# 任务完成,发送所有数据结果
|
||||
if DEBUG:
|
||||
print(result_data_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
else:
|
||||
publish_status(batch_tasks_id, f"OK", result_data_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# 目标图片的尺寸
|
||||
target_width = 512
|
||||
target_height = 768
|
||||
|
||||
# 原始图片的尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 计算宽度和高度的缩放比例
|
||||
width_ratio = target_width / original_width
|
||||
height_ratio = target_height / original_height
|
||||
|
||||
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||
scale_ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# 计算调整后的尺寸
|
||||
new_width = int(original_width * scale_ratio)
|
||||
new_height = int(original_height * scale_ratio)
|
||||
|
||||
# 调整图片大小
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
# 创建一个 512x768 的透明图片
|
||||
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
|
||||
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (target_width - new_width) // 2
|
||||
y_offset = (target_height - new_height) // 2
|
||||
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
if resized_image.mode == "RGBA":
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
else:
|
||||
result_image.paste(resized_image, (x_offset, y_offset))
|
||||
|
||||
image = np.array(result_image)
|
||||
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
return image
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=BATCH_GRI_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = BatchGenerateRelightImageModel(
|
||||
batch_tasks_id="abcd",
|
||||
user_id="89",
|
||||
batch_data_list=[
|
||||
RelightItemModel(
|
||||
tasks_id="123-5464",
|
||||
product_type="overall",
|
||||
image_url="test/703190759.png",
|
||||
prompt="Colorful black",
|
||||
direction="Right Light",
|
||||
),
|
||||
RelightItemModel(
|
||||
tasks_id="123-5464123",
|
||||
product_type="overall",
|
||||
image_url="test/703190759.png",
|
||||
direction="Right Light",
|
||||
prompt="Colorful black",
|
||||
)
|
||||
]
|
||||
)
|
||||
batch_generate_relight(rd.dict())
|
||||
# X = {
|
||||
# "batch_tasks_id": "abcd",
|
||||
# "user_id": "89",
|
||||
# "batch_data_list": [
|
||||
# {
|
||||
# "tasks_id": "123-5464",
|
||||
# "product_type": "overall",
|
||||
# "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
|
||||
# "prompt": "Colorful black",
|
||||
# "direction": "Right Light",
|
||||
# },
|
||||
# {
|
||||
# "tasks_id": "123-5464",
|
||||
# "product_type": "overall",
|
||||
# "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
|
||||
# "prompt": "Colorful black",
|
||||
# "direction": "Right Light",
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# }
|
||||
176
app/service/generate_batch_image/service_batch_pose_transform.py
Normal file
176
app/service/generate_batch_image/service_batch_pose_transform.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from celery import Celery
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.pose_transform import BatchPoseTransformModel
|
||||
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
logger = logging.getLogger()
|
||||
celery_app = Celery('post_transform_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app.conf.task_default_queue = 'queue_post_transform'
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
|
||||
category = "pose_transform"
|
||||
|
||||
|
||||
def upload_first_image(image, user_id, category, file_name):
|
||||
try:
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
object_name = f'{user_id}/{category}/{file_name}'
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# 目标图片的尺寸
|
||||
target_width = 512
|
||||
target_height = 768
|
||||
|
||||
# 原始图片的尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 计算宽度和高度的缩放比例
|
||||
width_ratio = target_width / original_width
|
||||
height_ratio = target_height / original_height
|
||||
|
||||
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||
scale_ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# 计算调整后的尺寸
|
||||
new_width = int(original_width * scale_ratio)
|
||||
new_height = int(original_height * scale_ratio)
|
||||
|
||||
# 调整图片大小
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
# 创建一个 512x768 的透明图片
|
||||
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
|
||||
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (target_width - new_width) // 2
|
||||
y_offset = (target_height - new_height) // 2
|
||||
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
if resized_image.mode == "RGBA":
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
else:
|
||||
result_image.paste(resized_image, (x_offset, y_offset))
|
||||
result_image = result_image.convert("RGB")
|
||||
image = np.array(result_image)
|
||||
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def batch_generate_pose_transform(batch_request_data):
|
||||
logger.info(f"batch_generate_pose_transform batch_request_data: {json.dumps(batch_request_data, indent=4)}")
|
||||
batch_size = batch_request_data['batch_size']
|
||||
image_url = batch_request_data['image_url']
|
||||
image = pre_processing_image(image_url)
|
||||
pose_num = batch_request_data['pose_id']
|
||||
tasks_id = batch_request_data['tasks_id']
|
||||
user_id = tasks_id.rsplit('-', 1)[1]
|
||||
|
||||
pose_num = [pose_num] * 1
|
||||
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
|
||||
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype))
|
||||
input_pose_num.set_data_from_numpy(pose_num_obj)
|
||||
|
||||
image_files = [image.astype(np.uint8)] * 1
|
||||
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
|
||||
input_image_files.set_data_from_numpy(image_files_obj)
|
||||
|
||||
result_url_list = []
|
||||
for i in range(batch_size):
|
||||
try:
|
||||
result = grpc_client.infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], client_timeout=60000, priority=100)
|
||||
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
|
||||
# 第一帧图像
|
||||
first_image = Image.fromarray(result_data[0])
|
||||
first_image_url = upload_first_image(first_image, user_id=user_id, category=f"{category}_first_img", file_name=f"{tasks_id}_batch_{i}.png")
|
||||
|
||||
# 上传GIF
|
||||
gif_buffer = BytesIO()
|
||||
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
|
||||
gif_buffer.seek(0)
|
||||
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=user_id, category=f"{category}_gif", file_name=f"{tasks_id}_batch_{i}.gif")
|
||||
|
||||
# 上传video
|
||||
video_url = upload_video(frames=result_data, user_id=user_id, category=f"{category}_video", file_name=f"{tasks_id}_batch_{i}.mp4")
|
||||
data = {
|
||||
"gif_url": gif_url,
|
||||
"video_url": video_url,
|
||||
"first_image_url": first_image_url,
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
data = {}
|
||||
result_url_list.append(data)
|
||||
if DEBUG is False:
|
||||
if i + 1 < batch_size:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
|
||||
logger.info(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
||||
# print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
||||
else:
|
||||
publish_status(tasks_id, f"OK", result_url_list)
|
||||
logger.info(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}")
|
||||
# print(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}")
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=BATCH_PS_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=BATCH_PS_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = BatchPoseTransformModel(
|
||||
tasks_id="123-89",
|
||||
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
||||
pose_id="1",
|
||||
batch_size=10
|
||||
)
|
||||
batch_generate_pose_transform(rd.dict())
|
||||
28
app/service/generate_batch_image/tasks.py
Normal file
28
app/service/generate_batch_image/tasks.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# import logging
|
||||
#
|
||||
# from celery import Celery
|
||||
#
|
||||
# from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product
|
||||
# from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight
|
||||
# from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform
|
||||
#
|
||||
# logger = logging.getLogger()
|
||||
# celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
# celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
# celery_app.conf.worker_hijack_root_logger = False
|
||||
# logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
#
|
||||
#
|
||||
# @celery_app.task
|
||||
# def batch_pose_transform_tasks(batch_request_data):
|
||||
# batch_generate_pose_transform(batch_request_data)
|
||||
#
|
||||
#
|
||||
# @celery_app.task
|
||||
# def batch_generate_relight_tasks(batch_request_data):
|
||||
# batch_generate_relight(batch_request_data)
|
||||
#
|
||||
#
|
||||
# @celery_app.task
|
||||
# def batch_generate_product_tasks(batch_request_data):
|
||||
# batch_generate_product(batch_request_data)
|
||||
36
app/service/generate_batch_image/test.py
Normal file
36
app/service/generate_batch_image/test.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from app.schemas.generate_image import BatchGenerateRelightImageModel, BatchGenerateProductImageModel
|
||||
from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product
|
||||
|
||||
from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = BatchGenerateProductImageModel(
|
||||
tasks_id="test1-89",
|
||||
image_strength=0.7,
|
||||
prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
product_type="single",
|
||||
batch_size=2
|
||||
)
|
||||
x = batch_generate_product.delay(rd.dict())
|
||||
print(x)
|
||||
|
||||
"""relight"""
|
||||
# rd = BatchGenerateRelightImageModel(
|
||||
# tasks_id="123-89",
|
||||
# # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
# prompt="Colorful black",
|
||||
# image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
||||
# direction="Right Light",
|
||||
# product_type="single",
|
||||
# batch_size=2
|
||||
# )
|
||||
# batch_generate_relight.delay(rd.dict())
|
||||
"""pose transform"""
|
||||
# rd = BatchPoseTransformModel(
|
||||
# tasks_id="123-89",
|
||||
# image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
||||
# pose_id="1",
|
||||
# batch_size=10
|
||||
# )
|
||||
# batch_pose_transform_tasks.delay(rd.dict())
|
||||
149
app/service/generate_image/service_agent_tool_generate_image.py
Normal file
149
app/service/generate_image/service_agent_tool_generate_image.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import tritonclient.http as httpclient
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
from app.core.config import *
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
class AgentToolGenerateImage:
|
||||
def __init__(self, version):
|
||||
if version == "fast":
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
|
||||
else:
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
|
||||
def get_result(self, prompt, size, version, category, gender):
|
||||
|
||||
image_url_list = []
|
||||
image_result_list = []
|
||||
clothing_category_list = []
|
||||
try:
|
||||
prompts = [prompt] * 1
|
||||
modes = ["txt2img"] * 1
|
||||
images = [self.image.astype(np.float16)] * 1
|
||||
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
||||
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
||||
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_mode.set_data_from_numpy(mode_obj)
|
||||
|
||||
inputs = [input_text, input_image, input_mode]
|
||||
for i in range(size):
|
||||
if version == "fast":
|
||||
response = self.grpc_client.infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, priority=0)
|
||||
else:
|
||||
response = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs, priority=0)
|
||||
image = response.as_numpy("generated_image")
|
||||
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
||||
_, img_byte_array = cv2.imencode('.jpg', image_result)
|
||||
|
||||
req = oss_upload_image(oss_client=minio_client, bucket='test', object_name=f'{uuid.uuid1()}-{i}.jpg', image_bytes=img_byte_array)
|
||||
image_url_list.append(f"{req.bucket_name}/{req.object_name}")
|
||||
image_result_list.append(image_result)
|
||||
|
||||
if category == "sketch":
|
||||
clothing_category_list = self.get_clothing_category(image_result_list, gender)
|
||||
|
||||
return image_url_list, clothing_category_list
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return image_url_list, clothing_category_list
|
||||
finally:
|
||||
self.grpc_client.close()
|
||||
self.triton_client.close()
|
||||
|
||||
def preprocess(self, img):
|
||||
img = mmcv.imread(img)
|
||||
img_scale = (224, 224)
|
||||
img = cv2.resize(img, img_scale)
|
||||
img = mmcv.imnormalize(
|
||||
img,
|
||||
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
|
||||
to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img
|
||||
|
||||
def get_category(self, image):
|
||||
inputs = [httpclient.InferInput("input__0", image.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(image, binary_data=True)
|
||||
results = self.triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
|
||||
scores = inference_output.detach().numpy()
|
||||
colattr = list(attr_type['labelName'])
|
||||
maxsc = np.max(scores[0][:5])
|
||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||
return colattr[indexs[0]]
|
||||
|
||||
def get_clothing_category(self, images, gender):
|
||||
category_list = []
|
||||
for image in images:
|
||||
sketch = self.preprocess(image)
|
||||
if gender.lower() == "female":
|
||||
category_list.append(self.get_category(sketch))
|
||||
elif gender.lower() == "male":
|
||||
category = self.get_category(sketch)
|
||||
if category == 'Trousers' or category == 'Skirt':
|
||||
category_list.append('Bottoms')
|
||||
elif category == 'Blouse' or category == 'Dress':
|
||||
category_list.append('Tops')
|
||||
else:
|
||||
category_list.append('Outwear')
|
||||
else:
|
||||
category_list.append(self.get_category(sketch))
|
||||
return category_list
|
||||
|
||||
|
||||
attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = {
|
||||
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||
"category": "sketch",
|
||||
"version": "high",
|
||||
"size": 2,
|
||||
"gender": "Female",
|
||||
}
|
||||
server = AgentToolGenerateImage(request_data['version'])
|
||||
image_url_list, clothing_category_list = server.get_result(
|
||||
prompt=request_data['prompt'],
|
||||
size=request_data['size'],
|
||||
version=request_data['version'],
|
||||
category=request_data['category'],
|
||||
gender=request_data['gender']
|
||||
)
|
||||
|
||||
print(image_url_list)
|
||||
print(clothing_category_list)
|
||||
@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateImageModel
|
||||
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -29,12 +30,6 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.version = request_data.version
|
||||
if request_data.version == "fast":
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
|
||||
@@ -153,15 +148,14 @@ class GenerateImage:
|
||||
|
||||
inputs = [input_text, input_image, input_mode]
|
||||
if self.version == "fast":
|
||||
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
|
||||
else:
|
||||
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
|
||||
|
||||
time_out = 600
|
||||
generate_data = None
|
||||
while time_out > 0:
|
||||
generate_data, _ = self.read_tasks_status()
|
||||
# logger.info(generate_data)
|
||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
break
|
||||
@@ -169,7 +163,6 @@ class GenerateImage:
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(0.1)
|
||||
# logger.info(time_out, generate_data)
|
||||
return generate_data
|
||||
except Exception as e:
|
||||
self.generate_data['status'] = "FAILURE"
|
||||
@@ -178,10 +171,8 @@ class GenerateImage:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
@@ -195,7 +186,7 @@ def infer_cancel(tasks_id):
|
||||
if __name__ == '__main__':
|
||||
rd = GenerateImageModel(
|
||||
tasks_id="123-89",
|
||||
prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background',
|
||||
prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view",
|
||||
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
|
||||
mode='txt2img',
|
||||
category="test",
|
||||
|
||||
@@ -17,6 +17,7 @@ import tritonclient.grpc as grpcclient
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateMultiViewModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -25,14 +26,7 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateMultiView:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
|
||||
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.image = self.get_image(request_data.image_url)
|
||||
self.tasks_id = request_data.tasks_id
|
||||
@@ -52,16 +46,11 @@ class GenerateMultiView:
|
||||
if error:
|
||||
self.generate_data['status'] = "FAILURE"
|
||||
self.generate_data['message'] = str(error)
|
||||
# self.generate_data['data'] = str(error)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||
else:
|
||||
# pil图像转成numpy数组
|
||||
images = result.as_numpy("generated_image")
|
||||
# for id, img in enumerate(images):
|
||||
# cv2.imwrite(f"{id}.png", img)
|
||||
# image_url = ""
|
||||
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
|
||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||
self.generate_data['status'] = "SUCCESS"
|
||||
self.generate_data['message'] = "success"
|
||||
self.generate_data['image_url'] = str(image_url)
|
||||
@@ -103,10 +92,8 @@ class GenerateMultiView:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_generate_data, GMV_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
@@ -212,6 +212,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateProductImageModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -220,12 +221,6 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateProductImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "product_image"
|
||||
@@ -295,9 +290,9 @@ class GenerateProductImage:
|
||||
inputs = [input_text, input_image, input_image_strength]
|
||||
|
||||
if self.product_type == "single":
|
||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
|
||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
|
||||
else:
|
||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
|
||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
|
||||
|
||||
time_out = 600
|
||||
while time_out > 0:
|
||||
@@ -318,9 +313,8 @@ class GenerateProductImage:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||
logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_gen_product_data, GPI_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
@@ -20,6 +20,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateRelightImageModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -28,10 +29,6 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateRelightImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "relight_image"
|
||||
@@ -42,7 +39,7 @@ class GenerateRelightImage:
|
||||
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
||||
self.direction = request_data.direction
|
||||
self.image_url = request_data.image_url
|
||||
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
||||
self.image = pre_processing_image(self.image_url)
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||
@@ -114,9 +111,9 @@ class GenerateRelightImage:
|
||||
|
||||
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
|
||||
if self.product_type == 'single':
|
||||
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
|
||||
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
|
||||
else:
|
||||
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
|
||||
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
|
||||
|
||||
time_out = 600
|
||||
while time_out > 0:
|
||||
@@ -137,10 +134,49 @@ class GenerateRelightImage:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||
logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_gen_product_data, GRI_RABBITMQ_QUEUES)
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# 目标图片的尺寸
|
||||
target_width = 512
|
||||
target_height = 768
|
||||
|
||||
# 原始图片的尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 计算宽度和高度的缩放比例
|
||||
width_ratio = target_width / original_width
|
||||
height_ratio = target_height / original_height
|
||||
|
||||
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||
scale_ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# 计算调整后的尺寸
|
||||
new_width = int(original_width * scale_ratio)
|
||||
new_height = int(original_height * scale_ratio)
|
||||
|
||||
# 调整图片大小
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
# 创建一个 512x768 的透明图片
|
||||
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
|
||||
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (target_width - new_width) // 2
|
||||
y_offset = (target_height - new_height) // 2
|
||||
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
if resized_image.mode == "RGBA":
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
else:
|
||||
result_image.paste(resized_image, (x_offset, y_offset))
|
||||
|
||||
image = np.array(result_image)
|
||||
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
return image
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
@@ -157,7 +193,7 @@ if __name__ == '__main__':
|
||||
prompt="Colorful black",
|
||||
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
||||
direction="Right Light",
|
||||
product_type="single"
|
||||
product_type="overall"
|
||||
)
|
||||
server = GenerateRelightImage(rd)
|
||||
print(server.get_result())
|
||||
|
||||
@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
|
||||
from app.core.config import *
|
||||
import tritonclient.grpc as grpcclient
|
||||
from app.schemas.generate_image import GenerateSingleLogoImageModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -28,10 +29,6 @@ logger = logging.getLogger()
|
||||
|
||||
class GenerateSingleLogoImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.batch_size = 1
|
||||
@@ -96,9 +93,8 @@ class GenerateSingleLogoImage:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||
if not DEBUG:
|
||||
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
|
||||
185
app/service/generate_image/service_pose_transform.py
Normal file
185
app/service/generate_image/service_pose_transform.py
Normal file
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_pose_transform.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import redis
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.pose_transform import PoseTransformModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class PoseTransformService:
|
||||
def __init__(self, request_data):
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "pose_transform"
|
||||
self.image_url = request_data.image_url
|
||||
self.pose_num = request_data.pose_id
|
||||
self.image = pre_processing_image(request_data.image_url)
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '',
|
||||
'video_url': '', 'image_url': ''}
|
||||
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
|
||||
def callback(self, result, error):
|
||||
if error:
|
||||
self.pose_transform_data['status'] = "FAILURE"
|
||||
self.pose_transform_data['message'] = str(error)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
|
||||
else:
|
||||
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
|
||||
|
||||
# 第一帧图像
|
||||
first_image = Image.fromarray(result_data[0])
|
||||
first_image_url = upload_first_image(first_image, user_id=self.user_id,
|
||||
category=f"{self.category}_first_img",
|
||||
file_name=f"{self.tasks_id}.png")
|
||||
|
||||
# 上传GIF
|
||||
gif_buffer = BytesIO()
|
||||
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
|
||||
gif_buffer.seek(0)
|
||||
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif",
|
||||
file_name=f"{self.tasks_id}.gif")
|
||||
|
||||
# 上传video
|
||||
video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video",
|
||||
file_name=f"{self.tasks_id}.mp4")
|
||||
|
||||
self.pose_transform_data['status'] = "SUCCESS"
|
||||
self.pose_transform_data['message'] = "success"
|
||||
self.pose_transform_data['gif_url'] = str(gif_url)
|
||||
self.pose_transform_data['video_url'] = str(video_url)
|
||||
self.pose_transform_data['image_url'] = str(first_image_url)
|
||||
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
|
||||
|
||||
def read_tasks_status(self):
|
||||
status_data = self.redis_client.get(self.tasks_id)
|
||||
return json.loads(status_data), status_data
|
||||
|
||||
def get_result(self):
|
||||
try:
|
||||
pose_num = [self.pose_num] * 1
|
||||
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
|
||||
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape,
|
||||
np_to_triton_dtype(pose_num_obj.dtype))
|
||||
input_pose_num.set_data_from_numpy(pose_num_obj)
|
||||
|
||||
image_files = [self.image.astype(np.uint8)] * 1
|
||||
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
|
||||
input_image_files.set_data_from_numpy(image_files_obj)
|
||||
|
||||
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files],
|
||||
callback=self.callback, client_timeout=60000)
|
||||
time_out = 60000
|
||||
while time_out > 0:
|
||||
pose_transform_data, _ = self.read_tasks_status()
|
||||
if pose_transform_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
break
|
||||
elif pose_transform_data['status'] == "SUCCESS":
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(1)
|
||||
pose_transform_data, _ = self.read_tasks_status()
|
||||
return pose_transform_data
|
||||
except Exception as e:
|
||||
self.pose_transform_data['status'] = "FAILURE"
|
||||
self.pose_transform_data['message'] = str(e)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
|
||||
if not DEBUG:
|
||||
publish_status(json.dumps(str_pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||
logger.info(
|
||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
|
||||
|
||||
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||
pose_transform_data = json.dumps(data)
|
||||
redis_client.set(tasks_id, pose_transform_data)
|
||||
return data
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:],
|
||||
data_type="PIL")
|
||||
# 目标图片的尺寸
|
||||
target_width = 512
|
||||
target_height = 768
|
||||
|
||||
# 原始图片的尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 计算宽度和高度的缩放比例
|
||||
width_ratio = target_width / original_width
|
||||
height_ratio = target_height / original_height
|
||||
|
||||
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||
scale_ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# 计算调整后的尺寸
|
||||
new_width = int(original_width * scale_ratio)
|
||||
new_height = int(original_height * scale_ratio)
|
||||
|
||||
# 调整图片大小
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
# 创建一个 512x768 的透明图片
|
||||
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
|
||||
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (target_width - new_width) // 2
|
||||
y_offset = (target_height - new_height) // 2
|
||||
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
if resized_image.mode == "RGBA":
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
else:
|
||||
result_image.paste(resized_image, (x_offset, y_offset))
|
||||
result_image = result_image.convert("RGB")
|
||||
image = np.array(result_image)
|
||||
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = PoseTransformModel(
|
||||
tasks_id="123-89",
|
||||
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
||||
pose_id="1"
|
||||
)
|
||||
server = PoseTransformService(rd)
|
||||
print(server.get_result())
|
||||
23
app/service/generate_image/utils/mq.py
Normal file
23
app/service/generate_image/utils/mq.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import json
|
||||
|
||||
import pika
|
||||
import logging
|
||||
|
||||
from app.core.config import RABBITMQ_PARAMS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def publish_status(message, queue_name):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=queue_name, durable=True)
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=queue_name,
|
||||
body=message,
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
logger.info(f" [x] Queue : {queue_name} | Sent message : {json.dumps(json.loads(message), indent=4)}")
|
||||
75
app/service/generate_image/utils/pose_transform_upload.py
Normal file
75
app/service/generate_image/utils/pose_transform_upload.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import io
|
||||
import logging
|
||||
import os.path
|
||||
|
||||
import numpy as np
|
||||
# import boto3
|
||||
from minio import Minio
|
||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
# minio 配置
|
||||
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
def upload_first_image(image, user_id, category, file_name):
|
||||
try:
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
object_name = f'{user_id}/{category}/{file_name}'
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
|
||||
def upload_gif(gif_buffer, user_id, category, file_name):
|
||||
try:
|
||||
object_name = f'{user_id}/{category}/{file_name}'
|
||||
req = minio_client.put_object(
|
||||
"aida-users",
|
||||
object_name,
|
||||
gif_buffer,
|
||||
length=gif_buffer.getbuffer().nbytes,
|
||||
content_type="image/gif"
|
||||
)
|
||||
return f"aida-users/{object_name}"
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_gif runtime exception : {e}")
|
||||
|
||||
|
||||
def upload_video(frames, user_id, category, file_name):
|
||||
try:
|
||||
save_path = ndarray_to_video(frames, file_name)
|
||||
object_name = f'{user_id}/{category}/{file_name}'
|
||||
minio_client.fput_object(
|
||||
"aida-users",
|
||||
object_name,
|
||||
save_path,
|
||||
content_type="video/mp4" # 指定MIME类型确保可在线播放[9](@ref)
|
||||
)
|
||||
return f"aida-users/{object_name}"
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_video runtime exception : {e}")
|
||||
|
||||
|
||||
def ndarray_to_video(images, output_path, frame_size=(512, 768), fps=9):
|
||||
save_path = os.path.join(POSE_TRANSFORM_VIDEO_PATH, output_path)
|
||||
clip = ImageSequenceClip([frame for frame in images], fps=fps)
|
||||
clip.write_videofile(save_path, codec='libx264')
|
||||
|
||||
return save_path
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
images = np.random.randint(0, 256, size=(4, 768, 512, 3), dtype=np.uint8)
|
||||
print(upload_video(images, user_id=89, category='pose_transform_video', file_name="1123123.mp4"))
|
||||
114
app/service/mannequins_edit/service.py
Normal file
114
app/service/mannequins_edit/service.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
||||
from app.schemas.mannequin_edit import MannequinModel
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
class MannequinEditService():
|
||||
def __init__(self, request_data):
|
||||
self.resize_pixel = request_data.resize_pixel
|
||||
self.top = request_data.top
|
||||
self.bottom = request_data.bottom
|
||||
self.image = oss_get_image(oss_client=minio_client, bucket=request_data.mannequins.split('/')[0], object_name=request_data.mannequins[request_data.mannequins.find('/') + 1:], data_type="cv2")
|
||||
self.mannequin_name = request_data.mannequin_name
|
||||
self.bucket_name = request_data.bucket_name
|
||||
if self.image.shape[2] == 4:
|
||||
self.bgr = self.image[:, :, :3]
|
||||
self.alpha = self.image[:, :, 3]
|
||||
self.bgr = cv2.bitwise_and(self.bgr, self.bgr, mask=cv2.normalize(self.alpha, None, 0, 1, cv2.NORM_MINMAX))
|
||||
self.h, self.w, _ = self.bgr.shape
|
||||
else:
|
||||
self.bgr = self.image
|
||||
self.h, self.w, _ = self.bgr.shape
|
||||
self.alpha = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
new_mannequin = self.resize_leg(self.top, self.bottom)
|
||||
_, encoded_image = cv2.imencode('.png', new_mannequin)
|
||||
image_bytes = encoded_image.tobytes()
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=self.bucket_name, object_name=f"{self.mannequin_name}.png", image_bytes=image_bytes)
|
||||
return req.bucket_name + "/" + req.object_name
|
||||
|
||||
def post_processing(self, image):
|
||||
# 原始图片的尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 计算宽度和高度的缩放比例
|
||||
width_ratio = self.w / original_width
|
||||
height_ratio = self.h / original_height
|
||||
|
||||
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||
scale_ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# 计算调整后的尺寸
|
||||
new_width = int(original_width * scale_ratio)
|
||||
new_height = int(original_height * scale_ratio)
|
||||
|
||||
# 调整图片大小
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
# 创建一个 512x768 的透明图片
|
||||
result_image = Image.new("RGBA", (self.w, self.h), (255, 255, 255, 0))
|
||||
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (self.w - new_width) // 2
|
||||
y_offset = (self.h - new_height) // 2
|
||||
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
if resized_image.mode == "RGBA":
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
else:
|
||||
result_image.paste(resized_image, (x_offset, y_offset))
|
||||
|
||||
image = np.array(result_image)
|
||||
return image
|
||||
|
||||
def resize_leg(self, top, bottom):
|
||||
# 上部
|
||||
top_part = self.bgr[:top, :]
|
||||
top_part_alpha = self.alpha[:top, :]
|
||||
|
||||
# 需要resize 部分
|
||||
part_resize = self.bgr[top:bottom, :]
|
||||
part_resize_alpha = self.alpha[top:bottom, :]
|
||||
|
||||
# 下部
|
||||
part_bottom = self.bgr[bottom:, :]
|
||||
part_bottom_alpha = self.alpha[bottom:, :]
|
||||
|
||||
new_height = int((bottom - top) + self.resize_pixel)
|
||||
|
||||
resized_thigh = cv2.resize(part_resize, (self.w, new_height), interpolation=cv2.INTER_LINEAR)
|
||||
resized_thigh_alpha = cv2.resize(part_resize_alpha, (self.w, new_height), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# 组合
|
||||
new_bgr = np.vstack((top_part, resized_thigh, part_bottom))
|
||||
new_bgr_alpha = np.vstack((top_part_alpha, resized_thigh_alpha, part_bottom_alpha))
|
||||
|
||||
if self.alpha is not None:
|
||||
# 拼接 alpha 通道
|
||||
# 合并 BGR 通道和 alpha 通道
|
||||
new_image = np.dstack((new_bgr, new_bgr_alpha))
|
||||
else:
|
||||
new_image = new_bgr
|
||||
new_image = self.post_processing(Image.fromarray(new_image))
|
||||
return new_image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = MannequinModel(
|
||||
mannequins="aida-sys-image/models/male/dc36ce58-46c3-4b6f-8787-5ca7d6fc26e6.png",
|
||||
resize_pixel=-100,
|
||||
bucket_name="test",
|
||||
mannequin_name="mannequin_name",
|
||||
top=270,
|
||||
bottom=432
|
||||
)
|
||||
service = MannequinEditService(request_data)
|
||||
print(service())
|
||||
68
app/service/project_info_extraction/service.py
Normal file
68
app/service/project_info_extraction/service.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
|
||||
|
||||
style = ['NEW_CHINESE', 'COUNTRY_STYLE', 'FUTURISM', 'MINIMALISM', 'LOLITA', 'Y2K', 'BUSINESS', 'MERLAD',
|
||||
'OUTDOOR_FUNCTIONAL', 'ROCK', 'DOPAMINE', 'GOTHIC', 'POST_APOCALYPTIC', 'ROMANTIC', 'WABI_SABI']
|
||||
position = ['Overall', 'Tops', 'Bottoms', 'Outwear', 'Blouse', 'Dress', 'Trousers', 'Skirt']
|
||||
gender = ['Female', 'Male']
|
||||
age_group = ['Adult', 'Child']
|
||||
process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
|
||||
|
||||
|
||||
class ProjectInfoExtraction:
|
||||
def __init__(self, request_data):
|
||||
# llm generate brand info init
|
||||
if len(request_data.image_list) or len(request_data.file_list):
|
||||
self.model = ChatTongyi(model="qwen-vl-plus", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
else:
|
||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
|
||||
self.response_schemas = [
|
||||
ResponseSchema(name="project_name", description="项目的名称."),
|
||||
ResponseSchema(name="process", description="项目的类型,单品还是系列."),
|
||||
ResponseSchema(name="ageGroup", description="项目设计服装的受众对象."),
|
||||
ResponseSchema(name="gender", description="项目设计服装的受众性别."),
|
||||
ResponseSchema(name="position", description="项目单品设计的部位."),
|
||||
ResponseSchema(name="style", description="项目的设计风格.")
|
||||
]
|
||||
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
|
||||
self.format_instructions = self.output_parser.get_format_instructions()
|
||||
self.prompt = PromptTemplate(
|
||||
template="你是一个时装品牌的设计师助理。根据用户输入提取出"
|
||||
"[project_name] : 项目的名称,"
|
||||
f"[process] : 项目的类型,从{process}选择."
|
||||
f"[ageGroup] : 服装的受众,从{age_group}选择."
|
||||
f"[gender] : 服装的适用性别,从{gender}选择"
|
||||
f"[position] : single_design的部位,如果[process]是SINGLE_DESIGN,从{position}中选择,如果[process]是SERIES_DESIGN,这项为空"
|
||||
f"[style] : 设计的风格,从{style}中选择"
|
||||
".\n{format_instructions}\n{question}",
|
||||
input_variables=["question"],
|
||||
partial_variables={"format_instructions": self.format_instructions}
|
||||
)
|
||||
self._input = self.prompt.format_prompt(question=request_data.prompt)
|
||||
|
||||
self.result_data = {}
|
||||
|
||||
def get_result(self):
|
||||
self.llm_extraction_project_info()
|
||||
return self.result_data
|
||||
|
||||
def llm_extraction_project_info(self):
|
||||
output = self.model(self._input.to_messages())
|
||||
project_info = self.output_parser.parse(output.content)
|
||||
self.result_data = project_info
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = ProjectInfoExtractionModel(
|
||||
prompt="性别为儿童",
|
||||
image_list=[
|
||||
'https://www.minio-api.aida.com.hk/test/019aaeed-3227-11f0-a194-0826ae3ad6b3.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250613%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250613T020236Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=a513b706c24134071a489c34f0fa2c0f510e871b8589dc0c08a0f26ea28ee2ff'
|
||||
],
|
||||
file_list=[]
|
||||
)
|
||||
service = ProjectInfoExtraction(request_data)
|
||||
print(service.get_result())
|
||||
@@ -9,6 +9,7 @@ from retry import retry
|
||||
|
||||
from app.core.config import QWEN_API_KEY
|
||||
from app.service.chat_robot.script.service.CallQWen import get_language
|
||||
from app.service.prompt_generation.util import minio_util
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -143,6 +144,38 @@ def get_translation_from_llama3(text):
|
||||
# response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
|
||||
|
||||
def get_prompt_from_image(image_path, text):
|
||||
start_time = time.time()
|
||||
# url = "http://localhost:11434/api/generate"
|
||||
url = "http://10.1.1.243:11434/api/generate"
|
||||
|
||||
image_base64 = minio_util.minio_url_to_base64(image_path.img)
|
||||
# image_base64 = minio_url_to_base64(image_path)
|
||||
|
||||
# 创建请求的负载 translator是自定义的翻译模型
|
||||
payload = {
|
||||
"model": "llama3.2-vision",
|
||||
"images": [image_base64],
|
||||
"prompt": f"{text}",
|
||||
"stream": False
|
||||
}
|
||||
# 将负载转换为 JSON 格式
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
# 处理响应
|
||||
if response.status_code == 200:
|
||||
# print("Response from server:")
|
||||
# print(response.json())
|
||||
resp = json.loads(response.content).get("response")
|
||||
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} \n, response is {resp}")
|
||||
# print("input : {}, sketch re-generate result : {}".format(text, resp))
|
||||
return resp
|
||||
else:
|
||||
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} , response is {response.content}")
|
||||
print(f"Request failed with status code {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
text = get_translation_from_llama3("[火焰]")
|
||||
|
||||
21
app/service/prompt_generation/util/minio_util.py
Normal file
21
app/service/prompt_generation/util/minio_util.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import base64
|
||||
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
def minio_url_to_base64(minio_url: str) -> str:
|
||||
bucket_name, object_name = minio_url.split("/", 1)
|
||||
|
||||
try:
|
||||
response = minio_client.get_object(bucket_name, object_name)
|
||||
image_data = response.read()
|
||||
return base64.b64encode(image_data).decode('utf-8')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get object: {e}")
|
||||
finally:
|
||||
if 'response' in locals():
|
||||
response.close()
|
||||
539
app/service/recommend/scheduled_task.py
Normal file
539
app/service/recommend/scheduled_task.py
Normal file
@@ -0,0 +1,539 @@
|
||||
import pymysql
|
||||
import numpy as np
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
import os
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
from torchvision import models, transforms
|
||||
from minio import Minio
|
||||
from PIL import Image
|
||||
import io
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.sparse import csr_matrix
|
||||
import matplotlib.font_manager as fm
|
||||
from scipy import sparse
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||
|
||||
# 自动选择可用字体
|
||||
try:
|
||||
# 尝试加载常见中文字体
|
||||
font_path = fm.findfont(fm.FontProperties(family=['Microsoft YaHei', 'SimHei', 'WenQuanYi Zen Hei']))
|
||||
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
|
||||
except:
|
||||
# 回退到英文字体
|
||||
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# 检查系统中可用的字体并选择支持中文的字体
|
||||
font_path = fm.findfont(fm.FontProperties(family='Microsoft YaHei')) # 或其他支持中文的字体
|
||||
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 设置为 Microsoft YaHei
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
filename='scheduler.log'
|
||||
)
|
||||
|
||||
# MinIO 配置信息
|
||||
minio_client = Minio(
|
||||
"www.minio.aida.com.hk:12024", # MinIO Endpoint
|
||||
access_key="admin", # Access Key
|
||||
secret_key="Aidlab123123!", # Secret Key
|
||||
secure=True # 使用https
|
||||
)
|
||||
|
||||
# 预加载系统sketch特征向量
|
||||
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||
|
||||
# 行为权重和衰减系数
|
||||
BEHAVIOR_CONFIG = {
|
||||
'portfolioClick': {'weight': 1, 'decay': 0.3},
|
||||
'portfolioLike': {'weight': 2, 'decay': 0.2},
|
||||
'secondCreation': {'weight': 3, 'decay': 0.1},
|
||||
'sketchLike': {'weight': 4, 'decay': 0} # 不衰减
|
||||
}
|
||||
|
||||
# 保存sketch_to_iid到文件
|
||||
def save_sketch_to_iid():
|
||||
"""保存sketch到iid的映射"""
|
||||
sketch_to_iid = {sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)}
|
||||
np.save('sketch_to_iid.npy', sketch_to_iid)
|
||||
print("sketch_to_iid 已保存")
|
||||
|
||||
|
||||
# 从文件加载sketch_to_iid
|
||||
def load_sketch_to_iid():
|
||||
"""加载保存的sketch到iid的映射"""
|
||||
if os.path.exists('sketch_to_iid.npy'):
|
||||
sketch_to_iid = np.load('sketch_to_iid.npy', allow_pickle=True).item()
|
||||
print("sketch_to_iid 已加载")
|
||||
return sketch_to_iid
|
||||
else:
|
||||
# 如果文件不存在,则生成并保存
|
||||
print("sketch_to_iid 文件不存在,正在生成并保存...")
|
||||
save_sketch_to_iid()
|
||||
return np.load('sketch_to_iid.npy', allow_pickle=True).item()
|
||||
|
||||
|
||||
# 使用load_sketch_to_iid来获取映射
|
||||
sketch_to_iid = load_sketch_to_iid()
|
||||
|
||||
# 在代码中其他地方使用sketch_to_iid
|
||||
# print(f"Total sketches: {len(sketch_to_iid)}")
|
||||
|
||||
# 定义图像预处理(与ResNet训练时的预处理一致)
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
|
||||
transforms.ToTensor(), # 转换为 Tensor
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
|
||||
])
|
||||
|
||||
# 加载预训练的 ResNet 模型 (ResNet50)
|
||||
resnet_model = models.resnet50(pretrained=True)
|
||||
modules = list(resnet_model.children())[:-1] # 移除最后的全连接层
|
||||
resnet_model = torch.nn.Sequential(*modules)
|
||||
resnet_model.eval() # 设置为评估模式
|
||||
|
||||
|
||||
# 从 MinIO 获取图片并进行预处理
|
||||
def get_sketch_image_from_minio(sketch_path):
|
||||
"""
|
||||
从 MinIO 获取 sketch 图像并预处理
|
||||
"""
|
||||
# 分割路径,获取桶名和文件路径
|
||||
path_parts = sketch_path.split('/', 1) # 根据第一个斜杠分割,得到桶名和路径
|
||||
bucket_name = path_parts[0] # 桶名
|
||||
file_name = path_parts[1] # 文件路径(从第二部分开始)
|
||||
|
||||
try:
|
||||
# 获取文件
|
||||
obj = minio_client.get_object(bucket_name, file_name)
|
||||
img_data = obj.read() # 读取图像数据
|
||||
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
|
||||
img = transform(img) # 对图像进行预处理
|
||||
return img.unsqueeze(0) # 扩展维度以适应批量处理
|
||||
except Exception as e:
|
||||
print(f"Error fetching image for {sketch_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_feature_vector_from_resnet(sketch_path):
|
||||
"""
|
||||
提取 sketch 图像的特征向量
|
||||
"""
|
||||
# 从 MinIO 获取图像并预处理
|
||||
img_tensor = get_sketch_image_from_minio(sketch_path)
|
||||
if img_tensor is None:
|
||||
return np.zeros(2048) # 如果获取失败,返回零向量
|
||||
|
||||
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
|
||||
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
|
||||
return feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
|
||||
|
||||
|
||||
def update_user_matrices():
|
||||
"""每天更新用户交互次数矩阵和特征向量矩阵"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 修改后的查询语句(移除category过滤)
|
||||
cursor.execute("""
|
||||
SELECT account_id, path, COUNT(*) as like_count
|
||||
FROM user_preference_log_test
|
||||
GROUP BY account_id, path
|
||||
""")
|
||||
user_data = cursor.fetchall()
|
||||
logging.info(f"成功读取{len(user_data)}条用户偏好记录")
|
||||
|
||||
# 计算矩阵
|
||||
interaction_matrix, raw_counts_sparse, user_index_interaction_matrix, sketch_index_interaction_matrix, iid_to_category_interaction_matrix = calculate_interaction_matrix(user_data)
|
||||
# visualize_sparse_matrix(raw_counts_sparse,'交互次数矩阵', 'interaction_frequency_matrix.png')
|
||||
# visualize_sparse_matrix(interaction_matrix, '交互次数得分矩阵', 'interaction_score_matrix.png')
|
||||
# plot_interaction_count_matrix(raw_counts_sparse)
|
||||
# feature_matrix, iid_to_category_feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix = calculate_feature_matrix(user_data)
|
||||
feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix, iid_to_category_feature_matrix = calculate_feature_matrix(user_data)
|
||||
# visualize_sparse_matrix(feature_matrix, '系统sketch与用户category平均特征向量关联度矩阵', 'correlation_matrix.png')
|
||||
# 存储矩阵
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", interaction_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", feature_matrix)
|
||||
#
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", iid_to_category_interaction_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", user_index_interaction_matrix)
|
||||
#
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_feature_matrix.npy", iid_to_category_feature_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", user_index_feature_matrix)
|
||||
#
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", sketch_index_interaction_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", sketch_index_feature_matrix)
|
||||
# logging.info("矩阵更新完成")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"定时任务执行失败: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def plot_interaction_count_matrix(interaction_count_matrix):
|
||||
"""绘制交互次数矩阵的分布图(热图),不隐藏零值"""
|
||||
try:
|
||||
if not isinstance(interaction_count_matrix, csr_matrix):
|
||||
sparse_matrix = csr_matrix(interaction_count_matrix)
|
||||
else:
|
||||
sparse_matrix = interaction_count_matrix
|
||||
|
||||
# 转换为密集矩阵
|
||||
try:
|
||||
dense_matrix = sparse_matrix.toarray()
|
||||
except MemoryError:
|
||||
logging.error("内存不足,无法转换为密集矩阵")
|
||||
return
|
||||
|
||||
# 自动检测可用中文字体
|
||||
try:
|
||||
font_path = fm.findfont(fm.FontProperties(family='sans-serif', style='normal'))
|
||||
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
|
||||
except:
|
||||
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] # 回退到英文字体
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# 处理大矩阵的显示,限制显示范围
|
||||
if dense_matrix.shape[0] > 1000 or dense_matrix.shape[1] > 1000:
|
||||
dense_matrix = dense_matrix[:1000, :1000] # 只绘制前1000行列
|
||||
|
||||
plt.figure(figsize=(15, 10))
|
||||
|
||||
# 使用 `cmap` 来设置颜色,零值可以使用特定颜色,调整 `vmin` 和 `vmax` 让热图更具对比
|
||||
sns.heatmap(
|
||||
dense_matrix,
|
||||
cmap="Blues", # 可以选择不同的颜色映射,"Blues" 或 "YlGnBu"
|
||||
annot=False, # 关闭标注
|
||||
cbar_kws={"label": "Interaction Count"}, # 添加颜色条标签
|
||||
linewidths=0.5,
|
||||
vmin=0, # 设置最小值,确保零值明显
|
||||
vmax=np.max(dense_matrix) # 设置最大值,保持颜色映射的合理性
|
||||
)
|
||||
|
||||
plt.title('User-Sketch Interaction Matrix (With Zero Entries)')
|
||||
plt.xlabel('Sketch Index')
|
||||
plt.ylabel('User Index')
|
||||
plt.savefig('interaction_heatmap_with_zeros.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logging.info("热图已保存为 interaction_heatmap_with_zeros.png")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"绘图失败: {str(e)}", exc_info=True)
|
||||
|
||||
def visualize_sparse_matrix(matrix, title='Non-zero Interactions (Scatter Plot)', filename="scatter_figure_interaction.png"):
|
||||
if not sparse.issparse(matrix):
|
||||
# 转换为稀疏矩阵
|
||||
matrix = sparse.csr_matrix(matrix)
|
||||
|
||||
# 获取非零元素的坐标和值
|
||||
rows, cols = matrix.nonzero()
|
||||
values = matrix.data
|
||||
|
||||
# 绘制散点图
|
||||
plt.figure(figsize=(24, 20))
|
||||
plt.scatter(cols, rows, c=values, cmap='coolwarm', alpha=0.7, s=1)
|
||||
plt.colorbar(label='Interaction Count')
|
||||
plt.title(title)
|
||||
plt.xlabel('Item Index')
|
||||
plt.ylabel('Item Index')
|
||||
plt.savefig(filename)
|
||||
|
||||
def calculate_interaction_matrix(user_data):
|
||||
"""基于新表结构的交互次数矩阵计算(仅系统sketch)"""
|
||||
# 获取所有用户ID
|
||||
all_users = set()
|
||||
for account_id, path, like_count in user_data:
|
||||
category = get_category_from_path(path)
|
||||
if category not in TABLE_CATEGORIES.keys():
|
||||
continue
|
||||
all_users.add(account_id)
|
||||
|
||||
# 获取所有系统sketch的iid
|
||||
system_sketch_iids = {sketch_to_iid[path] for path in SYSTEM_FEATURES.keys() if path in sketch_to_iid}
|
||||
|
||||
# 创建映射关系
|
||||
user_index = {uid: idx for idx, uid in enumerate(sorted(all_users))}
|
||||
sketch_index = {iid: idx for idx, iid in enumerate(sorted(system_sketch_iids))}
|
||||
|
||||
# 初始化双矩阵:归一化矩阵(密集)和原始计数矩阵(稀疏)
|
||||
interaction_matrix = np.zeros((len(all_users), len(sketch_index))) # 归一化矩阵
|
||||
data, rows, cols = [], [], [] # 用于构建稀疏矩阵的COO格式数据
|
||||
|
||||
# 预计算用户最大交互次数
|
||||
user_max_likes = defaultdict(int)
|
||||
for account_id, path, like_count in user_data:
|
||||
if sketch_to_iid.get(path) in system_sketch_iids:
|
||||
user_max_likes[account_id] = max(user_max_likes[account_id], like_count)
|
||||
|
||||
# 填充矩阵
|
||||
for account_id, path, like_count in user_data:
|
||||
sketch_iid = sketch_to_iid.get(path)
|
||||
if sketch_iid not in system_sketch_iids:
|
||||
continue
|
||||
|
||||
user_idx = user_index[account_id]
|
||||
sketch_idx = sketch_index[sketch_iid]
|
||||
|
||||
# 填充稀疏矩阵数据
|
||||
data.append(like_count)
|
||||
rows.append(user_idx)
|
||||
cols.append(sketch_idx)
|
||||
|
||||
# 归一化计算
|
||||
max_like = user_max_likes.get(account_id, 1)
|
||||
interaction_matrix[user_idx, sketch_idx] = np.log1p(1 + like_count) / np.log1p(1 + max_like)
|
||||
|
||||
# 构建稀疏矩阵(CSR格式适合快速行操作)
|
||||
interaction_count_matrix = csr_matrix((data, (rows, cols)), shape=(len(all_users), len(sketch_index)))
|
||||
|
||||
return interaction_matrix, interaction_count_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
|
||||
|
||||
|
||||
def calculate_feature_matrix(user_data):
|
||||
"""基于新表结构的特征矩阵计算,返回用户与系统草图的相似度矩阵(加权平均)"""
|
||||
|
||||
# 用户特征数据存储结构:{(account_id, category): {sketch_iid: [(feature_vector, weight)]}}
|
||||
user_feature_weights = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# 初始化所有用户和系统草图集合
|
||||
all_users = set()
|
||||
all_system_sketches = set(SYSTEM_FEATURES.keys())
|
||||
|
||||
# ==== 第一遍遍历:收集特征向量和权重 ====
|
||||
for account_id, path, like_count in user_data:
|
||||
category = get_category_from_path(path)
|
||||
if category not in TABLE_CATEGORIES.keys():
|
||||
continue
|
||||
|
||||
sketch_iid = sketch_to_iid.get(path)
|
||||
if not sketch_iid:
|
||||
continue
|
||||
|
||||
# 记录用户
|
||||
all_users.add(account_id)
|
||||
|
||||
# 提取特征并记录权重(like_count)
|
||||
if path in SYSTEM_FEATURES: # 系统草图
|
||||
feature = SYSTEM_FEATURES[path]
|
||||
weight = like_count # 使用like_count作为权重
|
||||
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
|
||||
else: # 用户草图
|
||||
feature = extract_feature_vector_from_resnet(path)
|
||||
weight = like_count
|
||||
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
|
||||
|
||||
# ==== 第二遍遍历:收集所有系统草图iid ====
|
||||
system_sketch_iids = set()
|
||||
for sketch_path in SYSTEM_FEATURES:
|
||||
if iid := sketch_to_iid.get(sketch_path):
|
||||
system_sketch_iids.add(iid)
|
||||
|
||||
# ==== 创建索引映射 ====
|
||||
user_list = sorted(all_users)
|
||||
sketch_list = sorted(system_sketch_iids)
|
||||
|
||||
user_index = {uid: idx for idx, uid in enumerate(user_list)}
|
||||
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
|
||||
|
||||
# ==== 初始化特征矩阵 ====
|
||||
feature_matrix = np.zeros((len(user_list), len(sketch_list)))
|
||||
|
||||
# ==== 预计算加权平均特征向量 ====
|
||||
user_avg_features = {}
|
||||
for (account_id, category), sketches in user_feature_weights.items():
|
||||
try:
|
||||
# 展平所有特征向量和权重
|
||||
all_features_weights = [(vec, weight) for vec_list in sketches.values() for vec, weight in vec_list]
|
||||
|
||||
if len(all_features_weights) == 0:
|
||||
continue
|
||||
|
||||
# 计算总权重
|
||||
total_weight = sum(weight for _, weight in all_features_weights)
|
||||
if total_weight <= 0: # 防止除零错误
|
||||
total_weight = 1.0
|
||||
|
||||
# 加权平均计算
|
||||
weighted_sum = np.zeros_like(all_features_weights[0][0]) # 获取特征向量维度
|
||||
for vec, weight in all_features_weights:
|
||||
weighted_sum += vec * weight
|
||||
|
||||
avg_vec = weighted_sum / total_weight
|
||||
user_avg_features[(account_id, category)] = avg_vec
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"用户({account_id},{category})加权特征计算失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# ==== 计算相似度并填充矩阵 ====
|
||||
for sketch_path, sys_vector in SYSTEM_FEATURES.items():
|
||||
sketch_iid = sketch_to_iid.get(sketch_path)
|
||||
|
||||
system_sketch_category = get_category_from_path(sketch_path)
|
||||
if not sketch_iid or sketch_iid not in sketch_index:
|
||||
continue
|
||||
|
||||
sketch_col = sketch_index[sketch_iid]
|
||||
|
||||
# 遍历所有用户
|
||||
for account_id in all_users:
|
||||
user_row = user_index.get(account_id)
|
||||
if user_row is None:
|
||||
continue
|
||||
|
||||
# 获取用户加权平均特征向量
|
||||
try:
|
||||
# 直接通过复合键获取用户特征向量
|
||||
user_vec = user_avg_features[(account_id, system_sketch_category)]
|
||||
except KeyError:
|
||||
# 该用户在此类别下无特征数据
|
||||
continue
|
||||
|
||||
# 计算余弦相似度
|
||||
cos_sim = cosine_similarity(user_vec, sys_vector)
|
||||
feature_matrix[user_row, sketch_col] = cos_sim
|
||||
|
||||
return feature_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
|
||||
|
||||
|
||||
def get_category_from_path(path):
|
||||
"""从path字段解析类别"""
|
||||
try:
|
||||
parts = path.split('/')
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[2]}_{parts[3]}"
|
||||
return "unknown"
|
||||
except:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def cosine_similarity(vec1, vec2):
|
||||
"""计算余弦相似度(增加零值处理)"""
|
||||
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
||||
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
||||
|
||||
|
||||
def fetch_user_behavior_data(days=30):
|
||||
"""从MySQL获取用户行为数据(整合旧查询和新需求)"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
|
||||
# 计算日期范围
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
# 整合查询(获取完整行为数据)
|
||||
query = f"""
|
||||
SELECT
|
||||
account_id,
|
||||
behavior_type,
|
||||
gender,
|
||||
category,
|
||||
url,
|
||||
create_time
|
||||
FROM user_behavior
|
||||
WHERE create_time BETWEEN '{start_date}' AND '{end_date}'
|
||||
"""
|
||||
|
||||
df = pd.read_sql(query, conn)
|
||||
logging.info(f"成功读取{len(df)}条用户行为记录")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"数据库查询失败: {str(e)}")
|
||||
return pd.DataFrame()
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def calculate_heat(row, current_date):
|
||||
"""计算单个行为的热度值(每次行为独立计算,不考虑聚合次数)"""
|
||||
# 计算时间差(天)
|
||||
days_passed = (current_date - row['create_time']).days
|
||||
|
||||
# 获取行为配置(默认权重为0)
|
||||
config = BEHAVIOR_CONFIG.get(row['behavior_type'], {'weight': 0, 'decay': 0})
|
||||
|
||||
# 计算热度值 = 权重 * e^(-衰减系数 * 天数)
|
||||
return config['weight'] * np.exp(-config['decay'] * days_passed)
|
||||
|
||||
def load_heat_matrix_as_array(file_path):
|
||||
"""
|
||||
直接加载为二维numpy数组
|
||||
返回: (data_array, row_labels, col_labels)
|
||||
"""
|
||||
with open(file_path) as f:
|
||||
saved = json.load(f)
|
||||
return (
|
||||
np.array(saved['data']), # 二维矩阵
|
||||
saved['row_labels'], # 行标签列表
|
||||
saved['col_labels'] # 列标签列表
|
||||
)
|
||||
|
||||
def update_heat_matrices():
|
||||
"""每日计算并存储热度矩阵(gender_category × path)"""
|
||||
current_date = datetime.now()
|
||||
|
||||
# 获取数据
|
||||
df = fetch_user_behavior_data(30)
|
||||
if df.empty:
|
||||
logging.warning("无有效数据,跳过今日计算")
|
||||
return None
|
||||
|
||||
# 计算热度值
|
||||
df['heat'] = df.apply(calculate_heat, axis=1, current_date=current_date)
|
||||
df['gender_category'] = df['gender'] + '_' + df['category']
|
||||
|
||||
# 构建热度向量
|
||||
heat_vectors = {}
|
||||
grouped = df.groupby(['gender_category', 'url'])['heat'].sum()
|
||||
for (gender_category, url), heat in grouped.items():
|
||||
heat_vectors.setdefault(gender_category, {})[url] = heat
|
||||
|
||||
# 存储结果
|
||||
save_path = 'heat_vectors_data'
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
date_str = current_date.strftime('%Y%m%d')
|
||||
|
||||
# vectors_file = f"{save_path}/heat_vectors_{date_str}.json"
|
||||
vectors_file = f"{save_path}/heat_vectors.json"
|
||||
with open(vectors_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'update_time': current_date.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'data': heat_vectors
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logging.info(f"成功存储热度向量,共{len(heat_vectors)}个分组,日期: {date_str}")
|
||||
return heat_vectors
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# update_user_matrices()
|
||||
# update_heat_matrices()
|
||||
scheduler = BlockingScheduler()
|
||||
scheduler.add_job(update_user_matrices, 'cron', hour=12, timezone='Asia/Shanghai')
|
||||
logging.info("定时任务已启动,每天12:00执行")
|
||||
scheduler.start()
|
||||
except KeyboardInterrupt:
|
||||
logging.info("定时任务已停止")
|
||||
except Exception as e:
|
||||
logging.error(f"调度器启动失败: {str(e)}", exc_info=True)
|
||||
240
app/service/recommend/service.py
Normal file
240
app/service/recommend/service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# 预加载资源
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
||||
|
||||
logger = logging.getLogger()
|
||||
import pymysql
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
|
||||
|
||||
matrix_data = {
|
||||
"interaction_matrix": None,
|
||||
"feature_matrix": None,
|
||||
"user_index_interaction": None,
|
||||
"sketch_index_interaction": None,
|
||||
"user_index_feature": None,
|
||||
"sketch_index_feature": None,
|
||||
"iid_to_sketch": None,
|
||||
"category_to_iids": None,
|
||||
"cached_scores": {},
|
||||
"cached_valid_idxs": {},
|
||||
"category_sketch_idxs_inter": None,
|
||||
"category_sketch_idxs_feature": None,
|
||||
"user_inter_full": dict(),
|
||||
"user_feat_full": dict(),
|
||||
"brand_feature_matrix": None,
|
||||
"brand_index_map": None,
|
||||
"heat_data": {},
|
||||
}
|
||||
|
||||
|
||||
def load_resources():
|
||||
"""加载所有矩阵和映射关系,并触发预缓存"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 清空缓存
|
||||
matrix_data["cached_scores"].clear()
|
||||
matrix_data["cached_valid_idxs"].clear()
|
||||
|
||||
# 加载数据
|
||||
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
||||
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
||||
|
||||
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
||||
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
||||
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
||||
allow_pickle=True).item()
|
||||
|
||||
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
||||
|
||||
brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
|
||||
if os.path.exists(brand_feature_path):
|
||||
matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
|
||||
else:
|
||||
logger.warning("brand_feature_matrix 文件不存在,使用空数组")
|
||||
matrix_data["brand_feature_matrix"] = np.array([])
|
||||
|
||||
# brand_index_map
|
||||
brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||
if os.path.exists(brand_index_path):
|
||||
matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
|
||||
else:
|
||||
logger.warning("brand_index_map 文件不存在,使用空字典")
|
||||
matrix_data["brand_index_map"] = {}
|
||||
|
||||
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
||||
|
||||
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
||||
|
||||
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
||||
matrix_data["category_to_iids"] = defaultdict(list)
|
||||
for iid, cat in category_to_iid_map.items():
|
||||
matrix_data["category_to_iids"][cat].append(iid)
|
||||
|
||||
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 触发预缓存
|
||||
precache_user_category()
|
||||
|
||||
if os.path.exists(HEAT_VECTOR_FILE):
|
||||
with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
|
||||
heat_json = json.load(f)
|
||||
matrix_data["heat_data"] = heat_json.get("data", {})
|
||||
logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
|
||||
else:
|
||||
matrix_data["heat_data"] = {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"资源加载失败: {str(e)}")
|
||||
raise RuntimeError("初始化失败")
|
||||
|
||||
|
||||
def precache_user_category():
|
||||
"""优化后的用户分类预缓存(添加耗时统计)"""
|
||||
if not all([
|
||||
matrix_data["interaction_matrix"] is not None,
|
||||
matrix_data["feature_matrix"] is not None,
|
||||
matrix_data["user_index_interaction"] is not None
|
||||
]):
|
||||
logger.warning("资源未加载完成,跳过预缓存")
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
time_stats = {
|
||||
"get_all_user_categories": 0,
|
||||
"process_user_category": 0,
|
||||
"thread_execution": 0,
|
||||
"cache_update": 0,
|
||||
"total": 0,
|
||||
}
|
||||
|
||||
# 统计用户类别获取时间
|
||||
t1 = time.perf_counter()
|
||||
user_categories = get_all_user_categories()
|
||||
time_stats["get_all_user_categories"] = time.perf_counter() - t1
|
||||
|
||||
precached_count = 0
|
||||
|
||||
def process_user_category(user_id, categories):
|
||||
"""单用户类别缓存计算(统计耗时)"""
|
||||
local_cache = {}
|
||||
local_valid_idxs = {}
|
||||
t_start = time.perf_counter()
|
||||
|
||||
for category in categories:
|
||||
cache_key = (user_id, category)
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
continue
|
||||
|
||||
try:
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
|
||||
# 统计获取类别 IID 耗时
|
||||
t_iid = time.perf_counter()
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
|
||||
for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
|
||||
valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
|
||||
for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
|
||||
time_stats["process_user_category"] += time.perf_counter() - t_iid
|
||||
|
||||
# 统计矩阵计算耗时
|
||||
t_matrix = time.perf_counter()
|
||||
processed_inter = np.zeros(len(valid_sketch_idxs_inter))
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
|
||||
processed_feat = np.zeros(len(valid_sketch_idxs_feature))
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores * 0.3
|
||||
time_stats["process_user_category"] += time.perf_counter() - t_matrix
|
||||
|
||||
if len(processed_inter) == len(processed_feat):
|
||||
local_cache[cache_key] = (processed_inter, processed_feat)
|
||||
local_valid_idxs[cache_key] = valid_sketch_idxs_inter
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
||||
|
||||
return local_cache, local_valid_idxs
|
||||
|
||||
# 统计线程执行时间
|
||||
t2 = time.perf_counter()
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
|
||||
for future in futures:
|
||||
try:
|
||||
t_cache = time.perf_counter()
|
||||
cache_part, valid_idxs_part = future.result()
|
||||
matrix_data["cached_scores"].update(cache_part)
|
||||
matrix_data["cached_valid_idxs"].update(valid_idxs_part)
|
||||
time_stats["cache_update"] += time.perf_counter() - t_cache
|
||||
precached_count += len(cache_part)
|
||||
except Exception as e:
|
||||
logger.error(f"线程执行错误: {str(e)}")
|
||||
time_stats["thread_execution"] = time.perf_counter() - t2
|
||||
|
||||
time_stats["total"] = time.perf_counter() - start_time
|
||||
|
||||
# 输出统计信息
|
||||
logger.info(f"""
|
||||
预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
|
||||
- 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
|
||||
- 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
|
||||
- 线程任务执行: {time_stats["thread_execution"]:.2f}s
|
||||
- 更新缓存数据: {time_stats["cache_update"]:.2f}s
|
||||
- 总耗时: {time_stats["total"]:.2f}s
|
||||
""")
|
||||
|
||||
|
||||
def get_all_user_categories():
|
||||
"""获取所有用户及其对应的分类"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = """
|
||||
SELECT DISTINCT account_id, path
|
||||
FROM user_preference_log_prediction
|
||||
"""
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
|
||||
user_categories = defaultdict(set)
|
||||
for account_id, path in results:
|
||||
category = get_category_from_path(path)
|
||||
user_categories[account_id].add(category)
|
||||
|
||||
return dict(user_categories)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询失败: {str(e)}")
|
||||
return {}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_category_from_path(path: str) -> str:
|
||||
"""从路径解析类别"""
|
||||
try:
|
||||
parts = path.split('/')
|
||||
if len(parts) >= 4:
|
||||
return f"{parts[2]}_{parts[3]}"
|
||||
return "unknown"
|
||||
except:
|
||||
return "unknown"
|
||||
@@ -6,7 +6,7 @@ from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||
from tqdm import tqdm
|
||||
|
||||
from app.core.config import OLLAMA_URL
|
||||
from app.core.config import OLLAMA_URL, CHROMADB_PATH
|
||||
|
||||
# 读取 csv 文件
|
||||
# csv_file_path = r'D:/Files/csv/output/output.csv'
|
||||
@@ -15,7 +15,7 @@ from app.core.config import OLLAMA_URL
|
||||
# df = pd.read_csv(csv_file_path, encoding='Windows-1252')
|
||||
|
||||
# 创建 Chroma 客户端
|
||||
client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db"))
|
||||
client = chromadb.Client(Settings(is_persistent=True, persist_directory=CHROMADB_PATH))
|
||||
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db"))
|
||||
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db"))
|
||||
# 创建集合
|
||||
|
||||
99
app/service/utils/redis_utils.py
Normal file
99
app/service/utils/redis_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
18
pyproject.toml
Executable file
18
pyproject.toml
Executable file
@@ -0,0 +1,18 @@
|
||||
[project]
|
||||
name = "trinity-client-aida"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"apscheduler>=3.11.0",
|
||||
"celery>=5.5.3",
|
||||
"geventhttpclient>=2.3.4",
|
||||
"google-search-results>=2.4.2",
|
||||
"moviepy>=2.2.1",
|
||||
"numpy==1.26.4",
|
||||
"pandas-stubs==2.2.3.250527",
|
||||
"pika-stubs==0.1.3",
|
||||
"python-multipart>=0.0.20",
|
||||
"tritonclient[all]>=2.58.0",
|
||||
"types-urllib3==1.26.25.14",
|
||||
]
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Reference in New Issue
Block a user