feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -148,4 +148,6 @@ app/logs/*
|
||||
*.pickle
|
||||
*.csv
|
||||
*.avi
|
||||
*.json
|
||||
*.json
|
||||
*.env*
|
||||
config.backup.py
|
||||
@@ -23,11 +23,11 @@
|
||||
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||
|
||||
|
||||
2. 启动服务器
|
||||
1. 启动服务器
|
||||
|
||||
$ uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
3. 打开 http://127.0.0.1:8000/docs
|
||||
2. 打开 http://127.0.0.1:8000/docs
|
||||
|
||||
Docker 部署
|
||||
---------------
|
||||
|
||||
@@ -2,8 +2,7 @@ import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.core.config import DEBUG
|
||||
from app.core.config import settings
|
||||
from app.schemas.attribute_retrieve import *
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.attribute.config import const, local_debug_const
|
||||
@@ -35,13 +34,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
|
||||
"""
|
||||
try:
|
||||
for item in request_item:
|
||||
logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||
if DEBUG:
|
||||
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict(), indent=4)}")
|
||||
if settings.DEBUG:
|
||||
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
|
||||
else:
|
||||
service = AttributeRecognition(const=const, request_data=request_item)
|
||||
data = service.get_result()
|
||||
logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -67,10 +66,10 @@ def category_recognition(request_item: list[CategoryRecognitionModel]):
|
||||
"""
|
||||
try:
|
||||
for item in request_item:
|
||||
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict(), indent=4)}")
|
||||
service = CategoryRecognition(request_data=request_item)
|
||||
data = service.get_result()
|
||||
logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -26,7 +26,7 @@ def seg_product(request_item: BrandDnaModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = BrandDna(request_item)
|
||||
result_url = service.get_result()
|
||||
except Exception as e:
|
||||
@@ -36,7 +36,7 @@ def seg_product(request_item: BrandDnaModel):
|
||||
|
||||
|
||||
@router.post("/GenerateBrand")
|
||||
def GenerateBrand(request_data: GenerateBrandModel):
|
||||
def generate_brand(request_data: GenerateBrandModel):
|
||||
"""
|
||||
通过prompt 生成 brand name ,brand slogan , brand logo。
|
||||
创建一个具有以下参数的请求体:
|
||||
|
||||
@@ -1,34 +1,25 @@
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
import pymysql
|
||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||
from minio import Minio
|
||||
import torch
|
||||
from torchvision import models, transforms
|
||||
from PIL import Image
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pymysql
|
||||
import torch
|
||||
from PIL import Image
|
||||
from fastapi import HTTPException, APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from minio import Minio
|
||||
from torchvision import models, transforms
|
||||
|
||||
from app.core.mysql_config import DB_CONFIG
|
||||
from app.core.new_config import settings
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# MinIO 配置
|
||||
minio_client = Minio(
|
||||
"www.minio.aida.com.hk:12024",
|
||||
access_key="admin",
|
||||
secret_key="Aidlab123123!",
|
||||
secure=True
|
||||
)
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
@@ -67,8 +58,8 @@ def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
|
||||
|
||||
|
||||
# 预加载
|
||||
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
|
||||
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||
BRAND_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
|
||||
SYSTEM_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
|
||||
|
||||
|
||||
def save_sketch_to_iid():
|
||||
@@ -76,11 +67,11 @@ def save_sketch_to_iid():
|
||||
sketch_path: iid
|
||||
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
|
||||
}
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
|
||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
|
||||
|
||||
|
||||
def load_sketch_to_iid():
|
||||
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
|
||||
path = f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
|
||||
if os.path.exists(path):
|
||||
return np.load(path, allow_pickle=True).item()
|
||||
save_sketch_to_iid()
|
||||
@@ -90,7 +81,7 @@ def load_sketch_to_iid():
|
||||
sketch_to_iid = load_sketch_to_iid()
|
||||
|
||||
|
||||
def getNewCategory(gender: str, sketch_category: str) -> str:
|
||||
def get_new_category(gender: str, sketch_category: str) -> str:
|
||||
return f"{gender.lower()}_{sketch_category.lower()}"
|
||||
|
||||
|
||||
@@ -103,8 +94,8 @@ def get_category_from_path(path: str) -> str:
|
||||
|
||||
def load_brand_matrix():
|
||||
"""单独加载 brand_matrix 和 brand_index_map"""
|
||||
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
|
||||
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||
mat_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_matrix.npy"
|
||||
idx_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||
try:
|
||||
matrix = np.load(mat_path)
|
||||
index_map = np.load(idx_path, allow_pickle=True).item()
|
||||
@@ -113,11 +104,19 @@ def load_brand_matrix():
|
||||
index_map = {}
|
||||
return matrix, index_map
|
||||
|
||||
|
||||
def cosine_similarity(vec1, vec2):
|
||||
"""计算余弦相似度(增加零值处理)"""
|
||||
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
|
||||
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
|
||||
|
||||
|
||||
def getNewCategory(gender, sketch_category):
|
||||
print(gender)
|
||||
print(sketch_category)
|
||||
return "None"
|
||||
|
||||
|
||||
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
||||
# 1. 收集品牌-分类-特征
|
||||
brand_feature = defaultdict(lambda: defaultdict(list))
|
||||
@@ -164,11 +163,11 @@ def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
|
||||
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
|
||||
|
||||
# 7. 持久化
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
|
||||
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
|
||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
|
||||
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
|
||||
|
||||
# 返回该品牌对应行
|
||||
return brand_matrix[row_idx:row_idx+1]
|
||||
return brand_matrix[row_idx:row_idx + 1]
|
||||
|
||||
|
||||
@router.get("/brand_dna_initialize/{brand_id}")
|
||||
@@ -178,14 +177,12 @@ async def brand_dna_initialize(brand_id: int):
|
||||
conn = pymysql.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, img_url, gender, category
|
||||
FROM product_image_attribute
|
||||
WHERE library_id IN (
|
||||
SELECT library_id
|
||||
FROM brand_rel_library
|
||||
WHERE brand_id = %s
|
||||
)
|
||||
""", (brand_id,))
|
||||
SELECT id, img_url, gender, category
|
||||
FROM product_image_attribute
|
||||
WHERE library_id IN (SELECT library_id
|
||||
FROM brand_rel_library
|
||||
WHERE brand_id = %s)
|
||||
""", (brand_id,))
|
||||
sketch_data = cursor.fetchall()
|
||||
|
||||
# 触发计算并持久化,若内部出错会抛异常
|
||||
|
||||
@@ -5,10 +5,11 @@ import time
|
||||
|
||||
from PIL import ImageEnhance
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from minio import Minio
|
||||
from app.core.config import settings
|
||||
from app.schemas.brighten import BrightenModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
@@ -20,6 +21,9 @@ def increase_brightness(img, factor):
|
||||
return bright_img
|
||||
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
@router.post("/brighten")
|
||||
async def brighten(request_item: BrightenModel):
|
||||
"""
|
||||
@@ -35,14 +39,14 @@ async def brighten(request_item: BrightenModel):
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
|
||||
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
image = oss_get_image(oss_client=minio_client, bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
|
||||
new_image = increase_brightness(image, request_item.brighten_value)
|
||||
image_data = io.BytesIO()
|
||||
new_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
|
||||
brighten_url = f"{req.bucket_name}/{req.object_name}"
|
||||
logger.info(f"run time is : {time.time() - start_time}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -30,9 +30,9 @@ def chat_robot(request_data: ChatRobotModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
|
||||
data = chat(post_data=request_data)
|
||||
logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -42,7 +42,7 @@ def clothing_seg(request_item: ClothingSegModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
server = ClothingSeg(request_item)
|
||||
result_url = server.get_result()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,203 +1,201 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
|
||||
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel
|
||||
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.design.model_process_service import model_transpose
|
||||
from app.service.design_batch.service import start_design_batch_generate
|
||||
from app.service.design_fast.design_generate import design_generate, design_generate_v2
|
||||
from app.service.design_fast.utils.redis_utils import Redis
|
||||
from app.service.design_fast.model_process_service import model_transpose
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/design")
|
||||
def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||
def design(request_data: DesignModel):
|
||||
"""
|
||||
objects.items.transparent:
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
mask_url 为空"" -> 单件衣服透明
|
||||
mask_url 非空"mask_url" -> 区域透明
|
||||
objects.items.transparent:
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
mask_url 为空"" -> 单件衣服透明
|
||||
mask_url 非空"mask_url" -> 区域透明
|
||||
|
||||
创建一个具有以下参数的请求体:
|
||||
示例参数:
|
||||
{
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
200,
|
||||
241
|
||||
],
|
||||
"hand_point_right": [
|
||||
223,
|
||||
297
|
||||
],
|
||||
"waistband_left": [
|
||||
112,
|
||||
241
|
||||
],
|
||||
"hand_point_left": [
|
||||
92,
|
||||
305
|
||||
],
|
||||
"shoulder_left": [
|
||||
99,
|
||||
116
|
||||
],
|
||||
"shoulder_right": [
|
||||
215,
|
||||
116
|
||||
]
|
||||
},
|
||||
"layer_order": true,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 270372,
|
||||
"color": "30 28 28",
|
||||
"image_id": 69780,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
创建一个具有以下参数的请求体:
|
||||
示例参数:
|
||||
{
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
203,
|
||||
249
|
||||
],
|
||||
"hand_point_right": [
|
||||
229,
|
||||
343
|
||||
],
|
||||
"waistband_left": [
|
||||
119,
|
||||
248
|
||||
],
|
||||
"hand_point_left": [
|
||||
97,
|
||||
343
|
||||
],
|
||||
"shoulder_left": [
|
||||
108,
|
||||
107
|
||||
],
|
||||
"shoulder_right": [
|
||||
212,
|
||||
107
|
||||
]
|
||||
},
|
||||
"priority": 10,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
"layer_order": true,
|
||||
"preview_submit": "submit",
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
{
|
||||
"businessId": 270373,
|
||||
"color": "30 28 28",
|
||||
"image_id": 98243,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
"items": [
|
||||
{
|
||||
"businessId": 2377945,
|
||||
"color": "209 196 171",
|
||||
"image_id": 189410,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-collection-element/89/Sketchboard/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
"priority": 12,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"seg_mask_url": "aida-clothing/mask/mask_8e96ddb0-e466-11f0-8de2-0242ac130002.png",
|
||||
"type": "Outwear"
|
||||
},
|
||||
"priority": 11,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 270374,
|
||||
"color": "172 68 68",
|
||||
"image_id": 98244,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
{
|
||||
"businessId": 2377946,
|
||||
"color": "122 152 139",
|
||||
"image_id": 81868,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0825001443.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
"priority": 11,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"seg_mask_url": "aida-clothing/mask/mask_8f0fab78-e466-11f0-8de2-0242ac130002.png",
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 2377947,
|
||||
"color": "111 78 63",
|
||||
"gradient": "aida-gradient/517c3a4d-aed7-4423-aa99-7b60d3577df1.png",
|
||||
"image_id": 116494,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/skirt/0825000219.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
"priority": 10,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"seg_mask_url": "aida-clothing/mask/mask_8f6191fe-e466-11f0-8de2-0242ac130002.png",
|
||||
"type": "Skirt"
|
||||
},
|
||||
"priority": 12,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
|
||||
"image_id": 96090,
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"process_id": "83"
|
||||
}
|
||||
"""
|
||||
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||
"image_id": 67277,
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"process_id": "89"
|
||||
}
|
||||
"""
|
||||
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
|
||||
# data = generate(request_data=request_data)
|
||||
# logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
# logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
#
|
||||
|
||||
try:
|
||||
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
data = design_generate(request_data=request_data)
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -215,47 +213,48 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
200,
|
||||
241
|
||||
203,
|
||||
249
|
||||
],
|
||||
"hand_point_right": [
|
||||
223,
|
||||
297
|
||||
229,
|
||||
343
|
||||
],
|
||||
"waistband_left": [
|
||||
112,
|
||||
241
|
||||
119,
|
||||
248
|
||||
],
|
||||
"hand_point_left": [
|
||||
92,
|
||||
305
|
||||
97,
|
||||
343
|
||||
],
|
||||
"shoulder_left": [
|
||||
99,
|
||||
116
|
||||
108,
|
||||
107
|
||||
],
|
||||
"relation_type": "System",
|
||||
"shoulder_right": [
|
||||
215,
|
||||
116
|
||||
]
|
||||
212,
|
||||
107
|
||||
],
|
||||
"relation_id": 1020356
|
||||
},
|
||||
"layer_order": true,
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"self_template": false,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 270372,
|
||||
"color": "30 28 28",
|
||||
"image_id": 69780,
|
||||
"color": "209 196 171",
|
||||
"image_id": 84093,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
|
||||
"path": "aida-users/89/sketchboard/female/Outwear/0943d209-7ce0-408c-bc61-83f15da94138.png",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
@@ -264,10 +263,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"location": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
"print_scale_list": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
]
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
@@ -276,22 +288,20 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 10,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"businessId": 270373,
|
||||
"color": "30 28 28",
|
||||
"image_id": 98243,
|
||||
"color": "63 71 73",
|
||||
"image_id": 100496,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
|
||||
"path": "aida-sys-image/images/female/blouse/0628001684.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
@@ -300,10 +310,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"location": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
"print_scale_list": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
]
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
@@ -312,7 +335,6 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 11,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
@@ -320,14 +342,14 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 270374,
|
||||
"color": "172 68 68",
|
||||
"image_id": 98244,
|
||||
"color": "111 78 63",
|
||||
"gradient": "aida-gradient/f69b98e8-4248-4f7a-98a2-21bac41bf3e0.png",
|
||||
"image_id": 92193,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
|
||||
"path": "aida-sys-image/images/female/trousers/0825001160.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
@@ -336,10 +358,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"location": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
"print_scale_list": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
]
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
@@ -348,31 +383,37 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 12,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
"type": "Outwear"
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
|
||||
"image_id": 96090,
|
||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||
"image_id": 67277,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
],
|
||||
"objectSign": "65830966"
|
||||
}
|
||||
],
|
||||
"process_id": "83"
|
||||
"process_id": "4802946666428422",
|
||||
"requestId": "1d1e7641-0d62-4da2-adc0-b4404910723c",
|
||||
"callback_url": "https://api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 异步
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
background_tasks.add_task(design_generate_v2, request_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
@@ -380,30 +421,30 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.post('/get_progress')
|
||||
def get_progress(request_data: DesignProgressModel):
|
||||
"""
|
||||
获取design 进度
|
||||
创建一个具有以下参数的请求体:
|
||||
- **process_id**: 进度id
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"process_id": "6878547032381675"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
process_id = request_data.process_id
|
||||
r = Redis()
|
||||
data = r.read(key=process_id)
|
||||
if data is None:
|
||||
raise ValueError(f"No progress ID: {process_id}")
|
||||
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
# @router.post('/get_progress')
|
||||
# def get_progress(request_data: DesignProgressModel):
|
||||
# """
|
||||
# 获取design 进度
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **process_id**: 进度id
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "process_id": "6878547032381675"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
# process_id = request_data.process_id
|
||||
# r = Redis()
|
||||
# data = r.read(key=process_id)
|
||||
# if data is None:
|
||||
# raise ValueError(f"No progress ID: {process_id}")
|
||||
# logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data, indent=4)}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"get_progress Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data)
|
||||
|
||||
|
||||
@router.post('/model_process')
|
||||
@@ -419,44 +460,42 @@ def model_process(request_data: ModelProgressModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
|
||||
data = model_transpose(image_path=request_data.model_path)
|
||||
logger.info(f"model_process response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"model_process Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
# ##############################################################
|
||||
|
||||
|
||||
@router.post("/design_batch_generate")
|
||||
async def design_batch(file: UploadFile = File(...),
|
||||
tasks_id: str = Form(...),
|
||||
user_id: str = Form(...),
|
||||
file_name: str = Form(...),
|
||||
total: int = Form(...)
|
||||
):
|
||||
dbg_config = DBGConfigModel(
|
||||
tasks_id=tasks_id,
|
||||
user_id=user_id,
|
||||
file_name=file_name,
|
||||
total=total
|
||||
)
|
||||
contents = await file.read()
|
||||
file_name = file.filename
|
||||
await save_request_file(contents, file_name)
|
||||
return await start_design_batch_generate(dbg_config, contents)
|
||||
|
||||
|
||||
async def save_request_file(contents, file_name):
|
||||
# 创建保存文件的目录(如果不存在)
|
||||
save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
# 处理文件
|
||||
file_path = os.path.join(save_dir, file_name)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
"""design 批量处理 停用"""
|
||||
# @router.post("/design_batch_generate")
|
||||
# async def design_batch(file: UploadFile = File(...),
|
||||
# tasks_id: str = Form(...),
|
||||
# user_id: str = Form(...),
|
||||
# file_name: str = Form(...),
|
||||
# total: int = Form(...)
|
||||
# ):
|
||||
# dbg_config = DBGConfigModel(
|
||||
# tasks_id=tasks_id,
|
||||
# user_id=user_id,
|
||||
# file_name=file_name,
|
||||
# total=total
|
||||
# )
|
||||
# contents = await file.read()
|
||||
# file_name = file.filename
|
||||
# await save_request_file(contents, file_name)
|
||||
# return await start_design_batch_generate(dbg_config, contents)
|
||||
#
|
||||
#
|
||||
# async def save_request_file(contents, file_name):
|
||||
# # 创建保存文件的目录(如果不存在)
|
||||
# save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
|
||||
# if not os.path.exists(save_dir):
|
||||
# os.makedirs(save_dir)
|
||||
# # 处理文件
|
||||
# file_path = os.path.join(save_dir, file_name)
|
||||
# with open(file_path, "wb") as f:
|
||||
# f.write(contents)
|
||||
|
||||
@@ -30,10 +30,10 @@ def design_pre_processing(request_data: DesignPreProcessingModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
server = DesignPreprocessing()
|
||||
data = server.pipeline(image_list=request_data.sketches)
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -33,18 +33,30 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun
|
||||
- **version**: 使用模型版本 fast 或者 high
|
||||
|
||||
示例参数:
|
||||
1. txt 2 img
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic",
|
||||
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
|
||||
"mode": "img2img",
|
||||
"category": "sketch",
|
||||
"gender": "male",
|
||||
"version": "fast"
|
||||
"tasks_id": "bd2cf809-24bc-49a6-91c9-193c6272a52e-2-89",
|
||||
"prompt": "a single item of sketch of dress, 4k, white background",
|
||||
"image_url": "",
|
||||
"mode": "txt2img",
|
||||
"category": "sketch",
|
||||
"gender": "Female",
|
||||
"version": "fast"
|
||||
}
|
||||
2. img 2 img
|
||||
{
|
||||
"tasks_id": "b861d4fa-5ae3-4a30-9c7a-7ba6bb9aa37b-1-89",
|
||||
"prompt": "a single item of sketch of dress, 4k, white background",
|
||||
"image_url": "aida-collection-element/89/Sketchboard/548da3a2-834f-49a7-b52c-e729c5ab5062.png",
|
||||
"mode": "img2img",
|
||||
"category": "sketch",
|
||||
"gender": "Female",
|
||||
"version": "fast"
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
|
||||
service = GenerateImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -65,42 +77,41 @@ def generate_image(tasks_id: str):
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''multi view'''
|
||||
'''multi view 停用'''
|
||||
|
||||
# @router.post("/generate_multi_view")
|
||||
# def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
# - **image_url**: 前视角图的输入,minio或S3 url 地址
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
# service = GenerateMultiView(request_item)
|
||||
# background_tasks.add_task(service.get_result)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel()
|
||||
|
||||
|
||||
@router.post("/generate_multi_view")
|
||||
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 前视角图的输入,minio或S3 url 地址
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateMultiView(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_multi_view_cancel/{tasks_id}")
|
||||
def generate_multi_view(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_multi_view_cancel(tasks_id)
|
||||
logger.info(f"generate_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
# @router.get("/generate_multi_view_cancel/{tasks_id}")
|
||||
# def generate_multi_view(tasks_id: str):
|
||||
# try:
|
||||
# logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
|
||||
# data = generate_multi_view_cancel(tasks_id)
|
||||
# logger.info(f"generate_cancel response @@@@@@:{data}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''single logo'''
|
||||
@@ -122,7 +133,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
|
||||
service = GenerateSingleLogoImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -167,7 +178,7 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = GenerateProductImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -188,166 +199,164 @@ def generate_product_image(tasks_id: str):
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''relight image'''
|
||||
'''relight image 停用'''
|
||||
|
||||
# @router.post("/generate_relight_image")
|
||||
# def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
# - **prompt**: 想要生成图片的描述词
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
# - **product_type**: 输入single item 还是 overall item
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
# "direction": "Right Light",
|
||||
# "product_type": "overall"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
# service = GenerateRelightImage(request_item)
|
||||
# background_tasks.add_task(service.get_result)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel()
|
||||
#
|
||||
#
|
||||
# @router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
|
||||
# def generate_relight_image(tasks_id: str):
|
||||
# try:
|
||||
# logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
|
||||
# data = generate_relight_image_cancel(tasks_id)
|
||||
# logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
@router.post("/generate_relight_image")
|
||||
def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
"""batch generate img 停用"""
|
||||
|
||||
# @router.post("/batch_generate_product_image")
|
||||
# async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于获取生成结果
|
||||
# - **prompt**: 想要生成图片的描述词
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **image_strength**: 生成强度,越低越接近原图
|
||||
# - **product_type**: 输入single item 还是 overall item
|
||||
# - **batch_size**: 批生成数量
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||
# "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||
# "image_strength": 0.8,
|
||||
# "product_type": "overall",
|
||||
# "batch_size": 1
|
||||
# }
|
||||
# """
|
||||
# return await start_product_batch_generate(request_batch_item)
|
||||
#
|
||||
#
|
||||
# @router.post("/batch_generate_relight_image")
|
||||
# async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于获取生成结果
|
||||
# - **prompt**: 想要生成图片的描述词
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
# - **product_type**: 输入single item 还是 overall item
|
||||
# - **batch_size**: 批生成数量
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
# "direction": "Right Light",
|
||||
# "product_type": "overall",
|
||||
# "batch_size": 1
|
||||
# }
|
||||
# """
|
||||
# return await start_relight_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"direction": "Right Light",
|
||||
"product_type": "overall"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateRelightImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
|
||||
def generate_relight_image(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_relight_image_cancel(tasks_id)
|
||||
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
"""batch generate img"""
|
||||
|
||||
|
||||
@router.post("/batch_generate_product_image")
|
||||
async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **image_strength**: 生成强度,越低越接近原图
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||
"image_strength": 0.8,
|
||||
"product_type": "overall",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_product_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
@router.post("/batch_generate_relight_image")
|
||||
async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"direction": "Right Light",
|
||||
"product_type": "overall",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_relight_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
@router.post("/batch_generate_pose_transform_image")
|
||||
async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **pose_id**: 1
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"pose_id": "1",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_pose_transform_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
"""agent tool"""
|
||||
|
||||
|
||||
@router.post("/agent_tool_generate_image")
|
||||
def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **category**: 生成图片的类别,sketch print 等等
|
||||
- **gender**: 生成sketch专用,服装类别
|
||||
- **version**: 使用模型版本 fast 或者 high
|
||||
- **size**: 生成数量
|
||||
- **version**: 使用模型版本 fast 或者 high
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||
"category": "sketch",
|
||||
"gender": "male",
|
||||
"size":2,
|
||||
"version":"high"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
|
||||
request_data = request_item.dict()
|
||||
service = AgentToolGenerateImage(request_data['version'])
|
||||
image_url_list, clothing_category_list = service.get_result(
|
||||
prompt=request_data['prompt'],
|
||||
size=request_data['size'],
|
||||
version=request_data['version'],
|
||||
category=request_data['category'],
|
||||
gender=request_data['gender']
|
||||
)
|
||||
data = {
|
||||
"image_url_list": image_url_list,
|
||||
"clothing_category_list": clothing_category_list
|
||||
}
|
||||
logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
# @router.post("/batch_generate_pose_transform_image")
|
||||
# async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **pose_id**: 1
|
||||
# - **batch_size**: 批生成数量
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
# "pose_id": "1",
|
||||
# "batch_size": 1
|
||||
# }
|
||||
# """
|
||||
# return await start_pose_transform_batch_generate(request_batch_item)
|
||||
#
|
||||
#
|
||||
# """agent tool"""
|
||||
#
|
||||
#
|
||||
# @router.post("/agent_tool_generate_image")
|
||||
# def agent_tool_generate_image(request_item: AgentTollGenerateImageModel):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **prompt**: 想要生成图片的描述词
|
||||
# - **category**: 生成图片的类别,sketch print 等等
|
||||
# - **gender**: 生成sketch专用,服装类别
|
||||
# - **version**: 使用模型版本 fast 或者 high
|
||||
# - **size**: 生成数量
|
||||
# - **version**: 使用模型版本 fast 或者 high
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||
# "category": "sketch",
|
||||
# "gender": "male",
|
||||
# "size":2,
|
||||
# "version":"high"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
|
||||
# request_data = request_item.dict()
|
||||
# service = AgentToolGenerateImage(request_data['version'])
|
||||
# image_url_list, clothing_category_list = service.get_result(
|
||||
# prompt=request_data['prompt'],
|
||||
# size=request_data['size'],
|
||||
# version=request_data['version'],
|
||||
# category=request_data['category'],
|
||||
# gender=request_data['gender']
|
||||
# )
|
||||
# data = {
|
||||
# "image_url_list": image_url_list,
|
||||
# "clothing_category_list": clothing_category_list
|
||||
# }
|
||||
# logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data)
|
||||
|
||||
@@ -14,22 +14,22 @@ logger = logging.getLogger()
|
||||
@router.post("/image2sketch")
|
||||
def image2sketch(request_item: Image2SketchModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **image_url**: 提取图片url
|
||||
- **default_style**: 原始、 1、2、3、4、5
|
||||
- **sketch_bucket**: sketch保存的bucket
|
||||
- **sketch_name**: sketch保存的object name
|
||||
创建一个具有以下参数的请求体:
|
||||
- **image_url**: 提取图片url
|
||||
- **default_style**: 原始、 1、2、3、4、5
|
||||
- **sketch_bucket**: sketch保存的bucket
|
||||
- **sketch_name**: sketch保存的object name
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
|
||||
"default_style": 0,
|
||||
"sketch_bucket": "test",
|
||||
"sketch_name": "image2sketch/area_fill_img.png"
|
||||
}
|
||||
"""
|
||||
示例参数:
|
||||
{
|
||||
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
|
||||
"default_style": 0,
|
||||
"sketch_bucket": "test",
|
||||
"sketch_name": "image2sketch/area_fill_img.png"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = LineArtService(request_item)
|
||||
result_url = service.get_result()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# 使用线程池执行器来运行长时间任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
# 用于跟踪任务状态
|
||||
task_status = {"running": False}
|
||||
|
||||
|
||||
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
|
||||
"""在后台线程中运行导入任务"""
|
||||
original_argv = None
|
||||
try:
|
||||
task_status["running"] = True
|
||||
# 保存原始 sys.argv
|
||||
original_argv = sys.argv.copy()
|
||||
|
||||
# 模拟命令行参数
|
||||
sys.argv = [
|
||||
"import_sys_sketch_to_milvus.py",
|
||||
"--batch-size", str(batch_size),
|
||||
"--retry-times", str(retry_times),
|
||||
]
|
||||
if limit is not None:
|
||||
sys.argv.extend(["--limit", str(limit)])
|
||||
if offset > 0:
|
||||
sys.argv.extend(["--offset", str(offset)])
|
||||
if skip_create_collection:
|
||||
sys.argv.append("--skip-create-collection")
|
||||
|
||||
import_main()
|
||||
task_status["running"] = False
|
||||
logger.info("导入任务完成")
|
||||
except Exception as e:
|
||||
task_status["running"] = False
|
||||
logger.error(f"导入任务失败: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
# 恢复原始 sys.argv
|
||||
if original_argv is not None:
|
||||
sys.argv = original_argv
|
||||
|
||||
|
||||
@router.post("/import-sys-sketch", response_model=ResponseModel)
|
||||
async def import_sys_sketch(
|
||||
batch_size: int = Query(1000, description="批量处理大小(默认:1000)"),
|
||||
retry_times: int = Query(3, description="失败重试次数(默认:3)"),
|
||||
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
|
||||
offset: int = Query(0, description="起始偏移量(默认:0)"),
|
||||
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
|
||||
):
|
||||
"""
|
||||
从 t_sys_file 导入系统图向量到 Milvus
|
||||
|
||||
该接口会异步执行导入任务,任务在后台运行。
|
||||
"""
|
||||
try:
|
||||
# 检查是否有任务正在运行
|
||||
if task_status["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="已有导入任务正在运行,请等待完成后再试"
|
||||
)
|
||||
|
||||
# 在后台线程中执行任务
|
||||
executor.submit(
|
||||
run_import_task,
|
||||
batch_size,
|
||||
retry_times,
|
||||
limit,
|
||||
offset,
|
||||
skip_create_collection
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="导入任务已启动,正在后台执行",
|
||||
data={
|
||||
"status": "started",
|
||||
"batch_size": batch_size,
|
||||
"retry_times": retry_times,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"skip_create_collection": skip_create_collection
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"启动导入任务失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
|
||||
async def get_import_status():
|
||||
"""
|
||||
获取导入任务状态
|
||||
"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="OK",
|
||||
data={
|
||||
"running": task_status["running"]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -35,10 +35,10 @@ def mannequins_edit(request_data: MannequinModel):
|
||||
}**
|
||||
"""
|
||||
try:
|
||||
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
|
||||
service = MannequinEditService(request_data)
|
||||
data = service()
|
||||
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"mannequins_edit Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -4,55 +4,55 @@ import logging
|
||||
import requests
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from app.core.config import COMFYUI_SERVER_ADDRESS
|
||||
from app.core.config import settings
|
||||
from app.schemas.comfyui_i2v import ComfyuiI2VModel, ComfyuiFLF2VModel
|
||||
from app.schemas.pose_transform import PoseTransformModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.comfyui_I2V.flf2v_server import ComfyUIServerFLF2V
|
||||
from app.service.comfyui_I2V.i2v_server import ComfyUIServerI2V
|
||||
from app.service.comfyui_I2V.pose2v_server import ComfyUIServerPose2V
|
||||
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
"""停用"""
|
||||
|
||||
@router.post("/pose_transform")
|
||||
def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **pose_id**: 1
|
||||
# @router.post("/pose_transform")
|
||||
# def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **pose_id**: 1
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
# "pose_id": "1"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
# service = PoseTransformService(request_item)
|
||||
# background_tasks.add_task(service.get_result)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel()
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"pose_id": "1"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = PoseTransformService(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/pose_transform_cancel/{tasks_id}")
|
||||
def pose_transform_cancel(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = pose_transform_infer_cancel(tasks_id)
|
||||
logger.info(f"pose_transform_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
# @router.get("/pose_transform_cancel/{tasks_id}")
|
||||
# def pose_transform_cancel(tasks_id: str):
|
||||
# try:
|
||||
# logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
|
||||
# data = pose_transform_infer_cancel(tasks_id)
|
||||
# logger.info(f"pose_transform_cancel response @@@@@@:{data}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
"""
|
||||
@@ -77,7 +77,7 @@ def comfyui_image_pose_2_video(request_item: PoseTransformModel, background_task
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"image_pose_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"image_pose_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = ComfyUIServerPose2V(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -107,7 +107,7 @@ def comfyui_image_2_video(request_item: ComfyuiI2VModel, background_tasks: Backg
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"image_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"image_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = ComfyUIServerI2V(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -139,7 +139,7 @@ def comfyui_flf_2_video(request_item: ComfyuiFLF2VModel, background_tasks: Backg
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"flf_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"flf_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = ComfyUIServerFLF2V(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -153,7 +153,7 @@ def comfyui_i_2_video_cancel(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"comfyui_i_2_video_cancel request item is : @@@@@@:{tasks_id}")
|
||||
response = requests.post(
|
||||
f"http://{COMFYUI_SERVER_ADDRESS}/interrupt",
|
||||
f"http://{settings.COMFYUI_SERVER_ADDRESS}/interrupt",
|
||||
json={"prompt_id": tasks_id}
|
||||
)
|
||||
data = {}
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommendation_system.precompute import run_precompute
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# 使用线程池执行器来运行长时间任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
# 用于跟踪任务状态
|
||||
task_status = {"running": False}
|
||||
|
||||
|
||||
def run_precompute_task():
|
||||
"""在后台线程中运行预计算任务"""
|
||||
try:
|
||||
task_status["running"] = True
|
||||
logger.info("开始执行预计算任务...")
|
||||
run_precompute()
|
||||
task_status["running"] = False
|
||||
logger.info("预计算任务完成")
|
||||
except Exception as e:
|
||||
task_status["running"] = False
|
||||
logger.error(f"预计算任务失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/precompute", response_model=ResponseModel)
|
||||
async def precompute():
|
||||
"""
|
||||
运行预计算任务
|
||||
|
||||
该接口会异步执行预计算任务,包括:
|
||||
1. 优化数据库表结构
|
||||
2. 历史数据迁移
|
||||
3. 初始用户偏好向量生成
|
||||
|
||||
任务在后台运行。
|
||||
"""
|
||||
try:
|
||||
# 检查是否有任务正在运行
|
||||
if task_status["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="已有预计算任务正在运行,请等待完成后再试"
|
||||
)
|
||||
|
||||
# 在后台线程中执行任务
|
||||
executor.submit(run_precompute_task)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="预计算任务已启动,正在后台执行",
|
||||
data={
|
||||
"status": "started",
|
||||
"tasks": [
|
||||
"优化数据库表结构",
|
||||
"历史数据迁移",
|
||||
"初始用户偏好向量生成"
|
||||
]
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/precompute/status", response_model=ResponseModel)
|
||||
async def get_precompute_status():
|
||||
"""
|
||||
获取预计算任务状态
|
||||
"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="OK",
|
||||
data={
|
||||
"running": task_status["running"]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.prompt_generation import PromptGenerationImageModel, ImageRequest
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, \
|
||||
get_prompt_from_image
|
||||
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, get_prompt_from_image
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
@@ -34,19 +31,19 @@ def prompt_generation(request_data: PromptGenerationImageModel):
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
@router.post("/img2prompt")
|
||||
def get_prompt_from_img(img: ImageRequest):
|
||||
"""
|
||||
自动识别图片并输出为prompt
|
||||
|
||||
:param img: 图片的minio地址
|
||||
:return: 图片的文字描述
|
||||
"""
|
||||
text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
|
||||
"The description should allow for the reconstruction of the corresponding line art based on the details "
|
||||
"given.")
|
||||
logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
|
||||
description = get_prompt_from_image(img, text)
|
||||
logger.info(f"生成的图片描述 response @@@@@@:{description}")
|
||||
return description
|
||||
# 停用
|
||||
# @router.post("/img2prompt")
|
||||
# def get_prompt_from_img(img: ImageRequest):
|
||||
# """
|
||||
# 自动识别图片并输出为prompt
|
||||
#
|
||||
# :param img: 图片的minio地址
|
||||
# :return: 图片的文字描述
|
||||
# """
|
||||
# text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
|
||||
# "The description should allow for the reconstruction of the corresponding line art based on the details "
|
||||
# "given.")
|
||||
# logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
|
||||
# description = get_prompt_from_image(img, text)
|
||||
# logger.info(f"生成的图片描述 response @@@@@@:{description}")
|
||||
# return description
|
||||
|
||||
@@ -26,9 +26,9 @@ def query_image(request_data: QueryImageModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
data = query(request_data.gender, request_data.content)
|
||||
logger.info(f"query_image response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"query_image response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"query_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -1,175 +1,206 @@
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException, APIRouter, Query
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
|
||||
from app.service.recommendation_system.incremental_listener import start_background_listener
|
||||
from app.service.recommendation_system.milvus_client import create_collection
|
||||
import numpy as np
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
|
||||
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
# """
|
||||
# :param user_id: 4
|
||||
# :param category: female_skirt
|
||||
# :param num_recommendations: 1
|
||||
# :return:
|
||||
# [
|
||||
# "aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
# ]
|
||||
# """
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# cache_key = (user_id, category)
|
||||
# # === 新增:用户存在性检查 ===
|
||||
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
# user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
#
|
||||
# # 任一矩阵不存在用户则返回随机推荐
|
||||
# if not (user_exists_inter and user_exists_feat):
|
||||
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
# return get_random_recommendations(category, num_recommendations)
|
||||
#
|
||||
# # 检查缓存
|
||||
# if cache_key in matrix_data["cached_scores"]:
|
||||
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
# else:
|
||||
# # 实时计算逻辑(同原代码)
|
||||
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
#
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
# valid_sketch_idxs_inter = [
|
||||
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
# if iid in category_iids
|
||||
# ]
|
||||
#
|
||||
# # 处理交互分数
|
||||
# raw_inter_scores = []
|
||||
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
# processed_inter = raw_inter_scores * 0.7
|
||||
#
|
||||
# # 处理特征分数
|
||||
# valid_sketch_idxs_feature = [
|
||||
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
# if iid in category_iids
|
||||
# ]
|
||||
# raw_feat_scores = []
|
||||
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
# processed_feat = raw_feat_scores
|
||||
# else:
|
||||
# processed_feat = np.array([])
|
||||
#
|
||||
# # 更新缓存
|
||||
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
#
|
||||
# # 合并分数
|
||||
# if brand_id is not None:
|
||||
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
#
|
||||
# brand_feat_valid = (
|
||||
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
# brand_idx_feature is not None and
|
||||
# valid_sketch_idxs_feature # 有可用索引
|
||||
# )
|
||||
#
|
||||
# if brand_feat_valid:
|
||||
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
# brand_idx_feature, valid_sketch_idxs_feature
|
||||
# ]
|
||||
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
# )
|
||||
# processed_brand_feat = raw_brand_feat_scores
|
||||
#
|
||||
# # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
# if processed_feat.size == 0:
|
||||
# processed_feat = np.zeros_like(processed_brand_feat)
|
||||
#
|
||||
# final_scores = processed_inter + 0.3 * (
|
||||
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
# )
|
||||
# else:
|
||||
# # brand 信息不可用
|
||||
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
# else:
|
||||
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
#
|
||||
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
#
|
||||
# # 概率采样
|
||||
# scores = np.array(final_scores)
|
||||
#
|
||||
# # 调整后的概率转换(带温度控制的softmax)
|
||||
# def calibrated_softmax(scores, temperature=1.0):
|
||||
# scores = scores / temperature
|
||||
# scale = scores - max(scores)
|
||||
# exps = np.exp(scale)
|
||||
# return exps / np.sum(exps)
|
||||
#
|
||||
# probs = calibrated_softmax(scores, 0.09)
|
||||
#
|
||||
# chosen_indices = np.random.choice(
|
||||
# len(valid_sketch_idxs),
|
||||
# size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
# p=probs,
|
||||
# replace=False
|
||||
# )
|
||||
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
#
|
||||
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
# return recommendations
|
||||
# except Exception as e:
|
||||
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
# raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# @router.on_event("startup")
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
"""启动时初始化增量监听任务"""
|
||||
try:
|
||||
# 确保 Milvus 集合已创建(若已存在则直接返回)
|
||||
try:
|
||||
create_collection()
|
||||
except Exception as exc:
|
||||
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
|
||||
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
start_background_listener(scheduler)
|
||||
scheduler.start()
|
||||
logger.info("增量监听定时任务已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
|
||||
# 初始加载
|
||||
load_resources()
|
||||
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_job(
|
||||
load_resources,
|
||||
trigger=CronTrigger(hour=0, minute=30),
|
||||
name="每日资源刷新"
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("定时任务已启动")
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
|
||||
async def recommend(
|
||||
user_id: int,
|
||||
category: str,
|
||||
style: Optional[str] = Query(
|
||||
None,
|
||||
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
|
||||
),
|
||||
):
|
||||
"""新版推荐接口(Milvus + Redis 偏好向量)。"""
|
||||
def softmax(scores):
|
||||
max_score = max(scores)
|
||||
exp_scores = [math.exp(s - max_score) for s in scores]
|
||||
sum_exp = sum(exp_scores)
|
||||
return [s / sum_exp for s in exp_scores]
|
||||
|
||||
|
||||
# def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
# """根据预加载热度向量推荐(冷启动)"""
|
||||
# try:
|
||||
# heat_data = matrix_data.get("heat_data", {})
|
||||
#
|
||||
# if category not in heat_data:
|
||||
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
|
||||
#
|
||||
# heat_dict = heat_data[category] # {url: score}
|
||||
# urls = list(heat_dict.keys())
|
||||
# scores = list(heat_dict.values())
|
||||
#
|
||||
# if not urls:
|
||||
# raise ValueError("该类别下无热度记录,使用随机推荐")
|
||||
#
|
||||
# probs = softmax(scores)
|
||||
# sample_size = min(num, len(urls))
|
||||
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
|
||||
#
|
||||
# return sampled_urls
|
||||
#
|
||||
# except Exception as e:
|
||||
# # 回退:完全随机推荐
|
||||
# all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# sample_size = min(num, len(category_iids))
|
||||
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
|
||||
def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
"""全品类随机推荐"""
|
||||
all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# 优先从当前品类选择
|
||||
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# 确保不超出实际数量
|
||||
sample_size = min(num, len(category_iids))
|
||||
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
"""
|
||||
@param user_id: 4
|
||||
@param category: female_skirt
|
||||
@param num_recommendations: 1
|
||||
@return:
|
||||
[
|
||||
"aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
]
|
||||
|
||||
"""
|
||||
try:
|
||||
results = get_new_recommendations(user_id, category, style)
|
||||
path = results[0] if results else ""
|
||||
return [path]
|
||||
logger.info(f"user_id:{user_id}-----category:{category}-----brand_id:{brand_id}-----brand_scale:{brand_scale}-----num_recommendations:{num_recommendations}")
|
||||
start_time = time.time()
|
||||
cache_key = (user_id, category)
|
||||
# === 新增:用户存在性检查 ===
|
||||
user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
|
||||
# 任一矩阵不存在用户则返回随机推荐
|
||||
if not (user_exists_inter and user_exists_feat):
|
||||
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
return get_random_recommendations(category, num_recommendations)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
else:
|
||||
# 实时计算逻辑(同原代码)
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
valid_sketch_idxs_inter = [
|
||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
|
||||
# 处理交互分数
|
||||
raw_inter_scores = []
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
|
||||
# 处理特征分数
|
||||
valid_sketch_idxs_feature = [
|
||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
raw_feat_scores = []
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores
|
||||
else:
|
||||
processed_feat = np.array([])
|
||||
|
||||
# 更新缓存
|
||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
|
||||
# 合并分数
|
||||
if brand_id is not None:
|
||||
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
|
||||
brand_feat_valid = (
|
||||
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
brand_idx_feature is not None and
|
||||
valid_sketch_idxs_feature # 有可用索引
|
||||
)
|
||||
|
||||
if brand_feat_valid:
|
||||
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
brand_idx_feature, valid_sketch_idxs_feature
|
||||
]
|
||||
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
)
|
||||
processed_brand_feat = raw_brand_feat_scores
|
||||
|
||||
# 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
if processed_feat.size == 0:
|
||||
processed_feat = np.zeros_like(processed_brand_feat)
|
||||
|
||||
final_scores = processed_inter + 0.3 * (
|
||||
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
)
|
||||
else:
|
||||
# brand 信息不可用
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
else:
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
|
||||
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
|
||||
# 概率采样
|
||||
scores = np.array(final_scores)
|
||||
|
||||
# 调整后的概率转换(带温度控制的softmax)
|
||||
def calibrated_softmax(scores, temperature=1.0):
|
||||
scores = scores / temperature
|
||||
scale = scores - max(scores)
|
||||
exps = np.exp(scale)
|
||||
return exps / np.sum(exps)
|
||||
|
||||
probs = calibrated_softmax(scores, 0.09)
|
||||
|
||||
chosen_indices = np.random.choice(
|
||||
len(valid_sketch_idxs),
|
||||
size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
p=probs,
|
||||
replace=False
|
||||
)
|
||||
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
|
||||
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -1,42 +1,40 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import api_attribute_retrieve, api_query_image
|
||||
from app.api import api_attribute_retrieve
|
||||
from app.api import api_brand_dna
|
||||
from app.api import api_brighten
|
||||
from app.api import api_chat_robot
|
||||
from app.api import api_clothing_seg
|
||||
from app.api import api_design
|
||||
from app.api import api_design_pre_processing
|
||||
from app.api import api_extraction_project_info
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_image2sketch
|
||||
from app.api import api_import_sys_sketch
|
||||
from app.api import api_mannequins_edit
|
||||
from app.api import api_pose_transform
|
||||
from app.api import api_precompute
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_recommendation
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_test
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(api_test.router, tags=["test"], prefix="/test")
|
||||
router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
|
||||
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")
|
||||
router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api")
|
||||
router.include_router(api_design.router, tags=['design'], prefix="/api")
|
||||
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
|
||||
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
|
||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
router.include_router(api_import_sys_sketch.router, tags=['api_import_sys_sketch'], prefix="/api")
|
||||
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
|
||||
|
||||
"""停用"""
|
||||
# from app.api import api_chat_robot
|
||||
# from app.api import api_query_image
|
||||
# from app.api import api_brighten
|
||||
# from app.api import api_extraction_project_info
|
||||
# from app.api import api_image2sketch
|
||||
# from app.api import api_super_resolution
|
||||
# router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
||||
# router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||
# router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
|
||||
# router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
|
||||
# router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||
# router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
|
||||
@@ -27,7 +27,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = SuperResolution(request_item)
|
||||
background_tasks.add_task(service.sr_result)
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,8 +4,7 @@ import logging
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, \
|
||||
BATCH_PS_RABBITMQ_QUEUES, RABBITMQ_ENV
|
||||
from app.core.config import settings, SR_RABBITMQ_QUEUES, GMV_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, BATCH_PS_RABBITMQ_QUEUES
|
||||
from app.schemas.response_template import ResponseModel
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -15,9 +14,9 @@ router = APIRouter()
|
||||
@router.get("{id}")
|
||||
def test(id: int):
|
||||
data = {
|
||||
"RABBITMQ_ENV":RABBITMQ_ENV,
|
||||
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
||||
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
||||
"RABBITMQ_ENV": settings.SERVE_ENV,
|
||||
# "超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
||||
# "多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
||||
"pose transform PS_RABBITMQ_QUEUES": PS_RABBITMQ_QUEUES,
|
||||
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
|
||||
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||
@@ -29,10 +28,9 @@ def test(id: int):
|
||||
"batch relight BATCH_GRI_RABBITMQ_QUEUES": BATCH_GRI_RABBITMQ_QUEUES,
|
||||
"batch pose transform BATCH_PS_RABBITMQ_QUEUES": BATCH_PS_RABBITMQ_QUEUES,
|
||||
|
||||
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
||||
"local_oss_server": OSS
|
||||
"JAVA_STREAM_API_URL": settings.JAVA_STREAM_API_URL,
|
||||
}
|
||||
logger.info(json.dumps(data))
|
||||
logger.info(json.dumps(data, ensure_ascii=False, indent=4))
|
||||
if id == 1:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
|
||||
235
app/core/config.backup.py
Normal file
235
app/core/config.backup.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import os
|
||||
|
||||
import pika
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseSettings
|
||||
|
||||
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
||||
load_dotenv(os.path.join(BASE_DIR, '.env'))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = 'FASTAPI BASE'
|
||||
SECRET_KEY: str = ''
|
||||
API_PREFIX: str = ''
|
||||
BACKEND_CORS_ORIGINS: list[str] = ['*']
|
||||
DATABASE_URL: str = ''
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||
SECURITY_ALGORITHM: str = 'HS256'
|
||||
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
|
||||
|
||||
OSS = "minio"
|
||||
DEBUG = False
|
||||
if DEBUG:
|
||||
LOGS_PATH = "logs/"
|
||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "../seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "service/recommend/"
|
||||
CHROMADB_PATH = "./chromadb/"
|
||||
else:
|
||||
LOGS_PATH = "app/logs/"
|
||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "/seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
||||
CHROMADB_PATH = "/chromadb/"
|
||||
|
||||
# RABBITMQ_ENV = "" # 生产环境
|
||||
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
|
||||
# RABBITMQ_ENV = "-local" # 本地测试环境
|
||||
|
||||
if RABBITMQ_ENV == "-dev":
|
||||
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
elif RABBITMQ_ENV == "-prod":
|
||||
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# minio 配置
|
||||
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
|
||||
# S3 配置
|
||||
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
|
||||
S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8"
|
||||
S3_REGION_NAME = "ap-east-1"
|
||||
|
||||
# redis 配置
|
||||
REDIS_HOST = "10.1.1.240"
|
||||
REDIS_PORT = "6379"
|
||||
REDIS_DB = "2"
|
||||
|
||||
# rabbitmq config
|
||||
RABBITMQ_PARAMS = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 5672,
|
||||
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
|
||||
"virtual_host": "/"
|
||||
}
|
||||
|
||||
# milvus 配置
|
||||
MILVUS_URL = "http://10.1.1.240:19530"
|
||||
MILVUS_TOKEN = "root:Milvus"
|
||||
MILVUS_ALIAS = "default"
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
|
||||
MILVUS_TABLE_SEG = "seg_cache"
|
||||
|
||||
# Mysql 配置
|
||||
DB_HOST = '18.167.251.121' # 数据库主机地址
|
||||
# DB_PORT = int( 33006)
|
||||
DB_PORT = 33008 # 数据库端口
|
||||
DB_USERNAME = 'aida_con_python' # 数据库用户名
|
||||
DB_PASSWORD = '123456' # 数据库密码
|
||||
DB_NAME = 'aida' # 数据库库名
|
||||
|
||||
# openai
|
||||
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
|
||||
OPENAI_STREAM = True
|
||||
BUFFER_THRESHOLD = 6 # must be even number
|
||||
SINGLE_TOKEN_THRESHOLD = 200
|
||||
TOKEN_THRESHOLD = 600
|
||||
OPENAI_TEMPERATURE = 0
|
||||
|
||||
# OPENAI_API_KEY = "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"
|
||||
OPENAI_API_KEY = "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"
|
||||
OPENAI_MODEL = "gpt-3.5-turbo-0613"
|
||||
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613", }
|
||||
|
||||
# SR service config
|
||||
SR_MODEL_NAME = "super_resolution"
|
||||
SR_TRITON_URL = "10.1.1.240:10031"
|
||||
SR_MINIO_BUCKET = "aida-users"
|
||||
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
|
||||
|
||||
# GenerateImage service config
|
||||
FAST_GI_MODEL_URL = '10.1.1.243:10011'
|
||||
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
||||
|
||||
GI_MODEL_URL = '10.1.1.240:10061'
|
||||
GI_MODEL_NAME = 'flux'
|
||||
|
||||
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||
GMV_MODEL_NAME = 'multi_view'
|
||||
|
||||
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
|
||||
|
||||
GI_MINIO_BUCKET = "aida-users"
|
||||
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
|
||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||
|
||||
# SLOGAN service config
|
||||
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
|
||||
|
||||
# Generate Single Logo service config
|
||||
GSL_MODEL_URL = '10.1.1.243:10041'
|
||||
GSL_MINIO_BUCKET = "aida-users"
|
||||
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
|
||||
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
|
||||
|
||||
# Generate Product service config
|
||||
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
|
||||
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
|
||||
# GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Product service config 旧版product img 模型
|
||||
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
|
||||
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
|
||||
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
|
||||
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
|
||||
GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Single Logo service config
|
||||
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
|
||||
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
|
||||
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||
|
||||
# Pose Transform service config
|
||||
|
||||
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
|
||||
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
|
||||
PT_MODEL_URL = '10.1.1.243:10061'
|
||||
|
||||
# SEG service config
|
||||
SEGMENTATION = {
|
||||
"new_model_name": "seg_knet",
|
||||
"name": "seg_ocrnet_hr18",
|
||||
"input": "seg_input__0",
|
||||
"output": "seg_output__0",
|
||||
}
|
||||
# ollama config
|
||||
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
|
||||
|
||||
# design batch
|
||||
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
|
||||
|
||||
# DESIGN config
|
||||
DESIGN_MODEL_URL = '10.1.1.240:10000'
|
||||
AIDA_CLOTHING = "aida-clothing"
|
||||
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
|
||||
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
||||
|
||||
# DESIGN 预处理
|
||||
IF_DEBUG_SHOW = False
|
||||
|
||||
# 优先级
|
||||
PRIORITY_DICT = {
|
||||
'earring_front': 99,
|
||||
'bag_front': 98,
|
||||
'hairstyle_front': 97,
|
||||
'outwear_front': 20,
|
||||
'tops_front': 19,
|
||||
'dress_front': 18,
|
||||
'blouse_front': 17,
|
||||
'skirt_front': 16,
|
||||
'trousers_front': 15,
|
||||
'bottoms_front': 14,
|
||||
'shoes_right': 1,
|
||||
'shoes_left': 1,
|
||||
'body': 0,
|
||||
'bottoms_back': -14,
|
||||
'trousers_back': -15,
|
||||
'skirt_back': -16,
|
||||
'blouse_back': -17,
|
||||
'dress_back': -18,
|
||||
'tops_back': -19,
|
||||
'outwear_back': -20,
|
||||
'hairstyle_back': -97,
|
||||
'bag_back': -98,
|
||||
'earring_back': -99,
|
||||
}
|
||||
|
||||
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
|
||||
|
||||
DB_CONFIG = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": "QWa998345",
|
||||
"database": "aida",
|
||||
"charset": "utf8mb4"
|
||||
}
|
||||
|
||||
TABLE_CATEGORIES = {
|
||||
"female_dress": "female/dress",
|
||||
"female_outwear": "female/outwear",
|
||||
"female_trousers": "female/trousers",
|
||||
"female_skirt": "female/skirt",
|
||||
"female_blouse": "female/blouse",
|
||||
"male_tops": "male/tops",
|
||||
"male_bottoms": "male/bottoms",
|
||||
"male_outwear": "male/outwear"
|
||||
}
|
||||
|
||||
# --- ComfyUI 配置信息 ---
|
||||
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址
|
||||
@@ -1,188 +1,91 @@
|
||||
import os
|
||||
|
||||
import pika
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseSettings
|
||||
|
||||
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
||||
load_dotenv(os.path.join(BASE_DIR, '.env'))
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = 'FASTAPI BASE'
|
||||
SECRET_KEY: str = ''
|
||||
API_PREFIX: str = ''
|
||||
BACKEND_CORS_ORIGINS: list[str] = ['*']
|
||||
DATABASE_URL: str = ''
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||
SECURITY_ALGORITHM: str = 'HS256'
|
||||
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
"""
|
||||
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
|
||||
"""
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
# extra='ignore' # 忽略环境变量中多余的键
|
||||
)
|
||||
# --- 服务端口配置信息 ---
|
||||
PORT: int = Field(default=8001, description="")
|
||||
# --- 服务环境 配置信息 ---
|
||||
SERVE_ENV: str = Field(default='', description="")
|
||||
# --- 开发状态 配置信息 ---
|
||||
DEBUG: bool = Field(default=False, description="")
|
||||
# --- 千问api 配置信息 ---
|
||||
QWEN_API_KEY: str = Field(default="", description="")
|
||||
|
||||
# --- ComfyUI 配置信息 ---
|
||||
COMFYUI_SERVER_ADDRESS: str = Field(default='', description="")
|
||||
|
||||
OSS = "minio"
|
||||
DEBUG = False
|
||||
if DEBUG:
|
||||
LOGS_PATH = "logs/"
|
||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "../seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "service/recommend/"
|
||||
CHROMADB_PATH = "./chromadb/"
|
||||
else:
|
||||
LOGS_PATH = "app/logs/"
|
||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "/seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
||||
CHROMADB_PATH = "/chromadb/"
|
||||
# --- minio 配置信息 ---
|
||||
MINIO_URL: str = Field(default='', description="")
|
||||
MINIO_ACCESS: str = Field(default='', description="")
|
||||
MINIO_SECRET: str = Field(default='', description="")
|
||||
MINIO_SECURE: bool = Field(default=True, description="")
|
||||
|
||||
# RABBITMQ_ENV = "" # 生产环境
|
||||
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
|
||||
# RABBITMQ_ENV = "-local" # 本地测试环境
|
||||
# --- redis 配置信息 ---
|
||||
REDIS_HOST: str = Field(default='', description="")
|
||||
REDIS_PORT: str = Field(default='', description="")
|
||||
REDIS_DB: int = Field(default=0, description="")
|
||||
|
||||
# --- mysql 配置信息 ---
|
||||
MYSQL_HOST: str = Field(default='', description="")
|
||||
MYSQL_PORT: str = Field(default='', description="")
|
||||
MYSQL_USER: str = Field(default='', description="")
|
||||
MYSQL_PASSWORD: str = Field(default='', description="")
|
||||
MYSQL_DB: str = Field(default='', description="")
|
||||
MYSQL_CHARSET: str = Field(default='utf8mb4', description="")
|
||||
|
||||
# --- rabbit-mq 配置信息 ---
|
||||
MQ_HOST: str = Field(default='', description="")
|
||||
MQ_PORT: str = Field(default='', description="")
|
||||
MQ_USERNAME: str = Field(default='', description="")
|
||||
MQ_PASSWORD: str = Field(default='', description="")
|
||||
MQ_VIRTUAL_HOST: str = Field(default='/', description="")
|
||||
MQ_ENV: str = Field(default='', description="")
|
||||
|
||||
# --- milvus 配置信息 ---
|
||||
MILVUS_URL: str = Field(default='', description="")
|
||||
MILVUS_TOKEN: str = Field(default='', description="")
|
||||
MILVUS_ALIAS: str = Field(default='', description="")
|
||||
|
||||
# --- ollama 配置信息 ---
|
||||
CHROMADB_PATH: str = Field(default='', description="")
|
||||
|
||||
# --- ollama 配置信息 ---
|
||||
OLLAMA_URL: str = Field(default='', description="")
|
||||
|
||||
# --- Design Callback Java 接口 ---
|
||||
JAVA_STREAM_API_URL: str = Field(default='', description="")
|
||||
|
||||
# --- 其他配置信息 以下均为Docker容器内配置---
|
||||
LOGS_PATH: str = Field(default="/logs/", description="")
|
||||
CATEGORY_PATH: str = Field(default="/app/service/attribute/config/descriptor/category/category_dis.csv", description="")
|
||||
SEG_CACHE_PATH: str = Field(default="/seg_cache/", description="")
|
||||
RECOMMEND_PATH_PREFIX: str = Field(default="/app/service/recommend/", description="")
|
||||
|
||||
if RABBITMQ_ENV == "-dev":
|
||||
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
elif RABBITMQ_ENV == "-prod":
|
||||
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# minio 配置
|
||||
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
|
||||
# S3 配置
|
||||
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
|
||||
S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8"
|
||||
S3_REGION_NAME = "ap-east-1"
|
||||
|
||||
# redis 配置
|
||||
REDIS_HOST = "10.1.1.240"
|
||||
REDIS_PORT = "6379"
|
||||
REDIS_DB = "2"
|
||||
|
||||
# rabbitmq config
|
||||
RABBITMQ_PARAMS = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 5672,
|
||||
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
|
||||
"virtual_host": "/"
|
||||
"""Design 服务"""
|
||||
# 推荐服装类别映射
|
||||
TABLE_CATEGORIES = {
|
||||
"female_dress": "female/dress",
|
||||
"female_outwear": "female/outwear",
|
||||
"female_trousers": "female/trousers",
|
||||
"female_skirt": "female/skirt",
|
||||
"female_blouse": "female/blouse",
|
||||
"male_tops": "male/tops",
|
||||
"male_bottoms": "male/bottoms",
|
||||
"male_outwear": "male/outwear"
|
||||
}
|
||||
|
||||
# milvus 配置
|
||||
MILVUS_URL = "http://10.1.1.240:19530"
|
||||
MILVUS_TOKEN = "root:Milvus"
|
||||
MILVUS_ALIAS = "default"
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
|
||||
MILVUS_TABLE_SEG = "seg_cache"
|
||||
|
||||
# Mysql 配置
|
||||
DB_HOST = '18.167.251.121' # 数据库主机地址
|
||||
# DB_PORT = int( 33006)
|
||||
DB_PORT = 33008 # 数据库端口
|
||||
DB_USERNAME = 'aida_con' # 数据库用户名
|
||||
DB_PASSWORD = '123456' # 数据库密码
|
||||
DB_NAME = 'aida_back' # 数据库库名
|
||||
|
||||
# openai
|
||||
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
|
||||
OPENAI_STREAM = True
|
||||
BUFFER_THRESHOLD = 6 # must be even number
|
||||
SINGLE_TOKEN_THRESHOLD = 200
|
||||
TOKEN_THRESHOLD = 600
|
||||
OPENAI_TEMPERATURE = 0
|
||||
|
||||
# OPENAI_API_KEY = "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"
|
||||
OPENAI_API_KEY = "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"
|
||||
OPENAI_MODEL = "gpt-3.5-turbo-0613"
|
||||
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613", }
|
||||
|
||||
# SR service config
|
||||
SR_MODEL_NAME = "super_resolution"
|
||||
SR_TRITON_URL = "10.1.1.240:10031"
|
||||
SR_MINIO_BUCKET = "aida-users"
|
||||
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
|
||||
|
||||
# GenerateImage service config
|
||||
FAST_GI_MODEL_URL = '10.1.1.243:10011'
|
||||
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
||||
|
||||
GI_MODEL_URL = '10.1.1.240:10061'
|
||||
GI_MODEL_NAME = 'flux'
|
||||
|
||||
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||
GMV_MODEL_NAME = 'multi_view'
|
||||
|
||||
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
|
||||
|
||||
GI_MINIO_BUCKET = "aida-users"
|
||||
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
|
||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||
|
||||
# SLOGAN service config
|
||||
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
|
||||
|
||||
# Generate Single Logo service config
|
||||
GSL_MODEL_URL = '10.1.1.243:10041'
|
||||
GSL_MINIO_BUCKET = "aida-users"
|
||||
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
|
||||
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
|
||||
|
||||
# Generate Product service config
|
||||
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
|
||||
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
|
||||
# GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Product service config 旧版product img 模型
|
||||
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
|
||||
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
|
||||
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
|
||||
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
|
||||
GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Single Logo service config
|
||||
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
|
||||
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
|
||||
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||
|
||||
# Pose Transform service config
|
||||
|
||||
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
|
||||
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
|
||||
PT_MODEL_URL = '10.1.1.243:10061'
|
||||
|
||||
# SEG service config
|
||||
SEGMENTATION = {
|
||||
"new_model_name": "seg_knet",
|
||||
"name": "seg_ocrnet_hr18",
|
||||
"input": "seg_input__0",
|
||||
"output": "seg_output__0",
|
||||
}
|
||||
# ollama config
|
||||
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
|
||||
|
||||
# design batch
|
||||
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
|
||||
|
||||
# DESIGN config
|
||||
DESIGN_MODEL_URL = '10.1.1.240:10000'
|
||||
AIDA_CLOTHING = "aida-clothing"
|
||||
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
|
||||
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
||||
|
||||
# DESIGN 预处理
|
||||
IF_DEBUG_SHOW = False
|
||||
|
||||
# 优先级
|
||||
# Design前后排优先级
|
||||
PRIORITY_DICT = {
|
||||
'earring_front': 99,
|
||||
'bag_front': 98,
|
||||
@@ -208,28 +111,71 @@ PRIORITY_DICT = {
|
||||
'bag_back': -98,
|
||||
'earring_back': -99,
|
||||
}
|
||||
# Design 关键点字段
|
||||
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
||||
# milvus配置信息
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
|
||||
|
||||
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
|
||||
# ollama 地址
|
||||
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
|
||||
|
||||
DB_CONFIG = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": "QWa998345",
|
||||
"database": "aida",
|
||||
"charset": "utf8mb4"
|
||||
}
|
||||
"""Triton Server Config"""
|
||||
# Design
|
||||
DESIGN_MODEL_URL = '10.1.1.240:10000'
|
||||
DESIGN_MODEL_NAME = 'seg_knet'
|
||||
# Generate Image
|
||||
GI_MODEL_URL = '10.1.1.240:10061'
|
||||
GI_MODEL_NAME = 'flux'
|
||||
# Generate Single Logo
|
||||
GSL_MODEL_URL = '10.1.1.243:10041'
|
||||
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
|
||||
# Generate Product (整套和单品)
|
||||
GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
|
||||
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
|
||||
|
||||
TABLE_CATEGORIES = {
|
||||
"female_dress": "female/dress",
|
||||
"female_outwear": "female/outwear",
|
||||
"female_trousers": "female/trousers",
|
||||
"female_skirt": "female/skirt",
|
||||
"female_blouse": "female/blouse",
|
||||
"male_tops": "male/tops",
|
||||
"male_bottoms": "male/bottoms",
|
||||
"male_outwear": "male/outwear"
|
||||
}
|
||||
# 以下停用中...*************
|
||||
# 多视角生成
|
||||
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||
GMV_MODEL_NAME = 'multi_view'
|
||||
# 超分
|
||||
SR_MODEL_NAME = "super_resolution"
|
||||
SR_TRITON_URL = "10.1.1.240:10031"
|
||||
# 打光
|
||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||
# agent 图片生成
|
||||
FAST_GI_MODEL_URL = '10.1.1.243:10011'
|
||||
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
||||
# 图转视频 triton版
|
||||
PT_MODEL_URL = '10.1.1.243:10061'
|
||||
|
||||
# --- ComfyUI 配置信息 ---
|
||||
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址
|
||||
# *************
|
||||
|
||||
"""MQ 队列信息"""
|
||||
# 生成图片 moodboard printboard sketchboard
|
||||
GI_RABBITMQ_QUEUES = f"GenerateImage-{settings.SERVE_ENV}"
|
||||
# 生成slogan
|
||||
SLOGAN_RABBITMQ_QUEUES = f"Slogan-{settings.SERVE_ENV}"
|
||||
# 转产品图
|
||||
GPI_RABBITMQ_QUEUES = f"ToProductImage-{settings.SERVE_ENV}"
|
||||
# 产品图转视频
|
||||
PS_RABBITMQ_QUEUES = f"PoseTransform-{settings.SERVE_ENV}"
|
||||
|
||||
# 以下停用中...*************
|
||||
# 产品图打光
|
||||
GRI_RABBITMQ_QUEUES = f"Relight-{settings.SERVE_ENV}"
|
||||
# 超分
|
||||
SR_RABBITMQ_QUEUES = f"SuperResolution-{settings.SERVE_ENV}"
|
||||
# 生成多视图
|
||||
GMV_RABBITMQ_QUEUES = f"GenerateMultiView-{settings.SERVE_ENV}"
|
||||
# 批量转产品图
|
||||
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage-{settings.SERVE_ENV}"
|
||||
# 批量打光
|
||||
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight-{settings.SERVE_ENV}"
|
||||
# 批量图片转视频
|
||||
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform-{settings.SERVE_ENV}"
|
||||
# 批量design
|
||||
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch-{settings.SERVE_ENV}"
|
||||
# *************
|
||||
|
||||
10
app/core/mysql_config.py
Normal file
10
app/core/mysql_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from app.core.config import settings
|
||||
|
||||
DB_CONFIG = {
|
||||
"host": settings.MYSQL_HOST,
|
||||
"port": settings.MYSQL_PORT,
|
||||
"user": settings.MYSQL_USER,
|
||||
"password": settings.MYSQL_PASSWORD,
|
||||
"database": settings.MYSQL_DB,
|
||||
"charset": settings.MYSQL_CHARSET,
|
||||
}
|
||||
10
app/core/rabbit_mq_config.py
Normal file
10
app/core/rabbit_mq_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# rabbitmq config
|
||||
import pika
|
||||
from app.core.config import settings
|
||||
|
||||
RABBITMQ_PARAMS = {
|
||||
"host": settings.MQ_HOST,
|
||||
"port": settings.MQ_PORT,
|
||||
"credentials": pika.credentials.PlainCredentials(username=settings.MQ_USERNAME, password=settings.MQ_PASSWORD),
|
||||
"virtual_host": settings.MQ_VIRTUAL_HOST,
|
||||
}
|
||||
@@ -79,12 +79,8 @@
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
],
|
||||
"process_id": "87",
|
||||
"tasks_id": ,
|
||||
"tasks_id": ""
|
||||
}
|
||||
|
||||
|
||||
//用 openai jsonl
|
||||
//
|
||||
34
app/main.py
34
app/main.py
@@ -1,10 +1,17 @@
|
||||
# 1. 这里的顺序至关重要!必须在最顶端
|
||||
import sys
|
||||
|
||||
try:
|
||||
import asyncore
|
||||
except ImportError:
|
||||
import pyasyncore
|
||||
|
||||
sys.modules['asyncore'] = pyasyncore
|
||||
import logging.config
|
||||
|
||||
import uvicorn
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api.api_route import router
|
||||
@@ -12,19 +19,22 @@ from app.core.config import settings
|
||||
from app.core.record_api_count import count_api_calls
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from logging_env import LOGGER_CONFIG_DICT
|
||||
from dotenv import load_dotenv
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||
logging.getLogger("pika").setLevel(logging.WARNING)
|
||||
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(
|
||||
title=settings.PROJECT_NAME, docs_url="/docs", redoc_url='/re-docs',
|
||||
openapi_url=f"{settings.API_PREFIX}/openapi.json",
|
||||
docs_url="/docs",
|
||||
redoc_url='/re-docs',
|
||||
openapi_url=f"/openapi.json",
|
||||
description='''
|
||||
Base frame with FastAPI
|
||||
- Super Resolution API
|
||||
@@ -33,13 +43,13 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
application.middleware("http")(count_api_calls)
|
||||
application.include_router(router=router, prefix=settings.API_PREFIX)
|
||||
application.include_router(router=router)
|
||||
return application
|
||||
|
||||
|
||||
@@ -47,14 +57,12 @@ app = get_application()
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
async def http_exception_handler(exc: HTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict()
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(app, host="0.0.0.0", port=settings.PORT)
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
import logging
|
||||
from pprint import pprint
|
||||
import torch
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from minio import Minio
|
||||
import torch
|
||||
import tritonclient.http as httpclient
|
||||
from app.core.config import *
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import settings, DESIGN_MODEL_URL
|
||||
from app.schemas.attribute_retrieve import AttributeRecognitionModel
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
class AttributeRecognition:
|
||||
def __init__(self, const, request_data):
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.request_data = []
|
||||
for i, sketch in enumerate(request_data):
|
||||
self.request_data.append(
|
||||
@@ -96,11 +98,12 @@ class AttributeRecognition:
|
||||
res = {**dict1, **dict2}
|
||||
return res
|
||||
|
||||
def get_image(self, url):
|
||||
@staticmethod
|
||||
def get_image(url):
|
||||
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
|
||||
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
@@ -7,24 +7,25 @@
|
||||
@Date :2023/9/16 18:31:08
|
||||
@detail :
|
||||
"""
|
||||
from minio import Minio
|
||||
from skimage import transform
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from minio import Minio
|
||||
import tritonclient.http as httpclient
|
||||
import torch
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings, DESIGN_MODEL_URL
|
||||
from app.schemas.attribute_retrieve import CategoryRecognitionModel
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
class CategoryRecognition:
|
||||
def __init__(self, request_data):
|
||||
self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
|
||||
self.request_data = []
|
||||
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
for sketch in request_data:
|
||||
@@ -46,13 +47,14 @@ class CategoryRecognition:
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img
|
||||
|
||||
def get_image(self, url):
|
||||
@staticmethod
|
||||
def get_image(url):
|
||||
# Get data of an object.
|
||||
# Read data from response.
|
||||
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
||||
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
@@ -68,7 +70,7 @@ class CategoryRecognition:
|
||||
|
||||
colattr = list(self.attr_type['labelName'])
|
||||
|
||||
task = self.attr_type['taskName'][0]
|
||||
# self.attr_type['taskName'][0]
|
||||
|
||||
maxsc = np.max(scores[0][:5])
|
||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||
|
||||
@@ -9,15 +9,16 @@ import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL, CATEGORY_PATH
|
||||
from app.core.config import DESIGN_MODEL_URL
|
||||
from app.core.config import settings
|
||||
from app.schemas.brand_dna import BrandDnaModel
|
||||
from app.service.attribute.config import local_debug_const, const
|
||||
from app.service.attribute.config import const
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class BrandDna:
|
||||
@@ -25,7 +26,7 @@ class BrandDna:
|
||||
self.sketch_bucket = "test"
|
||||
self.image_url = request_item.image_url
|
||||
self.is_brand_dna = request_item.is_brand_dna
|
||||
self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
|
||||
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
|
||||
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
|
||||
|
||||
@@ -3,23 +3,25 @@ import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
# from langchain_openai import ChatOpenAI
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
|
||||
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
|
||||
from app.schemas.brand_dna import GenerateBrandModel
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class GenerateBrandInfo:
|
||||
def __init__(self, request_data):
|
||||
# minio client init
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.generate_logo_prompt = None
|
||||
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
# user info init
|
||||
self.user_id = request_data.user_id
|
||||
@@ -55,7 +57,7 @@ class GenerateBrandInfo:
|
||||
return self.result_data
|
||||
|
||||
def llm_generate_brand_info(self):
|
||||
output = self.model(self._input.to_messages())
|
||||
output = self.model.invoke(self._input.to_messages())
|
||||
brand_data = self.output_parser.parse(output.content)
|
||||
self.result_data = brand_data
|
||||
self.generate_logo_prompt = brand_data['brand_logo_prompt']
|
||||
@@ -87,8 +89,8 @@ class GenerateBrandInfo:
|
||||
def upload_logo_image(self, image, object_name):
|
||||
try:
|
||||
_, img_byte_array = cv2.imencode('.jpg', image)
|
||||
object_name = f'{self.user_id}/{self.category}/{object_name}'
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
|
||||
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
|
||||
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# 加载.env文件的环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 创建一个大语言模型,model指定了大语言模型的种类
|
||||
model = ChatOpenAI(model="qwen2.5-14b-instruct")
|
||||
|
||||
# 想要接收的响应模式
|
||||
response_schemas = [
|
||||
ResponseSchema(name="brand_name", description="Brand name."),
|
||||
ResponseSchema(name="brand_slogan", description="Brand slogan."),
|
||||
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
|
||||
]
|
||||
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
prompt = PromptTemplate(
|
||||
template="你是一个时装品牌的设计师。根据用户输入提取出brand name,brand slogan,brand logo 描述。如果没有以上内容,需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt,这个prompt用于生成模型.\n{format_instructions}\n{question}",
|
||||
input_variables=["question"],
|
||||
partial_variables={"format_instructions": format_instructions}
|
||||
)
|
||||
_input = prompt.format_prompt(question="brand name: cat home")
|
||||
|
||||
output = model(_input.to_messages())
|
||||
brand_data = output_parser.parse(output.content)
|
||||
|
||||
|
||||
def generate_logo(bucket_name, object_name, prompt):
|
||||
pass
|
||||
@@ -3,27 +3,20 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import RUN_KEY, RunInfo
|
||||
from langchain_classic.agents import AgentExecutor
|
||||
from langchain_classic.schema import RUN_KEY
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import Callbacks, CallbackManager
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.outputs import RunInfo
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
session_key: str = "",
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
def __call__(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, session_key: str = "", *, tags: Optional[List[str]] = None, include_run_info: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
**kwargs:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
@@ -72,7 +65,7 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None and outputs['need_record']:
|
||||
self.memory.save_context(inputs, outputs, session_key)
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
@@ -95,7 +88,7 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs, session_key)
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
@@ -119,7 +112,8 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
|
||||
@@ -1,26 +1,15 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Tuple, Any, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.agents import (
|
||||
OpenAIFunctionsAgent,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
OutputParserException
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage
|
||||
)
|
||||
from langchain.tools import BaseTool, StructuredTool
|
||||
# from langchain.tools.convert_to_openai import FunctionDescription
|
||||
from langchain.utils.openai_functions import FunctionDescription
|
||||
from langchain_classic.agents import OpenAIFunctionsAgent
|
||||
from langchain_community.utils.ernie_functions import FunctionDescription
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import BaseMessage, AIMessage, FunctionMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -76,7 +65,6 @@ def _create_function_message(
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
@@ -177,6 +165,7 @@ class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
|
||||
into it.
|
||||
|
||||
Args:
|
||||
callbacks:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
**kwargs: Including user's input string
|
||||
|
||||
@@ -2,18 +2,16 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain_community.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, \
|
||||
get_openai_token_cost_for_model
|
||||
|
||||
|
||||
# from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
need_record: bool = True
|
||||
response_type: str = "string"
|
||||
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is None:
|
||||
@@ -22,7 +20,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
if "token_usage" not in response.llm_output:
|
||||
return None
|
||||
if "function_call" in response.generations[0][0].message.additional_kwargs:
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema", "tutorial_tool"]:
|
||||
self.need_record = False
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
|
||||
self.response_type = "image"
|
||||
@@ -39,6 +37,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
return None
|
||||
|
||||
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
|
||||
"""Write token usage to redis."""
|
||||
|
||||
@@ -44,12 +44,17 @@ class CustomDatabase(SQLDatabase):
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
def run(self, command: str, fetch: str = "all", **kwargs) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
|
||||
Args:
|
||||
command:
|
||||
fetch:
|
||||
**kwargs:
|
||||
|
||||
"""
|
||||
with self._engine.begin() as connection:
|
||||
if self._schema is not None:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.agents import Tool
|
||||
from langchain.callbacks import FileCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import SystemMessage, AIMessage
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from langchain_community.utilities import SerpAPIWrapper
|
||||
from langchain_core.callbacks import FileCallbackHandler
|
||||
from langchain_core.messages import SystemMessage, AIMessage
|
||||
from langchain_core.prompts import MessagesPlaceholder, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
|
||||
@@ -30,10 +30,10 @@ log_handler = FileCallbackHandler(logfile)
|
||||
# # callbacks=[OpenAICallbackHandler()]
|
||||
# )
|
||||
|
||||
llm = ChatTongyi(api_key=QWEN_API_KEY)
|
||||
llm = ChatTongyi(api_key=settings.QWEN_API_KEY)
|
||||
|
||||
search = SerpAPIWrapper()
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.DB_USERNAME}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
@@ -43,11 +43,11 @@ tools = [
|
||||
description="Can be used to perform Internet searches",
|
||||
func=search.run
|
||||
),
|
||||
QuerySQLDataBaseTool(db=db, return_direct=False),
|
||||
QuerySQLDataBaseTool(db=db),
|
||||
InfoSQLDatabaseTool(db=db),
|
||||
ListSQLDatabaseTool(db=db),
|
||||
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
|
||||
QuerySQLCheckerTool(db=db, llm=ChatTongyi(temperature=0, api_key=QWEN_API_KEY)),
|
||||
QuerySQLCheckerTool(db=db, llm=ChatTongyi(api_key=settings.QWEN_API_KEY)),
|
||||
# Tool(
|
||||
# name="tutorial_tool",
|
||||
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
@@ -133,5 +133,5 @@ def chat(post_data):
|
||||
'completion_tokens': final_outputs['completion_tokens'],
|
||||
'response_type': final_outputs["response_type"]
|
||||
}
|
||||
logging.info(json.dumps(api_response))
|
||||
logging.info(json.dumps(api_response, indent=4))
|
||||
return api_response
|
||||
|
||||
@@ -3,13 +3,12 @@ from typing import Any, Dict, List, Tuple
|
||||
import json
|
||||
|
||||
import redis
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
from langchain_classic.memory.utils import get_prompt_input_key
|
||||
from langchain_core.messages import messages_from_dict, get_buffer_string, BaseMessage, HumanMessage, AIMessage, message_to_dict
|
||||
from redis import Redis
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
|
||||
from langchain.schema.messages import _message_to_dict, messages_from_dict
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
@@ -24,8 +23,8 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
@classmethod
|
||||
def from_redis(
|
||||
cls,
|
||||
host: str = REDIS_HOST,
|
||||
port: int = REDIS_PORT,
|
||||
host: str = settings.REDIS_HOST,
|
||||
port: int = settings.REDIS_PORT,
|
||||
db: int = 3,
|
||||
**kwargs
|
||||
):
|
||||
@@ -79,7 +78,7 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
return inputs[prompt_input_key], outputs[output_key]
|
||||
|
||||
def add_message(self, key: str, message: BaseMessage) -> None:
|
||||
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
|
||||
self.redis_client.lpush(key, json.dumps(message_to_dict(message)))
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
|
||||
@@ -5,10 +5,10 @@ from dashscope import Generation
|
||||
from retry import retry
|
||||
from urllib3.exceptions import NewConnectionError
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||
from app.service.chat_robot.script.prompt import TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
|
||||
from app.service.search_image_with_text.service import query
|
||||
|
||||
@@ -149,7 +149,7 @@ tools = [
|
||||
}
|
||||
]
|
||||
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
@@ -159,7 +159,7 @@ qwen = QWenCallbackHandler()
|
||||
def search_from_internet(message):
|
||||
response = Generation.call(
|
||||
model='qwen-turbo',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=message,
|
||||
prompt='The output must be in English.Keep the final result under 200 words.'
|
||||
# tools=tools,
|
||||
@@ -190,7 +190,7 @@ def get_image_from_vector_db(gender, content):
|
||||
def get_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
@@ -203,7 +203,7 @@ def get_response(messages):
|
||||
def get_assistant_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=messages,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
@@ -212,8 +212,10 @@ def get_assistant_response(messages):
|
||||
return response
|
||||
|
||||
|
||||
global tool_info
|
||||
|
||||
|
||||
def call_with_messages(message):
|
||||
global tool_info
|
||||
user_input = message
|
||||
print('\n')
|
||||
|
||||
@@ -241,7 +243,7 @@ def call_with_messages(message):
|
||||
response_type = "chat"
|
||||
|
||||
while flag and count <= 3:
|
||||
first_response = get_response(messages)
|
||||
first_response = get_response
|
||||
assistant_output = first_response.output.choices[0].message
|
||||
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
|
||||
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
|
||||
@@ -260,7 +262,7 @@ def call_with_messages(message):
|
||||
]
|
||||
tool_info['content'] = search_from_internet(message)
|
||||
flag = False
|
||||
result_content = tool_info['content'].output.text
|
||||
result_content = tool_info['content']
|
||||
# 如果模型选择的工具是get_database_table
|
||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
||||
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
||||
|
||||
@@ -2,21 +2,15 @@
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langchain_community.tools.sql_database.tool import _QuerySQLCheckerToolInput
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
@@ -62,7 +56,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||||
"order by rand() LIMIT 2'"
|
||||
)
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -95,9 +89,9 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
||||
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
||||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||||
|
||||
|
||||
"Example Input: 'female_outwear, male_top'"
|
||||
)
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -183,11 +177,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def initialize_llm_chain(self, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
# from langchain.chains.llm import LLMChain
|
||||
|
||||
llm = values.get("llm") # type: ignore[arg-type]
|
||||
llm = values.get("llm") # type: ignore[arg-type]
|
||||
prompt = PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ from PIL import Image
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.schemas.clothing_seg import ClothingSegModel
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
class ClothingSeg:
|
||||
@@ -64,9 +64,9 @@ class ClothingSeg:
|
||||
if image_type == "sketch":
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
seg_mask = get_seg_result(1, image[:, :, :3])
|
||||
seg_mask = get_seg_result(image[:, :, :3])
|
||||
else:
|
||||
seg_mask = get_seg_result(1, image[:, :, :3])
|
||||
seg_mask = get_seg_result(image[:, :, :3])
|
||||
temp = seg_mask != 0.0
|
||||
mask = (255 * (temp + 0).astype(np.uint8))
|
||||
x_min, y_min, x_max, y_max = get_bounding_box(mask)
|
||||
|
||||
@@ -12,7 +12,8 @@ from PIL import Image
|
||||
from minio import Minio, S3Error
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
|
||||
from app.core.config import PS_RABBITMQ_QUEUES
|
||||
from app.core.config import settings
|
||||
from app.schemas.comfyui_i2v import ComfyuiFLF2VModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
|
||||
@@ -305,13 +306,14 @@ workflow_json = {
|
||||
|
||||
class ComfyUIServerFLF2V:
|
||||
def __init__(self, request_data):
|
||||
self.pose_transform_data = None
|
||||
self.start_image_url = request_data.start_image_url
|
||||
self.end_image_url = request_data.end_image_url
|
||||
self.prompt = request_data.prompt
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
def get_result(self):
|
||||
workflow_json['6']['inputs']['text'] = self.prompt
|
||||
@@ -341,7 +343,7 @@ class ComfyUIServerFLF2V:
|
||||
# 1. 提交任务
|
||||
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||
if not prompt_response:
|
||||
return
|
||||
return None
|
||||
|
||||
prompt_id = prompt_response.get("prompt_id")
|
||||
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||
@@ -361,6 +363,7 @@ class ComfyUIServerFLF2V:
|
||||
}
|
||||
logger.info(file_list)
|
||||
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||
return None
|
||||
|
||||
def download_from_minio_in_memory(self, image_url):
|
||||
bucket = image_url.split('/')[0]
|
||||
@@ -391,8 +394,9 @@ class ComfyUIServerFLF2V:
|
||||
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
|
||||
return None, None
|
||||
|
||||
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
|
||||
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
@staticmethod
|
||||
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
|
||||
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
|
||||
data = {
|
||||
"overwrite": "true",
|
||||
@@ -430,7 +434,7 @@ class ComfyUIServerFLF2V:
|
||||
# 1. 从 ComfyUI 获取视频二进制数据
|
||||
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||
if not mp4_bytes:
|
||||
return
|
||||
return None
|
||||
|
||||
# 2. 准备进行视频处理
|
||||
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||
@@ -518,7 +522,7 @@ class ComfyUIServerFLF2V:
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||
|
||||
# 推送消息
|
||||
if not DEBUG:
|
||||
if not settings.DEBUG:
|
||||
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||
logger.info(
|
||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||
@@ -530,13 +534,14 @@ class ComfyUIServerFLF2V:
|
||||
return None
|
||||
|
||||
# --- 辅助函数:提交任务到队列 ---
|
||||
def queue_prompt(self, prompt, client_id):
|
||||
@staticmethod
|
||||
def queue_prompt(prompt, client_id):
|
||||
"""向 ComfyUI 提交工作流提示。"""
|
||||
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
|
||||
# 提交任务到 /prompt 端点
|
||||
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
# print(f"-------------{response.text}")
|
||||
# print(f"------------{client_id}")
|
||||
|
||||
@@ -547,9 +552,10 @@ class ComfyUIServerFLF2V:
|
||||
logger.warning(response.text)
|
||||
return None
|
||||
|
||||
def poll_history(self, prompt_id, interval_seconds=5):
|
||||
@staticmethod
|
||||
def poll_history(prompt_id, interval_seconds=5):
|
||||
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
|
||||
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||
|
||||
@@ -574,7 +580,8 @@ class ComfyUIServerFLF2V:
|
||||
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||
pass
|
||||
|
||||
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
|
||||
@staticmethod
|
||||
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
|
||||
"""
|
||||
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||
|
||||
@@ -586,7 +593,7 @@ class ComfyUIServerFLF2V:
|
||||
返回:
|
||||
- 视频文件的二进制内容 (bytes) 或 None。
|
||||
"""
|
||||
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
|
||||
params = {
|
||||
"filename": filename,
|
||||
"subfolder": subfolder,
|
||||
|
||||
@@ -12,8 +12,8 @@ from PIL import Image
|
||||
from minio import Minio, S3Error
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
|
||||
from app.schemas.comfyui_i2v import ComfyuiPose2VModel, ComfyuiI2VModel
|
||||
from app.core.config import PS_RABBITMQ_QUEUES, settings
|
||||
from app.schemas.comfyui_i2v import ComfyuiI2VModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -293,13 +293,14 @@ workflow_json = {
|
||||
|
||||
class ComfyUIServerI2V:
|
||||
def __init__(self, request_data):
|
||||
self.pose_transform_data = None
|
||||
self.image_url = request_data.image_url
|
||||
self.prompt = request_data.prompt
|
||||
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
def get_result(self):
|
||||
workflow_json['93']['inputs']['text'] = self.prompt
|
||||
@@ -319,7 +320,7 @@ class ComfyUIServerI2V:
|
||||
# 1. 提交任务
|
||||
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||
if not prompt_response:
|
||||
return
|
||||
return None
|
||||
prompt_id = prompt_response.get("prompt_id")
|
||||
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||
outputs = self.poll_history(prompt_id)
|
||||
@@ -339,6 +340,7 @@ class ComfyUIServerI2V:
|
||||
}
|
||||
logger.info(file_list)
|
||||
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||
return None
|
||||
|
||||
def download_from_minio_in_memory(self, image_url):
|
||||
bucket = image_url.split('/')[0]
|
||||
@@ -369,8 +371,9 @@ class ComfyUIServerI2V:
|
||||
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
|
||||
return None, None
|
||||
|
||||
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
|
||||
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
@staticmethod
|
||||
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
|
||||
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
|
||||
data = {
|
||||
"overwrite": "true",
|
||||
@@ -408,7 +411,7 @@ class ComfyUIServerI2V:
|
||||
# 1. 从 ComfyUI 获取视频二进制数据
|
||||
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||
if not mp4_bytes:
|
||||
return
|
||||
return None
|
||||
|
||||
# 2. 准备进行视频处理
|
||||
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||
@@ -496,7 +499,7 @@ class ComfyUIServerI2V:
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||
|
||||
# 推送消息
|
||||
if not DEBUG:
|
||||
if not settings.DEBUG:
|
||||
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||
logger.info(
|
||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||
@@ -508,13 +511,14 @@ class ComfyUIServerI2V:
|
||||
return None
|
||||
|
||||
# --- 辅助函数:提交任务到队列 ---
|
||||
def queue_prompt(self, prompt, client_id):
|
||||
@staticmethod
|
||||
def queue_prompt(prompt, client_id):
|
||||
"""向 ComfyUI 提交工作流提示。"""
|
||||
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
|
||||
# 提交任务到 /prompt 端点
|
||||
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
# print(f"-------------{response.text}")
|
||||
# print(f"------------{client_id}")
|
||||
|
||||
@@ -525,9 +529,10 @@ class ComfyUIServerI2V:
|
||||
logger.warning(response.text)
|
||||
return None
|
||||
|
||||
def poll_history(self, prompt_id, interval_seconds=5):
|
||||
@staticmethod
|
||||
def poll_history(prompt_id, interval_seconds=5):
|
||||
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
|
||||
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||
|
||||
@@ -552,7 +557,8 @@ class ComfyUIServerI2V:
|
||||
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||
pass
|
||||
|
||||
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
|
||||
@staticmethod
|
||||
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
|
||||
"""
|
||||
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||
|
||||
@@ -564,7 +570,7 @@ class ComfyUIServerI2V:
|
||||
返回:
|
||||
- 视频文件的二进制内容 (bytes) 或 None。
|
||||
"""
|
||||
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
|
||||
params = {
|
||||
"filename": filename,
|
||||
"subfolder": subfolder,
|
||||
|
||||
@@ -13,7 +13,7 @@ from PIL import Image
|
||||
from minio import Minio, S3Error
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT, REDIS_DB, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
|
||||
from app.core.config import settings
|
||||
from app.schemas.comfyui_i2v import ComfyuiPose2VModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
|
||||
@@ -371,11 +371,11 @@ class ComfyUIServerPose2V:
|
||||
self.pose_num = request_data.pose_id
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True)
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
def get_result(self):
|
||||
workflow_json['174']['inputs']['file'] = video_map[self.pose_num]
|
||||
@@ -389,7 +389,7 @@ class ComfyUIServerPose2V:
|
||||
# 1. 提交任务
|
||||
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||
if not prompt_response:
|
||||
return
|
||||
return None
|
||||
|
||||
prompt_id = prompt_response.get("prompt_id")
|
||||
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||
@@ -411,6 +411,7 @@ class ComfyUIServerPose2V:
|
||||
}
|
||||
logger.info(file_list)
|
||||
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||
return None
|
||||
|
||||
def read_tasks_status(self):
|
||||
status_data = self.redis_client.get(self.tasks_id)
|
||||
@@ -492,8 +493,9 @@ class ComfyUIServerPose2V:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 发生未知错误: {e}")
|
||||
|
||||
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
|
||||
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
@staticmethod
|
||||
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
|
||||
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
|
||||
data = {
|
||||
"overwrite": "true",
|
||||
@@ -531,7 +533,7 @@ class ComfyUIServerPose2V:
|
||||
# 1. 从 ComfyUI 获取视频二进制数据
|
||||
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||
if not mp4_bytes:
|
||||
return
|
||||
return None
|
||||
|
||||
# 2. 准备进行视频处理
|
||||
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||
@@ -619,10 +621,10 @@ class ComfyUIServerPose2V:
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||
|
||||
# 推送消息
|
||||
if not DEBUG:
|
||||
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||
if not settings.DEBUG:
|
||||
publish_status(json.dumps(self.pose_transform_data), settings.COMFYUI_SERVER_ADDRESS)
|
||||
logger.info(
|
||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||
f" [x] Sent to: {settings.COMFYUI_SERVER_ADDRESS} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||
|
||||
return "\n🎉 所有任务完成!"
|
||||
|
||||
@@ -631,13 +633,15 @@ class ComfyUIServerPose2V:
|
||||
return None
|
||||
|
||||
# --- 辅助函数:提交任务到队列 ---
|
||||
def queue_prompt(self, prompt, client_id):
|
||||
@staticmethod
|
||||
def queue_prompt(prompt, client_id):
|
||||
"""向 ComfyUI 提交工作流提示。"""
|
||||
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
|
||||
# 提交任务到 /prompt 端点
|
||||
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
# noinspection HttpUrlsUsage
|
||||
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
# print(f"-------------{response.text}")
|
||||
# print(f"------------{client_id}")
|
||||
|
||||
@@ -648,9 +652,10 @@ class ComfyUIServerPose2V:
|
||||
logger.warning(response.text)
|
||||
return None
|
||||
|
||||
def poll_history(self, prompt_id, interval_seconds=5):
|
||||
@staticmethod
|
||||
def poll_history(prompt_id, interval_seconds=5):
|
||||
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
|
||||
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||
|
||||
@@ -675,7 +680,8 @@ class ComfyUIServerPose2V:
|
||||
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||
pass
|
||||
|
||||
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
|
||||
@staticmethod
|
||||
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
|
||||
"""
|
||||
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||
|
||||
@@ -687,7 +693,7 @@ class ComfyUIServerPose2V:
|
||||
返回:
|
||||
- 视频文件的二进制内容 (bytes) 或 None。
|
||||
"""
|
||||
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
|
||||
params = {
|
||||
"filename": filename,
|
||||
"subfolder": subfolder,
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def show(img, win_name="temp"):
|
||||
cv2.imshow(win_name, img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
def crop(img):
|
||||
mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2)
|
||||
img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340]
|
||||
return img_roi
|
||||
|
||||
|
||||
class Layer(object):
|
||||
def __init__(self):
|
||||
self._layer = []
|
||||
|
||||
@property
|
||||
def layer(self):
|
||||
return self._layer
|
||||
|
||||
def insert(self, layer_instance):
|
||||
if layer_instance['name'] == 'body':
|
||||
self._body = layer_instance
|
||||
self._layer.append(layer_instance)
|
||||
|
||||
def sort(self, priority):
|
||||
self._layer.sort(key=lambda x: priority[x['name']])
|
||||
|
||||
# def merge(self, cfg):
|
||||
# """
|
||||
# opencv shape order (height, width, channel)
|
||||
# image coordinate system:
|
||||
# |------------->x (width)
|
||||
# |
|
||||
# |
|
||||
# |
|
||||
# y (height)
|
||||
# Returns:
|
||||
#
|
||||
#
|
||||
# """
|
||||
# base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0))
|
||||
# for layer in self._layer:
|
||||
# y, x = layer['position']
|
||||
# base_image.paste(layer['image'], (x, y), layer['image'])
|
||||
# # base_image.show()
|
||||
#
|
||||
# for x in self._layer:
|
||||
# if np.all(x['mask'] == 0):
|
||||
# continue
|
||||
# # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask)
|
||||
# roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x)
|
||||
# temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2)
|
||||
# tmp1 = (roi * (temp_bg / 255)).astype(np.uint8)
|
||||
# temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2)
|
||||
# tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8)
|
||||
#
|
||||
# roi[:] = cv2.add(tmp1, tmp2)
|
||||
# # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA),
|
||||
# # win_name=x.get('name'))
|
||||
# # crop image and get the central part
|
||||
# if cfg.get('basic')['self_template'] == False:
|
||||
# dst_roi = crop(dst)
|
||||
# else:
|
||||
# dst_roi = dst
|
||||
# return dst_roi, signal
|
||||
#
|
||||
# @staticmethod
|
||||
# def get_roi(dst, image):
|
||||
# signal = False
|
||||
# dst_y, dst_x = dst.shape[:2]
|
||||
# roi_height, roi_width = image['mask'].shape
|
||||
# roi_y0, roi_x0 = image['position']
|
||||
#
|
||||
# if roi_y0 < 0:
|
||||
# roi_yin = 0
|
||||
# mask_yin = -roi_y0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_yin = roi_y0
|
||||
# mask_yin = 0
|
||||
# if roi_y0 + roi_height > dst_y:
|
||||
# roi_yout = dst_y
|
||||
# mask_yout = dst_y - roi_y0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_yout = roi_height + roi_y0
|
||||
# mask_yout = roi_height
|
||||
# # x part
|
||||
# if roi_x0 < 0:
|
||||
# roi_xin = 0
|
||||
# mask_xin = -roi_x0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_xin = roi_x0
|
||||
# mask_xin = 0
|
||||
# if roi_x0 + roi_width > dst_x:
|
||||
# roi_xout = dst_x
|
||||
# mask_xout = dst_x - roi_x0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_xout = roi_width + roi_x0
|
||||
# mask_xout = roi_width
|
||||
#
|
||||
# roi = dst[roi_yin: roi_yout, roi_xin: roi_xout]
|
||||
# roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout]
|
||||
# roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout]
|
||||
# return roi, roi_mask, roi_image, signal
|
||||
@@ -1,45 +0,0 @@
|
||||
class Priority(object):
|
||||
"""Item layer priority levels.
|
||||
"""
|
||||
|
||||
def __init__(self, item_list):
|
||||
self._priority = dict(
|
||||
earring_front=99,
|
||||
bag_front=98,
|
||||
hairstyle_front=97,
|
||||
outwear_front=20,
|
||||
bottoms_front=19,
|
||||
dress_front=18,
|
||||
blouse_front=17,
|
||||
skirt_front=16,
|
||||
trousers_front=15,
|
||||
tops_front=14,
|
||||
shoes_right=1,
|
||||
shoes_left=1,
|
||||
body=0,
|
||||
tops_back=-14,
|
||||
trousers_back=-15,
|
||||
skirt_back=-16,
|
||||
blouse_back=-17,
|
||||
dress_back=-18,
|
||||
bottoms_back=-19,
|
||||
outwear_back=-20,
|
||||
hairstyle_back=-97,
|
||||
bag_back=-98,
|
||||
earring_back=-99,
|
||||
)
|
||||
self.clothing_start_num = 10
|
||||
if not isinstance(item_list, list):
|
||||
raise ValueError('item_list must be a list!')
|
||||
for cate in item_list:
|
||||
cate = cate.lower()
|
||||
if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
raise ValueError(f'Item type error. Cannot recognize {cate}')
|
||||
for i, cate in enumerate(item_list):
|
||||
cate = cate.lower()
|
||||
self._priority[f'{cate}_front'] = self.clothing_start_num - i
|
||||
self._priority[f'{cate}_back'] = -(self.clothing_start_num - i)
|
||||
|
||||
@property
|
||||
def priority(self):
|
||||
return self._priority
|
||||
@@ -1,16 +0,0 @@
|
||||
from .builder import ITEMS, build_item
|
||||
from .clothing import Clothing # 4.0 sec
|
||||
from .body import Body
|
||||
from .top import Top, Blouse, Outwear, Dress
|
||||
from .bottom import Bottom, Trousers, Skirt
|
||||
from .shoes import Shoes
|
||||
from .bag import Bag
|
||||
from .others import Hairstyle, Earring
|
||||
|
||||
__all__ = [
|
||||
'ITEMS', 'build_item',
|
||||
'Clothing', 'Body',
|
||||
'Top', 'Blouse', 'Outwear', 'Dress',
|
||||
'Bottom', 'Trousers', 'Skirt',
|
||||
'Shoes', 'Bag', 'Hairstyle', 'Earring'
|
||||
]
|
||||
@@ -1,45 +0,0 @@
|
||||
import random
|
||||
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bag(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Bag, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
|
||||
"""
|
||||
align left
|
||||
Args:
|
||||
keypoint_type: string, "hand_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (y', x')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
location = random.choice(seq=['left', 'right'])
|
||||
if location == 'left':
|
||||
side_indicator = f'{keypoint_type}_left'
|
||||
else:
|
||||
side_indicator = f'{keypoint_type}_right'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale),
|
||||
body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale))
|
||||
return start_point
|
||||
@@ -1,36 +0,0 @@
|
||||
import cv2
|
||||
|
||||
from .builder import ITEMS
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Body(object):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']),
|
||||
# dict(type='ImageShow', key=['body_image', "body_mask"])
|
||||
]
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.result = dict()
|
||||
|
||||
def process(self):
|
||||
self.pipeline(self.result)
|
||||
pass
|
||||
|
||||
def organize(self, layer):
|
||||
body_layer = dict(priority=0,
|
||||
name=type(self).__name__.lower(),
|
||||
image=self.result['body_image'],
|
||||
image_url=self.result['image_url'],
|
||||
mask_image=None,
|
||||
mask_url=None,
|
||||
sacle=1,
|
||||
# mask=self.result['body_mask'],
|
||||
position=(0, 0))
|
||||
layer.insert(body_layer)
|
||||
|
||||
@staticmethod
|
||||
def show(img):
|
||||
cv2.imshow('', img)
|
||||
cv2.waitKey(0)
|
||||
@@ -1,39 +0,0 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bottom(Clothing):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
if pipeline is None:
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
# dict(type='Segmentation'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Bottom, self).__init__(**kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Trousers(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Trousers, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Skirt(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Skirt, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bottoms(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Bottoms, self).__init__(pipeline, **kwargs)
|
||||
@@ -1,9 +0,0 @@
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
ITEMS = Registry('item')
|
||||
PIPELINES = Registry('pipeline')
|
||||
|
||||
|
||||
def build_item(cfg, default_args=None):
|
||||
item = build_from_cfg(cfg, ITEMS, default_args)
|
||||
return item
|
||||
@@ -1,100 +0,0 @@
|
||||
import cv2
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
from .builder import ITEMS
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Clothing(object):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.result = dict(name=type(self).__name__.lower(), **kwargs)
|
||||
|
||||
def process(self):
|
||||
self.pipeline(self.result)
|
||||
|
||||
def apply_scale(self, img):
|
||||
scale = self.result['scale']
|
||||
height, width = img.shape[0: 2]
|
||||
if len(img.shape) > 2:
|
||||
height, width = img.shape[0: 2]
|
||||
scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
|
||||
return scaled_img
|
||||
|
||||
def organize(self, layer):
|
||||
start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"])
|
||||
|
||||
front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None),
|
||||
name=f'{type(self).__name__.lower()}_front',
|
||||
image=self.result["front_image"],
|
||||
# mask_image=self.result['front_mask_image'],
|
||||
image_url=self.result['front_image_url'],
|
||||
mask_url=self.result['mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=self.result["resize_scale"],
|
||||
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
|
||||
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
|
||||
pattern_image_url=self.result['pattern_image_url'],
|
||||
pattern_image=self.result['pattern_image']
|
||||
|
||||
)
|
||||
layer.insert(front_layer)
|
||||
|
||||
back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None),
|
||||
name=f'{type(self).__name__.lower()}_back',
|
||||
image=self.result["back_image"],
|
||||
# mask_image=self.result['back_mask_image'],
|
||||
image_url=self.result['back_image_url'],
|
||||
mask_url=self.result['mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=self.result["resize_scale"],
|
||||
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
|
||||
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
|
||||
pattern_image_url=self.result['pattern_image_url'],
|
||||
)
|
||||
layer.insert(back_layer)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
|
||||
"""
|
||||
Align left
|
||||
Args:
|
||||
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale + offset
|
||||
y' = x_body - x1 * scale + offset
|
||||
|
||||
"""
|
||||
|
||||
side_indicator = f'{keypoint_type}_left'
|
||||
|
||||
# if keypoint_type == "ear_point":
|
||||
# start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
|
||||
# body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
|
||||
# else:
|
||||
# start_point = (
|
||||
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
|
||||
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
|
||||
# )
|
||||
|
||||
# milvus_DB_keypoint_cache:
|
||||
start_point = (
|
||||
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
|
||||
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
|
||||
)
|
||||
# start_point = (
|
||||
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
|
||||
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
|
||||
# )
|
||||
|
||||
return start_point
|
||||
@@ -1,59 +0,0 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Hairstyle(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Hairstyle, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
|
||||
"""
|
||||
align up
|
||||
Args:
|
||||
keypoint_type: string, "head_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
side_indicator = f'{keypoint_type}_up'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
# logging.info(clothes_point[side_indicator])
|
||||
|
||||
start_point = (
|
||||
int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)),
|
||||
int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale))
|
||||
)
|
||||
return start_point
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Earring(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Earring, self).__init__(**kwargs)
|
||||
@@ -1,19 +0,0 @@
|
||||
from .compose import Compose
|
||||
from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow
|
||||
from .keypoints import KeypointDetection
|
||||
from .segmentation import Segmentation
|
||||
from .painting import Painting, PrintPainting
|
||||
from .scale import Scaling
|
||||
from .contour_detection import ContourDetection
|
||||
from .split import Split
|
||||
|
||||
__all__ = [
|
||||
'Compose',
|
||||
'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow',
|
||||
'KeypointDetection',
|
||||
'Segmentation',
|
||||
'Painting', 'PrintPainting',
|
||||
'Scaling',
|
||||
'ContourDetection',
|
||||
'split',
|
||||
]
|
||||
@@ -1,36 +0,0 @@
|
||||
import collections
|
||||
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Compose(object):
|
||||
def __init__(self, transforms):
|
||||
assert isinstance(transforms, collections.abc.Sequence)
|
||||
self.transforms = []
|
||||
for transform in transforms:
|
||||
if isinstance(transform, dict):
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
self.transforms.append(transform)
|
||||
elif callable(transform):
|
||||
self.transforms.append(transform)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict')
|
||||
|
||||
def __call__(self, data):
|
||||
"""Call function to apply transforms sequentially.
|
||||
|
||||
Args:
|
||||
data (dict): A result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Transformed data.
|
||||
"""
|
||||
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
@@ -1,59 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ContourDetection(object):
|
||||
def __init__(self):
|
||||
# logging.info("ContourDetection run ")
|
||||
pass
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
# shoe diff
|
||||
if result['name'] == 'shoes':
|
||||
Contour = self.get_contours(result['image'])
|
||||
Mask = np.zeros(result['image'].shape[:2], np.uint8)
|
||||
for i in range(2):
|
||||
Max_contour = Contour[i]
|
||||
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
|
||||
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
|
||||
cv2.drawContours(Mask, [Approx], -1, 255, -1)
|
||||
if result['pre_mask'] is None:
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
else:
|
||||
Contour = self.get_contours(result['image'])
|
||||
Mask = np.zeros(result['image'].shape[:2], np.uint8)
|
||||
if len(Contour):
|
||||
Max_contour = Contour[0]
|
||||
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
|
||||
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
|
||||
cv2.drawContours(Mask, [Approx], -1, 255, -1)
|
||||
else:
|
||||
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
|
||||
# TODO 修复部分图片出现透明的情况 下版本上线
|
||||
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
|
||||
# Mask = cv2.bitwise_not(Mask)
|
||||
if result['pre_mask'] is None:
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
result['front_mask'] = result['mask']
|
||||
result['back_mask'] = result['mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_contours(image):
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
Edge = cv2.Canny(gray, 10, 150)
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
|
||||
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
return Contour
|
||||
@@ -1,140 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.decorator import RunTime, ClassCallRunTime
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.design_ensemble import get_keypoint_result
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class KeypointDetection(object):
|
||||
"""
|
||||
path here: abstract path
|
||||
"""
|
||||
|
||||
# def __init__(self):
|
||||
# self.client = MilvusClient(
|
||||
# uri="http://10.1.1.240:19530",
|
||||
# token="root:Milvus",
|
||||
# db_name=MILVUS_ALIAS
|
||||
# )
|
||||
|
||||
# def __del__(self):
|
||||
# start_time = time.time()
|
||||
# self.client.close()
|
||||
# print(f"client close time : {time.time() - start_time}")
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
# logging.info("KeypointDetection run ")
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
|
||||
keypoint_cache = self.keypoint_cache(result, site)
|
||||
# 取消向量查询 直接过模型推理
|
||||
# keypoint_cache = False
|
||||
|
||||
if keypoint_cache is False:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
else:
|
||||
result['clothes_keypoint'] = keypoint_cache
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def infer_keypoint_result(result):
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
start_time = time.time()
|
||||
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
|
||||
# logging.info(f"infer keypoint time : {time.time() - start_time}")
|
||||
return keypoint_infer_result, site
|
||||
|
||||
@staticmethod
|
||||
# @ RunTime
|
||||
def save_keypoint_cache(keypoint_id, cache, site):
|
||||
if site == "down":
|
||||
zeros = np.zeros(20, dtype=int)
|
||||
result = np.concatenate([zeros, cache.flatten()])
|
||||
else:
|
||||
zeros = np.zeros(4, dtype=int)
|
||||
result = np.concatenate([cache.flatten(), zeros])
|
||||
# 取消向量保存 直接拿结果
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": site,
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
# start_time = time.time()
|
||||
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
client.close()
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logging.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
@staticmethod
|
||||
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
||||
if site == "up":
|
||||
# 需要的是up 即推理出来的是up 那么查询的就是down
|
||||
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||
else:
|
||||
# 需要的是down 即推理出来的是down 那么查询的就是up
|
||||
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": "all",
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
start_time = time.time()
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# mr = collection.upsert(data)
|
||||
client.upsert(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
data=data
|
||||
)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logging.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
# @ RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
keypoint_id = result['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
# ids=[keypoint_id],
|
||||
filter=f"keypoint_id == {keypoint_id}",
|
||||
output_fields=['keypoint_vector', 'keypoint_site']
|
||||
)
|
||||
if len(res) == 0:
|
||||
# 没有结果 直接推理拿结果 并保存
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
||||
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
||||
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||
elif res[0]["keypoint_site"] != site:
|
||||
# 需要的类型和查询到的不一致,则更新类型为all
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
||||
except Exception as e:
|
||||
logging.info(f"search keypoint cache milvus error {e}")
|
||||
return False
|
||||
@@ -1,134 +0,0 @@
|
||||
import cv2
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadImageFromFile(object):
|
||||
def __init__(self, path, color=None, print_dict=None):
|
||||
self.path = path
|
||||
self.color = color
|
||||
self.print_dict = print_dict
|
||||
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
result['image'], result['pre_mask'] = self.read_image(self.path)
|
||||
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
result['keypoint'] = self.get_keypoint(result['name'])
|
||||
result['path'] = self.path
|
||||
result['img_shape'] = result['image'].shape
|
||||
result['ori_shape'] = result['image'].shape
|
||||
result['color'] = self.color if self.color is not None else None
|
||||
result['print_dict'] = self.print_dict
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_keypoint(name):
|
||||
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
|
||||
keypoint = 'shoulder'
|
||||
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
|
||||
keypoint = 'waistband'
|
||||
elif name == 'bag':
|
||||
keypoint = 'hand_point'
|
||||
elif name == 'shoes':
|
||||
keypoint = 'toe'
|
||||
elif name == 'hairstyle':
|
||||
keypoint = 'head_point'
|
||||
elif name == 'earring':
|
||||
keypoint = 'ear_point'
|
||||
else:
|
||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||
f"bag, shoes, hairstyle, earring.")
|
||||
return keypoint
|
||||
|
||||
@staticmethod
|
||||
def read_image(image_path):
|
||||
image_mask = None
|
||||
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
if image.shape[2] == 4: # 如果是四通道 mask
|
||||
image_mask = image[:, :, 3]
|
||||
image = image[:, :, :3]
|
||||
|
||||
if image.shape[:2] <= (50, 50):
|
||||
# 计算新尺寸
|
||||
new_size = (image.shape[1] * 2, image.shape[0] * 2)
|
||||
# 调整大小
|
||||
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
|
||||
return image, image_mask
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadBodyImageFromFile(object):
|
||||
def __init__(self, body_path):
|
||||
self.body_path = body_path
|
||||
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
result["image_url"] = result['body_path'] = self.body_path
|
||||
result["name"] = "mannequin"
|
||||
# if not result['image_url'].lower().endswith(".png"):
|
||||
# bucket = self.body_path.split("/", 1)[0]
|
||||
# object_name = self.body_path.split("/", 1)[1]
|
||||
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
||||
# image = self.minioClient.get_object(bucket, object_name)
|
||||
# image = Image.open(io.BytesIO(image.data))
|
||||
# image = image.convert("RGBA")
|
||||
# data = image.getdata()
|
||||
# #
|
||||
# new_data = []
|
||||
# for item in data:
|
||||
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
|
||||
# new_data.append((255, 255, 255, 0))
|
||||
# else:
|
||||
# new_data.append(item)
|
||||
# image.putdata(new_data)
|
||||
# image_data = io.BytesIO()
|
||||
# image.save(image_data, format='PNG')
|
||||
# image_data.seek(0)
|
||||
# image_bytes = image_data.read()
|
||||
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
# self.body_path = image_path
|
||||
# result["image_url"] = result['body_path'] = self.body_path
|
||||
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
|
||||
# put_image_time = time.time()
|
||||
# result['body_image'] = Image.open(io.BytesIO(response.read()))
|
||||
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
|
||||
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
|
||||
return result
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ImageShow(object):
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
import matplotlib.pyplot as plt
|
||||
if isinstance(self.key, list):
|
||||
for key in self.key:
|
||||
plt.imshow(result[key])
|
||||
plt.title(key)
|
||||
plt.show()
|
||||
elif isinstance(self.key, str):
|
||||
img = self._resize_img(result[self.key])
|
||||
cv2.imshow(self.key, img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
raise TypeError(f'key should be string but got type {type(self.key)}.')
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _resize_img(img):
|
||||
shape = img.shape
|
||||
if shape[0] > 400 or shape[1] > 400:
|
||||
ratio = min(400 / shape[0], 400 / shape[1])
|
||||
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
|
||||
return img
|
||||
@@ -1,605 +0,0 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Painting(object):
|
||||
def __init__(self, painting_flag=True):
|
||||
self.painting_flag = painting_flag
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
|
||||
dim_image_h, dim_image_w = result['image'].shape[0:2]
|
||||
if "gradient" in result.keys() and result['gradient'] != "":
|
||||
bucket_name = result['gradient'].split('/')[0]
|
||||
object_name = result['gradient'][result['gradient'].find('/') + 1:]
|
||||
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
pattern = self.get_pattern(result['color'])
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
result['alpha'] = 100 / 255.0
|
||||
else:
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = result['image'] * (closed_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_gradient(bucket_name, object_name):
|
||||
# image_data = minio_client.get_object(bucket_name, object_name)
|
||||
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
|
||||
|
||||
# 从数据流中读取图像
|
||||
# image_bytes = image_data.read()
|
||||
|
||||
# 将图像数据转换为numpy数组
|
||||
# image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
|
||||
|
||||
# 使用OpenCV解码图像数组
|
||||
# image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def crop_image(image, image_size_h, image_size_w):
|
||||
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
|
||||
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
return pattern
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class PrintPainting(object):
|
||||
def __init__(self, print_flag=True):
|
||||
self.print_flag = print_flag
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
single_print = result['print']['single']
|
||||
overall_print = result['print']['overall']
|
||||
element_print = result['print']['element']
|
||||
result['single_image'] = None
|
||||
result['print_image'] = None
|
||||
if overall_print['print_path_list']:
|
||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||
result['print_image'] = result['pattern_image']
|
||||
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
|
||||
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||
|
||||
# resize 到sketch大小
|
||||
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
else:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
|
||||
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
|
||||
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
|
||||
|
||||
if single_print['print_path_list']:
|
||||
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(single_print['print_path_list'])):
|
||||
image, image_mode = self.read_image(single_print['print_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
|
||||
if element_print['element_path_list']:
|
||||
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(element_print['element_path_list'])):
|
||||
image, image_mode = self.read_image(element_print['element_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
# TODO element 丢失信息
|
||||
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
||||
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
|
||||
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
|
||||
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
|
||||
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
|
||||
mask_inv = cv2.bitwise_not(mask_)
|
||||
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
|
||||
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
|
||||
print_background = img1_bg + img2_fg
|
||||
return print_background
|
||||
|
||||
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
|
||||
if print_trigger:
|
||||
print_ = self.get_print(print_dict)
|
||||
painting_dict['Trigger'] = not is_single
|
||||
painting_dict['location'] = print_['location']
|
||||
single_mask_inv_print = self.get_mask_inv(print_['image'])
|
||||
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||
if not is_single:
|
||||
self.random_seed = random.randint(0, 1000)
|
||||
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
||||
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
|
||||
return painting_dict
|
||||
|
||||
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
||||
tile = None
|
||||
if not trigger:
|
||||
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
if len(pattern.shape) == 2:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
|
||||
if len(pattern.shape) == 3:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
|
||||
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
|
||||
return tile
|
||||
|
||||
def get_mask_inv(self, print_):
|
||||
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
|
||||
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
mask_inv = cv2.inRange(print_tile, lower, upper)
|
||||
return mask_inv
|
||||
else:
|
||||
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
|
||||
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
|
||||
return mask_inv
|
||||
|
||||
@staticmethod
|
||||
def printpaint(result, painting_dict, print_=False):
|
||||
|
||||
if print_ and painting_dict['Trigger']:
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
||||
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
||||
else:
|
||||
print_mask = result['mask']
|
||||
img_fg = result['final_image']
|
||||
if print_ and not painting_dict['Trigger']:
|
||||
index_ = None
|
||||
try:
|
||||
index_ = len(painting_dict['location'])
|
||||
except:
|
||||
assert f'there must be parameter of location if choose IfSingle'
|
||||
|
||||
for i in range(index_):
|
||||
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
||||
|
||||
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
||||
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
||||
|
||||
change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
||||
# problem in change_mask
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
print_image = cv2.add(img_bg, img_fg)
|
||||
return print_image
|
||||
|
||||
@staticmethod
|
||||
def get_print(print_dict):
|
||||
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
|
||||
print_dict['scale'] = 0.3
|
||||
else:
|
||||
print_dict['scale'] = print_dict['print_scale_list'][0]
|
||||
|
||||
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
|
||||
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
|
||||
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL")
|
||||
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
||||
if image.mode == "RGBA":
|
||||
new_background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
new_background.paste(image, mask=image.split()[3])
|
||||
image = new_background
|
||||
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||
return print_dict
|
||||
|
||||
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
|
||||
print_w = print_shape[1]
|
||||
print_h = print_shape[0]
|
||||
|
||||
random.seed(self.random_seed)
|
||||
# logging.info(f'overall print location : {location}')
|
||||
# x_offset = random.randint(0, image.shape[0] - image_size_h)
|
||||
# y_offset = random.randint(0, image.shape[1] - image_size_w)
|
||||
|
||||
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||
x_offset = print_w - int(location[0][1] % print_w)
|
||||
y_offset = print_w - int(location[0][0] % print_h)
|
||||
|
||||
# y_offset = int(location[0][0])
|
||||
# x_offset = int(location[0][1])
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||
elif len(image.shape) == 3:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_low_high_lab(Lab_value, L=False):
|
||||
if L:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
else:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
return high, low
|
||||
|
||||
@staticmethod
|
||||
def img_rotate(image, angel, scale):
|
||||
"""顺时针旋转图像任意角度
|
||||
|
||||
Args:
|
||||
image (np.array): [原始图像]
|
||||
angel (float): [逆时针旋转的角度]
|
||||
|
||||
Returns:
|
||||
[array]: [旋转后的图像]
|
||||
"""
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
# if type(angel) is not int:
|
||||
# angel = 0
|
||||
M = cv2.getRotationMatrix2D(center, -angel, scale)
|
||||
# 调整旋转后的图像长宽
|
||||
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
|
||||
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
|
||||
M[0, 2] += (rotated_w - w) // 2
|
||||
M[1, 2] += (rotated_h - h) // 2
|
||||
# 旋转图像
|
||||
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
|
||||
|
||||
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
|
||||
# return rotated_img, (0, 0)
|
||||
|
||||
@staticmethod
|
||||
def rotate_crop_image(img, angle, crop):
|
||||
"""
|
||||
angle: 旋转的角度
|
||||
crop: 是否需要进行裁剪,布尔向量
|
||||
"""
|
||||
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
|
||||
w, h = img.shape[:2]
|
||||
# 旋转角度的周期是360°
|
||||
angle %= 360
|
||||
# 计算仿射变换矩阵
|
||||
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
||||
# 得到旋转后的图像
|
||||
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
|
||||
|
||||
# 如果需要去除黑边
|
||||
if crop:
|
||||
# 裁剪角度的等效周期是180°
|
||||
angle_crop = angle % 180
|
||||
if angle > 90:
|
||||
angle_crop = 180 - angle_crop
|
||||
# 转化角度为弧度
|
||||
theta = angle_crop * np.pi / 180
|
||||
# 计算高宽比
|
||||
hw_ratio = float(h) / float(w)
|
||||
# 计算裁剪边长系数的分子项
|
||||
tan_theta = np.tan(theta)
|
||||
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
|
||||
|
||||
# 计算分母中和高宽比相关的项
|
||||
r = hw_ratio if h > w else 1 / hw_ratio
|
||||
# 计算分母项
|
||||
denominator = r * tan_theta + 1
|
||||
# 最终的边长系数
|
||||
crop_mult = numerator / denominator
|
||||
|
||||
# 得到裁剪区域
|
||||
w_crop = int(crop_mult * w)
|
||||
h_crop = int(crop_mult * h)
|
||||
x0 = int((w - w_crop) / 2)
|
||||
y0 = int((h - h_crop) / 2)
|
||||
|
||||
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
|
||||
|
||||
return img_rotated
|
||||
|
||||
@staticmethod
|
||||
def read_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
image = Image.fromarray(image_rgb)
|
||||
image_mode = "RGBA"
|
||||
else:
|
||||
image_mode = "RGB"
|
||||
return image, image_mode
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop(img, target_width, target_height):
|
||||
# 获取原始图像的尺寸
|
||||
original_height, original_width = img.shape[:2]
|
||||
|
||||
# 计算目标尺寸的宽高比
|
||||
target_ratio = target_width / target_height
|
||||
|
||||
# 计算原始图像的宽高比
|
||||
original_ratio = original_width / original_height
|
||||
|
||||
# 调整尺寸
|
||||
if original_ratio > target_ratio:
|
||||
# 原始图像更宽,按高度resize,然后裁剪宽度
|
||||
new_height = target_height
|
||||
new_width = int(original_width * (target_height / original_height))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪宽度
|
||||
start_x = (new_width - target_width) // 2
|
||||
cropped_img = resized_img[:, start_x:start_x + target_width]
|
||||
else:
|
||||
# 原始图像更高,按宽度resize,然后裁剪高度
|
||||
new_width = target_width
|
||||
new_height = int(original_height * (target_width / original_width))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪高度
|
||||
start_y = (new_height - target_height) // 2
|
||||
cropped_img = resized_img[start_y:start_y + target_height, :]
|
||||
|
||||
return cropped_img
|
||||
@@ -1,57 +0,0 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Scaling(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
|
||||
# milvus_db_keypoint_cache
|
||||
distance_clo = math.sqrt(
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
|
||||
+
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2)
|
||||
|
||||
distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
|
||||
# distance_clo = math.sqrt(
|
||||
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2
|
||||
# +
|
||||
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2)
|
||||
#
|
||||
# distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
|
||||
if distance_clo == 0:
|
||||
result['scale'] = 1
|
||||
else:
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'toe':
|
||||
distance_bdy = math.sqrt(
|
||||
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
|
||||
+
|
||||
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
|
||||
)
|
||||
|
||||
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
|
||||
Edge = cv2.Canny(Blur, 10, 200)
|
||||
Edge = cv2.dilate(Edge, None)
|
||||
Edge = cv2.erode(Edge, None)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
|
||||
Max_contour = Contours[0]
|
||||
x, y, w, h = cv2.boundingRect(Max_contour)
|
||||
width = w
|
||||
distance_clo = width
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'hand_point':
|
||||
result['scale'] = result['scale_bag']
|
||||
elif result['keypoint'] == 'ear_point':
|
||||
result['scale'] = result['scale_earrings']
|
||||
return result
|
||||
@@ -1,71 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.design_ensemble import get_seg_result
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Segmentation(object):
|
||||
|
||||
@ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
|
||||
seg_mask = oss_get_image(bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
|
||||
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
|
||||
# 转换颜色空间为 RGB(OpenCV 默认是 BGR)
|
||||
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
|
||||
|
||||
r, g, b = cv2.split(image_rgb)
|
||||
red_mask = r > g
|
||||
green_mask = g > r
|
||||
|
||||
# 创建红色和绿色掩码
|
||||
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
|
||||
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
else:
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
result['seg_result'] = seg_result
|
||||
if not _:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# 处理前片后片
|
||||
temp_front = seg_result == 1.0
|
||||
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = seg_result == 2.0
|
||||
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
# logger.warning("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"加载失败: {e}")
|
||||
return False, None
|
||||
@@ -1,79 +0,0 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.conversion_image import rgb_to_rgba
|
||||
from ...utils.upload_image import upload_png_mask
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Split(object):
|
||||
"""
|
||||
Split image into front and back layer according to the segmentation result
|
||||
"""
|
||||
|
||||
# @ClassCallRunTime
|
||||
# KNet
|
||||
def __call__(self, result):
|
||||
try:
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
front_mask = result['front_mask']
|
||||
back_mask = result['back_mask']
|
||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
||||
rgba_image = cv2.resize(rgba_image, new_size)
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
front_mask = cv2.resize(front_mask, new_size)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
result['front_image'], result["front_image_url"], _ = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=None)
|
||||
|
||||
height, width = front_mask.shape
|
||||
mask_image = np.zeros((height, width, 3))
|
||||
mask_image[front_mask != 0] = [0, 0, 255]
|
||||
|
||||
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
else:
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
result['back_image'] = None
|
||||
result["back_image_url"] = None
|
||||
# result["back_mask_url"] = None
|
||||
# result['back_mask_image'] = None
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
||||
@@ -1,121 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
from ..utils.conversion_image import rgb_to_rgba
|
||||
from ..utils.upload_image import upload_png_mask
|
||||
from ...utils.generate_uuid import generate_uuid
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Shoes(Clothing):
|
||||
# TODO location of shoes has little mismatch
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Shoes, self).__init__(**kwargs)
|
||||
|
||||
def organize(self, layer):
|
||||
left_shoe_mask, right_shoe_mask = self.cut()
|
||||
|
||||
left_layer = dict(name=f'{type(self).__name__.lower()}_left',
|
||||
image=self.result['shoes_left'],
|
||||
image_url=self.result['left_image_url'],
|
||||
mask_url=self.result['left_mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=self.calculate_start_point(self.result['keypoint'],
|
||||
self.result['scale'],
|
||||
self.result['clothes_keypoint'],
|
||||
self.result['body_point'],
|
||||
'left'))
|
||||
layer.insert(left_layer)
|
||||
|
||||
right_layer = dict(name=f'{type(self).__name__.lower()}_right',
|
||||
image=self.result['shoes_right'],
|
||||
image_url=self.result['right_image_url'],
|
||||
mask_url=self.result['right_mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=self.calculate_start_point(self.result['keypoint'],
|
||||
self.result['scale'],
|
||||
self.result['clothes_keypoint'],
|
||||
self.result['body_point'],
|
||||
'right'))
|
||||
|
||||
layer.insert(right_layer)
|
||||
|
||||
def cut(self):
|
||||
"""
|
||||
Cut shoes mask into two pieces
|
||||
Returns:
|
||||
"""
|
||||
contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contours = sorted(contour, key=cv2.contourArea, reverse=True)
|
||||
|
||||
bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]]
|
||||
(contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False))
|
||||
|
||||
epsilon_left = 0.001 * cv2.arcLength(contours[0], True)
|
||||
|
||||
approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True)
|
||||
mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8)
|
||||
cv2.drawContours(mask_left, [approx_left], -1, 255, -1)
|
||||
item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0)
|
||||
|
||||
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left)
|
||||
result_image = np.zeros_like(rgba_image)
|
||||
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
|
||||
result_left_image_pil = Image.fromarray(result_image, 'RGBA')
|
||||
result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS)
|
||||
self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}")
|
||||
|
||||
epsilon_right = 0.001 * cv2.arcLength(contours[1], True)
|
||||
approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True)
|
||||
mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8)
|
||||
cv2.drawContours(mask_right, [approx_right], -1, 255, -1)
|
||||
item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0)
|
||||
|
||||
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right)
|
||||
result_image = np.zeros_like(rgba_image)
|
||||
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
|
||||
result_right_image_pil = Image.fromarray(result_image, 'RGBA')
|
||||
result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS)
|
||||
self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}")
|
||||
|
||||
return item_mask_left, item_mask_right
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location):
|
||||
"""
|
||||
left shoes align left
|
||||
right shoes align right
|
||||
Args:
|
||||
keypoint_type: string, "toe"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
location: string, indicates whether the start point belongs to right or left shoe
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
if location not in ['left', 'right']:
|
||||
raise KeyError(f'location value must be left or right but got {location}')
|
||||
side_indicator = f'{keypoint_type}_{location}'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
|
||||
body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
|
||||
return start_point
|
||||
@@ -1,46 +0,0 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Top(Clothing):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
if pipeline is None:
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
# dict(type='ContourDetection'),
|
||||
dict(type='Segmentation'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Top, self).__init__(**kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Blouse(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Blouse, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Outwear(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Outwear, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Dress(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Dress, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
# Men's clothing
|
||||
@ITEMS.register_module()
|
||||
class Tops(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Tops, self).__init__(pipeline, **kwargs)
|
||||
@@ -1,197 +0,0 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
|
||||
import cv2
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
from app.service.design.core.layer import Layer
|
||||
from app.service.design.items import build_item
|
||||
from app.service.design.utils.redis_utils import Redis
|
||||
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
def process_item(item, layers):
|
||||
# logging.info("process running.........")
|
||||
item.process()
|
||||
item.organize(layers)
|
||||
if item.result['name'] == "mannequin":
|
||||
return item.result['body_image'].size
|
||||
|
||||
|
||||
def update_progress(process_id, total):
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
if progress and total != 1:
|
||||
if int(progress) <= 100:
|
||||
r.write(key=process_id, value=int(progress) + int(100 / total))
|
||||
else:
|
||||
r.write(key=process_id, value=99)
|
||||
return progress
|
||||
elif total == 1:
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
else:
|
||||
r.write(key=process_id, value=int(100 / total))
|
||||
return progress
|
||||
|
||||
|
||||
def final_progress(process_id):
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
|
||||
|
||||
@RunTime
|
||||
def generate(request_data):
|
||||
return_response = {}
|
||||
return_png_mask = []
|
||||
request_data = request_data.dict()
|
||||
assert "process_id" in request_data.keys(), "Need process_id parameters"
|
||||
|
||||
objects = request_data['objects']
|
||||
# insert_keypoint_cache(objects)
|
||||
process_id = request_data['process_id']
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# 提交每个对象的处理任务
|
||||
futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)}
|
||||
# 获取处理结果
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
obj = futures[future]
|
||||
return_response[obj] = future.result()[0]
|
||||
return_png_mask.extend(future.result()[1])
|
||||
# upload_results = process_images(return_png_mask)
|
||||
final_progress(process_id)
|
||||
return return_response
|
||||
|
||||
|
||||
def process_object(cfg, process_id, total):
|
||||
uploaded_images = []
|
||||
basic_info = cfg.get('basic')
|
||||
items_response = {
|
||||
'layers': []
|
||||
}
|
||||
if cfg.get('basic')['single_overall'] == 'overall':
|
||||
basic_info['debug'] = False
|
||||
items = [build_item(x, default_args=basic_info) for x in cfg.get('items')]
|
||||
layers = Layer()
|
||||
body_size = None
|
||||
futures = []
|
||||
for item in items:
|
||||
futures = [process_item(item, layers)]
|
||||
for future in futures:
|
||||
if future is not None:
|
||||
body_size = future
|
||||
# 是否自定义排序
|
||||
if basic_info.get('layer_order', False):
|
||||
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
|
||||
else:
|
||||
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
|
||||
# 上传所有图片
|
||||
# for layer in layers:
|
||||
# if 'image' in layer.keys() and layer['image'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url'], 'image_type': 'image'})
|
||||
# if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url'], 'image_type': 'pattern_image'})
|
||||
# if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url'], 'image_type': 'mask'})
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
# 合成
|
||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic_info)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
'image_category': lay['name'],
|
||||
'position': lay['position'],
|
||||
'priority': lay.get("priority", None),
|
||||
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
|
||||
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
|
||||
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
||||
'mask_url': lay['mask_url'],
|
||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
|
||||
|
||||
# 'image': lay['image'],
|
||||
# 'mask_image': lay['mask_image'],
|
||||
})
|
||||
elif cfg.get('basic')['single_overall'] == 'single':
|
||||
assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters "
|
||||
basic_info['debug'] = False
|
||||
for item in cfg.get('items'):
|
||||
if item['type'] == cfg.get('basic')['switch_category']:
|
||||
item = build_item(item, default_args=cfg.get('basic'))
|
||||
item.process()
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item.result['name']}_front",
|
||||
'image_size': item.result['back_image'].size if item.result['back_image'] else None,
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item.result['front_image_url'],
|
||||
'mask_url': item.result['mask_url'],
|
||||
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
|
||||
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item.result['name']}_back",
|
||||
'image_size': item.result['front_image'].size if item.result['front_image'] else None,
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item.result['back_image_url'],
|
||||
'mask_url': item.result['mask_url'],
|
||||
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
|
||||
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
|
||||
break
|
||||
update_progress(process_id, total)
|
||||
return items_response, uploaded_images
|
||||
|
||||
|
||||
@RunTime
|
||||
def process_images(images):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
results = list(executor.map(upload_images, images))
|
||||
# results = []
|
||||
# for image in images:
|
||||
# results.append(upload_images(image))
|
||||
return results
|
||||
|
||||
|
||||
# @RunTime
|
||||
def upload_images(image_obj):
|
||||
bucket_name = image_obj['image_url'].split("/", 1)[0]
|
||||
object_name = image_obj['image_url'].split("/", 1)[1]
|
||||
if image_obj['image_type'] == 'image' or image_obj['image_type'] == 'pattern_image':
|
||||
image_data = io.BytesIO()
|
||||
image_obj['image_obj'].save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return image_obj['image_url']
|
||||
else:
|
||||
mask_inverted = cv2.bitwise_not(image_obj['image_obj'])
|
||||
# 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明
|
||||
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
|
||||
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
return image_obj['image_url']
|
||||
|
||||
|
||||
def update_base_size_priority(layers, size):
|
||||
# 计算透明背景图片的宽度
|
||||
min_x = min(info['position'][1] for info in layers)
|
||||
x_list = []
|
||||
for info in layers:
|
||||
if info['image'] is not None:
|
||||
x_list.append(info['position'][1] + info['image'].width)
|
||||
max_x = max(x_list)
|
||||
new_width = max_x - min_x
|
||||
new_height = 700
|
||||
# 更新坐标
|
||||
for info in layers:
|
||||
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
||||
return layers, (new_width, new_height)
|
||||
@@ -1,31 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :conversion_image.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/21 10:40:29
|
||||
@detail :
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
# def rgb_to_rgba(rgb_size, rgb_image, mask):
|
||||
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
|
||||
# # 创建四通道的结果图像
|
||||
# rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
# alpha_channel = np.where(mask > 0, 255, 0)
|
||||
# # 更新RGBA图像的透明度通道
|
||||
# rgba_image[:, :, 3] = alpha_channel
|
||||
# return rgba_image
|
||||
|
||||
def rgb_to_rgba(rgb_image, mask):
|
||||
# 创建全透明的alpha通道
|
||||
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
|
||||
# 合并RGB图像和alpha通道
|
||||
rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
return rgba_image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = open("")
|
||||
@@ -1,143 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :design_ensemble.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/16 19:36:21
|
||||
@detail :发起请求 获取推理结果
|
||||
"""
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
"""
|
||||
keypoint
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
def keypoint_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
img_scale = (256, 256)
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.resize(img, img_scale)
|
||||
w_scale = img_scale[0] / w
|
||||
h_scale = img_scale[1] / h
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, (w_scale, h_scale)
|
||||
|
||||
|
||||
# @ RunTime
|
||||
# 推理
|
||||
def get_keypoint_result(image, site):
|
||||
keypoint_result = None
|
||||
try:
|
||||
image, scale_factor = keypoint_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
transformed_img = image.astype(np.float32)
|
||||
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
|
||||
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output'))
|
||||
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
|
||||
except Exception as e:
|
||||
logging.warning(f"get_keypoint_result : {e}")
|
||||
return keypoint_result
|
||||
|
||||
|
||||
def keypoint_postprocess(output, scale_factor):
|
||||
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
|
||||
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
|
||||
segment_result = max_coords.numpy()
|
||||
scale_factor = [1 / x for x in scale_factor[::-1]]
|
||||
scale_matrix = np.diag(scale_factor)
|
||||
nan = np.isinf(scale_matrix)
|
||||
scale_matrix[nan] = 0
|
||||
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||||
|
||||
|
||||
"""
|
||||
seg
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
# KNet
|
||||
def seg_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale_w, img_scale_h = ori_shape
|
||||
if ori_shape[0] > 1024:
|
||||
img_scale_w = 1024
|
||||
if ori_shape[1] > 1024:
|
||||
img_scale_h = 1024
|
||||
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||||
if ori_shape != (img_scale_w, img_scale_h):
|
||||
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||||
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, ori_shape
|
||||
|
||||
|
||||
# @ RunTime
|
||||
def get_seg_result(image_id, image):
|
||||
image, ori_shape = seg_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||||
transformed_img = image.astype(np.float32)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||
]
|
||||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||||
return seg_result
|
||||
|
||||
|
||||
# no cache
|
||||
def seg_postprocess(image_id, output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
|
||||
def key_point_show(image_path, key_point_result=None):
|
||||
img = cv2.imread(image_path)
|
||||
points_list = key_point_result
|
||||
point_size = 1
|
||||
point_color = (0, 0, 255) # BGR
|
||||
thickness = 4 # 可以为 0 、4、8
|
||||
for point in points_list:
|
||||
cv2.circle(img, point[::-1], point_size, point_color, thickness)
|
||||
cv2.imshow("0", img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
|
||||
a = get_keypoint_result(image, "up")
|
||||
new_list = []
|
||||
print(list)
|
||||
for i in a[0]:
|
||||
new_list.append((int(i[0]), int(i[1])))
|
||||
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
|
||||
# a = get_seg_result(1, image)
|
||||
print(a)
|
||||
@@ -1,99 +0,0 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
@@ -1,181 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :synthesis_item.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/26 14:13:04
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
def positioning(all_mask_shape, mask_shape, offset):
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
mask_start = 0
|
||||
mask_end = 0
|
||||
if offset == 0:
|
||||
all_start = 0
|
||||
all_end = min(all_mask_shape, mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = min(all_mask_shape, mask_shape)
|
||||
elif offset > 0:
|
||||
all_start = min(offset, all_mask_shape)
|
||||
all_end = min(offset + mask_shape, all_mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
|
||||
elif offset < 0:
|
||||
if abs(offset) > mask_shape:
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
else:
|
||||
all_start = 0
|
||||
if mask_shape - abs(offset) > all_mask_shape:
|
||||
all_end = min(mask_shape - abs(offset), all_mask_shape)
|
||||
else:
|
||||
all_end = mask_shape - abs(offset)
|
||||
|
||||
if abs(offset) > mask_shape:
|
||||
mask_start = mask_shape
|
||||
mask_end = mask_shape
|
||||
else:
|
||||
mask_start = abs(offset)
|
||||
if mask_shape - abs(offset) >= all_mask_shape:
|
||||
mask_end = all_mask_shape + abs(offset)
|
||||
else:
|
||||
mask_end = mask_shape
|
||||
return all_start, all_end, mask_start, mask_end
|
||||
|
||||
|
||||
# @RunTime
|
||||
def synthesis(data, size, basic_info):
|
||||
# 创建底图
|
||||
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
try:
|
||||
all_mask_shape = (size[1], size[0])
|
||||
body_mask = None
|
||||
for d in data:
|
||||
if d['name'] == 'body':
|
||||
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
|
||||
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
|
||||
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position
|
||||
body_mask = np.array(transparent_image.split()[3])
|
||||
|
||||
# 根据新的坐标获取新的肩点
|
||||
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
|
||||
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
|
||||
top_outer_mask = np.array(binary_body_mask)
|
||||
bottom_outer_mask = np.array(binary_body_mask)
|
||||
|
||||
top = True
|
||||
bottom = True
|
||||
i = len(data)
|
||||
while i:
|
||||
i -= 1
|
||||
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
||||
top = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
top_outer_mask = background + top_outer_mask
|
||||
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
|
||||
bottom = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
bottom_outer_mask = background + bottom_outer_mask
|
||||
elif bottom is False and top is False:
|
||||
break
|
||||
|
||||
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
|
||||
|
||||
for layer in data:
|
||||
if layer['image'] is not None:
|
||||
if layer['name'] != "body":
|
||||
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||
mask_alpha = Image.fromarray(mask_data)
|
||||
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
|
||||
else:
|
||||
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
|
||||
result_image = base_image
|
||||
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
|
||||
# oss upload
|
||||
image_bytes = image_data.read()
|
||||
bucket_name = "aida-results"
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"synthesis runtime exception : {e}")
|
||||
|
||||
|
||||
def synthesis_single(front_image, back_image):
|
||||
result_image = None
|
||||
if front_image:
|
||||
result_image = front_image
|
||||
if back_image:
|
||||
result_image.paste(back_image, (0, 0), back_image)
|
||||
|
||||
# with io.BytesIO() as output:
|
||||
# result_image.save(output, format='PNG')
|
||||
# data = output.getvalue()
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
# oss upload
|
||||
bucket_name = 'aida-results'
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
from celery import Celery
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, OthersItem
|
||||
from app.service.design_batch.utils.MQ import publish_status
|
||||
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_others
|
||||
@@ -12,12 +12,12 @@ from app.service.design_batch.utils.save_json import oss_upload_json
|
||||
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
||||
|
||||
id_lock = threading.Lock()
|
||||
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app = Celery('tasks', broker=f'amqp://{settings.MQ_USERNAME}:{settings.MQ_PASSWORD}@{settings.MQ_HOST}:{settings.MQ_PORT}//', backend='rpc://')
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
logger = logging.getLogger()
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
print("start")
|
||||
|
||||
@@ -51,10 +51,12 @@ def process_layer(item, layers):
|
||||
front_layer, back_layer = organize_others(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
return None
|
||||
else:
|
||||
front_layer, back_layer = organize_clothing(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
return None
|
||||
|
||||
|
||||
@celery_app.task
|
||||
@@ -76,12 +78,11 @@ def batch_design(objects_data, tasks_id, json_name):
|
||||
for item in object['items']:
|
||||
item_results.append(process_item(item, basic))
|
||||
layers = []
|
||||
body_size = None
|
||||
for item in item_results:
|
||||
body_size = process_layer(item, layers)
|
||||
process_layer(item, layers)
|
||||
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
||||
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
layers, new_size = update_base_size_priority(layers)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
|
||||
@@ -18,11 +18,11 @@ class BackPerspective:
|
||||
result['back_perspective_url'] = file_path
|
||||
return result
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
seg_result = get_seg_result(result['image'])[0]
|
||||
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
|
||||
seg_result = result['seg_result']
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
seg_result = get_seg_result(result['image'])[0]
|
||||
|
||||
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
|
||||
back_sketch = result['image'].copy()
|
||||
@@ -34,7 +34,8 @@ class BackPerspective:
|
||||
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
|
||||
return result
|
||||
|
||||
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
|
||||
@staticmethod
|
||||
def thicken_contours_and_display(mask, thickness=10, color=(0, 0, 0)):
|
||||
mask = mask.astype(np.uint8) * 255
|
||||
# 查找轮廓
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
@@ -48,9 +49,9 @@ class BackPerspective:
|
||||
# 在空白图像上绘制白色的轮廓
|
||||
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
|
||||
# 找到轮廓的中心(可以用重心等方法近似)
|
||||
M = cv2.moments(contour)
|
||||
cx = int(M['m10'] / M['m00'])
|
||||
cy = int(M['m01'] / M['m00'])
|
||||
m = cv2.moments(contour)
|
||||
cx = int(m['m10'] / m['m00'])
|
||||
cy = int(m['m01'] / m['m00'])
|
||||
# 进行距离变换,离中心越近的值越小
|
||||
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
|
||||
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留
|
||||
|
||||
@@ -79,9 +79,9 @@ class Color:
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
r, g, b = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
pattern[0, 0, 0] = int(b)
|
||||
pattern[0, 0, 1] = int(g)
|
||||
pattern[0, 0, 2] = int(r)
|
||||
return pattern
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
|
||||
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.utils.decorator import ClassCallRunTime, RunTime
|
||||
|
||||
@@ -21,12 +21,12 @@ class KeyPoint:
|
||||
def __call__(self, result):
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
# keypoint_cache = self.keypoint_cache(result, site)
|
||||
keypoint_cache = False
|
||||
# 取消向量查询 直接过模型推理
|
||||
if keypoint_cache is False:
|
||||
if not keypoint_cache:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
else:
|
||||
@@ -55,8 +55,8 @@ class KeyPoint:
|
||||
}
|
||||
]
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
client.close()
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
@@ -79,7 +79,7 @@ class KeyPoint:
|
||||
]
|
||||
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
client.upsert(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
data=data
|
||||
@@ -92,7 +92,7 @@ class KeyPoint:
|
||||
@RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
keypoint_id = result['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
class PrintPainting:
|
||||
def __init__(self, minio_client):
|
||||
self.random_seed = None
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
@@ -408,7 +409,7 @@ class PrintPainting:
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.core.config import settings
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
@@ -36,11 +36,11 @@ class Segmentation:
|
||||
# preview 过模型 不缓存
|
||||
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
# submit 过模型 缓存
|
||||
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# null 正常流程 加载本地缓存 无缓存则过模型
|
||||
else:
|
||||
@@ -49,7 +49,7 @@ class Segmentation:
|
||||
# 判断缓存和实际图片size是否相同
|
||||
if not _ or result["image"].shape[:2] != seg_result.shape:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
result['seg_result'] = seg_result
|
||||
|
||||
@@ -63,7 +63,7 @@ class Segmentation:
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
@@ -72,7 +72,7 @@ class Segmentation:
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
|
||||
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
|
||||
@@ -4,9 +4,7 @@ import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
|
||||
from app.service.design_fast.utils.transparent import sketch_to_transparent
|
||||
from app.service.design_fast.utils.upload_image import upload_png_mask
|
||||
@@ -40,7 +38,7 @@ class Split(object):
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
front_mask = cv2.resize(front_mask, new_size)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
|
||||
if 'transparent' in result.keys():
|
||||
# 用户自选区域transparent
|
||||
transparent = result['transparent']
|
||||
@@ -98,21 +96,21 @@ class Split(object):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
result_pattern_image_pil = Image.fromarray(cv2.cvtColor(result_pattern_image_rgba, cv2.COLOR_BGR2RGBA))
|
||||
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,16 +2,17 @@ import json
|
||||
|
||||
import pika
|
||||
|
||||
from app.core.config import RABBITMQ_PARAMS, BATCH_DESIGN_RABBITMQ_QUEUES
|
||||
from app.core.config import settings
|
||||
from app.core.rabbit_mq_config import RABBITMQ_PARAMS
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
|
||||
channel.queue_declare(queue=settings.BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=BATCH_DESIGN_RABBITMQ_QUEUES,
|
||||
routing_key=settings.BATCH_DESIGN_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
|
||||
@@ -16,7 +16,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
|
||||
|
||||
"""
|
||||
keypoint
|
||||
@@ -91,29 +91,29 @@ def seg_preprocess(img_path):
|
||||
|
||||
|
||||
# @ RunTime
|
||||
def get_seg_result(image_id, image):
|
||||
def get_seg_result(image):
|
||||
image, ori_shape = seg_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||||
transformed_img = image.astype(np.float32)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||
httpclient.InferInput(DESIGN_MODEL_NAME, transformed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||
httpclient.InferRequestedOutput("seg_input__0", binary_data=True),
|
||||
]
|
||||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||
results = client.infer(model_name=DESIGN_MODEL_NAME, inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||||
inference_output1 = results.as_numpy("seg_input__0")
|
||||
seg_result = seg_postprocess(inference_output1, ori_shape)
|
||||
return seg_result
|
||||
|
||||
|
||||
# no cache
|
||||
def seg_postprocess(image_id, output, ori_shape):
|
||||
def seg_postprocess(output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
@@ -98,6 +98,8 @@ def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offse
|
||||
"""
|
||||
Align left
|
||||
Args:
|
||||
offset:
|
||||
resize_scale:
|
||||
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from app.service.design_fast.utils.redis_utils import Redis
|
||||
from app.service.utils.redis_utils import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
@@ -13,9 +13,12 @@ import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from minio import Minio
|
||||
from app.core.config import settings
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
def positioning(all_mask_shape, mask_shape, offset):
|
||||
@@ -136,7 +139,7 @@ def synthesis(data, size, basic_info):
|
||||
image_bytes = image_data.read()
|
||||
bucket_name = "aida-results"
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
|
||||
@@ -177,11 +180,11 @@ def synthesis_single(front_image, back_image):
|
||||
# oss upload
|
||||
bucket_name = 'aida-results'
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
|
||||
|
||||
def update_base_size_priority(layers, size):
|
||||
def update_base_size_priority(layers):
|
||||
# 计算透明背景图片的宽度
|
||||
min_x = min(info['position'][1] for info in layers)
|
||||
x_list = []
|
||||
|
||||
@@ -12,7 +12,6 @@ import logging
|
||||
|
||||
import cv2
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
|
||||
@@ -25,15 +24,15 @@ def upload_png_mask(minio_client, front_image, object_name, mask=None):
|
||||
# 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明
|
||||
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
|
||||
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
|
||||
req = oss_upload_image(oss_client=minio_client, bucket="aida-clothing", object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
mask_url = f"aida-clothing/mask/mask_{object_name}.png"
|
||||
|
||||
image_data = io.BytesIO()
|
||||
front_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
|
||||
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
|
||||
req = oss_upload_image(oss_client=minio_client, bucket="aida-clothing", object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
|
||||
image_url = f"aida-clothing/image/image_{object_name}.png"
|
||||
return front_image, image_url, mask_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
import requests
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem
|
||||
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
|
||||
from app.service.design_fast.utils.progress import final_progress, update_progress
|
||||
@@ -16,7 +16,7 @@ id_lock = threading.Lock()
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
def process_item(item, basic):
|
||||
@@ -48,10 +48,12 @@ def process_layer(item, layers):
|
||||
front_layer, back_layer = organize_others(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
return None
|
||||
else:
|
||||
front_layer, back_layer = organize_clothing(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
return None
|
||||
|
||||
|
||||
@RunTime
|
||||
@@ -73,12 +75,11 @@ def design_generate(request_data):
|
||||
for item in object['items']:
|
||||
item_results.append(process_item(item, basic))
|
||||
layers = []
|
||||
body_size = None
|
||||
for item in item_results:
|
||||
body_size = process_layer(item, layers)
|
||||
process_layer(item, layers)
|
||||
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
||||
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
layers, new_size = update_base_size_priority(layers)
|
||||
# pattern_overall_image_url 、 pattern_print_image_url
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
@@ -149,7 +150,7 @@ def design_generate_v2(request_data):
|
||||
request_id = request_data.requestId
|
||||
threads = []
|
||||
|
||||
def process_object(step, object, callback_url):
|
||||
def process_object(object, callback_url):
|
||||
basic = object['basic']
|
||||
items_response = {
|
||||
'layers': [],
|
||||
@@ -161,12 +162,11 @@ def design_generate_v2(request_data):
|
||||
for item in object['items']:
|
||||
item_results.append(process_item(item, basic))
|
||||
layers = []
|
||||
body_size = None
|
||||
for item in item_results:
|
||||
body_size = process_layer(item, layers)
|
||||
process_layer(item, layers)
|
||||
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
||||
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
layers, new_size = update_base_size_priority(layers)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
@@ -229,7 +229,7 @@ def design_generate_v2(request_data):
|
||||
logger.info(response.text)
|
||||
|
||||
for step, object in enumerate(objects_data):
|
||||
t = threading.Thread(target=process_object, args=(step, object, callback_url))
|
||||
t = threading.Thread(target=process_object, args=(object, callback_url))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import io
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
from minio import Minio
|
||||
from app.core.config import settings
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
def model_transpose(image_path):
|
||||
bucket = image_path.split("/", 1)[0]
|
||||
object_name = image_path.split("/", 1)[1]
|
||||
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
||||
image = oss_get_image(bucket=bucket, object_name=object_name, data_type="PIL")
|
||||
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name, data_type="PIL")
|
||||
image = image.convert("RGBA")
|
||||
data = image.getdata()
|
||||
#
|
||||
@@ -23,6 +28,6 @@ def model_transpose(image_path):
|
||||
image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
oss_upload_image(bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
|
||||
oss_upload_image(oss_client=minio_client, bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
|
||||
image_path = f"{bucket}/{new_object_name}"
|
||||
return image_path
|
||||
@@ -18,11 +18,11 @@ class BackPerspective:
|
||||
result['back_perspective_url'] = file_path
|
||||
return result
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
seg_result = get_seg_result(result['image'])[0]
|
||||
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
|
||||
seg_result = result['seg_result']
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
seg_result = get_seg_result(result['image'])[0]
|
||||
|
||||
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
|
||||
back_sketch = result['image'].copy()
|
||||
@@ -34,7 +34,8 @@ class BackPerspective:
|
||||
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
|
||||
return result
|
||||
|
||||
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
|
||||
@staticmethod
|
||||
def thicken_contours_and_display(mask, thickness=10, color=(0, 0, 0)):
|
||||
mask = mask.astype(np.uint8) * 255
|
||||
# 查找轮廓
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
@@ -48,9 +49,9 @@ class BackPerspective:
|
||||
# 在空白图像上绘制白色的轮廓
|
||||
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
|
||||
# 找到轮廓的中心(可以用重心等方法近似)
|
||||
M = cv2.moments(contour)
|
||||
cx = int(M['m10'] / M['m00'])
|
||||
cy = int(M['m01'] / M['m00'])
|
||||
m = cv2.moments(contour)
|
||||
# cx = int(m['m10'] / m['m00'])
|
||||
# cy = int(m['m01'] / m['m00'])
|
||||
# 进行距离变换,离中心越近的值越小
|
||||
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
|
||||
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留
|
||||
|
||||
@@ -81,9 +81,9 @@ class Color:
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
r, g, b = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
pattern[0, 0, 0] = int(b)
|
||||
pattern[0, 0, 1] = int(g)
|
||||
pattern[0, 0, 2] = int(r)
|
||||
return pattern
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
|
||||
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.utils.decorator import ClassCallRunTime, RunTime
|
||||
|
||||
@@ -21,12 +21,12 @@ class KeyPoint:
|
||||
def __call__(self, result):
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
# keypoint_cache = self.keypoint_cache(result, site)
|
||||
keypoint_cache = False
|
||||
# 取消向量查询 直接过模型推理
|
||||
if keypoint_cache is False:
|
||||
if not keypoint_cache:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
else:
|
||||
@@ -55,8 +55,8 @@ class KeyPoint:
|
||||
}
|
||||
]
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
client.close()
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
@@ -79,7 +79,7 @@ class KeyPoint:
|
||||
]
|
||||
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
client.upsert(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
data=data
|
||||
@@ -92,7 +92,7 @@ class KeyPoint:
|
||||
@RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
keypoint_id = result['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
from skimage.morphology import skeletonize
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
@@ -47,26 +44,32 @@ class LoadImage:
|
||||
# else:
|
||||
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
|
||||
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), result['path'])
|
||||
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY))
|
||||
result['keypoint'] = self.get_keypoint(result['name'])
|
||||
result['img_shape'] = result['image'].shape
|
||||
result['ori_shape'] = result['image'].shape
|
||||
return result
|
||||
|
||||
def get_lines(self, img, path):
|
||||
@staticmethod
|
||||
def get_lines(img):
|
||||
binary = cv2.adaptiveThreshold(img, 255,
|
||||
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV,
|
||||
25, 10)
|
||||
binary_bool = binary > 0
|
||||
skeleton = skeletonize(binary_bool, method='zhang')
|
||||
mask = skeleton
|
||||
result = np.ones_like(img) * 255
|
||||
result[mask] = img[mask]
|
||||
|
||||
# 步骤2:细化边缘(可选,让线条更干净)
|
||||
# kernel = np.ones((1, 1), np.uint8)
|
||||
# clean = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
|
||||
|
||||
thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
|
||||
mask = thinned > 0
|
||||
result = np.ones_like(img) * 255
|
||||
result[mask] = img[mask]
|
||||
# thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
|
||||
# mask = thinned > 0
|
||||
# result = np.ones_like(img) * 255
|
||||
# result[mask] = img[mask]
|
||||
|
||||
# 步骤3:反转回 白底黑线
|
||||
# lines = cv2.bitwise_not(thinned)
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
class NoSegPrintPainting:
|
||||
def __init__(self, minio_client):
|
||||
self.random_seed = random.randint(0, 1000)
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
@@ -174,7 +175,6 @@ class NoSegPrintPainting:
|
||||
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||
if not is_single:
|
||||
self.random_seed = random.randint(0, 1000)
|
||||
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
||||
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
@@ -244,7 +244,7 @@ class NoSegPrintPainting:
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
class PrintPainting:
|
||||
def __init__(self, minio_client):
|
||||
self.random_seed = None
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
@@ -416,7 +417,7 @@ class PrintPainting:
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.core.config import settings
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
@@ -36,11 +36,11 @@ class Segmentation:
|
||||
# preview 过模型 不缓存
|
||||
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
# submit 过模型 缓存
|
||||
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# null 正常流程 加载本地缓存 无缓存则过模型
|
||||
else:
|
||||
@@ -49,7 +49,7 @@ class Segmentation:
|
||||
# 判断缓存和实际图片size是否相同
|
||||
if not _ or result["image"].shape[:2] != seg_result.shape:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
result['seg_result'] = seg_result
|
||||
|
||||
@@ -63,7 +63,7 @@ class Segmentation:
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
@@ -72,7 +72,7 @@ class Segmentation:
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
|
||||
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
|
||||
@@ -4,9 +4,7 @@ import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
|
||||
from app.service.design_fast.utils.transparent import sketch_to_transparent
|
||||
from app.service.design_fast.utils.upload_image import upload_png_mask
|
||||
@@ -41,7 +39,7 @@ class Split(object):
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
|
||||
if 'transparent' in result.keys():
|
||||
# 用户自选区域transparent
|
||||
transparent = result['transparent']
|
||||
@@ -106,26 +104,27 @@ class Split(object):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
|
||||
# mask_image[back_mask != 0] = [0, 255, 0]
|
||||
mask_image[ori_back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, ori_front_mask + ori_back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
|
||||
else:
|
||||
ori_front_mask, ori_back_mask = None, None
|
||||
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
|
||||
result_pattern_overall_image_pil = Image.fromarray(cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), COLOR_BGR2RGBA))
|
||||
result_pattern_overall_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
|
||||
result['pattern_overall_image'], result['pattern_overall_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_overall_image_pil, f'{generate_uuid()}')
|
||||
|
||||
result_pattern_print_image_pil = Image.fromarray(cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), COLOR_BGR2RGBA))
|
||||
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
|
||||
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
|
||||
|
||||
"""
|
||||
keypoint
|
||||
@@ -98,29 +98,29 @@ def seg_preprocess(img_path):
|
||||
|
||||
|
||||
# @ RunTime
|
||||
def get_seg_result(image_id, image):
|
||||
def get_seg_result(image):
|
||||
image, ori_shape = seg_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
transformed_img = image.astype(np.float32)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||
httpclient.InferInput("seg_input__0", transformed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||
httpclient.InferRequestedOutput("seg_output__0", binary_data=True),
|
||||
]
|
||||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||
results = client.infer(model_name=DESIGN_MODEL_NAME, inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||||
inference_output1 = results.as_numpy("seg_output__0")
|
||||
seg_result = seg_postprocess(inference_output1, ori_shape)
|
||||
return seg_result
|
||||
|
||||
|
||||
# no cache
|
||||
def seg_postprocess(image_id, output, ori_shape):
|
||||
def seg_postprocess(output, ori_shape):
|
||||
seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50))
|
||||
seg_logit = seg_logit[25: - 25, 25: - 25]
|
||||
return seg_logit
|
||||
|
||||
@@ -112,6 +112,8 @@ def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offse
|
||||
"""
|
||||
Align left
|
||||
Args:
|
||||
offset:
|
||||
resize_scale:
|
||||
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from app.service.design_fast.utils.redis_utils import Redis
|
||||
from app.service.utils.redis_utils import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user