Merge remote-tracking branch 'refs/remotes/origin/develop'

# Conflicts:
#	app/core/config.py
#	app/service/chat_robot/script/service/CallQWen.py
This commit is contained in:
zchen
2025-09-02 20:08:54 +08:00
74 changed files with 5502 additions and 310 deletions

7
.gitignore vendored
View File

@@ -142,3 +142,10 @@ app/logs/*
*.npy
*.pytorch
*.jpg
*.mp4
*.sqlite3
*.bin
*.pickle
*.csv
*.avi
*.json

View File

@@ -3,16 +3,17 @@ import logging
from fastapi import APIRouter, HTTPException
from app.schemas.brand_dna import BrandDnaModel
from app.schemas.brand_dna import BrandDnaModel, GenerateBrandModel
from app.schemas.response_template import ResponseModel
from app.service.brand_dna.service import BrandDna
from app.service.brand_dna.service_generate_brand_info import GenerateBrandInfo
router = APIRouter()
logger = logging.getLogger()
@router.post("/seg_product")
def image2sketch(request_item: BrandDnaModel):
def seg_product(request_item: BrandDnaModel):
"""
创建一个具有以下参数的请求体:
- **image_url**: 提取图片url
@@ -20,8 +21,8 @@ def image2sketch(request_item: BrandDnaModel):
示例参数:
{
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
"is_brand_dna": False
"image_url": "aida-results/result_00006a48-e315-11ee-b7c8-b48351119060.png",
"is_brand_dna": false
}
"""
try:
@@ -32,3 +33,27 @@ def image2sketch(request_item: BrandDnaModel):
logger.warning(f"brand dna Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=result_url)
@router.post("/GenerateBrand")
def GenerateBrand(request_data: GenerateBrandModel):
"""
通过prompt 生成 brand name brand slogan , brand logo。
创建一个具有以下参数的请求体:
- **prompt**:
示例参数:
{
"prompt": "xiaomi",
"user_id": "89"
}
"""
try:
logger.info(f"GenerateBrand request item is : @@@@@@:{request_data}")
service = GenerateBrandInfo(request_data)
data = service.get_result()
logger.info(f"GenerateBrand response @@@@@@:{data}")
except Exception as e:
logger.warning(f"GenerateBrand Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -0,0 +1,212 @@
import io
import logging
import sys
import time
from typing import List
from collections import defaultdict
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
import pymysql
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
from minio import Minio
import torch
from torchvision import models, transforms
from PIL import Image
import os
from fastapi.responses import JSONResponse
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
# MinIO 配置
minio_client = Minio(
"www.minio.aida.com.hk:12024",
access_key="admin",
secret_key="Aidlab123123!",
secure=True
)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ResNet50去掉最后全连接层
resnet_model = models.resnet50(pretrained=True)
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1])
resnet_model.eval()
def get_sketch_image_from_minio(sketch_path: str):
path_parts = sketch_path.split('/', 1)
if len(path_parts) != 2:
return None
bucket_name, file_name = path_parts
try:
obj = minio_client.get_object(bucket_name, file_name)
img = Image.open(io.BytesIO(obj.read()))
return transform(img).unsqueeze(0)
except Exception as e:
logger.warning(f"Fetch image failed [{sketch_path}]: {e}")
return None
def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
img_tensor = get_sketch_image_from_minio(sketch_path)
if img_tensor is None:
return np.zeros(2048, dtype=np.float32)
with torch.no_grad():
vec = resnet_model(img_tensor) # [1, 2048, 1, 1]
return vec.squeeze().cpu().numpy()
# 预加载
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
def save_sketch_to_iid():
sketch_to_iid = {
sketch_path: iid
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
}
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
def load_sketch_to_iid():
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
if os.path.exists(path):
return np.load(path, allow_pickle=True).item()
save_sketch_to_iid()
return np.load(path, allow_pickle=True).item()
sketch_to_iid = load_sketch_to_iid()
def getNewCategory(gender: str, sketch_category: str) -> str:
return f"{gender.lower()}_{sketch_category.lower()}"
def get_category_from_path(path: str) -> str:
parts = path.split('/')
if len(parts) >= 4:
return f"{parts[2].lower()}_{parts[3].lower()}"
return "unknown_unknown"
def load_brand_matrix():
"""单独加载 brand_matrix 和 brand_index_map"""
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
try:
matrix = np.load(mat_path)
index_map = np.load(idx_path, allow_pickle=True).item()
except FileNotFoundError:
matrix = np.zeros((0, len(sketch_to_iid)), dtype=np.float32)
index_map = {}
return matrix, index_map
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
# 1. 收集品牌-分类-特征
brand_feature = defaultdict(lambda: defaultdict(list))
for _id, sketch_path, gender, sketch_category in sketch_data:
cat = getNewCategory(gender, sketch_category)
feat = BRAND_FEATURES.get(_id) or extract_feature_vector_from_resnet(sketch_path)
brand_feature[(brand_id, cat)][_id].append(feat)
# 2. 构建 sketch 索引
sketch_list = sorted(sketch_to_iid.values())
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
n_sketch = len(sketch_list)
# 3. 加载或初始化矩阵
brand_matrix, brand_index_map = load_brand_matrix()
# 4. 增加/更新 行
if brand_id in brand_index_map:
row_idx = brand_index_map[brand_id]
else:
row_idx = brand_matrix.shape[0]
brand_index_map[brand_id] = row_idx
brand_matrix = np.vstack([
brand_matrix,
np.zeros((1, n_sketch), dtype=np.float32)
])
# 5. 计算品牌-分类平均向量
brand_avg = {}
for key, id_dict in brand_feature.items():
all_feats = [v for feats in id_dict.values() for v in feats]
if all_feats:
brand_avg[key] = np.mean(all_feats, axis=0)
# 6. 填充相似度
for sketch_path, sys_vec in SYSTEM_FEATURES.items():
iid = sketch_to_iid.get(sketch_path)
if not iid or iid not in sketch_index:
continue
cat_key = (brand_id, get_category_from_path(sketch_path))
avg_vec = brand_avg.get(cat_key)
if avg_vec is not None:
cos_sim = cosine_similarity(avg_vec, sys_vec)
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
# 7. 持久化
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
# 返回该品牌对应行
return brand_matrix[row_idx:row_idx+1]
@router.get("/brand_dna_initialize/{brand_id}")
async def brand_dna_initialize(brand_id: int):
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
cursor.execute("""
SELECT id, img_url, gender, category
FROM product_image_attribute
WHERE library_id IN (
SELECT library_id
FROM brand_rel_library
WHERE brand_id = %s
)
""", (brand_id,))
sketch_data = cursor.fetchall()
# 触发计算并持久化,若内部出错会抛异常
_ = calculate_brand_matrix(sketch_data, brand_id)
# 返回成功
return {"success": True}
except HTTPException:
# 已经是明确的 HTTPException直接抛出
raise
except Exception as e:
logger.error(f"品牌初始化失败 [{brand_id}]: {e}", exc_info=True)
# 返回失败的 JSON同时设置 500 状态码
return JSONResponse(
status_code=500,
content={"success": False, "message": "品牌初始化失败"}
)
finally:
if conn:
conn.close()

View File

@@ -0,0 +1,51 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.response_template import ResponseModel
from app.schemas.clothing_seg import ClothingSegModel
from app.service.clothing_seg.service import ClothingSeg
router = APIRouter()
logger = logging.getLogger()
@router.post("/clothing_seg")
def clothing_seg(request_item: ClothingSegModel):
"""
创建一个具有以下参数的请求体:
- **user_id**: 用户id
- **image_data**: 图片数据
{
"image_url": "test/clothing_seg/dress.jpg",
"image_type": "product"
}
示例参数:
{
"user_id": 89,
"image_data": [
{
"image_url": "test/clothing_seg/dress.jpg",
"image_type": "sketch"
},
{
"image_url": "test/clothing_seg/skirt_559.jpg",
"image_type": "sketch"
},
{
"image_url": "test/clothing_seg/10144613.jpg",
"image_type": "product"
}
]
}
"""
try:
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}")
server = ClothingSeg(request_item)
result_url = server.get_result()
except Exception as e:
logger.warning(f"clothing_seg Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=result_url)

View File

@@ -433,12 +433,12 @@ def model_process(request_data: ModelProgressModel):
@router.post("/design_batch_generate")
async def design(file: UploadFile = File(...),
tasks_id: str = Form(...),
user_id: str = Form(...),
file_name: str = Form(...),
total: int = Form(...)
):
async def design_batch(file: UploadFile = File(...),
tasks_id: str = Form(...),
user_id: str = Form(...),
file_name: str = Form(...),
total: int = Form(...)
):
dbg_config = DBGConfigModel(
tasks_id=tasks_id,
user_id=user_id,

View File

@@ -0,0 +1,39 @@
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
from app.schemas.response_template import ResponseModel
from app.service.project_info_extraction.service import ProjectInfoExtraction
router = APIRouter()
logger = logging.getLogger()
@router.post("/extraction_project_info")
def extraction_project_info(request_data: ProjectInfoExtractionModel):
"""
通过prompt 提取project_name,role ,gender ,style。
创建一个具有以下参数的请求体:
- **prompt**:
示例参数:
{
"prompt": "海边派对主题的系列设计",
"image_list": [
"https://www.minio-api.aida.com.hk/test/test123.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250519%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250519T050808Z&X-Amz-Expires=7200&X-Amz-SignedHeaders=host&X-Amz-Signature=296ff07cc4692d0a26ddffac582064f036494af343389fe60193dc2c5dc883ff"
],
"file_list": [
""
]
}
"""
try:
logger.info(f"extraction_project_info request item is : @@@@@@:{request_data}")
service = ProjectInfoExtraction(request_data)
data = service.get_result()
logger.info(f"extraction_project_info response @@@@@@:{data}")
except Exception as e:
logger.warning(f"extraction_project_info Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -3,8 +3,11 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel
from app.schemas.pose_transform import BatchPoseTransformModel
from app.schemas.response_template import ResponseModel
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
from app.service.generate_image.service_agent_tool_generate_image import AgentToolGenerateImage
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
@@ -228,3 +231,123 @@ def generate_relight_image(tasks_id: str):
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
"""batch generate img"""
@router.post("/batch_generate_product_image")
async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **image_strength**: 生成强度,越低越接近原图
- **product_type**: 输入single item 还是 overall item
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
"image_strength": 0.8,
"product_type": "overall",
"batch_size": 1
}
"""
return await start_product_batch_generate(request_batch_item)
@router.post("/batch_generate_relight_image")
async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
- **product_type**: 输入single item 还是 overall item
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"direction": "Right Light",
"product_type": "overall",
"batch_size": 1
}
"""
return await start_relight_batch_generate(request_batch_item)
@router.post("/batch_generate_pose_transform_image")
async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 被生成图片的S3或minio url地址
- **pose_id**: 1
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"pose_id": "1",
"batch_size": 1
}
"""
return await start_pose_transform_batch_generate(request_batch_item)
"""agent tool"""
@router.post("/agent_tool_generate_image")
def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **prompt**: 想要生成图片的描述词
- **category**: 生成图片的类别sketch print 等等
- **gender**: 生成sketch专用服装类别
- **version**: 使用模型版本 fast 或者 high
- **size**: 生成数量
- **version**: 使用模型版本 fast 或者 high
示例参数:
{
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
"category": "sketch",
"gender": "male",
"size":2,
"version":"high"
}
"""
try:
logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
request_data = request_item.dict()
service = AgentToolGenerateImage(request_data['version'])
image_url_list, clothing_category_list = service.get_result(
prompt=request_data['prompt'],
size=request_data['size'],
version=request_data['version'],
category=request_data['category'],
gender=request_data['gender']
)
data = {
"image_url_list": image_url_list,
"clothing_category_list": clothing_category_list
}
logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
except Exception as e:
logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -0,0 +1,45 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.mannequin_edit import MannequinModel
from app.schemas.response_template import ResponseModel
from app.service.mannequins_edit.service import MannequinEditService
router = APIRouter()
logger = logging.getLogger()
@router.post("/mannequins_edit")
def mannequins_edit(request_data: MannequinModel):
"""
模特腿长调整
创建一个具有以下参数的请求体:
- **mannequins**: mannequins url等信息
- **resize_pixel**: 拉伸像素量
- **bucket_name**: bucket name
- **mannequin_name**: 模特名称
- **top**: 拉伸y轴点位
- **bottom**: 拉伸y轴点位
示例参数:
- **{
"mannequins": "aida-sys-image/models/male/dc36ce58-46c3-4b6f-8787-5ca7d6fc26e6.png",
"resize_pixel": -50,
"bucket_name": "test",
"mannequin_name": "mannequin_name",
"top" : 270,
"bottom" : 432
}**
"""
try:
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict())}")
service = MannequinEditService(request_data)
data = service()
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"mannequins_edit Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -0,0 +1,49 @@
import json
import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.pose_transform import PoseTransformModel
from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
router = APIRouter()
logger = logging.getLogger()
@router.post("/pose_transform")
def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 被生成图片的S3或minio url地址
- **pose_id**: 1
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"pose_id": "1"
}
"""
try:
logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = PoseTransformService(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/pose_transform_cancel/{tasks_id}")
def pose_transform_cancel(tasks_id: str):
try:
logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
data = pose_transform_infer_cancel(tasks_id)
logger.info(f"pose_transform_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])

View File

@@ -4,9 +4,10 @@ import time
from fastapi import APIRouter, HTTPException
from app.schemas.prompt_generation import PromptGenerationImageModel
from app.schemas.prompt_generation import PromptGenerationImageModel, ImageRequest
from app.schemas.response_template import ResponseModel
from app.service.prompt_generation.chatgpt_for_translation import translate_to_en, get_translation_from_llama3
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, \
get_prompt_from_image
router = APIRouter()
logger = logging.getLogger()
@@ -32,3 +33,20 @@ def prompt_generation(request_data: PromptGenerationImageModel):
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
@router.post("/img2prompt")
def get_prompt_from_img(img: ImageRequest):
"""
自动识别图片并输出为prompt
:param img: 图片的minio地址
:return: 图片的文字描述
"""
text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
"The description should allow for the reconstruction of the corresponding line art based on the details "
"given.")
logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
description = get_prompt_from_image(img, text)
logger.info(f"生成的图片描述 response @@@@@@:{description}")
return description

View File

@@ -0,0 +1,204 @@
import io
import logging
import sys
import time
from typing import List
import os
import json
import math
import random
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
@router.on_event("startup")
async def startup_event():
# 初始加载
load_resources()
# 配置定时任务
scheduler = BackgroundScheduler()
scheduler.add_job(
load_resources,
trigger=CronTrigger(hour=0, minute=30),
name="每日资源刷新"
)
scheduler.start()
logger.info("定时任务已启动")
def softmax(scores):
max_score = max(scores)
exp_scores = [math.exp(s - max_score) for s in scores]
sum_exp = sum(exp_scores)
return [s / sum_exp for s in exp_scores]
# def get_random_recommendations(category: str, num: int) -> List[str]:
# """根据预加载热度向量推荐(冷启动)"""
# try:
# heat_data = matrix_data.get("heat_data", {})
#
# if category not in heat_data:
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
#
# heat_dict = heat_data[category] # {url: score}
# urls = list(heat_dict.keys())
# scores = list(heat_dict.values())
#
# if not urls:
# raise ValueError("该类别下无热度记录,使用随机推荐")
#
# probs = softmax(scores)
# sample_size = min(num, len(urls))
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
#
# return sampled_urls
#
# except Exception as e:
# # 回退:完全随机推荐
# all_iids = list(matrix_data["iid_to_sketch"].keys())
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# sample_size = min(num, len(category_iids))
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
def get_random_recommendations(category: str, num: int) -> List[str]:
"""全品类随机推荐"""
all_iids = list(matrix_data["iid_to_sketch"].keys())
# 优先从当前品类选择
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# 确保不超出实际数量
sample_size = min(num, len(category_iids))
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
"""
:param user_id: 4
:param category: female_skirt
:param num_recommendations: 1
:return:
[
"aida-sys-image/images/female/skirt/903000017.jpg"
]
"""
try:
start_time = time.time()
cache_key = (user_id, category)
# === 新增:用户存在性检查 ===
user_exists_inter = user_id in matrix_data["user_index_interaction"]
user_exists_feat = user_id in matrix_data["user_index_feature"]
# 任一矩阵不存在用户则返回随机推荐
if not (user_exists_inter and user_exists_feat):
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
return get_random_recommendations(category, num_recommendations)
# 检查缓存
if cache_key in matrix_data["cached_scores"]:
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
else:
# 实时计算逻辑(同原代码)
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
category_iids = matrix_data["category_to_iids"].get(category, [])
valid_sketch_idxs_inter = [
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
if iid in category_iids
]
# 处理交互分数
raw_inter_scores = []
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
# 处理特征分数
valid_sketch_idxs_feature = [
idx for iid, idx in matrix_data["sketch_index_feature"].items()
if iid in category_iids
]
raw_feat_scores = []
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores
else:
processed_feat = np.array([])
# 更新缓存
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
# 合并分数
if brand_id is not None:
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
brand_feat_valid = (
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
brand_idx_feature is not None and
valid_sketch_idxs_feature # 有可用索引
)
if brand_feat_valid:
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
brand_idx_feature, valid_sketch_idxs_feature
]
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
)
processed_brand_feat = raw_brand_feat_scores
# 如果 processed_feat 是空的,替换为全 0避免 shape 不一致
if processed_feat.size == 0:
processed_feat = np.zeros_like(processed_brand_feat)
final_scores = processed_inter + 0.3 * (
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
)
else:
# brand 信息不可用
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
else:
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
# 概率采样
scores = np.array(final_scores)
# 调整后的概率转换带温度控制的softmax
def calibrated_softmax(scores, temperature=1.0):
scores = scores / temperature
scale = scores - max(scores)
exps = np.exp(scale)
return exps / np.sum(exps)
probs = calibrated_softmax(scores, 0.09)
chosen_indices = np.random.choice(
len(valid_sketch_idxs),
size=min(num_recommendations, len(valid_sketch_idxs)),
p=probs,
replace=False
)
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}")
return recommendations
except Exception as e:
logger.error(f"推荐失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -4,11 +4,16 @@ from app.api import api_attribute_retrieve, api_query_image
from app.api import api_brand_dna
from app.api import api_brighten
from app.api import api_chat_robot
from app.api import api_clothing_seg
from app.api import api_design
from app.api import api_design_pre_processing
from app.api import api_extraction_project_info
from app.api import api_generate_image
from app.api import api_image2sketch
from app.api import api_mannequins_edit
from app.api import api_pose_transform
from app.api import api_prompt_generation
from app.api import api_recommendation
from app.api import api_super_resolution
from app.api import api_test
@@ -26,3 +31,8 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")

View File

@@ -4,7 +4,7 @@ import logging
from fastapi import APIRouter
from fastapi import HTTPException
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, BATCH_PS_RABBITMQ_QUEUES
from app.schemas.response_template import ResponseModel
logger = logging.getLogger()
@@ -14,10 +14,19 @@ router = APIRouter()
@router.get("{id}")
def test(id: int):
data = {
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
"pose transform PS_RABBITMQ_QUEUES": PS_RABBITMQ_QUEUES,
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
"to product image GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
"relight GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
# batch
"batch product BATCH_GPI_RABBITMQ_QUEUES": BATCH_GPI_RABBITMQ_QUEUES,
"batch relight BATCH_GRI_RABBITMQ_QUEUES": BATCH_GRI_RABBITMQ_QUEUES,
"batch pose transform BATCH_PS_RABBITMQ_QUEUES": BATCH_PS_RABBITMQ_QUEUES,
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
"local_oss_server": OSS
}

View File

@@ -9,14 +9,14 @@ load_dotenv(os.path.join(BASE_DIR, '.env'))
class Settings(BaseSettings):
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
SECRET_KEY = os.getenv('SECRET_KEY', '')
API_PREFIX = ''
BACKEND_CORS_ORIGINS = ['*']
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '')
PROJECT_NAME: str = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
SECRET_KEY: str = os.getenv('SECRET_KEY', '')
API_PREFIX: str = ''
BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL: str = os.getenv('SQL_DATABASE_URL', '')
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM = 'HS256'
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
OSS = "minio"
@@ -25,13 +25,20 @@ if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "../seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
RECOMMEND_PATH_PREFIX = "service/recommend/"
CHROMADB_PATH = "./chromadb/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "/seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
CHROMADB_PATH = "/chromadb/"
RABBITMQ_ENV = "-prod" # 生产环境
# RABBITMQ_ENV = "-dev" # 开发环境
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = "-dev" # 开发环境
# RABBITMQ_ENV = "-local" # 本地测试环境
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
@@ -99,7 +106,7 @@ OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ENV}")
# GenerateImage service config
FAST_GI_MODEL_URL = '10.1.1.243:10011'
@@ -111,20 +118,20 @@ GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}")
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}")
# Generate Product service config
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
@@ -132,17 +139,25 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
# GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Product service config 旧版product img 模型
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
BATCH_GPI_RABBITMQ_QUEUES = os.getenv("BATCH_GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"BatchToProductImage{RABBITMQ_ENV}")
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
BATCH_GRI_RABBITMQ_QUEUES = os.getenv("BATCH_GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"BatchRelight{RABBITMQ_ENV}")
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}")
BATCH_PS_RABBITMQ_QUEUES = os.getenv("BATCH_PS_RABBITMQ_QUEUES", f"BatchPoseTransform{RABBITMQ_ENV}")
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config
SEGMENTATION = {
"new_model_name": "seg_knet",
@@ -152,6 +167,10 @@ SEGMENTATION = {
}
# ollama config
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
# design batch
BATCH_DESIGN_RABBITMQ_QUEUES = os.getenv("BATCH_DESIGN_RABBITMQ_QUEUES", f"DesignBatch{RABBITMQ_ENV}")
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
@@ -189,3 +208,23 @@ PRIORITY_DICT = {
}
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
DB_CONFIG = {
"host": "18.167.251.121",
"port": 3306,
"user": "root",
"password": "QWa998345",
"database": "aida",
"charset": "utf8mb4"
}
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}

View File

@@ -1,15 +1,17 @@
import logging.config
from http.client import HTTPException
from fastapi.responses import JSONResponse
from fastapi import FastAPI, HTTPException, Request
import uvicorn
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from app.api.api_route import router
from app.core.config import settings
from app.core.record_api_count import count_api_calls
from app.schemas.response_template import ResponseModel
from app.service.recommend.service import load_resources
from logging_env import LOGGER_CONFIG_DICT
logging.config.dictConfig(LOGGER_CONFIG_DICT)
@@ -17,6 +19,8 @@ logging.getLogger("pika").setLevel(logging.WARNING)
from starlette.middleware.cors import CORSMiddleware
logger = logging.getLogger(__name__)
def get_application() -> FastAPI:
application = FastAPI(
@@ -51,5 +55,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -4,3 +4,8 @@ from pydantic import BaseModel
class BrandDnaModel(BaseModel):
image_url: str
is_brand_dna: bool
class GenerateBrandModel(BaseModel):
user_id: str
prompt: str

View File

@@ -2,7 +2,6 @@ from pydantic import BaseModel
class ChatRobotModel(BaseModel):
gender: str
message: str
session_id: str
user_id: int

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ClothingSegModel(BaseModel):
user_id: str
image_data: list[dict]

View File

@@ -1,3 +1,5 @@
from typing import List
from pydantic import BaseModel
@@ -36,3 +38,53 @@ class GenerateRelightImageModel(BaseModel):
image_url: str
direction: str
product_type: str
"""
batch generate image
"""
# product任务子项
class ProductItemModel(BaseModel):
tasks_id: str
image_strength: float
prompt: str
image_url: str
product_type: str
# product批处理 集合
class BatchGenerateProductImageModel(BaseModel):
batch_tasks_id: str
user_id: str
batch_data_list: List[ProductItemModel]
# relight任务子项
class RelightItemModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
direction: str
product_type: str
# relight批处理集合
class BatchGenerateRelightImageModel(BaseModel):
batch_tasks_id: str
user_id: str
batch_data_list: List[RelightItemModel]
"""
agent tool generate image
"""
class AgentTollGenerateImageModel(BaseModel):
prompt: str
category: str
gender: str
version: str
size: int

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class MannequinModel(BaseModel):
mannequins: str
resize_pixel: float
bucket_name: str
mannequin_name: str
top: int
bottom: int

View File

@@ -0,0 +1,14 @@
from pydantic import BaseModel
class PoseTransformModel(BaseModel):
image_url: str
tasks_id: str
pose_id: str
class BatchPoseTransformModel(BaseModel):
image_url: str
tasks_id: str
pose_id: str
batch_size: int

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
class ProjectInfoExtractionModel(BaseModel):
prompt: str
image_list: list
file_list: list

View File

@@ -3,3 +3,7 @@ from pydantic import BaseModel
class PromptGenerationImageModel(BaseModel):
text: str
class ImageRequest(BaseModel):
img: str

View File

@@ -9,9 +9,9 @@ import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL, CATEGORY_PATH
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import local_debug_const
from app.service.attribute.config import local_debug_const, const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
@@ -25,18 +25,18 @@ class BrandDna:
self.sketch_bucket = "test"
self.image_url = request_item.image_url
self.is_brand_dna = request_item.is_brand_dna
# self.attr_type = pd.read_csv(CATEGORY_PATH)
self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.attr_type = pd.read_csv(CATEGORY_PATH)
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
# self.const = const
self.const = local_debug_const
self.const = const
# self.const = local_debug_const
# 获取结果
def get_result(self):
mask, image = self.get_seg_mask()
cv2.imshow("", image)
cv2.waitKey(0)
# cv2.imshow("", image)
# cv2.waitKey(0)
height, width, channels = image.shape
result_dict = []
@@ -50,8 +50,8 @@ class BrandDna:
outwear_img[mask == value] = image[mask == value]
outwear_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", outwear_img)
cv2.waitKey(0)
# cv2.imshow("", outwear_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(outwear_img)
@@ -89,8 +89,8 @@ class BrandDna:
tops_img[mask == value] = image[mask == value]
tops_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", tops_img)
cv2.waitKey(0)
# cv2.imshow("", tops_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(tops_img)
@@ -129,8 +129,8 @@ class BrandDna:
bottoms_img[mask == value] = image[mask == value]
bottoms_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", bottoms_img)
cv2.waitKey(0)
# cv2.imshow("", bottoms_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(bottoms_img)
@@ -327,7 +327,7 @@ if __name__ == '__main__':
# result_url = service.get_result()
# print(result_url)
request_item = BrandDnaModel(
image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png",
image_url="aida-results/result_00006a48-e315-11ee-b7c8-b48351119060.png",
is_brand_dna=True
)
service = BrandDna(request_item)

View File

@@ -0,0 +1,104 @@
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
# from langchain_openai import ChatOpenAI
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
class GenerateBrandInfo:
def __init__(self, request_data):
# minio client init
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# user info init
self.user_id = request_data.user_id
self.category = "brand_logo"
# generate logo init
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.batch_size = 1
self.mode = 'txt2img'
# llm generate brand info init
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
ResponseSchema(name="brand_slogan", description="Brand slogan."),
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
]
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
self.format_instructions = self.output_parser.get_format_instructions()
self.prompt = PromptTemplate(
template="你是一个时装品牌的设计师。根据用户输入提取出brand namebrand slogan,brand logo 描述。如果没有以上内容需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt这个prompt用于生成模型prompt需要完全表达用户的想法并使用英文使用简洁明了的单词不要过长。.\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": self.format_instructions}
)
self._input = self.prompt.format_prompt(question=request_data.prompt)
self.result_data = {}
def get_result(self):
self.llm_generate_brand_info()
self.generate_brand_logo()
return self.result_data
def llm_generate_brand_info(self):
output = self.model(self._input.to_messages())
brand_data = self.output_parser.parse(output.content)
self.result_data = brand_data
self.generate_logo_prompt = brand_data['brand_logo_prompt']
def generate_brand_logo(self):
prompts = [self.generate_logo_prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
logo_url = self.upload_logo_image(image_result, generate_uuid())
self.result_data['brand_logo'] = logo_url
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}'
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
if __name__ == '__main__':
request_data = GenerateBrandModel(
user_id="89",
prompt="华为"
)
service = GenerateBrandInfo(request_data)
print(service.get_result())

View File

@@ -0,0 +1,32 @@
from dotenv import load_dotenv
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
# 加载.env文件的环境变量
load_dotenv()
# 创建一个大语言模型model指定了大语言模型的种类
model = ChatOpenAI(model="qwen2.5-14b-instruct")
# 想要接收的响应模式
response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
ResponseSchema(name="brand_slogan", description="Brand slogan."),
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
template="你是一个时装品牌的设计师。根据用户输入提取出brand namebrand slogan,brand logo 描述。如果没有以上内容需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt这个prompt用于生成模型.\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": format_instructions}
)
_input = prompt.format_prompt(question="brand name: cat home")
output = model(_input.to_messages())
brand_data = output_parser.parse(output.content)
def generate_logo(bucket_name, object_name, prompt):
pass

View File

@@ -90,7 +90,6 @@ def chat(post_data):
user_id = post_data.user_id
session_id = post_data.session_id
input_message = post_data.message
gender = post_data.gender
# final_outputs = agent_executor(
# {"input": input_message, "gender": gender},
@@ -98,7 +97,7 @@ def chat(post_data):
# session_key=f"buffer:{user_id}:{session_id}",
# )
final_outputs = CallQWen.call_with_messages(input_message, gender)
final_outputs = CallQWen.call_with_messages(input_message)
# api_response = {
# 'user_id': user_id,
# 'session_id': session_id,

View File

@@ -34,6 +34,39 @@ You may encounter the following types of questions:
Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
"""
FASHION_CHAT_BOT_PREFIX_TEMP = """
You are a fashion design assistant with the following capabilities:
1. Direct conversation: Answer general questions (e.g., greetings, opinions).
2. Tool usage:
- `get_image_from_vector_db`: Retrieve clothing items (requires gender parameter).
- `internet_search`: Fetch real-time fashion trends.
- `tutorial_tool`: Provide styling guides.
Key Rules:
1. Tool Selection:
- Use `get_image_from_vector_db` for clothing queries (e.g., "show men's jackets").
- Use `internet_search` for time-sensitive queries (e.g., "2024 Paris Fashion Week trends").
- Use `tutorial_tool` for educational requests (e.g., "how to layer outfits").
2. Gender Handling (for `get_image_from_vector_db` only):
- Step 1: Check the **current user input** for gender keywords (e.g., "women/men/she/he"). If found, extract and pass as `gender`.
- Step 2: If no gender in current input, scan the **chat history** for the most recent gender reference.
- Step 3: If undetermined, default to `"unisex"`.
3. Output Format:
- Direct replies: Keep responses under 20 words.
- Tool calls:
- Always include required parameters (e.g., `gender` for `get_image_from_vector_db`).
- Auto-fill `gender` using the above rules if unspecified.
Examples:
1. User: "Find red dresses for women"
→ `get_image_from_vector_db(gender="female", query="dress")`
2. User: "show men's jackets"
→ `get_image_from_vector_db(gender="male", query="outwear")`
3. User: "Show casual outfits"
→ `get_image_from_vector_db(gender="unisex", query="casual outfits")`"""
TOOL_SELECT_SUFFIX = """
Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly.
For database-related questions, use SQL tools to identify relevant tables and query their schemas.

View File

@@ -9,7 +9,7 @@ from app.core.config import *
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
GET_LANGUAGE_PREFIX
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
from app.service.search_image_with_text.service import query
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
@@ -212,14 +212,15 @@ def get_assistant_response(messages):
return response
def call_with_messages(message, gender):
def call_with_messages(message):
global tool_info
user_input = message
print('\n')
messages = [
{
"content": FASHION_CHAT_BOT_PREFIX, # 系统message
# "content": FASHION_CHAT_BOT_PREFIX, # 系统message
"content": FASHION_CHAT_BOT_PREFIX_TEMP, # 修改后的系统message
"role": "system"
},
{
@@ -255,7 +256,7 @@ def call_with_messages(message, gender):
tool_info = {"name": "search_from_internet", "role": "tool"}
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
message = [
{'role': 'assistant', 'content': content['query']}
{'role': 'assistant', 'content': content['query'] if "query" in content.keys() else user_input}
]
tool_info['content'] = search_from_internet(message)
flag = False
@@ -282,6 +283,8 @@ def call_with_messages(message, gender):
result_content = tool_info['content']
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
# todo 从历史对话中获取性别目前无法获得性别时默认使用female
gender = content['gender'] if "gender" in content.keys() and content['gender'] != 'unisex' else 'female'
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
'content': get_image_from_vector_db(gender, content['parameters']['content'] if "parameters" in content.keys() else content['content'])}
flag = False

View File

@@ -0,0 +1,161 @@
import io
import time
from pprint import pprint
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.clothing_seg import ClothingSegModel
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class ClothingSeg:
def __init__(self, request_data):
self.image_data = request_data.image_data
self.user_id = request_data.user_id
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
@RunTime
def get_result(self):
self.read_image()
self.clothing_seg()
self.upload_image()
for data in self.image_data:
del data["image"]
del data["clothing"]
return self.image_data
@RunTime
def upload_image(self):
for data in self.image_data:
data["clothing_url"] = []
for clothing in data["clothing"]:
object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png"
image_data = io.BytesIO()
clothing.save(image_data, format="PNG")
image_data.seek(0)
image_bytes = image_data.read()
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
data["clothing_url"].append(f"aida-users/{object_name}")
@RunTime
def read_image(self):
for data in self.image_data:
url = data["image_url"]
image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
data["image"] = image
@RunTime
def clothing_seg(self):
for data in self.image_data:
image_type = data["image_type"]
image = data["image"]
clothing_result = []
if image_type == "sketch":
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
seg_mask = get_seg_result(1, image[:, :, :3])
else:
seg_mask = get_seg_result(1, image[:, :, :3])
temp = seg_mask != 0.0
mask = (255 * (temp + 0).astype(np.uint8))
x_min, y_min, x_max, y_max = get_bounding_box(mask)
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
h, w = cropped_image.shape[:2]
mask_pil = Image.fromarray(cropped_mask).convert("L")
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
clothing_result.append(transparent_image)
else:
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input0_data = [input_image.astype(np.float32)] * 1
input0_data = np.array(input0_data, dtype=np.float32)
inputs = [
grpcclient.InferInput(
"INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
),
]
inputs[0].set_data_from_numpy(input0_data)
outputs = [
# grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
# output0_data = response.as_numpy("OUTPUT0")
# cv2.imwrite("output02.png", output0_data * 100)
output1_data = response.as_numpy("OUTPUT1")
for alpha in output1_data:
alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
h, w = cropped_image.shape[:2]
mask_pil = Image.fromarray(cropped_mask).convert("L")
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
clothing_result.append(transparent_image)
data["clothing"] = clothing_result
@RunTime
def get_bounding_box(mask):
"""
从仅包含 0 和 1 的掩码图像中获取边界框。
:param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1
:return: 边界框坐标 (x_min, y_min, x_max, y_max)
"""
# 找到所有值不为 0 的像素的坐标
rows, cols = np.where(mask != 0)
if len(rows) == 0 or len(cols) == 0:
# 如果没有找到不为 0 的像素,返回全 0 的边界框
return 0, 0, 0, 0
# 计算边界框的坐标
x_min = np.min(cols)
y_min = np.min(rows)
x_max = np.max(cols)
y_max = np.max(rows)
return x_min, y_min, x_max, y_max
if __name__ == "__main__":
test_data = ClothingSegModel(
user_id=89,
image_data=[
# {
# "image_url": "test/clothing_seg/dress.jpg",
# "image_type": "sketch"
# },
# {
# "image_url": "test/clothing_seg/skirt_559.jpg",
# "image_type": "sketch"
# },
{
"image_url": "aida-collection-element/87/Sketchboard/ab40e035-547a-48c5-9f97-1db7bf56ad77.jpg",
"image_type": "sketch"
}
]
)
start_time = time.time()
server = ClothingSeg(test_data)
pprint(server.get_result())
print(time.time() - start_time)

View File

@@ -5,9 +5,9 @@ from celery import Celery
from minio import Minio
from app.core.config import *
from app.service.design_batch.item import BodyItem, TopItem, BottomItem
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, AccessoriesItem
from app.service.design_batch.utils.MQ import publish_status
from app.service.design_batch.utils.organize import organize_body, organize_clothing
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_accessories
from app.service.design_batch.utils.save_json import oss_upload_json
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
@@ -19,6 +19,8 @@ logging.getLogger('pika').setLevel(logging.WARNING)
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
print("start")
def process_item(item, basic):
# 处理project中单个item
@@ -28,9 +30,14 @@ def process_item(item, basic):
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
item_data = top_server.process()
else:
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
elif item['type'].lower() in ['accessories']:
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
else:
raise NotImplementedError(f"Item type {item['type']} not implemented")
return item_data
@@ -40,6 +47,10 @@ def process_layer(item, layers):
body_layer = organize_body(item)
layers.append(body_layer)
return item['body_image'].size
elif item['name'] == 'accessories':
front_layer, back_layer = organize_accessories(item)
layers.append(front_layer)
layers.append(back_layer)
else:
front_layer, back_layer = organize_clothing(item)
layers.append(front_layer)
@@ -48,6 +59,9 @@ def process_layer(item, layers):
@celery_app.task
def batch_design(objects_data, tasks_id, json_name):
print(objects_data)
print(tasks_id)
print(json_name)
object_response = []
threads = []
active_threads = 0
@@ -71,7 +85,7 @@ def batch_design(objects_data, tasks_id, json_name):
for lay in layers:
items_response['layers'].append({
'image_category': lay['name'],
'image_category': "body" if lay['name'] == 'mannequin' else lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
@@ -121,6 +135,7 @@ def batch_design(objects_data, tasks_id, json_name):
for t in threads:
t.join()
logger.debug(object_response)
print(object_response)
oss_upload_json(minio_client, object_response, json_name)
publish_status(tasks_id, "ok", json_name)
return object_response

View File

@@ -1,4 +1,4 @@
from app.service.design_batch.pipeline import *
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
class BaseItem:
@@ -9,6 +9,27 @@ class BaseItem:
self.result.update(basic)
class AccessoriesItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Accessories_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
ContourDetection(),
# Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.Accessories_pipeline:
self.result = item(self.result)
return self.result
class TopItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
@@ -16,6 +37,7 @@ class TopItem(BaseItem):
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
@@ -35,7 +57,8 @@ class BottomItem(BaseItem):
LoadImage(minio_client),
KeyPoint(),
ContourDetection(),
# Segmentation(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),

View File

@@ -1,3 +1,4 @@
from .back_perspective import BackPerspective
from .color import Color
from .contour_detection import ContourDetection
from .keypoint import KeyPoint
@@ -13,6 +14,7 @@ __all__ = [
'KeyPoint',
'ContourDetection',
'Segmentation',
'BackPerspective',
'Color',
'PrintPainting',
'Scaling',

View File

@@ -0,0 +1,79 @@
import cv2
import numpy as np
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.new_oss_client import oss_upload_image
class BackPerspective:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
# 如果sketch为系统图 查看是否有对应的 背后视角图
if result['path'].split('/')[0] == 'aida-sys-image':
file_path = result['path'].replace("images", 'images_back', 1)
if self.is_file_exists(bucket_name='aida-sys-image', file_name=file_path[file_path.find('/') + 1:]):
result['back_perspective_url'] = file_path
return result
else:
seg_result = get_seg_result("1", result['image'])[0]
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
seg_result = result['seg_result']
else:
seg_result = get_seg_result("1", result['image'])[0]
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
back_sketch = result['image'].copy()
back_sketch[m > 100] = 255
# 上传背后视角图
_, img_encoded = cv2.imencode(".jpg", back_sketch)
resp = oss_upload_image(self.minio_client, bucket='test', object_name=result['path'], image_bytes=img_encoded.tobytes())
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
return result
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
mask = mask.astype(np.uint8) * 255
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 创建一个彩色副本用于绘制轮廓
mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
def thicken_contour_inward(contour, thick):
# 创建一个空白的黑色图像与原始掩码大小相同
blank = np.zeros_like(mask)
# 在空白图像上绘制白色的轮廓
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
# 找到轮廓的中心(可以用重心等方法近似)
M = cv2.moments(contour)
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
# 进行距离变换,离中心越近的值越小
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留
result = np.zeros_like(mask)
for i in range(dist_transform.shape[0]):
for j in range(dist_transform.shape[1]):
if dist_transform[i, j] < thick:
result[i, j] = 255
return result
for contour in contours:
thickened_contour = thicken_contour_inward(contour, thickness)
mask_color[thickened_contour > 0] = color
_, binary_result = cv2.threshold(mask_color, 127, 255, cv2.THRESH_BINARY)
# 转换为掩码形式
mask_result = cv2.cvtColor(binary_result, cv2.COLOR_BGR2GRAY)
return mask_result
def is_file_exists(self, bucket_name, file_name):
try:
self.minio_client.stat_object(bucket_name, file_name)
return True
except Exception:
return False

View File

@@ -14,14 +14,39 @@ class Color:
def __call__(self, result):
dim_image_h, dim_image_w = result['image'].shape[0:2]
# 渐变色
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
# 无色
elif "color" not in result.keys() or result['color'] == "":
result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
result['alpha'] = 100 / 255.0
return result
# 正常颜色
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
if "partial_color" in result.keys() and result['partial_color'] != "":
bucket_name = result['partial_color'].split('/')[0]
object_name = result['partial_color'][result['partial_color'].find('/') + 1:]
partial_color = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
h, w = partial_color.shape[0:2]
resize_pattern = cv2.resize(resize_pattern, (w, h), interpolation=cv2.INTER_AREA)
# 分离出 png 图的 alpha 通道
alpha_channel = partial_color[:, :, 3]
# 提取 png 图的 RGB 通道
png_rgb = partial_color[:, :, :3]
# 创建一个与 cv 图大小相同的掩码,用于指示哪些像素需要替换
mask = alpha_channel > 0
# 将掩码扩展为 3 通道,以便与 cv 图进行逐元素操作
mask_3ch = np.stack([mask] * 3, axis=-1)
# 根据掩码将 png 图的颜色覆盖到 cv 图上
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)

View File

@@ -4,7 +4,8 @@ import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.design_batch.utils.design_ensemble import get_keypoint_result
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
from app.service.utils.decorator import ClassCallRunTime, RunTime
logger = logging.getLogger(__name__)
@@ -16,14 +17,15 @@ class KeyPoint:
def get_name(cls):
return cls.name
@ClassCallRunTime
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# keypoint_cache = self.keypoint_cache(result, site)
keypoint_cache = False
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
@@ -87,7 +89,7 @@ class KeyPoint:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
@RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)

View File

@@ -1,6 +1,9 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image
@@ -71,6 +74,8 @@ class LoadImage:
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
elif name == 'accessories':
keypoint = "accessories"
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")

View File

@@ -15,8 +15,25 @@ class PrintPainting:
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
partial_path = result['print']['partial'] if 'partial' in result['print'] else None
result['single_image'] = None
result['print_image'] = None
# TODO 给result['pattern_image'] resize 到resize_scale的大小
# TODO 给result['mask'] resize 到resize_scale的大小
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
pass
else:
height, width = result['pattern_image'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
result['pattern_image'] = cv2.resize(result['pattern_image'], (new_width, new_height))
result['final_image'] = cv2.resize(result['final_image'], (new_width, new_height))
result['mask'] = cv2.resize(result['mask'], (new_width, new_height))
result['gray'] = cv2.resize(result['gray'], (new_width, new_height))
print(1)
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
@@ -39,7 +56,7 @@ class PrintPainting:
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
mask = image.split()[3]
resized_source = image.resize(new_size)
@@ -62,9 +79,12 @@ class PrintPainting:
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
@@ -143,7 +163,7 @@ class PrintPainting:
for i in range(len(element_print['element_path_list'])):
image, image_mode = self.read_image(element_print['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
new_size = (int(result['final_image'].shape[1] * element_print['element_scale_list'][i][0]), int(result['final_image'].shape[0] * element_print['element_scale_list'][i][1]))
mask = image.split()[3]
resized_source = image.resize(new_size)
@@ -165,9 +185,11 @@ class PrintPainting:
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
@@ -241,6 +263,45 @@ class PrintPainting:
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if partial_path:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
image, image_mode = self.read_image(partial_path)
if image_mode == "RGBA":
new_size = (result['pattern_image'].shape[1], result['pattern_image'].shape[0])
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
# rotated_resized_source = resized_source.rotate(-partial_print['print_angle_list'][i])
# rotated_resized_source_mask = resized_source_mask.rotate(-partial_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(resized_source, (0, 0), resized_source)
source_image_pil_mask.paste(resized_source_mask, (0, 0), resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
@@ -360,10 +421,10 @@ class PrintPainting:
return print_image
def get_print(self, print_dict):
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
print_dict['scale'] = print_dict['print_scale_list'][0][0]
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
@@ -386,8 +447,9 @@ class PrintPainting:
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
x_offset = print_w - int(location[0][1] % print_w)
y_offset = print_w - int(location[0][0] % print_h)
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
@@ -409,7 +471,7 @@ class PrintPainting:
return high, low
@staticmethod
def img_rotate(image, angel, scale):
def img_rotate(image, angel):
"""顺时针旋转图像任意角度
Args:
@@ -424,7 +486,7 @@ class PrintPainting:
center = (w // 2, h // 2)
# if type(angel) is not int:
# angel = 0
M = cv2.getRotationMatrix2D(center, -angel, scale)
M = cv2.getRotationMatrix2D(center, -angel, 1)
# 调整旋转后的图像长宽
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
@@ -433,7 +495,7 @@ class PrintPainting:
# 旋转图像
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
return rotated_img, ((rotated_img.shape[1] - image.shape[1]) // 2, (rotated_img.shape[0] - image.shape[0]) // 2)
# return rotated_img, (0, 0)
@staticmethod
@@ -442,8 +504,11 @@ class PrintPainting:
angle: 旋转的角度
crop: 是否需要进行裁剪,布尔向量
"""
if not isinstance(crop, bool):
raise ValueError("The 'crop' parameter must be a boolean.")
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
w, h = img.shape[:2]
h, w = img.shape[:2]
# 旋转角度的周期是360°
angle %= 360
# 计算仿射变换矩阵
@@ -455,7 +520,7 @@ class PrintPainting:
if crop:
# 裁剪角度的等效周期是180°
angle_crop = angle % 180
if angle > 90:
if angle_crop > 90:
angle_crop = 180 - angle_crop
# 转化角度为弧度
theta = angle_crop * np.pi / 180

View File

@@ -46,4 +46,16 @@ class Scaling:
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
elif result['keypoint'] == 'accessories':
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
distance_clo = result['img_shape'][1]
distance_bdy = 320 / 2
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
else:
result['scale'] = 1
return result

View File

@@ -5,7 +5,8 @@ import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.design_batch.utils.design_ensemble import get_seg_result
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
@@ -15,6 +16,7 @@ class Segmentation:
def __init__(self, minio_client):
self.minio_client = minio_client
@ClassCallRunTime
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
@@ -31,24 +33,37 @@ class Segmentation:
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
result['seg_result'] = seg_result
if not _:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
seg_result = get_seg_result(result["image_id"], result['image'])
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
# 处理前片后片
temp_front = seg_result == 1.0
temp_front = seg_result == 1
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
temp_back = seg_result == 2
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"seg_cache/{image_id}.npy"
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
@@ -57,7 +72,7 @@ class Segmentation:
@staticmethod
def load_seg_result(image_id):
file_path = f"seg_cache/{image_id}.npy"
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)

View File

@@ -7,10 +7,11 @@ from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.design_batch.utils.conversion_image import rgb_to_rgba
from app.service.design_batch.utils.upload_image import upload_png_mask
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
from app.service.design_fast.utils.upload_image import upload_png_mask
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
class Split(object):
@@ -20,51 +21,95 @@ class Split(object):
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
front_mask = result['front_mask']
back_mask = result['back_mask']
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'):
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
front_mask = result['front_mask']
back_mask = result['back_mask']
else:
height, width = result['front_mask'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
front_mask = cv2.resize(result['front_mask'], (new_width, new_height))
back_mask = cv2.resize(result['back_mask'], (new_width, new_height))
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
rgba_image = cv2.resize(rgba_image, new_size)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
# 预处理用户自选区mask
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
blue_mask = b > r
# 创建红色和绿色掩码
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
else:
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
# result_back_image = np.zeros_like(rgba_image)
# back_mask = cv2.resize(back_mask, new_size)
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
#
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# else:
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# result['back_image'] = None
# result["back_image_url"] = None
# # result["back_mask_url"] = None
# # result['back_mask_image'] = None
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
# result["back_mask_url"] = None
# result['back_mask_image'] = None
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))

View File

@@ -2,16 +2,16 @@ import json
import pika
from app.core.config import RABBITMQ_PARAMS
from app.core.config import RABBITMQ_PARAMS, BATCH_DESIGN_RABBITMQ_QUEUES
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue='DesignBatch', durable=True)
channel.queue_declare(queue=BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key='DesignBatch',
routing_key=BATCH_DESIGN_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,

View File

@@ -33,8 +33,8 @@ def organize_clothing(layer):
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
pattern_image=layer['pattern_image']
pattern_image=layer['pattern_image'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
@@ -50,6 +50,46 @@ def organize_clothing(layer):
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
return front_layer, back_layer
def organize_accessories(layer):
# 起始坐标
start_point = (0, 0)
# 前片数据
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
name=f'{layer["name"].lower()}_front',
image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
pattern_image=layer['pattern_image'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
name=f'{layer["name"].lower()}_back',
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
return front_layer, back_layer

View File

@@ -200,6 +200,11 @@ def design_generate_v2(request_data):
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
# 发送结果给java端
url = JAVA_STREAM_API_URL
# xu_pei_test_url = "https://cd21b9110505.ngrok-free.app/api/third/party/receiveDesignResults"
logger.info(f"java 回调 -> {url}")
# logger.info(f"xupei java 回调 -> {xu_pei_test_url}")
headers = {
'Accept': "*/*",
'Accept-Encoding': "gzip, deflate, br",
@@ -213,6 +218,11 @@ def design_generate_v2(request_data):
# 打印结果
logger.info(response.text)
# response = post_request(xu_pei_test_url, json_data=items_response, headers=headers)
# if response:
# 打印结果
# logger.info(f"xupei test response : {response.text}")
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object))
threads.append(t)

View File

@@ -57,7 +57,7 @@ class BottomItem(BaseItem):
LoadImage(minio_client),
KeyPoint(),
ContourDetection(),
# Segmentation(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),

View File

@@ -29,6 +29,24 @@ class Color:
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
if "partial_color" in result.keys() and result['partial_color'] != "":
bucket_name = result['partial_color'].split('/')[0]
object_name = result['partial_color'][result['partial_color'].find('/') + 1:]
partial_color = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
h, w = partial_color.shape[0:2]
resize_pattern = cv2.resize(resize_pattern, (w, h), interpolation=cv2.INTER_AREA)
# 分离出 png 图的 alpha 通道
alpha_channel = partial_color[:, :, 3]
# 提取 png 图的 RGB 通道
png_rgb = partial_color[:, :, :3]
# 创建一个与 cv 图大小相同的掩码,用于指示哪些像素需要替换
mask = alpha_channel > 0
# 将掩码扩展为 3 通道,以便与 cv 图进行逐元素操作
mask_3ch = np.stack([mask] * 3, axis=-1)
# 根据掩码将 png 图的颜色覆盖到 cv 图上
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)

View File

@@ -15,6 +15,7 @@ class PrintPainting:
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
partial_path = result['print']['partial'] if 'partial' in result['print'] else None
result['single_image'] = None
result['print_image'] = None
# TODO 给result['pattern_image'] resize 到resize_scale的大小
@@ -32,7 +33,6 @@ class PrintPainting:
result['mask'] = cv2.resize(result['mask'], (new_width, new_height))
result['gray'] = cv2.resize(result['gray'], (new_width, new_height))
print(1)
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
@@ -54,90 +54,89 @@ class PrintPainting:
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
if image_mode == "RGB":
image_rgba = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
image = Image.fromarray(image_rgba)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
# else:
# mask = self.get_mask_inv(image)
# mask = np.expand_dims(mask, axis=2)
# mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
# mask = cv2.bitwise_not(mask)
#
# mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# # 旋转后的坐标需要重新算
# rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
# rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
# # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
# x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
#
# image_x = print_background.shape[1] # 底图宽
# image_y = print_background.shape[0] # 底图高
# print_x = rotate_image.shape[1] #印花宽
# print_y = rotate_image.shape[0] #印花高
#
# # 有bug
# # if x + print_x > image_x:
# # rotate_image = rotate_image[:, :x + print_x - image_x]
# # rotate_mask = rotate_mask[:, :x + print_x - image_x]
# # #
# # if y + print_y > image_y:
# # rotate_image = rotate_image[:y + print_y - image_y]
# # rotate_mask = rotate_mask[:y + print_y - image_y]
#
# # 不能是并行
# # 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# # 先挪 再判断 最后裁剪
#
# # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
# if x <= 0: # 如果X轴偏移量小于0说明印花需要被裁剪至合适大小 或当X轴偏移量大于印花宽度时裁剪后的印花宽度为0
# rotate_image = rotate_image[:, abs(x):]
# rotate_mask = rotate_mask[:, abs(x):]
# start_x = x = 0
# else:
# start_x = x
#
# if y <= 0: # 如果X轴偏移量大于0说明印花需要被裁剪至合适大小 或当Y轴偏移量大于印花宽度时裁剪后的印花宽度为0
# rotate_image = rotate_image[abs(y):, :]
# rotate_mask = rotate_mask[abs(y):, :]
# start_y = y = 0
# else:
# start_y = y
#
# # ------------------
# # 如果print-size大于image-size 则需要裁剪print
#
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :image_x - x]
# rotate_mask = rotate_mask[:, :image_x - x]
#
# if y + print_y > image_y:
# rotate_image = rotate_image[:image_y - y, :]
# rotate_mask = rotate_mask[:image_y - y, :]
#
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
#
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
# mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
# print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
@@ -262,6 +261,45 @@ class PrintPainting:
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if partial_path:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
image, image_mode = self.read_image(partial_path)
if image_mode == "RGBA":
new_size = (result['pattern_image'].shape[1], result['pattern_image'].shape[0])
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
# rotated_resized_source = resized_source.rotate(-partial_print['print_angle_list'][i])
# rotated_resized_source_mask = resized_source_mask.rotate(-partial_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(resized_source, (0, 0), resized_source)
source_image_pil_mask.paste(resized_source_mask, (0, 0), resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
@@ -414,7 +452,6 @@ class PrintPainting:
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:

View File

@@ -65,35 +65,51 @@ class Split(object):
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
# result_back_image = np.zeros_like(rgba_image)
# back_mask = cv2.resize(back_mask, new_size)
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
#
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# else:
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# result['back_image'] = None
# result["back_image_url"] = None
# # result["back_mask_url"] = None
# # result['back_mask_image'] = None
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
# result["back_mask_url"] = None
# result['back_mask_image'] = None
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))

View File

@@ -0,0 +1,24 @@
from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product, publish_status as product_publish_status
from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight, publish_status as relight_publish_status
from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform, publish_status as pose_transform_publish_status
async def start_product_batch_generate(data):
generate_clothes_task = batch_generate_product.delay(data.dict())
print(generate_clothes_task)
product_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "")
return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state}
async def start_relight_batch_generate(data):
generate_clothes_task = batch_generate_relight.delay(data.dict())
print(generate_clothes_task)
relight_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "")
return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state}
async def start_pose_transform_batch_generate(data):
generate_clothes_task = batch_generate_pose_transform.delay(data.dict())
print(generate_clothes_task)
pose_transform_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}

View File

@@ -0,0 +1,242 @@
# 旧版product
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import BatchGenerateProductImageModel, ProductItemModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
celery_app = Celery('product_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_product'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logger = logging.getLogger()
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
category = "product_image"
@celery_app.task
def batch_generate_product(batch_request_data):
batch_size = len(batch_request_data['batch_data_list'])
logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}")
batch_tasks_id = batch_request_data['batch_tasks_id']
user_id = batch_request_data['user_id']
result_data_list = []
for i, data in enumerate(batch_request_data['batch_data_list']):
tasks_id = data['tasks_id']
image = pre_processing_image(data['image_url'])
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
images = [image.astype(np.uint8)] * 1
prompts = [data['prompt']] * 1
if data['product_type'] == "single":
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_image_strength.set_data_from_numpy(image_strength_obj)
inputs = [input_text, input_image, input_image_strength]
try:
if data['product_type'] == "single":
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
image = result.as_numpy("generated_cnet_image")
else:
result = grpc_client.infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
except Exception as e:
if 'mask_list' in str(e):
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
e_image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
e_input_image_strength = grpcclient.InferInput("image_strength", e_image_strength_obj.shape, np_to_triton_dtype(e_image_strength_obj.dtype))
e_input_text.set_data_from_numpy(e_text_obj)
e_input_image.set_data_from_numpy(e_image_obj)
e_input_image_strength.set_data_from_numpy(e_image_strength_obj)
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=[e_input_text, e_input_image, e_input_image_strength], priority=100)
image = result.as_numpy("generated_cnet_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
else:
image_result = str(e)
logger.error(image_result)
if isinstance(image_result, Image.Image):
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
data['product_img'] = image_url
result_data_list.append(data)
else:
image_url = image_result
data['product_img'] = image_url
result_data_list.append(data)
# 发送每条结果
if DEBUG:
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
else:
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
# 任务完成,发送所有数据结果
if DEBUG:
print(result_data_list)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
else:
publish_status(batch_tasks_id, f"OK", result_data_list)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def post_processing_image(image, left, top):
resized_image = image.resize((int(image.width * (768 / image.height)), 768))
# 计算裁剪的坐标
left = (resized_image.width - 512) // 2
upper = 0
right = left + 512
lower = 768
# 进行裁剪
cropped_image = resized_image.crop((left, upper, right, lower))
return cropped_image
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_GPI_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_GPI_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
# rd = BatchGenerateProductImageModel(
# tasks_id="123-15-51-89",
# image_strength=0.7,
# prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
# image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# product_type="overall",
# batch_size=20
# )
# batch_generate_product(rd.dict())
# rd = {
# "user_id": "89",
# "batch_data_list": [
# {
# "tasks_id": "A-123-15-51-89",
# "image_strength": 0.7,
# "prompt": " The best quality, ma123sterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# "product_type": "overall",
# },
# {
# "tasks_id": "B-123-15-51-89",
# "image_strength": 0.7,
# "prompt": " The best quality, masterpiece, real image.Outwear123,high quality clothing details,8K realistic,HDR",
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# "product_type": "overall",
# }
# ]
# }
rd = BatchGenerateProductImageModel(
batch_tasks_id="abcd",
user_id="89",
batch_data_list=[
ProductItemModel(
tasks_id="123-5464",
image_strength=0.7,
product_type="overall",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
prompt="123"
),
ProductItemModel(
tasks_id="123-5464123",
image_strength=0.7,
product_type="overall",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
prompt="123"
)
]
)
batch_generate_product(rd.dict())

View File

@@ -0,0 +1,250 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import BatchGenerateRelightImageModel, RelightItemModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
celery_app = Celery('relight_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_relight'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
category = "relight_image"
@celery_app.task
def batch_generate_relight(batch_request_data):
batch_size = len(batch_request_data['batch_data_list'])
logger.info(f"batch_generate_relight batch_request_data: {json.dumps(batch_request_data, indent=4)}")
batch_tasks_id = batch_request_data['batch_tasks_id']
user_id = batch_request_data['user_id']
result_data_list = []
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
seed = "1"
for i, data in enumerate(batch_request_data['batch_data_list']):
direction = data['direction']
prompt = data['prompt']
product_type = data['product_type']
image_url = data['image_url']
image = pre_processing_image(image_url)
tasks_id = data['tasks_id']
prompts = [prompt] * 1
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * 1
seeds = [seed] * 1
nagetive_prompts = [negative_prompt] * 1
directions = [direction] * 1
if product_type == 'single':
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
try:
if data['product_type'] == "single":
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
image = result.as_numpy("generated_relight_image")
else:
result = grpc_client.infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
except Exception as e:
print(e)
if 'mask_list' in str(e):
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
e_na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
e_seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
e_direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
e_input_natext = grpcclient.InferInput("negative_prompt", e_na_text_obj.shape, np_to_triton_dtype(e_na_text_obj.dtype))
e_input_seed = grpcclient.InferInput("seed", e_seed_obj.shape, np_to_triton_dtype(e_seed_obj.dtype))
e_input_direction = grpcclient.InferInput("direction", e_direction_obj.shape, np_to_triton_dtype(e_direction_obj.dtype))
e_input_text.set_data_from_numpy(e_text_obj)
e_input_image.set_data_from_numpy(e_image_obj)
e_input_natext.set_data_from_numpy(e_na_text_obj)
e_input_seed.set_data_from_numpy(e_seed_obj)
e_input_direction.set_data_from_numpy(e_direction_obj)
e_inputs = [e_input_text, e_input_natext, e_input_image, e_input_seed, e_input_direction]
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=e_inputs, priority=100)
image = result.as_numpy("generated_relight_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
else:
image_result = str(e)
logger.error(e)
if isinstance(image_result, Image.Image):
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
data['relight_img'] = image_url
result_data_list.append(data)
else:
image_url = image_result
data['relight_img'] = image_url
result_data_list.append(data)
# 发送每条结果
if DEBUG:
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
else:
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
# 任务完成,发送所有数据结果
if DEBUG:
print(result_data_list)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
else:
publish_status(batch_tasks_id, f"OK", result_data_list)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_GRI_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
rd = BatchGenerateRelightImageModel(
batch_tasks_id="abcd",
user_id="89",
batch_data_list=[
RelightItemModel(
tasks_id="123-5464",
product_type="overall",
image_url="test/703190759.png",
prompt="Colorful black",
direction="Right Light",
),
RelightItemModel(
tasks_id="123-5464123",
product_type="overall",
image_url="test/703190759.png",
direction="Right Light",
prompt="Colorful black",
)
]
)
batch_generate_relight(rd.dict())
# X = {
# "batch_tasks_id": "abcd",
# "user_id": "89",
# "batch_data_list": [
# {
# "tasks_id": "123-5464",
# "product_type": "overall",
# "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
# "prompt": "Colorful black",
# "direction": "Right Light",
# },
# {
# "tasks_id": "123-5464",
# "product_type": "overall",
# "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
# "prompt": "Colorful black",
# "direction": "Right Light",
# }
#
# ]
# }

View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
from io import BytesIO
import imageio
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.pose_transform import BatchPoseTransformModel
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video
from app.service.utils.new_oss_client import oss_upload_image
from app.service.utils.oss_client import oss_get_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
logger = logging.getLogger()
celery_app = Celery('post_transform_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_post_transform'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
category = "pose_transform"
def upload_first_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
result_image = result_image.convert("RGB")
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
@celery_app.task
def batch_generate_pose_transform(batch_request_data):
logger.info(f"batch_generate_pose_transform batch_request_data: {json.dumps(batch_request_data, indent=4)}")
batch_size = batch_request_data['batch_size']
image_url = batch_request_data['image_url']
image = pre_processing_image(image_url)
pose_num = batch_request_data['pose_id']
tasks_id = batch_request_data['tasks_id']
user_id = tasks_id.rsplit('-', 1)[1]
pose_num = [pose_num] * 1
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype))
input_pose_num.set_data_from_numpy(pose_num_obj)
image_files = [image.astype(np.uint8)] * 1
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
result_url_list = []
for i in range(batch_size):
try:
result = grpc_client.infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], client_timeout=60000, priority=100)
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
# 第一帧图像
first_image = Image.fromarray(result_data[0])
first_image_url = upload_first_image(first_image, user_id=user_id, category=f"{category}_first_img", file_name=f"{tasks_id}_batch_{i}.png")
# 上传GIF
gif_buffer = BytesIO()
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
gif_buffer.seek(0)
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=user_id, category=f"{category}_gif", file_name=f"{tasks_id}_batch_{i}.gif")
# 上传video
video_url = upload_video(frames=result_data, user_id=user_id, category=f"{category}_video", file_name=f"{tasks_id}_batch_{i}.mp4")
data = {
"gif_url": gif_url,
"video_url": video_url,
"first_image_url": first_image_url,
}
except Exception as e:
print(e)
data = {}
result_url_list.append(data)
if DEBUG is False:
if i + 1 < batch_size:
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
logger.info(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}")
# print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}")
else:
publish_status(tasks_id, f"OK", result_url_list)
logger.info(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progressOK | image_url{image_url}")
# print(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progressOK | image_url{image_url}")
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_PS_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_PS_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
rd = BatchPoseTransformModel(
tasks_id="123-89",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
pose_id="1",
batch_size=10
)
batch_generate_pose_transform(rd.dict())

View File

@@ -0,0 +1,28 @@
# import logging
#
# from celery import Celery
#
# from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product
# from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight
# from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform
#
# logger = logging.getLogger()
# celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
# celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
# celery_app.conf.worker_hijack_root_logger = False
# logging.getLogger('pika').setLevel(logging.WARNING)
#
#
# @celery_app.task
# def batch_pose_transform_tasks(batch_request_data):
# batch_generate_pose_transform(batch_request_data)
#
#
# @celery_app.task
# def batch_generate_relight_tasks(batch_request_data):
# batch_generate_relight(batch_request_data)
#
#
# @celery_app.task
# def batch_generate_product_tasks(batch_request_data):
# batch_generate_product(batch_request_data)

View File

@@ -0,0 +1,36 @@
from app.schemas.generate_image import BatchGenerateRelightImageModel, BatchGenerateProductImageModel
from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product
from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight
if __name__ == '__main__':
rd = BatchGenerateProductImageModel(
tasks_id="test1-89",
image_strength=0.7,
prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
product_type="single",
batch_size=2
)
x = batch_generate_product.delay(rd.dict())
print(x)
"""relight"""
# rd = BatchGenerateRelightImageModel(
# tasks_id="123-89",
# # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
# prompt="Colorful black",
# image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
# direction="Right Light",
# product_type="single",
# batch_size=2
# )
# batch_generate_relight.delay(rd.dict())
"""pose transform"""
# rd = BatchPoseTransformModel(
# tasks_id="123-89",
# image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
# pose_id="1",
# batch_size=10
# )
# batch_pose_transform_tasks.delay(rd.dict())

View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import logging
import time
import uuid
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import tritonclient.http as httpclient
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class AgentToolGenerateImage:
def __init__(self, version):
if version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
else:
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
def get_result(self, prompt, size, version, category, gender):
image_url_list = []
image_result_list = []
clothing_category_list = []
try:
prompts = [prompt] * 1
modes = ["txt2img"] * 1
images = [self.image.astype(np.float16)] * 1
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
for i in range(size):
if version == "fast":
response = self.grpc_client.infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, priority=0)
else:
response = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs, priority=0)
image = response.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
_, img_byte_array = cv2.imencode('.jpg', image_result)
req = oss_upload_image(oss_client=minio_client, bucket='test', object_name=f'{uuid.uuid1()}-{i}.jpg', image_bytes=img_byte_array)
image_url_list.append(f"{req.bucket_name}/{req.object_name}")
image_result_list.append(image_result)
if category == "sketch":
clothing_category_list = self.get_clothing_category(image_result_list, gender)
return image_url_list, clothing_category_list
except Exception as e:
logger.error(e)
return image_url_list, clothing_category_list
finally:
self.grpc_client.close()
self.triton_client.close()
def preprocess(self, img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(
img,
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
def get_category(self, image):
inputs = [httpclient.InferInput("input__0", image.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
return colattr[indexs[0]]
def get_clothing_category(self, images, gender):
category_list = []
for image in images:
sketch = self.preprocess(image)
if gender.lower() == "female":
category_list.append(self.get_category(sketch))
elif gender.lower() == "male":
category = self.get_category(sketch)
if category == 'Trousers' or category == 'Skirt':
category_list.append('Bottoms')
elif category == 'Blouse' or category == 'Dress':
category_list.append('Tops')
else:
category_list.append('Outwear')
else:
category_list.append(self.get_category(sketch))
return category_list
attr_type = pd.read_csv(CATEGORY_PATH)
if __name__ == '__main__':
request_data = {
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
"category": "sketch",
"version": "high",
"size": 2,
"gender": "Female",
}
server = AgentToolGenerateImage(request_data['version'])
image_url_list, clothing_category_list = server.get_result(
prompt=request_data['prompt'],
size=request_data['size'],
version=request_data['version'],
category=request_data['category'],
gender=request_data['gender']
)
print(image_url_list)
print(clothing_category_list)

View File

@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
@@ -29,12 +30,6 @@ logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.version = request_data.version
if request_data.version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
@@ -153,15 +148,14 @@ class GenerateImage:
inputs = [input_text, input_image, input_mode]
if self.version == "fast":
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
else:
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
@@ -169,7 +163,6 @@ class GenerateImage:
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
@@ -178,10 +171,8 @@ class GenerateImage:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
if not DEBUG:
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):
@@ -195,7 +186,7 @@ def infer_cancel(tasks_id):
if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background',
prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view",
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
mode='txt2img',
category="test",

View File

@@ -17,6 +17,7 @@ import tritonclient.grpc as grpcclient
from app.core.config import *
from app.schemas.generate_image import GenerateMultiViewModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
@@ -25,14 +26,7 @@ logger = logging.getLogger()
class GenerateMultiView:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.image = self.get_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
@@ -52,16 +46,11 @@ class GenerateMultiView:
if error:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
# pil图像转成numpy数组
images = result.as_numpy("generated_image")
# for id, img in enumerate(images):
# cv2.imwrite(f"{id}.png", img)
# image_url = ""
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
@@ -103,10 +92,8 @@ class GenerateMultiView:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
if not DEBUG:
publish_status(str_generate_data, GMV_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):

View File

@@ -212,6 +212,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateProductImageModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
@@ -220,12 +221,6 @@ logger = logging.getLogger()
class GenerateProductImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
@@ -295,9 +290,9 @@ class GenerateProductImage:
inputs = [input_text, input_image, input_image_strength]
if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
time_out = 600
while time_out > 0:
@@ -318,9 +313,8 @@ class GenerateProductImage:
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
if not DEBUG:
publish_status(str_gen_product_data, GPI_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):

View File

@@ -20,6 +20,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
@@ -28,10 +29,6 @@ logger = logging.getLogger()
class GenerateRelightImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "relight_image"
@@ -42,7 +39,7 @@ class GenerateRelightImage:
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.direction = request_data.direction
self.image_url = request_data.image_url
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
self.image = pre_processing_image(self.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
@@ -114,9 +111,9 @@ class GenerateRelightImage:
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
if self.product_type == 'single':
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
else:
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
time_out = 600
while time_out > 0:
@@ -137,10 +134,49 @@ class GenerateRelightImage:
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GRI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
if not DEBUG:
publish_status(str_gen_product_data, GRI_RABBITMQ_QUEUES)
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
@@ -157,7 +193,7 @@ if __name__ == '__main__':
prompt="Colorful black",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
direction="Right Light",
product_type="single"
product_type="overall"
)
server = GenerateRelightImage(rd)
print(server.get_result())

View File

@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
import tritonclient.grpc as grpcclient
from app.schemas.generate_image import GenerateSingleLogoImageModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
logger = logging.getLogger()
@@ -28,10 +29,6 @@ logger = logging.getLogger()
class GenerateSingleLogoImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.batch_size = 1
@@ -96,9 +93,8 @@ class GenerateSingleLogoImage:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
if not DEBUG:
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):

View File

@@ -0,0 +1,185 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_pose_transform.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
from io import BytesIO
import imageio
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from PIL import Image
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.pose_transform import PoseTransformModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class PoseTransformService:
def __init__(self, request_data):
self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "pose_transform"
self.image_url = request_data.image_url
self.pose_num = request_data.pose_id
self.image = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '',
'video_url': '', 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error):
if error:
self.pose_transform_data['status'] = "FAILURE"
self.pose_transform_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
else:
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
# 第一帧图像
first_image = Image.fromarray(result_data[0])
first_image_url = upload_first_image(first_image, user_id=self.user_id,
category=f"{self.category}_first_img",
file_name=f"{self.tasks_id}.png")
# 上传GIF
gif_buffer = BytesIO()
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
gif_buffer.seek(0)
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif",
file_name=f"{self.tasks_id}.gif")
# 上传video
video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video",
file_name=f"{self.tasks_id}.mp4")
self.pose_transform_data['status'] = "SUCCESS"
self.pose_transform_data['message'] = "success"
self.pose_transform_data['gif_url'] = str(gif_url)
self.pose_transform_data['video_url'] = str(video_url)
self.pose_transform_data['image_url'] = str(first_image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
pose_num = [self.pose_num] * 1
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape,
np_to_triton_dtype(pose_num_obj.dtype))
input_pose_num.set_data_from_numpy(pose_num_obj)
image_files = [self.image.astype(np.uint8)] * 1
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files],
callback=self.callback, client_timeout=60000)
time_out = 60000
while time_out > 0:
pose_transform_data, _ = self.read_tasks_status()
if pose_transform_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif pose_transform_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
pose_transform_data, _ = self.read_tasks_status()
return pose_transform_data
except Exception as e:
self.pose_transform_data['status'] = "FAILURE"
self.pose_transform_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
raise Exception(str(e))
finally:
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
if not DEBUG:
publish_status(json.dumps(str_pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
pose_transform_data = json.dumps(data)
redis_client.set(tasks_id, pose_transform_data)
return data
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:],
data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
result_image = result_image.convert("RGB")
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
if __name__ == '__main__':
rd = PoseTransformModel(
tasks_id="123-89",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
pose_id="1"
)
server = PoseTransformService(rd)
print(server.get_result())

View File

@@ -0,0 +1,23 @@
import json
import pika
import logging
from app.core.config import RABBITMQ_PARAMS
logger = logging.getLogger(__name__)
def publish_status(message, queue_name):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=queue_name, durable=True)
channel.basic_publish(exchange='',
routing_key=queue_name,
body=message,
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
logger.info(f" [x] Queue : {queue_name} | Sent message : {json.dumps(json.loads(message), indent=4)}")

View File

@@ -0,0 +1,75 @@
import io
import logging
import os.path
import numpy as np
# import boto3
from minio import Minio
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
# minio 配置
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def upload_first_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_gif(gif_buffer, user_id, category, file_name):
try:
object_name = f'{user_id}/{category}/{file_name}'
req = minio_client.put_object(
"aida-users",
object_name,
gif_buffer,
length=gif_buffer.getbuffer().nbytes,
content_type="image/gif"
)
return f"aida-users/{object_name}"
except Exception as e:
logging.warning(f"upload_gif runtime exception : {e}")
def upload_video(frames, user_id, category, file_name):
try:
save_path = ndarray_to_video(frames, file_name)
object_name = f'{user_id}/{category}/{file_name}'
minio_client.fput_object(
"aida-users",
object_name,
save_path,
content_type="video/mp4" # 指定MIME类型确保可在线播放[9](@ref)
)
return f"aida-users/{object_name}"
except Exception as e:
logging.warning(f"upload_video runtime exception : {e}")
def ndarray_to_video(images, output_path, frame_size=(512, 768), fps=9):
save_path = os.path.join(POSE_TRANSFORM_VIDEO_PATH, output_path)
clip = ImageSequenceClip([frame for frame in images], fps=fps)
clip.write_videofile(save_path, codec='libx264')
return save_path
if __name__ == '__main__':
images = np.random.randint(0, 256, size=(4, 768, 512, 3), dtype=np.uint8)
print(upload_video(images, user_id=89, category='pose_transform_video', file_name="1123123.mp4"))

View File

@@ -0,0 +1,114 @@
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
from app.schemas.mannequin_edit import MannequinModel
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class MannequinEditService():
def __init__(self, request_data):
self.resize_pixel = request_data.resize_pixel
self.top = request_data.top
self.bottom = request_data.bottom
self.image = oss_get_image(oss_client=minio_client, bucket=request_data.mannequins.split('/')[0], object_name=request_data.mannequins[request_data.mannequins.find('/') + 1:], data_type="cv2")
self.mannequin_name = request_data.mannequin_name
self.bucket_name = request_data.bucket_name
if self.image.shape[2] == 4:
self.bgr = self.image[:, :, :3]
self.alpha = self.image[:, :, 3]
self.bgr = cv2.bitwise_and(self.bgr, self.bgr, mask=cv2.normalize(self.alpha, None, 0, 1, cv2.NORM_MINMAX))
self.h, self.w, _ = self.bgr.shape
else:
self.bgr = self.image
self.h, self.w, _ = self.bgr.shape
self.alpha = None
def __call__(self, *args, **kwargs):
new_mannequin = self.resize_leg(self.top, self.bottom)
_, encoded_image = cv2.imencode('.png', new_mannequin)
image_bytes = encoded_image.tobytes()
req = oss_upload_image(oss_client=minio_client, bucket=self.bucket_name, object_name=f"{self.mannequin_name}.png", image_bytes=image_bytes)
return req.bucket_name + "/" + req.object_name
def post_processing(self, image):
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = self.w / original_width
height_ratio = self.h / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (self.w, self.h), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (self.w - new_width) // 2
y_offset = (self.h - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
return image
def resize_leg(self, top, bottom):
# 上部
top_part = self.bgr[:top, :]
top_part_alpha = self.alpha[:top, :]
# 需要resize 部分
part_resize = self.bgr[top:bottom, :]
part_resize_alpha = self.alpha[top:bottom, :]
# 下部
part_bottom = self.bgr[bottom:, :]
part_bottom_alpha = self.alpha[bottom:, :]
new_height = int((bottom - top) + self.resize_pixel)
resized_thigh = cv2.resize(part_resize, (self.w, new_height), interpolation=cv2.INTER_LINEAR)
resized_thigh_alpha = cv2.resize(part_resize_alpha, (self.w, new_height), interpolation=cv2.INTER_LINEAR)
# 组合
new_bgr = np.vstack((top_part, resized_thigh, part_bottom))
new_bgr_alpha = np.vstack((top_part_alpha, resized_thigh_alpha, part_bottom_alpha))
if self.alpha is not None:
# 拼接 alpha 通道
# 合并 BGR 通道和 alpha 通道
new_image = np.dstack((new_bgr, new_bgr_alpha))
else:
new_image = new_bgr
new_image = self.post_processing(Image.fromarray(new_image))
return new_image
if __name__ == '__main__':
request_data = MannequinModel(
mannequins="aida-sys-image/models/male/dc36ce58-46c3-4b6f-8787-5ca7d6fc26e6.png",
resize_pixel=-100,
bucket_name="test",
mannequin_name="mannequin_name",
top=270,
bottom=432
)
service = MannequinEditService(request_data)
print(service())

View File

@@ -0,0 +1,68 @@
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
style = ['NEW_CHINESE', 'COUNTRY_STYLE', 'FUTURISM', 'MINIMALISM', 'LOLITA', 'Y2K', 'BUSINESS', 'MERLAD',
'OUTDOOR_FUNCTIONAL', 'ROCK', 'DOPAMINE', 'GOTHIC', 'POST_APOCALYPTIC', 'ROMANTIC', 'WABI_SABI']
position = ['Overall', 'Tops', 'Bottoms', 'Outwear', 'Blouse', 'Dress', 'Trousers', 'Skirt']
gender = ['Female', 'Male']
age_group = ['Adult', 'Child']
process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
class ProjectInfoExtraction:
def __init__(self, request_data):
# llm generate brand info init
if len(request_data.image_list) or len(request_data.file_list):
self.model = ChatTongyi(model="qwen-vl-plus", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
else:
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.response_schemas = [
ResponseSchema(name="project_name", description="项目的名称."),
ResponseSchema(name="process", description="项目的类型,单品还是系列."),
ResponseSchema(name="ageGroup", description="项目设计服装的受众对象."),
ResponseSchema(name="gender", description="项目设计服装的受众性别."),
ResponseSchema(name="position", description="项目单品设计的部位."),
ResponseSchema(name="style", description="项目的设计风格.")
]
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
self.format_instructions = self.output_parser.get_format_instructions()
self.prompt = PromptTemplate(
template="你是一个时装品牌的设计师助理。根据用户输入提取出"
"[project_name] : 项目的名称,"
f"[process] : 项目的类型,从{process}选择."
f"[ageGroup] : 服装的受众,从{age_group}选择."
f"[gender] : 服装的适用性别,从{gender}选择"
f"[position] : single_design的部位如果[process]是SINGLE_DESIGN,从{position}中选择,如果[process]是SERIES_DESIGN这项为空"
f"[style] : 设计的风格,从{style}中选择"
".\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": self.format_instructions}
)
self._input = self.prompt.format_prompt(question=request_data.prompt)
self.result_data = {}
def get_result(self):
self.llm_extraction_project_info()
return self.result_data
def llm_extraction_project_info(self):
output = self.model(self._input.to_messages())
project_info = self.output_parser.parse(output.content)
self.result_data = project_info
if __name__ == '__main__':
request_data = ProjectInfoExtractionModel(
prompt="性别为儿童",
image_list=[
'https://www.minio-api.aida.com.hk/test/019aaeed-3227-11f0-a194-0826ae3ad6b3.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250613%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250613T020236Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=a513b706c24134071a489c34f0fa2c0f510e871b8589dc0c08a0f26ea28ee2ff'
],
file_list=[]
)
service = ProjectInfoExtraction(request_data)
print(service.get_result())

View File

@@ -9,6 +9,7 @@ from retry import retry
from app.core.config import QWEN_API_KEY
from app.service.chat_robot.script.service.CallQWen import get_language
from app.service.prompt_generation.util import minio_util
logger = logging.getLogger(__name__)
@@ -143,6 +144,38 @@ def get_translation_from_llama3(text):
# response = requests.post(url, data=json.dumps(payload), headers=headers)
def get_prompt_from_image(image_path, text):
start_time = time.time()
# url = "http://localhost:11434/api/generate"
url = "http://10.1.1.243:11434/api/generate"
image_base64 = minio_util.minio_url_to_base64(image_path.img)
# image_base64 = minio_url_to_base64(image_path)
# 创建请求的负载 translator是自定义的翻译模型
payload = {
"model": "llama3.2-vision",
"images": [image_base64],
"prompt": f"{text}",
"stream": False
}
# 将负载转换为 JSON 格式
headers = {'Content-Type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
# 处理响应
if response.status_code == 200:
# print("Response from server:")
# print(response.json())
resp = json.loads(response.content).get("response")
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} \n, response is {resp}")
# print("input : {}, sketch re-generate result : {}".format(text, resp))
return resp
else:
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} , response is {response.content}")
print(f"Request failed with status code {response.status_code}")
print(response.text)
def main():
"""Main function"""
text = get_translation_from_llama3("[火焰]")

View File

@@ -0,0 +1,21 @@
import base64
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def minio_url_to_base64(minio_url: str) -> str:
bucket_name, object_name = minio_url.split("/", 1)
try:
response = minio_client.get_object(bucket_name, object_name)
image_data = response.read()
return base64.b64encode(image_data).decode('utf-8')
except Exception as e:
raise RuntimeError(f"Failed to get object: {e}")
finally:
if 'response' in locals():
response.close()

View File

@@ -0,0 +1,539 @@
import pymysql
import numpy as np
from apscheduler.schedulers.blocking import BlockingScheduler
import os
import logging
from collections import defaultdict
import torch
from torchvision import models, transforms
from minio import Minio
from PIL import Image
import io
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix
import matplotlib.font_manager as fm
from scipy import sparse
import pandas as pd
from datetime import datetime, timedelta
import json
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
# 自动选择可用字体
try:
# 尝试加载常见中文字体
font_path = fm.findfont(fm.FontProperties(family=['Microsoft YaHei', 'SimHei', 'WenQuanYi Zen Hei']))
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
except:
# 回退到英文字体
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
# 检查系统中可用的字体并选择支持中文的字体
font_path = fm.findfont(fm.FontProperties(family='Microsoft YaHei')) # 或其他支持中文的字体
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 设置为 Microsoft YaHei
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 配置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
filename='scheduler.log'
)
# MinIO 配置信息
minio_client = Minio(
"www.minio.aida.com.hk:12024", # MinIO Endpoint
access_key="admin", # Access Key
secret_key="Aidlab123123!", # Secret Key
secure=True # 使用https
)
# 预加载系统sketch特征向量
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
# 行为权重和衰减系数
BEHAVIOR_CONFIG = {
'portfolioClick': {'weight': 1, 'decay': 0.3},
'portfolioLike': {'weight': 2, 'decay': 0.2},
'secondCreation': {'weight': 3, 'decay': 0.1},
'sketchLike': {'weight': 4, 'decay': 0} # 不衰减
}
# 保存sketch_to_iid到文件
def save_sketch_to_iid():
"""保存sketch到iid的映射"""
sketch_to_iid = {sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)}
np.save('sketch_to_iid.npy', sketch_to_iid)
print("sketch_to_iid 已保存")
# 从文件加载sketch_to_iid
def load_sketch_to_iid():
"""加载保存的sketch到iid的映射"""
if os.path.exists('sketch_to_iid.npy'):
sketch_to_iid = np.load('sketch_to_iid.npy', allow_pickle=True).item()
print("sketch_to_iid 已加载")
return sketch_to_iid
else:
# 如果文件不存在,则生成并保存
print("sketch_to_iid 文件不存在,正在生成并保存...")
save_sketch_to_iid()
return np.load('sketch_to_iid.npy', allow_pickle=True).item()
# 使用load_sketch_to_iid来获取映射
sketch_to_iid = load_sketch_to_iid()
# 在代码中其他地方使用sketch_to_iid
# print(f"Total sketches: {len(sketch_to_iid)}")
# 定义图像预处理与ResNet训练时的预处理一致
transform = transforms.Compose([
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
])
# 加载预训练的 ResNet 模型 (ResNet50)
resnet_model = models.resnet50(pretrained=True)
modules = list(resnet_model.children())[:-1] # 移除最后的全连接层
resnet_model = torch.nn.Sequential(*modules)
resnet_model.eval() # 设置为评估模式
# 从 MinIO 获取图片并进行预处理
def get_sketch_image_from_minio(sketch_path):
"""
从 MinIO 获取 sketch 图像并预处理
"""
# 分割路径,获取桶名和文件路径
path_parts = sketch_path.split('/', 1) # 根据第一个斜杠分割,得到桶名和路径
bucket_name = path_parts[0] # 桶名
file_name = path_parts[1] # 文件路径(从第二部分开始)
try:
# 获取文件
obj = minio_client.get_object(bucket_name, file_name)
img_data = obj.read() # 读取图像数据
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
img = transform(img) # 对图像进行预处理
return img.unsqueeze(0) # 扩展维度以适应批量处理
except Exception as e:
print(f"Error fetching image for {sketch_path}: {e}")
return None
def extract_feature_vector_from_resnet(sketch_path):
"""
提取 sketch 图像的特征向量
"""
# 从 MinIO 获取图像并预处理
img_tensor = get_sketch_image_from_minio(sketch_path)
if img_tensor is None:
return np.zeros(2048) # 如果获取失败,返回零向量
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
return feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
def update_user_matrices():
"""每天更新用户交互次数矩阵和特征向量矩阵"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
# 修改后的查询语句移除category过滤
cursor.execute("""
SELECT account_id, path, COUNT(*) as like_count
FROM user_preference_log_test
GROUP BY account_id, path
""")
user_data = cursor.fetchall()
logging.info(f"成功读取{len(user_data)}条用户偏好记录")
# 计算矩阵
interaction_matrix, raw_counts_sparse, user_index_interaction_matrix, sketch_index_interaction_matrix, iid_to_category_interaction_matrix = calculate_interaction_matrix(user_data)
# visualize_sparse_matrix(raw_counts_sparse,'交互次数矩阵', 'interaction_frequency_matrix.png')
# visualize_sparse_matrix(interaction_matrix, '交互次数得分矩阵', 'interaction_score_matrix.png')
# plot_interaction_count_matrix(raw_counts_sparse)
# feature_matrix, iid_to_category_feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix = calculate_feature_matrix(user_data)
feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix, iid_to_category_feature_matrix = calculate_feature_matrix(user_data)
# visualize_sparse_matrix(feature_matrix, '系统sketch与用户category平均特征向量关联度矩阵', 'correlation_matrix.png')
# 存储矩阵
np.save(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", feature_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", iid_to_category_interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", user_index_interaction_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_feature_matrix.npy", iid_to_category_feature_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", user_index_feature_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", sketch_index_interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", sketch_index_feature_matrix)
# logging.info("矩阵更新完成")
except Exception as e:
logging.error(f"定时任务执行失败: {str(e)}", exc_info=True)
finally:
if conn:
conn.close()
def plot_interaction_count_matrix(interaction_count_matrix):
"""绘制交互次数矩阵的分布图(热图),不隐藏零值"""
try:
if not isinstance(interaction_count_matrix, csr_matrix):
sparse_matrix = csr_matrix(interaction_count_matrix)
else:
sparse_matrix = interaction_count_matrix
# 转换为密集矩阵
try:
dense_matrix = sparse_matrix.toarray()
except MemoryError:
logging.error("内存不足,无法转换为密集矩阵")
return
# 自动检测可用中文字体
try:
font_path = fm.findfont(fm.FontProperties(family='sans-serif', style='normal'))
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
except:
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] # 回退到英文字体
plt.rcParams['axes.unicode_minus'] = False
# 处理大矩阵的显示,限制显示范围
if dense_matrix.shape[0] > 1000 or dense_matrix.shape[1] > 1000:
dense_matrix = dense_matrix[:1000, :1000] # 只绘制前1000行列
plt.figure(figsize=(15, 10))
# 使用 `cmap` 来设置颜色,零值可以使用特定颜色,调整 `vmin` 和 `vmax` 让热图更具对比
sns.heatmap(
dense_matrix,
cmap="Blues", # 可以选择不同的颜色映射,"Blues" 或 "YlGnBu"
annot=False, # 关闭标注
cbar_kws={"label": "Interaction Count"}, # 添加颜色条标签
linewidths=0.5,
vmin=0, # 设置最小值,确保零值明显
vmax=np.max(dense_matrix) # 设置最大值,保持颜色映射的合理性
)
plt.title('User-Sketch Interaction Matrix (With Zero Entries)')
plt.xlabel('Sketch Index')
plt.ylabel('User Index')
plt.savefig('interaction_heatmap_with_zeros.png', dpi=150, bbox_inches='tight')
plt.close()
logging.info("热图已保存为 interaction_heatmap_with_zeros.png")
except Exception as e:
logging.error(f"绘图失败: {str(e)}", exc_info=True)
def visualize_sparse_matrix(matrix, title='Non-zero Interactions (Scatter Plot)', filename="scatter_figure_interaction.png"):
if not sparse.issparse(matrix):
# 转换为稀疏矩阵
matrix = sparse.csr_matrix(matrix)
# 获取非零元素的坐标和值
rows, cols = matrix.nonzero()
values = matrix.data
# 绘制散点图
plt.figure(figsize=(24, 20))
plt.scatter(cols, rows, c=values, cmap='coolwarm', alpha=0.7, s=1)
plt.colorbar(label='Interaction Count')
plt.title(title)
plt.xlabel('Item Index')
plt.ylabel('Item Index')
plt.savefig(filename)
def calculate_interaction_matrix(user_data):
"""基于新表结构的交互次数矩阵计算仅系统sketch"""
# 获取所有用户ID
all_users = set()
for account_id, path, like_count in user_data:
category = get_category_from_path(path)
if category not in TABLE_CATEGORIES.keys():
continue
all_users.add(account_id)
# 获取所有系统sketch的iid
system_sketch_iids = {sketch_to_iid[path] for path in SYSTEM_FEATURES.keys() if path in sketch_to_iid}
# 创建映射关系
user_index = {uid: idx for idx, uid in enumerate(sorted(all_users))}
sketch_index = {iid: idx for idx, iid in enumerate(sorted(system_sketch_iids))}
# 初始化双矩阵:归一化矩阵(密集)和原始计数矩阵(稀疏)
interaction_matrix = np.zeros((len(all_users), len(sketch_index))) # 归一化矩阵
data, rows, cols = [], [], [] # 用于构建稀疏矩阵的COO格式数据
# 预计算用户最大交互次数
user_max_likes = defaultdict(int)
for account_id, path, like_count in user_data:
if sketch_to_iid.get(path) in system_sketch_iids:
user_max_likes[account_id] = max(user_max_likes[account_id], like_count)
# 填充矩阵
for account_id, path, like_count in user_data:
sketch_iid = sketch_to_iid.get(path)
if sketch_iid not in system_sketch_iids:
continue
user_idx = user_index[account_id]
sketch_idx = sketch_index[sketch_iid]
# 填充稀疏矩阵数据
data.append(like_count)
rows.append(user_idx)
cols.append(sketch_idx)
# 归一化计算
max_like = user_max_likes.get(account_id, 1)
interaction_matrix[user_idx, sketch_idx] = np.log1p(1 + like_count) / np.log1p(1 + max_like)
# 构建稀疏矩阵CSR格式适合快速行操作
interaction_count_matrix = csr_matrix((data, (rows, cols)), shape=(len(all_users), len(sketch_index)))
return interaction_matrix, interaction_count_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
def calculate_feature_matrix(user_data):
"""基于新表结构的特征矩阵计算,返回用户与系统草图的相似度矩阵(加权平均)"""
# 用户特征数据存储结构:{(account_id, category): {sketch_iid: [(feature_vector, weight)]}}
user_feature_weights = defaultdict(lambda: defaultdict(list))
# 初始化所有用户和系统草图集合
all_users = set()
all_system_sketches = set(SYSTEM_FEATURES.keys())
# ==== 第一遍遍历:收集特征向量和权重 ====
for account_id, path, like_count in user_data:
category = get_category_from_path(path)
if category not in TABLE_CATEGORIES.keys():
continue
sketch_iid = sketch_to_iid.get(path)
if not sketch_iid:
continue
# 记录用户
all_users.add(account_id)
# 提取特征并记录权重like_count
if path in SYSTEM_FEATURES: # 系统草图
feature = SYSTEM_FEATURES[path]
weight = like_count # 使用like_count作为权重
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
else: # 用户草图
feature = extract_feature_vector_from_resnet(path)
weight = like_count
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
# ==== 第二遍遍历收集所有系统草图iid ====
system_sketch_iids = set()
for sketch_path in SYSTEM_FEATURES:
if iid := sketch_to_iid.get(sketch_path):
system_sketch_iids.add(iid)
# ==== 创建索引映射 ====
user_list = sorted(all_users)
sketch_list = sorted(system_sketch_iids)
user_index = {uid: idx for idx, uid in enumerate(user_list)}
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
# ==== 初始化特征矩阵 ====
feature_matrix = np.zeros((len(user_list), len(sketch_list)))
# ==== 预计算加权平均特征向量 ====
user_avg_features = {}
for (account_id, category), sketches in user_feature_weights.items():
try:
# 展平所有特征向量和权重
all_features_weights = [(vec, weight) for vec_list in sketches.values() for vec, weight in vec_list]
if len(all_features_weights) == 0:
continue
# 计算总权重
total_weight = sum(weight for _, weight in all_features_weights)
if total_weight <= 0: # 防止除零错误
total_weight = 1.0
# 加权平均计算
weighted_sum = np.zeros_like(all_features_weights[0][0]) # 获取特征向量维度
for vec, weight in all_features_weights:
weighted_sum += vec * weight
avg_vec = weighted_sum / total_weight
user_avg_features[(account_id, category)] = avg_vec
except Exception as e:
logging.warning(f"用户({account_id},{category})加权特征计算失败: {str(e)}")
continue
# ==== 计算相似度并填充矩阵 ====
for sketch_path, sys_vector in SYSTEM_FEATURES.items():
sketch_iid = sketch_to_iid.get(sketch_path)
system_sketch_category = get_category_from_path(sketch_path)
if not sketch_iid or sketch_iid not in sketch_index:
continue
sketch_col = sketch_index[sketch_iid]
# 遍历所有用户
for account_id in all_users:
user_row = user_index.get(account_id)
if user_row is None:
continue
# 获取用户加权平均特征向量
try:
# 直接通过复合键获取用户特征向量
user_vec = user_avg_features[(account_id, system_sketch_category)]
except KeyError:
# 该用户在此类别下无特征数据
continue
# 计算余弦相似度
cos_sim = cosine_similarity(user_vec, sys_vector)
feature_matrix[user_row, sketch_col] = cos_sim
return feature_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
def get_category_from_path(path):
"""从path字段解析类别"""
try:
parts = path.split('/')
if len(parts) >= 2:
return f"{parts[2]}_{parts[3]}"
return "unknown"
except:
return "unknown"
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
def fetch_user_behavior_data(days=30):
"""从MySQL获取用户行为数据整合旧查询和新需求"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
# 计算日期范围
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
# 整合查询(获取完整行为数据)
query = f"""
SELECT
account_id,
behavior_type,
gender,
category,
url,
create_time
FROM user_behavior
WHERE create_time BETWEEN '{start_date}' AND '{end_date}'
"""
df = pd.read_sql(query, conn)
logging.info(f"成功读取{len(df)}条用户行为记录")
return df
except Exception as e:
logging.error(f"数据库查询失败: {str(e)}")
return pd.DataFrame()
finally:
if conn:
conn.close()
def calculate_heat(row, current_date):
"""计算单个行为的热度值(每次行为独立计算,不考虑聚合次数)"""
# 计算时间差(天)
days_passed = (current_date - row['create_time']).days
# 获取行为配置默认权重为0
config = BEHAVIOR_CONFIG.get(row['behavior_type'], {'weight': 0, 'decay': 0})
# 计算热度值 = 权重 * e^(-衰减系数 * 天数)
return config['weight'] * np.exp(-config['decay'] * days_passed)
def load_heat_matrix_as_array(file_path):
"""
直接加载为二维numpy数组
返回: (data_array, row_labels, col_labels)
"""
with open(file_path) as f:
saved = json.load(f)
return (
np.array(saved['data']), # 二维矩阵
saved['row_labels'], # 行标签列表
saved['col_labels'] # 列标签列表
)
def update_heat_matrices():
"""每日计算并存储热度矩阵gender_category × path"""
current_date = datetime.now()
# 获取数据
df = fetch_user_behavior_data(30)
if df.empty:
logging.warning("无有效数据,跳过今日计算")
return None
# 计算热度值
df['heat'] = df.apply(calculate_heat, axis=1, current_date=current_date)
df['gender_category'] = df['gender'] + '_' + df['category']
# 构建热度向量
heat_vectors = {}
grouped = df.groupby(['gender_category', 'url'])['heat'].sum()
for (gender_category, url), heat in grouped.items():
heat_vectors.setdefault(gender_category, {})[url] = heat
# 存储结果
save_path = 'heat_vectors_data'
os.makedirs(save_path, exist_ok=True)
date_str = current_date.strftime('%Y%m%d')
# vectors_file = f"{save_path}/heat_vectors_{date_str}.json"
vectors_file = f"{save_path}/heat_vectors.json"
with open(vectors_file, 'w', encoding='utf-8') as f:
json.dump({
'update_time': current_date.strftime('%Y-%m-%d %H:%M:%S'),
'data': heat_vectors
}, f, ensure_ascii=False, indent=2)
logging.info(f"成功存储热度向量,共{len(heat_vectors)}个分组,日期: {date_str}")
return heat_vectors
if __name__ == "__main__":
try:
# update_user_matrices()
# update_heat_matrices()
scheduler = BlockingScheduler()
scheduler.add_job(update_user_matrices, 'cron', hour=12, timezone='Asia/Shanghai')
logging.info("定时任务已启动每天12:00执行")
scheduler.start()
except KeyboardInterrupt:
logging.info("定时任务已停止")
except Exception as e:
logging.error(f"调度器启动失败: {str(e)}", exc_info=True)

View File

@@ -0,0 +1,240 @@
# 预加载资源
import logging
import time
from collections import defaultdict
import os
import json
import numpy as np
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
logger = logging.getLogger()
import pymysql
from concurrent.futures import ThreadPoolExecutor
HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
matrix_data = {
"interaction_matrix": None,
"feature_matrix": None,
"user_index_interaction": None,
"sketch_index_interaction": None,
"user_index_feature": None,
"sketch_index_feature": None,
"iid_to_sketch": None,
"category_to_iids": None,
"cached_scores": {},
"cached_valid_idxs": {},
"category_sketch_idxs_inter": None,
"category_sketch_idxs_feature": None,
"user_inter_full": dict(),
"user_feat_full": dict(),
"brand_feature_matrix": None,
"brand_index_map": None,
"heat_data": {},
}
def load_resources():
"""加载所有矩阵和映射关系,并触发预缓存"""
try:
start_time = time.time()
# 清空缓存
matrix_data["cached_scores"].clear()
matrix_data["cached_valid_idxs"].clear()
# 加载数据
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
allow_pickle=True).item()
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
if os.path.exists(brand_feature_path):
matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
else:
logger.warning("brand_feature_matrix 文件不存在,使用空数组")
matrix_data["brand_feature_matrix"] = np.array([])
# brand_index_map
brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
if os.path.exists(brand_index_path):
matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
else:
logger.warning("brand_index_map 文件不存在,使用空字典")
matrix_data["brand_index_map"] = {}
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
matrix_data["category_to_iids"] = defaultdict(list)
for iid, cat in category_to_iid_map.items():
matrix_data["category_to_iids"][cat].append(iid)
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}")
# 触发预缓存
precache_user_category()
if os.path.exists(HEAT_VECTOR_FILE):
with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
heat_json = json.load(f)
matrix_data["heat_data"] = heat_json.get("data", {})
logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
else:
matrix_data["heat_data"] = {}
except Exception as e:
logger.error(f"资源加载失败: {str(e)}")
raise RuntimeError("初始化失败")
def precache_user_category():
"""优化后的用户分类预缓存(添加耗时统计)"""
if not all([
matrix_data["interaction_matrix"] is not None,
matrix_data["feature_matrix"] is not None,
matrix_data["user_index_interaction"] is not None
]):
logger.warning("资源未加载完成,跳过预缓存")
return
start_time = time.perf_counter()
time_stats = {
"get_all_user_categories": 0,
"process_user_category": 0,
"thread_execution": 0,
"cache_update": 0,
"total": 0,
}
# 统计用户类别获取时间
t1 = time.perf_counter()
user_categories = get_all_user_categories()
time_stats["get_all_user_categories"] = time.perf_counter() - t1
precached_count = 0
def process_user_category(user_id, categories):
"""单用户类别缓存计算(统计耗时)"""
local_cache = {}
local_valid_idxs = {}
t_start = time.perf_counter()
for category in categories:
cache_key = (user_id, category)
if cache_key in matrix_data["cached_scores"]:
continue
try:
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
# 统计获取类别 IID 耗时
t_iid = time.perf_counter()
category_iids = matrix_data["category_to_iids"].get(category, [])
valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
time_stats["process_user_category"] += time.perf_counter() - t_iid
# 统计矩阵计算耗时
t_matrix = time.perf_counter()
processed_inter = np.zeros(len(valid_sketch_idxs_inter))
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
processed_feat = np.zeros(len(valid_sketch_idxs_feature))
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores * 0.3
time_stats["process_user_category"] += time.perf_counter() - t_matrix
if len(processed_inter) == len(processed_feat):
local_cache[cache_key] = (processed_inter, processed_feat)
local_valid_idxs[cache_key] = valid_sketch_idxs_inter
except Exception as e:
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
return local_cache, local_valid_idxs
# 统计线程执行时间
t2 = time.perf_counter()
with ThreadPoolExecutor(max_workers=8) as executor:
futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
for future in futures:
try:
t_cache = time.perf_counter()
cache_part, valid_idxs_part = future.result()
matrix_data["cached_scores"].update(cache_part)
matrix_data["cached_valid_idxs"].update(valid_idxs_part)
time_stats["cache_update"] += time.perf_counter() - t_cache
precached_count += len(cache_part)
except Exception as e:
logger.error(f"线程执行错误: {str(e)}")
time_stats["thread_execution"] = time.perf_counter() - t2
time_stats["total"] = time.perf_counter() - start_time
# 输出统计信息
logger.info(f"""
预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
- 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
- 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
- 线程任务执行: {time_stats["thread_execution"]:.2f}s
- 更新缓存数据: {time_stats["cache_update"]:.2f}s
- 总耗时: {time_stats["total"]:.2f}s
""")
def get_all_user_categories():
"""获取所有用户及其对应的分类"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
query = """
SELECT DISTINCT account_id, path
FROM user_preference_log_prediction
"""
cursor.execute(query)
results = cursor.fetchall()
user_categories = defaultdict(set)
for account_id, path in results:
category = get_category_from_path(path)
user_categories[account_id].add(category)
return dict(user_categories)
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}")
return {}
finally:
if conn:
conn.close()
def get_category_from_path(path: str) -> str:
"""从路径解析类别"""
try:
parts = path.split('/')
if len(parts) >= 4:
return f"{parts[2]}_{parts[3]}"
return "unknown"
except:
return "unknown"

View File

@@ -6,7 +6,7 @@ from chromadb.config import Settings
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
from tqdm import tqdm
from app.core.config import OLLAMA_URL
from app.core.config import OLLAMA_URL, CHROMADB_PATH
# 读取 csv 文件
# csv_file_path = r'D:/Files/csv/output/output.csv'
@@ -15,7 +15,7 @@ from app.core.config import OLLAMA_URL
# df = pd.read_csv(csv_file_path, encoding='Windows-1252')
# 创建 Chroma 客户端
client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db"))
client = chromadb.Client(Settings(is_persistent=True, persist_directory=CHROMADB_PATH))
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db"))
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db"))
# 创建集合

View File

@@ -0,0 +1,99 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

18
pyproject.toml Executable file
View File

@@ -0,0 +1,18 @@
[project]
name = "trinity-client-aida"
version = "0.1.0"
description = "Add your description here"
requires-python = ">=3.12"
dependencies = [
"apscheduler>=3.11.0",
"celery>=5.5.3",
"geventhttpclient>=2.3.4",
"google-search-results>=2.4.2",
"moviepy>=2.2.1",
"numpy==1.26.4",
"pandas-stubs==2.2.3.250527",
"pika-stubs==0.1.3",
"python-multipart>=0.0.20",
"tritonclient[all]>=2.58.0",
"types-urllib3==1.26.25.14",
]

Binary file not shown.

1193
uv.lock generated Executable file

File diff suppressed because it is too large Load Diff