feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped

This commit is contained in:
zcr
2025-12-30 16:49:08 +08:00
parent 1be716e414
commit 18024a2d70
167 changed files with 5283 additions and 10464 deletions

4
.gitignore vendored
View File

@@ -148,4 +148,6 @@ app/logs/*
*.pickle
*.csv
*.avi
*.json
*.json
*.env*
config.backup.py

View File

@@ -23,11 +23,11 @@
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
2. 启动服务器
1. 启动服务器
$ uvicorn app.main:app --host 0.0.0.0 --port 8000
3. 打开 http://127.0.0.1:8000/docs
2. 打开 http://127.0.0.1:8000/docs
Docker 部署
---------------

View File

@@ -2,8 +2,7 @@ import json
import logging
from fastapi import APIRouter, HTTPException
from app.core.config import DEBUG
from app.core.config import settings
from app.schemas.attribute_retrieve import *
from app.schemas.response_template import ResponseModel
from app.service.attribute.config import const, local_debug_const
@@ -35,13 +34,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
"""
try:
for item in request_item:
logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
if DEBUG:
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict(), indent=4)}")
if settings.DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else:
service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result()
logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -67,10 +66,10 @@ def category_recognition(request_item: list[CategoryRecognitionModel]):
"""
try:
for item in request_item:
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict(), indent=4)}")
service = CategoryRecognition(request_data=request_item)
data = service.get_result()
logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}")
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -26,7 +26,7 @@ def seg_product(request_item: BrandDnaModel):
}
"""
try:
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = BrandDna(request_item)
result_url = service.get_result()
except Exception as e:
@@ -36,7 +36,7 @@ def seg_product(request_item: BrandDnaModel):
@router.post("/GenerateBrand")
def GenerateBrand(request_data: GenerateBrandModel):
def generate_brand(request_data: GenerateBrandModel):
"""
通过prompt 生成 brand name brand slogan , brand logo。
创建一个具有以下参数的请求体:

View File

@@ -1,34 +1,25 @@
import io
import logging
import sys
import time
from typing import List
from collections import defaultdict
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
import pymysql
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
from minio import Minio
import torch
from torchvision import models, transforms
from PIL import Image
import os
import sys
from collections import defaultdict
import numpy as np
import pymysql
import torch
from PIL import Image
from fastapi import HTTPException, APIRouter
from fastapi.responses import JSONResponse
from minio import Minio
from torchvision import models, transforms
from app.core.mysql_config import DB_CONFIG
from app.core.new_config import settings
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
# MinIO 配置
minio_client = Minio(
"www.minio.aida.com.hk:12024",
access_key="admin",
secret_key="Aidlab123123!",
secure=True
)
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
transform = transforms.Compose([
transforms.Resize((224, 224)),
@@ -67,8 +58,8 @@ def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
# 预加载
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
BRAND_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
def save_sketch_to_iid():
@@ -76,11 +67,11 @@ def save_sketch_to_iid():
sketch_path: iid
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
}
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
np.save(f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
def load_sketch_to_iid():
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
path = f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
if os.path.exists(path):
return np.load(path, allow_pickle=True).item()
save_sketch_to_iid()
@@ -90,7 +81,7 @@ def load_sketch_to_iid():
sketch_to_iid = load_sketch_to_iid()
def getNewCategory(gender: str, sketch_category: str) -> str:
def get_new_category(gender: str, sketch_category: str) -> str:
return f"{gender.lower()}_{sketch_category.lower()}"
@@ -103,8 +94,8 @@ def get_category_from_path(path: str) -> str:
def load_brand_matrix():
"""单独加载 brand_matrix 和 brand_index_map"""
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
mat_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy"
try:
matrix = np.load(mat_path)
index_map = np.load(idx_path, allow_pickle=True).item()
@@ -113,11 +104,19 @@ def load_brand_matrix():
index_map = {}
return matrix, index_map
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
def getNewCategory(gender, sketch_category):
print(gender)
print(sketch_category)
return "None"
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
# 1. 收集品牌-分类-特征
brand_feature = defaultdict(lambda: defaultdict(list))
@@ -164,11 +163,11 @@ def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
# 7. 持久化
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
# 返回该品牌对应行
return brand_matrix[row_idx:row_idx+1]
return brand_matrix[row_idx:row_idx + 1]
@router.get("/brand_dna_initialize/{brand_id}")
@@ -178,14 +177,12 @@ async def brand_dna_initialize(brand_id: int):
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
cursor.execute("""
SELECT id, img_url, gender, category
FROM product_image_attribute
WHERE library_id IN (
SELECT library_id
FROM brand_rel_library
WHERE brand_id = %s
)
""", (brand_id,))
SELECT id, img_url, gender, category
FROM product_image_attribute
WHERE library_id IN (SELECT library_id
FROM brand_rel_library
WHERE brand_id = %s)
""", (brand_id,))
sketch_data = cursor.fetchall()
# 触发计算并持久化,若内部出错会抛异常

View File

@@ -5,10 +5,11 @@ import time
from PIL import ImageEnhance
from fastapi import APIRouter, HTTPException
from minio import Minio
from app.core.config import settings
from app.schemas.brighten import BrightenModel
from app.schemas.response_template import ResponseModel
from app.service.utils.oss_client import oss_get_image, oss_upload_image
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
router = APIRouter()
logger = logging.getLogger()
@@ -20,6 +21,9 @@ def increase_brightness(img, factor):
return bright_img
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@router.post("/brighten")
async def brighten(request_item: BrightenModel):
"""
@@ -35,14 +39,14 @@ async def brighten(request_item: BrightenModel):
"""
try:
start_time = time.time()
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict())}")
image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
image = oss_get_image(oss_client=minio_client, bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
new_image = increase_brightness(image, request_item.brighten_value)
image_data = io.BytesIO()
new_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
req = oss_upload_image(oss_client=minio_client, bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
brighten_url = f"{req.bucket_name}/{req.object_name}"
logger.info(f"run time is : {time.time() - start_time}")
except Exception as e:

View File

@@ -30,9 +30,9 @@ def chat_robot(request_data: ChatRobotModel):
}
"""
try:
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
data = chat(post_data=request_data)
logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}")
logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -42,7 +42,7 @@ def clothing_seg(request_item: ClothingSegModel):
}
"""
try:
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
server = ClothingSeg(request_item)
result_url = server.get_result()
except Exception as e:

View File

@@ -1,203 +1,201 @@
import json
import logging
import os
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
from fastapi import APIRouter, HTTPException, BackgroundTasks
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel
from app.schemas.response_template import ResponseModel
from app.service.design.model_process_service import model_transpose
from app.service.design_batch.service import start_design_batch_generate
from app.service.design_fast.design_generate import design_generate, design_generate_v2
from app.service.design_fast.utils.redis_utils import Redis
from app.service.design_fast.model_process_service import model_transpose
router = APIRouter()
logger = logging.getLogger()
@router.post("/design")
def design(request_data: DesignModel, background_tasks: BackgroundTasks):
def design(request_data: DesignModel):
"""
objects.items.transparent:
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
mask_url 为空"" -> 单件衣服透明
mask_url 非空"mask_url" -> 区域透明
objects.items.transparent:
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
mask_url 为空"" -> 单件衣服透明
mask_url 非空"mask_url" -> 区域透明
创建一个具有以下参数的请求体:
示例参数:
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
200,
241
],
"hand_point_right": [
223,
297
],
"waistband_left": [
112,
241
],
"hand_point_left": [
92,
305
],
"shoulder_left": [
99,
116
],
"shoulder_right": [
215,
116
]
},
"layer_order": true,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 270372,
"color": "30 28 28",
"image_id": 69780,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
创建一个具有以下参数的请求体:
示例参数:
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
203,
249
],
"hand_point_right": [
229,
343
],
"waistband_left": [
119,
248
],
"hand_point_left": [
97,
343
],
"shoulder_left": [
108,
107
],
"shoulder_right": [
212,
107
]
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
"layer_order": true,
"preview_submit": "submit",
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
{
"businessId": 270373,
"color": "30 28 28",
"image_id": 98243,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
"items": [
{
"businessId": 2377945,
"color": "209 196 171",
"image_id": 189410,
"offset": [
0,
0
],
"path": "aida-collection-element/89/Sketchboard/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
"priority": 12,
"resize_scale": [
1.0,
1.0
],
"seg_mask_url": "aida-clothing/mask/mask_8e96ddb0-e466-11f0-8de2-0242ac130002.png",
"type": "Outwear"
},
"priority": 11,
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"businessId": 270374,
"color": "172 68 68",
"image_id": 98244,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
{
"businessId": 2377946,
"color": "122 152 139",
"image_id": 81868,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/blouse/0825001443.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
"priority": 11,
"resize_scale": [
1.0,
1.0
],
"seg_mask_url": "aida-clothing/mask/mask_8f0fab78-e466-11f0-8de2-0242ac130002.png",
"type": "Blouse"
},
{
"businessId": 2377947,
"color": "111 78 63",
"gradient": "aida-gradient/517c3a4d-aed7-4423-aa99-7b60d3577df1.png",
"image_id": 116494,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/skirt/0825000219.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"seg_mask_url": "aida-clothing/mask/mask_8f6191fe-e466-11f0-8de2-0242ac130002.png",
"type": "Skirt"
},
"priority": 12,
"resize_scale": [
1.0,
1.0
],
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
"type": "Outwear"
},
{
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
"image_id": 96090,
"type": "Body"
}
]
}
],
"process_id": "83"
}
"""
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
{
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"type": "Body"
}
]
}
],
"process_id": "89"
}
"""
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
# data = generate(request_data=request_data)
# logger.info(f"design response @@@@@@:{json.dumps(data)}")
# logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
#
try:
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = design_generate(request_data=request_data)
logger.info(f"design response @@@@@@:{json.dumps(data)}")
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -215,47 +213,48 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"basic": {
"body_point_test": {
"waistband_right": [
200,
241
203,
249
],
"hand_point_right": [
223,
297
229,
343
],
"waistband_left": [
112,
241
119,
248
],
"hand_point_left": [
92,
305
97,
343
],
"shoulder_left": [
99,
116
108,
107
],
"relation_type": "System",
"shoulder_right": [
215,
116
]
212,
107
],
"relation_id": 1020356
},
"layer_order": true,
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"self_template": false,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 270372,
"color": "30 28 28",
"image_id": 69780,
"color": "209 196 171",
"image_id": 84093,
"offset": [
0,
0
1,
1
],
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
"path": "aida-users/89/sketchboard/female/Outwear/0943d209-7ce0-408c-bc61-83f15da94138.png",
"print": {
"element": {
"element_angle_list": [],
@@ -264,10 +263,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"location": [
[
0.0,
0.0
]
],
"print_angle_list": [
0.0,
0.0
],
"print_path_list": [],
"print_scale_list": []
"print_scale_list": [
[
0.0,
0.0
]
]
},
"single": {
"location": [],
@@ -276,22 +288,20 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"print_scale_list": []
}
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
"type": "Outwear"
},
{
"businessId": 270373,
"color": "30 28 28",
"image_id": 98243,
"color": "63 71 73",
"image_id": 100496,
"offset": [
0,
0
1,
1
],
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
"path": "aida-sys-image/images/female/blouse/0628001684.jpg",
"print": {
"element": {
"element_angle_list": [],
@@ -300,10 +310,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"location": [
[
0.0,
0.0
]
],
"print_angle_list": [
0.0,
0.0
],
"print_path_list": [],
"print_scale_list": []
"print_scale_list": [
[
0.0,
0.0
]
]
},
"single": {
"location": [],
@@ -312,7 +335,6 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"print_scale_list": []
}
},
"priority": 11,
"resize_scale": [
1.0,
1.0
@@ -320,14 +342,14 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"type": "Blouse"
},
{
"businessId": 270374,
"color": "172 68 68",
"image_id": 98244,
"color": "111 78 63",
"gradient": "aida-gradient/f69b98e8-4248-4f7a-98a2-21bac41bf3e0.png",
"image_id": 92193,
"offset": [
0,
0
1,
1
],
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
"path": "aida-sys-image/images/female/trousers/0825001160.jpg",
"print": {
"element": {
"element_angle_list": [],
@@ -336,10 +358,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"location": [
[
0.0,
0.0
]
],
"print_angle_list": [
0.0,
0.0
],
"print_path_list": [],
"print_scale_list": []
"print_scale_list": [
[
0.0,
0.0
]
]
},
"single": {
"location": [],
@@ -348,31 +383,37 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"print_scale_list": []
}
},
"priority": 12,
"resize_scale": [
1.0,
1.0
],
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
"type": "Outwear"
"type": "Trousers"
},
{
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
"image_id": 96090,
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
],
"objectSign": "65830966"
}
],
"process_id": "83"
"process_id": "4802946666428422",
"requestId": "1d1e7641-0d62-4da2-adc0-b4404910723c",
"callback_url": "https://api.aida.com.hk/api/third/party/receiveDesignResults"
}
"""
try:
# 异步
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
background_tasks.add_task(design_generate_v2, request_data)
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
@@ -380,30 +421,30 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
return ResponseModel()
@router.post('/get_progress')
def get_progress(request_data: DesignProgressModel):
"""
获取design 进度
创建一个具有以下参数的请求体:
- **process_id**: 进度id
示例参数:
{
"process_id": "6878547032381675"
}
"""
try:
logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}")
process_id = request_data.process_id
r = Redis()
data = r.read(key=process_id)
if data is None:
raise ValueError(f"No progress ID: {process_id}")
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}")
except Exception as e:
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
# @router.post('/get_progress')
# def get_progress(request_data: DesignProgressModel):
# """
# 获取design 进度
# 创建一个具有以下参数的请求体:
# - **process_id**: 进度id
#
# 示例参数:
# {
# "process_id": "6878547032381675"
# }
# """
# try:
# logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
# process_id = request_data.process_id
# r = Redis()
# data = r.read(key=process_id)
# if data is None:
# raise ValueError(f"No progress ID: {process_id}")
# logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data, indent=4)}")
# except Exception as e:
# logger.warning(f"get_progress Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data)
@router.post('/model_process')
@@ -419,44 +460,42 @@ def model_process(request_data: ModelProgressModel):
}
"""
try:
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = model_transpose(image_path=request_data.model_path)
logger.info(f"model_process response @@@@@@:{json.dumps(data)}")
logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"model_process Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
# ##############################################################
@router.post("/design_batch_generate")
async def design_batch(file: UploadFile = File(...),
tasks_id: str = Form(...),
user_id: str = Form(...),
file_name: str = Form(...),
total: int = Form(...)
):
dbg_config = DBGConfigModel(
tasks_id=tasks_id,
user_id=user_id,
file_name=file_name,
total=total
)
contents = await file.read()
file_name = file.filename
await save_request_file(contents, file_name)
return await start_design_batch_generate(dbg_config, contents)
async def save_request_file(contents, file_name):
# 创建保存文件的目录(如果不存在)
save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 处理文件
file_path = os.path.join(save_dir, file_name)
with open(file_path, "wb") as f:
f.write(contents)
"""design 批量处理 停用"""
# @router.post("/design_batch_generate")
# async def design_batch(file: UploadFile = File(...),
# tasks_id: str = Form(...),
# user_id: str = Form(...),
# file_name: str = Form(...),
# total: int = Form(...)
# ):
# dbg_config = DBGConfigModel(
# tasks_id=tasks_id,
# user_id=user_id,
# file_name=file_name,
# total=total
# )
# contents = await file.read()
# file_name = file.filename
# await save_request_file(contents, file_name)
# return await start_design_batch_generate(dbg_config, contents)
#
#
# async def save_request_file(contents, file_name):
# # 创建保存文件的目录(如果不存在)
# save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
# # 处理文件
# file_path = os.path.join(save_dir, file_name)
# with open(file_path, "wb") as f:
# f.write(contents)

View File

@@ -30,10 +30,10 @@ def design_pre_processing(request_data: DesignPreProcessingModel):
}
"""
try:
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
server = DesignPreprocessing()
data = server.pipeline(image_list=request_data.sketches)
logger.info(f"design response @@@@@@:{json.dumps(data)}")
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -33,18 +33,30 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun
- **version**: 使用模型版本 fast 或者 high
示例参数:
1. txt 2 img
{
"tasks_id": "123-89",
"prompt": "skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic",
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
"mode": "img2img",
"category": "sketch",
"gender": "male",
"version": "fast"
"tasks_id": "bd2cf809-24bc-49a6-91c9-193c6272a52e-2-89",
"prompt": "a single item of sketch of dress, 4k, white background",
"image_url": "",
"mode": "txt2img",
"category": "sketch",
"gender": "Female",
"version": "fast"
}
2. img 2 img
{
"tasks_id": "b861d4fa-5ae3-4a30-9c7a-7ba6bb9aa37b-1-89",
"prompt": "a single item of sketch of dress, 4k, white background",
"image_url": "aida-collection-element/89/Sketchboard/548da3a2-834f-49a7-b52c-e729c5ab5062.png",
"mode": "img2img",
"category": "sketch",
"gender": "Female",
"version": "fast"
}
"""
try:
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
service = GenerateImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -65,42 +77,41 @@ def generate_image(tasks_id: str):
return ResponseModel(data=data['data'])
'''multi view'''
'''multi view 停用'''
# @router.post("/generate_multi_view")
# def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **image_url**: 前视角图的输入minio或S3 url 地址
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
# }
# """
# try:
# logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
# service = GenerateMultiView(request_item)
# background_tasks.add_task(service.get_result)
# except Exception as e:
# logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel()
@router.post("/generate_multi_view")
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 前视角图的输入minio或S3 url 地址
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
}
"""
try:
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateMultiView(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_multi_view_cancel/{tasks_id}")
def generate_multi_view(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_multi_view_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
# @router.get("/generate_multi_view_cancel/{tasks_id}")
# def generate_multi_view(tasks_id: str):
# try:
# logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
# data = generate_multi_view_cancel(tasks_id)
# logger.info(f"generate_cancel response @@@@@@:{data}")
# except Exception as e:
# logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data['data'])
'''single logo'''
@@ -122,7 +133,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_
}
"""
try:
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
service = GenerateSingleLogoImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -167,7 +178,7 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t
}
"""
try:
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = GenerateProductImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -188,166 +199,164 @@ def generate_product_image(tasks_id: str):
return ResponseModel(data=data['data'])
'''relight image'''
'''relight image 停用'''
# @router.post("/generate_relight_image")
# def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **prompt**: 想要生成图片的描述词
# - **image_url**: 被生成图片的S3或minio url地址
# - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
# - **product_type**: 输入single item 还是 overall item
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
# "direction": "Right Light",
# "product_type": "overall"
# }
# """
# try:
# logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
# service = GenerateRelightImage(request_item)
# background_tasks.add_task(service.get_result)
# except Exception as e:
# logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel()
#
#
# @router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
# def generate_relight_image(tasks_id: str):
# try:
# logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
# data = generate_relight_image_cancel(tasks_id)
# logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
# except Exception as e:
# logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data['data'])
@router.post("/generate_relight_image")
def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
- **product_type**: 输入single item 还是 overall item
"""batch generate img 停用"""
# @router.post("/batch_generate_product_image")
# async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于获取生成结果
# - **prompt**: 想要生成图片的描述词
# - **image_url**: 被生成图片的S3或minio url地址
# - **image_strength**: 生成强度,越低越接近原图
# - **product_type**: 输入single item 还是 overall item
# - **batch_size**: 批生成数量
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
# "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
# "image_strength": 0.8,
# "product_type": "overall",
# "batch_size": 1
# }
# """
# return await start_product_batch_generate(request_batch_item)
#
#
# @router.post("/batch_generate_relight_image")
# async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于获取生成结果
# - **prompt**: 想要生成图片的描述词
# - **image_url**: 被生成图片的S3或minio url地址
# - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
# - **product_type**: 输入single item 还是 overall item
# - **batch_size**: 批生成数量
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
# "direction": "Right Light",
# "product_type": "overall",
# "batch_size": 1
# }
# """
# return await start_relight_batch_generate(request_batch_item)
示例参数:
{
"tasks_id": "123-89",
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"direction": "Right Light",
"product_type": "overall"
}
"""
try:
logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateRelightImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
def generate_relight_image(tasks_id: str):
try:
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_relight_image_cancel(tasks_id)
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
"""batch generate img"""
@router.post("/batch_generate_product_image")
async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **image_strength**: 生成强度,越低越接近原图
- **product_type**: 输入single item 还是 overall item
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
"image_strength": 0.8,
"product_type": "overall",
"batch_size": 1
}
"""
return await start_product_batch_generate(request_batch_item)
@router.post("/batch_generate_relight_image")
async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
- **product_type**: 输入single item 还是 overall item
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"direction": "Right Light",
"product_type": "overall",
"batch_size": 1
}
"""
return await start_relight_batch_generate(request_batch_item)
@router.post("/batch_generate_pose_transform_image")
async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 被生成图片的S3或minio url地址
- **pose_id**: 1
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"pose_id": "1",
"batch_size": 1
}
"""
return await start_pose_transform_batch_generate(request_batch_item)
"""agent tool"""
@router.post("/agent_tool_generate_image")
def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **prompt**: 想要生成图片的描述词
- **category**: 生成图片的类别sketch print 等等
- **gender**: 生成sketch专用服装类别
- **version**: 使用模型版本 fast 或者 high
- **size**: 生成数量
- **version**: 使用模型版本 fast 或者 high
示例参数:
{
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
"category": "sketch",
"gender": "male",
"size":2,
"version":"high"
}
"""
try:
logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
request_data = request_item.dict()
service = AgentToolGenerateImage(request_data['version'])
image_url_list, clothing_category_list = service.get_result(
prompt=request_data['prompt'],
size=request_data['size'],
version=request_data['version'],
category=request_data['category'],
gender=request_data['gender']
)
data = {
"image_url_list": image_url_list,
"clothing_category_list": clothing_category_list
}
logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
except Exception as e:
logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
# @router.post("/batch_generate_pose_transform_image")
# async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **image_url**: 被生成图片的S3或minio url地址
# - **pose_id**: 1
# - **batch_size**: 批生成数量
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
# "pose_id": "1",
# "batch_size": 1
# }
# """
# return await start_pose_transform_batch_generate(request_batch_item)
#
#
# """agent tool"""
#
#
# @router.post("/agent_tool_generate_image")
# def agent_tool_generate_image(request_item: AgentTollGenerateImageModel):
# """
# 创建一个具有以下参数的请求体:
# - **prompt**: 想要生成图片的描述词
# - **category**: 生成图片的类别sketch print 等等
# - **gender**: 生成sketch专用服装类别
# - **version**: 使用模型版本 fast 或者 high
# - **size**: 生成数量
# - **version**: 使用模型版本 fast 或者 high
#
#
# 示例参数:
# {
# "prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
# "category": "sketch",
# "gender": "male",
# "size":2,
# "version":"high"
# }
# """
# try:
# logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
# request_data = request_item.dict()
# service = AgentToolGenerateImage(request_data['version'])
# image_url_list, clothing_category_list = service.get_result(
# prompt=request_data['prompt'],
# size=request_data['size'],
# version=request_data['version'],
# category=request_data['category'],
# gender=request_data['gender']
# )
# data = {
# "image_url_list": image_url_list,
# "clothing_category_list": clothing_category_list
# }
# logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
# except Exception as e:
# logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data)

View File

@@ -14,22 +14,22 @@ logger = logging.getLogger()
@router.post("/image2sketch")
def image2sketch(request_item: Image2SketchModel):
"""
创建一个具有以下参数的请求体:
- **image_url**: 提取图片url
- **default_style**: 原始、 1、2、3、4、5
- **sketch_bucket**: sketch保存的bucket
- **sketch_name**: sketch保存的object name
创建一个具有以下参数的请求体:
- **image_url**: 提取图片url
- **default_style**: 原始、 1、2、3、4、5
- **sketch_bucket**: sketch保存的bucket
- **sketch_name**: sketch保存的object name
示例参数:
{
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
"default_style": 0,
"sketch_bucket": "test",
"sketch_name": "image2sketch/area_fill_img.png"
}
"""
示例参数:
{
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
"default_style": 0,
"sketch_bucket": "test",
"sketch_name": "image2sketch/area_fill_img.png"
}
"""
try:
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = LineArtService(request_item)
result_url = service.get_result()
except Exception as e:

View File

@@ -1,116 +0,0 @@
import logging
import sys
from typing import Optional
from fastapi import APIRouter, HTTPException, Query
from concurrent.futures import ThreadPoolExecutor
import threading
from app.schemas.response_template import ResponseModel
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
logger = logging.getLogger()
router = APIRouter()
# 使用线程池执行器来运行长时间任务
executor = ThreadPoolExecutor(max_workers=1)
# 用于跟踪任务状态
task_status = {"running": False}
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
"""在后台线程中运行导入任务"""
original_argv = None
try:
task_status["running"] = True
# 保存原始 sys.argv
original_argv = sys.argv.copy()
# 模拟命令行参数
sys.argv = [
"import_sys_sketch_to_milvus.py",
"--batch-size", str(batch_size),
"--retry-times", str(retry_times),
]
if limit is not None:
sys.argv.extend(["--limit", str(limit)])
if offset > 0:
sys.argv.extend(["--offset", str(offset)])
if skip_create_collection:
sys.argv.append("--skip-create-collection")
import_main()
task_status["running"] = False
logger.info("导入任务完成")
except Exception as e:
task_status["running"] = False
logger.error(f"导入任务失败: {e}", exc_info=True)
raise
finally:
# 恢复原始 sys.argv
if original_argv is not None:
sys.argv = original_argv
@router.post("/import-sys-sketch", response_model=ResponseModel)
async def import_sys_sketch(
batch_size: int = Query(1000, description="批量处理大小默认1000"),
retry_times: int = Query(3, description="失败重试次数默认3"),
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
offset: int = Query(0, description="起始偏移量默认0"),
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
):
"""
从 t_sys_file 导入系统图向量到 Milvus
该接口会异步执行导入任务,任务在后台运行。
"""
try:
# 检查是否有任务正在运行
if task_status["running"]:
raise HTTPException(
status_code=409,
detail="已有导入任务正在运行,请等待完成后再试"
)
# 在后台线程中执行任务
executor.submit(
run_import_task,
batch_size,
retry_times,
limit,
offset,
skip_create_collection
)
return ResponseModel(
code=200,
msg="导入任务已启动,正在后台执行",
data={
"status": "started",
"batch_size": batch_size,
"retry_times": retry_times,
"limit": limit,
"offset": offset,
"skip_create_collection": skip_create_collection
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"启动导入任务失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
async def get_import_status():
"""
获取导入任务状态
"""
return ResponseModel(
code=200,
msg="OK",
data={
"running": task_status["running"]
}
)

View File

@@ -35,10 +35,10 @@ def mannequins_edit(request_data: MannequinModel):
}**
"""
try:
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
service = MannequinEditService(request_data)
data = service()
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data)}")
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"mannequins_edit Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -4,55 +4,55 @@ import logging
import requests
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.core.config import COMFYUI_SERVER_ADDRESS
from app.core.config import settings
from app.schemas.comfyui_i2v import ComfyuiI2VModel, ComfyuiFLF2VModel
from app.schemas.pose_transform import PoseTransformModel
from app.schemas.response_template import ResponseModel
from app.service.comfyui_I2V.flf2v_server import ComfyUIServerFLF2V
from app.service.comfyui_I2V.i2v_server import ComfyUIServerI2V
from app.service.comfyui_I2V.pose2v_server import ComfyUIServerPose2V
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
router = APIRouter()
logger = logging.getLogger()
"""停用"""
@router.post("/pose_transform")
def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 被生成图片的S3或minio url地址
- **pose_id**: 1
# @router.post("/pose_transform")
# def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **image_url**: 被生成图片的S3或minio url地址
# - **pose_id**: 1
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
# "pose_id": "1"
# }
# """
# try:
# logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
# service = PoseTransformService(request_item)
# background_tasks.add_task(service.get_result)
# except Exception as e:
# logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel()
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"pose_id": "1"
}
"""
try:
logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = PoseTransformService(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/pose_transform_cancel/{tasks_id}")
def pose_transform_cancel(tasks_id: str):
try:
logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
data = pose_transform_infer_cancel(tasks_id)
logger.info(f"pose_transform_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
# @router.get("/pose_transform_cancel/{tasks_id}")
# def pose_transform_cancel(tasks_id: str):
# try:
# logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
# data = pose_transform_infer_cancel(tasks_id)
# logger.info(f"pose_transform_cancel response @@@@@@:{data}")
# except Exception as e:
# logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data['data'])
"""
@@ -77,7 +77,7 @@ def comfyui_image_pose_2_video(request_item: PoseTransformModel, background_task
}
"""
try:
logger.info(f"image_pose_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"image_pose_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = ComfyUIServerPose2V(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -107,7 +107,7 @@ def comfyui_image_2_video(request_item: ComfyuiI2VModel, background_tasks: Backg
}
"""
try:
logger.info(f"image_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"image_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = ComfyUIServerI2V(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -139,7 +139,7 @@ def comfyui_flf_2_video(request_item: ComfyuiFLF2VModel, background_tasks: Backg
}
"""
try:
logger.info(f"flf_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"flf_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = ComfyUIServerFLF2V(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -153,7 +153,7 @@ def comfyui_i_2_video_cancel(tasks_id: str):
try:
logger.info(f"comfyui_i_2_video_cancel request item is : @@@@@@:{tasks_id}")
response = requests.post(
f"http://{COMFYUI_SERVER_ADDRESS}/interrupt",
f"http://{settings.COMFYUI_SERVER_ADDRESS}/interrupt",
json={"prompt_id": tasks_id}
)
data = {}

View File

@@ -1,85 +0,0 @@
import logging
from fastapi import APIRouter, HTTPException
from concurrent.futures import ThreadPoolExecutor
from app.schemas.response_template import ResponseModel
from app.service.recommendation_system.precompute import run_precompute
logger = logging.getLogger()
router = APIRouter()
# 使用线程池执行器来运行长时间任务
executor = ThreadPoolExecutor(max_workers=1)
# 用于跟踪任务状态
task_status = {"running": False}
def run_precompute_task():
"""在后台线程中运行预计算任务"""
try:
task_status["running"] = True
logger.info("开始执行预计算任务...")
run_precompute()
task_status["running"] = False
logger.info("预计算任务完成")
except Exception as e:
task_status["running"] = False
logger.error(f"预计算任务失败: {e}", exc_info=True)
raise
@router.post("/precompute", response_model=ResponseModel)
async def precompute():
"""
运行预计算任务
该接口会异步执行预计算任务,包括:
1. 优化数据库表结构
2. 历史数据迁移
3. 初始用户偏好向量生成
任务在后台运行。
"""
try:
# 检查是否有任务正在运行
if task_status["running"]:
raise HTTPException(
status_code=409,
detail="已有预计算任务正在运行,请等待完成后再试"
)
# 在后台线程中执行任务
executor.submit(run_precompute_task)
return ResponseModel(
code=200,
msg="预计算任务已启动,正在后台执行",
data={
"status": "started",
"tasks": [
"优化数据库表结构",
"历史数据迁移",
"初始用户偏好向量生成"
]
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
@router.get("/precompute/status", response_model=ResponseModel)
async def get_precompute_status():
"""
获取预计算任务状态
"""
return ResponseModel(
code=200,
msg="OK",
data={
"running": task_status["running"]
}
)

View File

@@ -1,13 +1,10 @@
import json
import logging
import time
from fastapi import APIRouter, HTTPException
from app.schemas.prompt_generation import PromptGenerationImageModel, ImageRequest
from app.schemas.response_template import ResponseModel
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, \
get_prompt_from_image
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, get_prompt_from_image
router = APIRouter()
logger = logging.getLogger()
@@ -34,19 +31,19 @@ def prompt_generation(request_data: PromptGenerationImageModel):
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
@router.post("/img2prompt")
def get_prompt_from_img(img: ImageRequest):
"""
自动识别图片并输出为prompt
:param img: 图片的minio地址
:return: 图片的文字描述
"""
text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
"The description should allow for the reconstruction of the corresponding line art based on the details "
"given.")
logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
description = get_prompt_from_image(img, text)
logger.info(f"生成的图片描述 response @@@@@@:{description}")
return description
# 停用
# @router.post("/img2prompt")
# def get_prompt_from_img(img: ImageRequest):
# """
# 自动识别图片并输出为prompt
#
# :param img: 图片的minio地址
# :return: 图片的文字描述
# """
# text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
# "The description should allow for the reconstruction of the corresponding line art based on the details "
# "given.")
# logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
# description = get_prompt_from_image(img, text)
# logger.info(f"生成的图片描述 response @@@@@@:{description}")
# return description

View File

@@ -26,9 +26,9 @@ def query_image(request_data: QueryImageModel):
}
"""
try:
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = query(request_data.gender, request_data.content)
logger.info(f"query_image response @@@@@@:{json.dumps(data)}")
logger.info(f"query_image response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"query_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -1,175 +1,206 @@
import io
import logging
import math
import sys
from typing import List, Optional
from fastapi import HTTPException, APIRouter, Query
from apscheduler.schedulers.background import BackgroundScheduler
import time
from typing import List
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
from app.service.recommendation_system.incremental_listener import start_background_listener
from app.service.recommendation_system.milvus_client import create_collection
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
# """
# :param user_id: 4
# :param category: female_skirt
# :param num_recommendations: 1
# :return:
# [
# "aida-sys-image/images/female/skirt/903000017.jpg"
# ]
# """
# try:
# start_time = time.time()
# cache_key = (user_id, category)
# # === 新增:用户存在性检查 ===
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
# user_exists_feat = user_id in matrix_data["user_index_feature"]
#
# # 任一矩阵不存在用户则返回随机推荐
# if not (user_exists_inter and user_exists_feat):
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
# return get_random_recommendations(category, num_recommendations)
#
# # 检查缓存
# if cache_key in matrix_data["cached_scores"]:
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
# else:
# # 实时计算逻辑(同原代码)
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
#
# category_iids = matrix_data["category_to_iids"].get(category, [])
# valid_sketch_idxs_inter = [
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
# if iid in category_iids
# ]
#
# # 处理交互分数
# raw_inter_scores = []
# if user_idx_inter is not None and valid_sketch_idxs_inter:
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
# processed_inter = raw_inter_scores * 0.7
#
# # 处理特征分数
# valid_sketch_idxs_feature = [
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
# if iid in category_iids
# ]
# raw_feat_scores = []
# if user_idx_feature is not None and valid_sketch_idxs_feature:
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
# processed_feat = raw_feat_scores
# else:
# processed_feat = np.array([])
#
# # 更新缓存
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
#
# # 合并分数
# if brand_id is not None:
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
#
# brand_feat_valid = (
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
# brand_idx_feature is not None and
# valid_sketch_idxs_feature # 有可用索引
# )
#
# if brand_feat_valid:
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
# brand_idx_feature, valid_sketch_idxs_feature
# ]
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
# )
# processed_brand_feat = raw_brand_feat_scores
#
# # 如果 processed_feat 是空的,替换为全 0避免 shape 不一致
# if processed_feat.size == 0:
# processed_feat = np.zeros_like(processed_brand_feat)
#
# final_scores = processed_inter + 0.3 * (
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
# )
# else:
# # brand 信息不可用
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
# else:
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
#
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
#
# # 概率采样
# scores = np.array(final_scores)
#
# # 调整后的概率转换带温度控制的softmax
# def calibrated_softmax(scores, temperature=1.0):
# scores = scores / temperature
# scale = scores - max(scores)
# exps = np.exp(scale)
# return exps / np.sum(exps)
#
# probs = calibrated_softmax(scores, 0.09)
#
# chosen_indices = np.random.choice(
# len(valid_sketch_idxs),
# size=min(num_recommendations, len(valid_sketch_idxs)),
# p=probs,
# replace=False
# )
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
#
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
# return recommendations
# except Exception as e:
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
# raise HTTPException(status_code=500, detail=str(e))
# @router.on_event("startup")
@router.on_event("startup")
async def startup_event():
"""启动时初始化增量监听任务"""
try:
# 确保 Milvus 集合已创建(若已存在则直接返回)
try:
create_collection()
except Exception as exc:
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
# 配置定时任务
scheduler = BackgroundScheduler()
start_background_listener(scheduler)
scheduler.start()
logger.info("增量监听定时任务已启动")
except Exception as e:
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
# 初始加载
load_resources()
# 配置定时任务
scheduler = BackgroundScheduler()
scheduler.add_job(
load_resources,
trigger=CronTrigger(hour=0, minute=30),
name="每日资源刷新"
)
scheduler.start()
logger.info("定时任务已启动")
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
async def recommend(
user_id: int,
category: str,
style: Optional[str] = Query(
None,
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
),
):
"""新版推荐接口Milvus + Redis 偏好向量)。"""
def softmax(scores):
max_score = max(scores)
exp_scores = [math.exp(s - max_score) for s in scores]
sum_exp = sum(exp_scores)
return [s / sum_exp for s in exp_scores]
# def get_random_recommendations(category: str, num: int) -> List[str]:
# """根据预加载热度向量推荐(冷启动)"""
# try:
# heat_data = matrix_data.get("heat_data", {})
#
# if category not in heat_data:
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
#
# heat_dict = heat_data[category] # {url: score}
# urls = list(heat_dict.keys())
# scores = list(heat_dict.values())
#
# if not urls:
# raise ValueError("该类别下无热度记录,使用随机推荐")
#
# probs = softmax(scores)
# sample_size = min(num, len(urls))
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
#
# return sampled_urls
#
# except Exception as e:
# # 回退:完全随机推荐
# all_iids = list(matrix_data["iid_to_sketch"].keys())
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# sample_size = min(num, len(category_iids))
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
def get_random_recommendations(category: str, num: int) -> List[str]:
"""全品类随机推荐"""
all_iids = list(matrix_data["iid_to_sketch"].keys())
# 优先从当前品类选择
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# 确保不超出实际数量
sample_size = min(num, len(category_iids))
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
"""
@param user_id: 4
@param category: female_skirt
@param num_recommendations: 1
@return:
[
"aida-sys-image/images/female/skirt/903000017.jpg"
]
"""
try:
results = get_new_recommendations(user_id, category, style)
path = results[0] if results else ""
return [path]
logger.info(f"user_id:{user_id}-----category:{category}-----brand_id:{brand_id}-----brand_scale:{brand_scale}-----num_recommendations:{num_recommendations}")
start_time = time.time()
cache_key = (user_id, category)
# === 新增:用户存在性检查 ===
user_exists_inter = user_id in matrix_data["user_index_interaction"]
user_exists_feat = user_id in matrix_data["user_index_feature"]
# 任一矩阵不存在用户则返回随机推荐
if not (user_exists_inter and user_exists_feat):
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
return get_random_recommendations(category, num_recommendations)
# 检查缓存
if cache_key in matrix_data["cached_scores"]:
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
else:
# 实时计算逻辑(同原代码)
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
category_iids = matrix_data["category_to_iids"].get(category, [])
valid_sketch_idxs_inter = [
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
if iid in category_iids
]
# 处理交互分数
raw_inter_scores = []
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
# 处理特征分数
valid_sketch_idxs_feature = [
idx for iid, idx in matrix_data["sketch_index_feature"].items()
if iid in category_iids
]
raw_feat_scores = []
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores
else:
processed_feat = np.array([])
# 更新缓存
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
# 合并分数
if brand_id is not None:
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
brand_feat_valid = (
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
brand_idx_feature is not None and
valid_sketch_idxs_feature # 有可用索引
)
if brand_feat_valid:
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
brand_idx_feature, valid_sketch_idxs_feature
]
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
)
processed_brand_feat = raw_brand_feat_scores
# 如果 processed_feat 是空的,替换为全 0避免 shape 不一致
if processed_feat.size == 0:
processed_feat = np.zeros_like(processed_brand_feat)
final_scores = processed_inter + 0.3 * (
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
)
else:
# brand 信息不可用
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
else:
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
# 概率采样
scores = np.array(final_scores)
# 调整后的概率转换带温度控制的softmax
def calibrated_softmax(scores, temperature=1.0):
scores = scores / temperature
scale = scores - max(scores)
exps = np.exp(scale)
return exps / np.sum(exps)
probs = calibrated_softmax(scores, 0.09)
chosen_indices = np.random.choice(
len(valid_sketch_idxs),
size=min(num_recommendations, len(valid_sketch_idxs)),
p=probs,
replace=False
)
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}")
return recommendations
except Exception as e:
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
logger.error(f"推荐失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,42 +1,40 @@
from fastapi import APIRouter
from app.api import api_attribute_retrieve, api_query_image
from app.api import api_attribute_retrieve
from app.api import api_brand_dna
from app.api import api_brighten
from app.api import api_chat_robot
from app.api import api_clothing_seg
from app.api import api_design
from app.api import api_design_pre_processing
from app.api import api_extraction_project_info
from app.api import api_generate_image
from app.api import api_image2sketch
from app.api import api_import_sys_sketch
from app.api import api_mannequins_edit
from app.api import api_pose_transform
from app.api import api_precompute
from app.api import api_prompt_generation
from app.api import api_recommendation
from app.api import api_super_resolution
from app.api import api_test
router = APIRouter()
router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")
router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api")
router.include_router(api_design.router, tags=['design'], prefix="/api")
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
router.include_router(api_import_sys_sketch.router, tags=['api_import_sys_sketch'], prefix="/api")
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
"""停用"""
# from app.api import api_chat_robot
# from app.api import api_query_image
# from app.api import api_brighten
# from app.api import api_extraction_project_info
# from app.api import api_image2sketch
# from app.api import api_super_resolution
# router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
# router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
# router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
# router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
# router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
# router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")

View File

@@ -27,7 +27,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg
}
"""
try:
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = SuperResolution(request_item)
background_tasks.add_task(service.sr_result)
except Exception as e:

View File

@@ -4,8 +4,7 @@ import logging
from fastapi import APIRouter
from fastapi import HTTPException
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, \
BATCH_PS_RABBITMQ_QUEUES, RABBITMQ_ENV
from app.core.config import settings, SR_RABBITMQ_QUEUES, GMV_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, BATCH_PS_RABBITMQ_QUEUES
from app.schemas.response_template import ResponseModel
logger = logging.getLogger()
@@ -15,9 +14,9 @@ router = APIRouter()
@router.get("{id}")
def test(id: int):
data = {
"RABBITMQ_ENV":RABBITMQ_ENV,
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
"RABBITMQ_ENV": settings.SERVE_ENV,
# "超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
# "多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
"pose transform PS_RABBITMQ_QUEUES": PS_RABBITMQ_QUEUES,
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
@@ -29,10 +28,9 @@ def test(id: int):
"batch relight BATCH_GRI_RABBITMQ_QUEUES": BATCH_GRI_RABBITMQ_QUEUES,
"batch pose transform BATCH_PS_RABBITMQ_QUEUES": BATCH_PS_RABBITMQ_QUEUES,
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
"local_oss_server": OSS
"JAVA_STREAM_API_URL": settings.JAVA_STREAM_API_URL,
}
logger.info(json.dumps(data))
logger.info(json.dumps(data, ensure_ascii=False, indent=4))
if id == 1:
raise HTTPException(status_code=404, detail="Item not found")

235
app/core/config.backup.py Normal file
View File

@@ -0,0 +1,235 @@
import os
import pika
from dotenv import load_dotenv
from pydantic import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
load_dotenv(os.path.join(BASE_DIR, '.env'))
class Settings(BaseSettings):
PROJECT_NAME: str = 'FASTAPI BASE'
SECRET_KEY: str = ''
API_PREFIX: str = ''
BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL: str = ''
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
OSS = "minio"
DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "../seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
RECOMMEND_PATH_PREFIX = "service/recommend/"
CHROMADB_PATH = "./chromadb/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "/seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
CHROMADB_PATH = "/chromadb/"
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
# RABBITMQ_ENV = "-local" # 本地测试环境
if RABBITMQ_ENV == "-dev":
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
elif RABBITMQ_ENV == "-prod":
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
settings = Settings()
# minio 配置
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# S3 配置
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
S3_AWS_SECRET_ACCESS_KEY = "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"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
REDIS_DB = "2"
# rabbitmq config
RABBITMQ_PARAMS = {
"host": "18.167.251.121",
"port": 5672,
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
"virtual_host": "/"
}
# milvus 配置
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con_python' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# SR service config
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
# GenerateImage service config
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
# Generate Product service config
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
# GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Product service config 旧版product img 模型
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config
SEGMENTATION = {
"new_model_name": "seg_knet",
"name": "seg_ocrnet_hr18",
"input": "seg_input__0",
"output": "seg_output__0",
}
# ollama config
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
# design batch
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# DESIGN 预处理
IF_DEBUG_SHOW = False
# 优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
'hairstyle_front': 97,
'outwear_front': 20,
'tops_front': 19,
'dress_front': 18,
'blouse_front': 17,
'skirt_front': 16,
'trousers_front': 15,
'bottoms_front': 14,
'shoes_right': 1,
'shoes_left': 1,
'body': 0,
'bottoms_back': -14,
'trousers_back': -15,
'skirt_back': -16,
'blouse_back': -17,
'dress_back': -18,
'tops_back': -19,
'outwear_back': -20,
'hairstyle_back': -97,
'bag_back': -98,
'earring_back': -99,
}
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
DB_CONFIG = {
"host": "18.167.251.121",
"port": 3306,
"user": "root",
"password": "QWa998345",
"database": "aida",
"charset": "utf8mb4"
}
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址

View File

@@ -1,188 +1,91 @@
import os
import pika
from dotenv import load_dotenv
from pydantic import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
load_dotenv(os.path.join(BASE_DIR, '.env'))
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
PROJECT_NAME: str = 'FASTAPI BASE'
SECRET_KEY: str = ''
API_PREFIX: str = ''
BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL: str = ''
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
"""
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
"""
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
# extra='ignore' # 忽略环境变量中多余的键
)
# --- 服务端口配置信息 ---
PORT: int = Field(default=8001, description="")
# --- 服务环境 配置信息 ---
SERVE_ENV: str = Field(default='', description="")
# --- 开发状态 配置信息 ---
DEBUG: bool = Field(default=False, description="")
# --- 千问api 配置信息 ---
QWEN_API_KEY: str = Field(default="", description="")
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS: str = Field(default='', description="")
OSS = "minio"
DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "../seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
RECOMMEND_PATH_PREFIX = "service/recommend/"
CHROMADB_PATH = "./chromadb/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "/seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
CHROMADB_PATH = "/chromadb/"
# --- minio 配置信息 ---
MINIO_URL: str = Field(default='', description="")
MINIO_ACCESS: str = Field(default='', description="")
MINIO_SECRET: str = Field(default='', description="")
MINIO_SECURE: bool = Field(default=True, description="")
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
# RABBITMQ_ENV = "-local" # 本地测试环境
# --- redis 配置信息 ---
REDIS_HOST: str = Field(default='', description="")
REDIS_PORT: str = Field(default='', description="")
REDIS_DB: int = Field(default=0, description="")
# --- mysql 配置信息 ---
MYSQL_HOST: str = Field(default='', description="")
MYSQL_PORT: str = Field(default='', description="")
MYSQL_USER: str = Field(default='', description="")
MYSQL_PASSWORD: str = Field(default='', description="")
MYSQL_DB: str = Field(default='', description="")
MYSQL_CHARSET: str = Field(default='utf8mb4', description="")
# --- rabbit-mq 配置信息 ---
MQ_HOST: str = Field(default='', description="")
MQ_PORT: str = Field(default='', description="")
MQ_USERNAME: str = Field(default='', description="")
MQ_PASSWORD: str = Field(default='', description="")
MQ_VIRTUAL_HOST: str = Field(default='/', description="")
MQ_ENV: str = Field(default='', description="")
# --- milvus 配置信息 ---
MILVUS_URL: str = Field(default='', description="")
MILVUS_TOKEN: str = Field(default='', description="")
MILVUS_ALIAS: str = Field(default='', description="")
# --- ollama 配置信息 ---
CHROMADB_PATH: str = Field(default='', description="")
# --- ollama 配置信息 ---
OLLAMA_URL: str = Field(default='', description="")
# --- Design Callback Java 接口 ---
JAVA_STREAM_API_URL: str = Field(default='', description="")
# --- 其他配置信息 以下均为Docker容器内配置---
LOGS_PATH: str = Field(default="/logs/", description="")
CATEGORY_PATH: str = Field(default="/app/service/attribute/config/descriptor/category/category_dis.csv", description="")
SEG_CACHE_PATH: str = Field(default="/seg_cache/", description="")
RECOMMEND_PATH_PREFIX: str = Field(default="/app/service/recommend/", description="")
if RABBITMQ_ENV == "-dev":
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
elif RABBITMQ_ENV == "-prod":
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
settings = Settings()
# minio 配置
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# S3 配置
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
S3_AWS_SECRET_ACCESS_KEY = "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"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
REDIS_DB = "2"
# rabbitmq config
RABBITMQ_PARAMS = {
"host": "18.167.251.121",
"port": 5672,
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
"virtual_host": "/"
"""Design 服务"""
# 推荐服装类别映射
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# milvus 配置
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida_back' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# SR service config
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
# GenerateImage service config
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
# Generate Product service config
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
# GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Product service config 旧版product img 模型
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config
SEGMENTATION = {
"new_model_name": "seg_knet",
"name": "seg_ocrnet_hr18",
"input": "seg_input__0",
"output": "seg_output__0",
}
# ollama config
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
# design batch
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# DESIGN 预处理
IF_DEBUG_SHOW = False
# 优先级
# Design前后排优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
@@ -208,28 +111,71 @@ PRIORITY_DICT = {
'bag_back': -98,
'earring_back': -99,
}
# Design 关键点字段
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# milvus配置信息
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
# ollama 地址
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
DB_CONFIG = {
"host": "18.167.251.121",
"port": 3306,
"user": "root",
"password": "QWa998345",
"database": "aida",
"charset": "utf8mb4"
}
"""Triton Server Config"""
# Design
DESIGN_MODEL_URL = '10.1.1.240:10000'
DESIGN_MODEL_NAME = 'seg_knet'
# Generate Image
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
# Generate Single Logo
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
# Generate Product (整套和单品)
GPI_MODEL_URL = '10.1.1.243:10051'
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# 以下停用中...*************
# 多视角生成
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
# 超分
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
# 打光
GRI_MODEL_URL = '10.1.1.240:10051'
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
# agent 图片生成
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
# 图转视频 triton版
PT_MODEL_URL = '10.1.1.243:10061'
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址
# *************
"""MQ 队列信息"""
# 生成图片 moodboard printboard sketchboard
GI_RABBITMQ_QUEUES = f"GenerateImage-{settings.SERVE_ENV}"
# 生成slogan
SLOGAN_RABBITMQ_QUEUES = f"Slogan-{settings.SERVE_ENV}"
# 转产品图
GPI_RABBITMQ_QUEUES = f"ToProductImage-{settings.SERVE_ENV}"
# 产品图转视频
PS_RABBITMQ_QUEUES = f"PoseTransform-{settings.SERVE_ENV}"
# 以下停用中...*************
# 产品图打光
GRI_RABBITMQ_QUEUES = f"Relight-{settings.SERVE_ENV}"
# 超分
SR_RABBITMQ_QUEUES = f"SuperResolution-{settings.SERVE_ENV}"
# 生成多视图
GMV_RABBITMQ_QUEUES = f"GenerateMultiView-{settings.SERVE_ENV}"
# 批量转产品图
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage-{settings.SERVE_ENV}"
# 批量打光
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight-{settings.SERVE_ENV}"
# 批量图片转视频
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform-{settings.SERVE_ENV}"
# 批量design
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch-{settings.SERVE_ENV}"
# *************

10
app/core/mysql_config.py Normal file
View File

@@ -0,0 +1,10 @@
from app.core.config import settings
DB_CONFIG = {
"host": settings.MYSQL_HOST,
"port": settings.MYSQL_PORT,
"user": settings.MYSQL_USER,
"password": settings.MYSQL_PASSWORD,
"database": settings.MYSQL_DB,
"charset": settings.MYSQL_CHARSET,
}

View File

@@ -0,0 +1,10 @@
# rabbitmq config
import pika
from app.core.config import settings
RABBITMQ_PARAMS = {
"host": settings.MQ_HOST,
"port": settings.MQ_PORT,
"credentials": pika.credentials.PlainCredentials(username=settings.MQ_USERNAME, password=settings.MQ_PASSWORD),
"virtual_host": settings.MQ_VIRTUAL_HOST,
}

View File

@@ -79,12 +79,8 @@
}
]
}
],
"process_id": "87",
"tasks_id": ,
"tasks_id": ""
}
//用 openai jsonl
//

View File

@@ -1,10 +1,17 @@
# 1. 这里的顺序至关重要!必须在最顶端
import sys
try:
import asyncore
except ImportError:
import pyasyncore
sys.modules['asyncore'] = pyasyncore
import logging.config
import uvicorn
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from fastapi import HTTPException, Request
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from app.api.api_route import router
@@ -12,19 +19,22 @@ from app.core.config import settings
from app.core.record_api_count import count_api_calls
from app.schemas.response_template import ResponseModel
from logging_env import LOGGER_CONFIG_DICT
from dotenv import load_dotenv
from starlette.middleware.cors import CORSMiddleware
logging.config.dictConfig(LOGGER_CONFIG_DICT)
logging.getLogger("pika").setLevel(logging.WARNING)
from starlette.middleware.cors import CORSMiddleware
logger = logging.getLogger(__name__)
load_dotenv()
def get_application() -> FastAPI:
application = FastAPI(
title=settings.PROJECT_NAME, docs_url="/docs", redoc_url='/re-docs',
openapi_url=f"{settings.API_PREFIX}/openapi.json",
docs_url="/docs",
redoc_url='/re-docs',
openapi_url=f"/openapi.json",
description='''
Base frame with FastAPI
- Super Resolution API
@@ -33,13 +43,13 @@ def get_application() -> FastAPI:
)
application.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
application.middleware("http")(count_api_calls)
application.include_router(router=router, prefix=settings.API_PREFIX)
application.include_router(router=router)
return application
@@ -47,14 +57,12 @@ app = get_application()
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
async def http_exception_handler(exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict()
)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=settings.PORT)

View File

@@ -1,22 +1,24 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import logging
from pprint import pprint
import torch
import cv2
import mmcv
import numpy as np
import pandas as pd
from minio import Minio
import torch
import tritonclient.http as httpclient
from app.core.config import *
from minio import Minio
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.oss_client import oss_get_image
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class AttributeRecognition:
def __init__(self, const, request_data):
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
@@ -96,11 +98,12 @@ class AttributeRecognition:
res = {**dict1, **dict2}
return res
def get_image(self, url):
@staticmethod
def get_image(url):
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

View File

@@ -7,24 +7,25 @@
@Date 2023/9/16 18:31:08
@detail
"""
from minio import Minio
from skimage import transform
import cv2
import mmcv
import numpy as np
import pandas as pd
from minio import Minio
import tritonclient.http as httpclient
import torch
from app.core.config import *
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import CategoryRecognitionModel
from app.service.utils.oss_client import oss_get_image
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class CategoryRecognition:
def __init__(self, request_data):
self.attr_type = pd.read_csv(CATEGORY_PATH)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
self.request_data = []
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
for sketch in request_data:
@@ -46,13 +47,14 @@ class CategoryRecognition:
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
def get_image(self, url):
@staticmethod
def get_image(url):
# Get data of an object.
# Read data from response.
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
@@ -68,7 +70,7 @@ class CategoryRecognition:
colattr = list(self.attr_type['labelName'])
task = self.attr_type['taskName'][0]
# self.attr_type['taskName'][0]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]

View File

@@ -9,15 +9,16 @@ import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL, CATEGORY_PATH
from app.core.config import DESIGN_MODEL_URL
from app.core.config import settings
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import local_debug_const, const
from app.service.attribute.config import const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
logger = logging.getLogger()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
logger = logging.getLogger()
class BrandDna:
@@ -25,7 +26,7 @@ class BrandDna:
self.sketch_bucket = "test"
self.image_url = request_item.image_url
self.is_brand_dna = request_item.is_brand_dna
self.attr_type = pd.read_csv(CATEGORY_PATH)
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')

View File

@@ -3,23 +3,25 @@ import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
# from langchain_openai import ChatOpenAI
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.core.config import settings
class GenerateBrandInfo:
def __init__(self, request_data):
# minio client init
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.generate_logo_prompt = None
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
# user info init
self.user_id = request_data.user_id
@@ -55,7 +57,7 @@ class GenerateBrandInfo:
return self.result_data
def llm_generate_brand_info(self):
output = self.model(self._input.to_messages())
output = self.model.invoke(self._input.to_messages())
brand_data = self.output_parser.parse(output.content)
self.result_data = brand_data
self.generate_logo_prompt = brand_data['brand_logo_prompt']
@@ -87,8 +89,8 @@ class GenerateBrandInfo:
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}'
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:

View File

@@ -1,32 +0,0 @@
from dotenv import load_dotenv
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
# 加载.env文件的环境变量
load_dotenv()
# 创建一个大语言模型model指定了大语言模型的种类
model = ChatOpenAI(model="qwen2.5-14b-instruct")
# 想要接收的响应模式
response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
ResponseSchema(name="brand_slogan", description="Brand slogan."),
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
template="你是一个时装品牌的设计师。根据用户输入提取出brand namebrand slogan,brand logo 描述。如果没有以上内容需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt这个prompt用于生成模型.\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": format_instructions}
)
_input = prompt.format_prompt(question="brand name: cat home")
output = model(_input.to_messages())
brand_data = output_parser.parse(output.content)
def generate_logo(bucket_name, object_name, prompt):
pass

View File

@@ -3,27 +3,20 @@ import json
import logging
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain.agents import AgentExecutor
from langchain.callbacks.manager import Callbacks, CallbackManager
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, RunInfo
from langchain_classic.agents import AgentExecutor
from langchain_classic.schema import RUN_KEY
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import Callbacks, CallbackManager
from langchain_core.load import dumpd
from langchain_core.outputs import RunInfo
class CustomAgentExecutor(AgentExecutor):
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
session_key: str = "",
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
def __call__(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, session_key: str = "", *, tags: Optional[List[str]] = None, include_run_info: bool = False, **kwargs) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
**kwargs:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
@@ -72,7 +65,7 @@ class CustomAgentExecutor(AgentExecutor):
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None and outputs['need_record']:
self.memory.save_context(inputs, outputs, session_key)
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
@@ -95,7 +88,7 @@ class CustomAgentExecutor(AgentExecutor):
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs, session_key)
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
@@ -119,7 +112,8 @@ class CustomAgentExecutor(AgentExecutor):
{return_value_key: observation},
"",
)
except:
except Exception as e:
print(e)
pass
# Invalid tools won't be in the map, so we return False.

View File

@@ -1,26 +1,15 @@
import json
import re
from dataclasses import dataclass
from json import JSONDecodeError
from typing import List, Tuple, Any, Union
from dataclasses import dataclass
from langchain.callbacks.manager import Callbacks
from langchain.agents import (
OpenAIFunctionsAgent,
)
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
OutputParserException
)
from langchain.schema.messages import (
AIMessage,
FunctionMessage
)
from langchain.tools import BaseTool, StructuredTool
# from langchain.tools.convert_to_openai import FunctionDescription
from langchain.utils.openai_functions import FunctionDescription
from langchain_classic.agents import OpenAIFunctionsAgent
from langchain_community.utils.ernie_functions import FunctionDescription
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import Callbacks
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage, AIMessage, FunctionMessage
from langchain_core.tools import BaseTool
@dataclass
@@ -76,7 +65,6 @@ def _create_function_message(
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
@@ -177,6 +165,7 @@ class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
into it.
Args:
callbacks:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
**kwargs: Including user's input string

View File

@@ -2,18 +2,16 @@
from typing import Any, Dict
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from langchain.schema import LLMResult
from langchain_community.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, \
get_openai_token_cost_for_model
# from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
from langchain_core.outputs import LLMResult
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
need_record: bool = True
response_type: str = "string"
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
@@ -22,7 +20,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
if "token_usage" not in response.llm_output:
return None
if "function_call" in response.generations[0][0].message.additional_kwargs:
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema", "tutorial_tool"]:
self.need_record = False
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
self.response_type = "image"
@@ -39,6 +37,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
return None
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
"""Write token usage to redis."""

View File

@@ -44,12 +44,17 @@ class CustomDatabase(SQLDatabase):
final_str = "\n\n".join(tables)
return final_str
def run(self, command: str, fetch: str = "all") -> str:
def run(self, command: str, fetch: str = "all", **kwargs) -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
Args:
command:
fetch:
**kwargs:
"""
with self._engine.begin() as connection:
if self._schema is not None:

View File

@@ -1,15 +1,15 @@
import json
import logging
from langchain.agents import Tool
from langchain.callbacks import FileCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.utilities import SerpAPIWrapper
from langchain_community.utilities import SerpAPIWrapper
from langchain_core.callbacks import FileCallbackHandler
from langchain_core.messages import SystemMessage, AIMessage
from langchain_core.prompts import MessagesPlaceholder, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_core.tools import Tool
from langchain_community.chat_models import ChatTongyi
from loguru import logger
from app.core.config import *
from app.core.config import settings
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
@@ -30,10 +30,10 @@ log_handler = FileCallbackHandler(logfile)
# # callbacks=[OpenAICallbackHandler()]
# )
llm = ChatTongyi(api_key=QWEN_API_KEY)
llm = ChatTongyi(api_key=settings.QWEN_API_KEY)
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.DB_USERNAME}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
@@ -43,11 +43,11 @@ tools = [
description="Can be used to perform Internet searches",
func=search.run
),
QuerySQLDataBaseTool(db=db, return_direct=False),
QuerySQLDataBaseTool(db=db),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
QuerySQLCheckerTool(db=db, llm=ChatTongyi(temperature=0, api_key=QWEN_API_KEY)),
QuerySQLCheckerTool(db=db, llm=ChatTongyi(api_key=settings.QWEN_API_KEY)),
# Tool(
# name="tutorial_tool",
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
@@ -133,5 +133,5 @@ def chat(post_data):
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs["response_type"]
}
logging.info(json.dumps(api_response))
logging.info(json.dumps(api_response, indent=4))
return api_response

View File

@@ -3,13 +3,12 @@ from typing import Any, Dict, List, Tuple
import json
import redis
from langchain_classic.memory.chat_memory import BaseChatMemory
from langchain_classic.memory.utils import get_prompt_input_key
from langchain_core.messages import messages_from_dict, get_buffer_string, BaseMessage, HumanMessage, AIMessage, message_to_dict
from redis import Redis
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
from langchain.schema.messages import _message_to_dict, messages_from_dict
from langchain.memory.utils import get_prompt_input_key
from app.core.config import *
from app.core.config import settings
class UserConversationBufferWindowMemory(BaseChatMemory):
@@ -24,8 +23,8 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
@classmethod
def from_redis(
cls,
host: str = REDIS_HOST,
port: int = REDIS_PORT,
host: str = settings.REDIS_HOST,
port: int = settings.REDIS_PORT,
db: int = 3,
**kwargs
):
@@ -79,7 +78,7 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
return inputs[prompt_input_key], outputs[output_key]
def add_message(self, key: str, message: BaseMessage) -> None:
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
self.redis_client.lpush(key, json.dumps(message_to_dict(message)))
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
"""Save context from this conversation to buffer."""

View File

@@ -5,10 +5,10 @@ from dashscope import Generation
from retry import retry
from urllib3.exceptions import NewConnectionError
from app.core.config import *
from app.core.config import settings
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
from app.service.chat_robot.script.prompt import TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
from app.service.search_image_with_text.service import query
@@ -149,7 +149,7 @@ tools = [
}
]
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
@@ -159,7 +159,7 @@ qwen = QWenCallbackHandler()
def search_from_internet(message):
response = Generation.call(
model='qwen-turbo',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=message,
prompt='The output must be in English.Keep the final result under 200 words.'
# tools=tools,
@@ -190,7 +190,7 @@ def get_image_from_vector_db(gender, content):
def get_response(messages):
response = Generation.call(
model='qwen-max',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=messages,
tools=tools,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
@@ -203,7 +203,7 @@ def get_response(messages):
def get_assistant_response(messages):
response = Generation.call(
model='qwen-max',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=messages,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
@@ -212,8 +212,10 @@ def get_assistant_response(messages):
return response
global tool_info
def call_with_messages(message):
global tool_info
user_input = message
print('\n')
@@ -241,7 +243,7 @@ def call_with_messages(message):
response_type = "chat"
while flag and count <= 3:
first_response = get_response(messages)
first_response = get_response
assistant_output = first_response.output.choices[0].message
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
@@ -260,7 +262,7 @@ def call_with_messages(message):
]
tool_info['content'] = search_from_internet(message)
flag = False
result_content = tool_info['content'].output.text
result_content = tool_info['content']
# 如果模型选择的工具是get_database_table
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}

View File

@@ -2,21 +2,15 @@
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import _QuerySQLCheckerToolInput
# from langchain.sql_database import SQLDatabase
from langchain_community.utilities import SQLDatabase
from langchain.tools.base import BaseTool
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Extra, Field, root_validator
class BaseSQLDatabaseTool(BaseModel):
@@ -62,7 +56,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'"
)
)
def _run(
self,
@@ -95,9 +89,9 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
"Example Input: 'female_outwear, male_top'"
)
)
def _run(
self,
@@ -183,11 +177,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def initialize_llm_chain(self, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
# from langchain.chains.llm import LLMChain
llm = values.get("llm") # type: ignore[arg-type]
llm = values.get("llm") # type: ignore[arg-type]
prompt = PromptTemplate(
template=QUERY_CHECKER, input_variables=["dialect", "query"]
)

View File

@@ -1,6 +1,6 @@
from typing import Any
from langchain.tools.base import BaseTool
from langchain_core.tools import BaseTool
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN

View File

@@ -9,14 +9,14 @@ from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.core.config import settings
from app.schemas.clothing_seg import ClothingSegModel
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class ClothingSeg:
@@ -64,9 +64,9 @@ class ClothingSeg:
if image_type == "sketch":
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
seg_mask = get_seg_result(1, image[:, :, :3])
seg_mask = get_seg_result(image[:, :, :3])
else:
seg_mask = get_seg_result(1, image[:, :, :3])
seg_mask = get_seg_result(image[:, :, :3])
temp = seg_mask != 0.0
mask = (255 * (temp + 0).astype(np.uint8))
x_min, y_min, x_max, y_max = get_bounding_box(mask)

View File

@@ -12,7 +12,8 @@ from PIL import Image
from minio import Minio, S3Error
from moviepy.video.io.VideoFileClip import VideoFileClip
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
from app.core.config import PS_RABBITMQ_QUEUES
from app.core.config import settings
from app.schemas.comfyui_i2v import ComfyuiFLF2VModel
from app.service.generate_image.utils.mq import publish_status
@@ -305,13 +306,14 @@ workflow_json = {
class ComfyUIServerFLF2V:
def __init__(self, request_data):
self.pose_transform_data = None
self.start_image_url = request_data.start_image_url
self.end_image_url = request_data.end_image_url
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def get_result(self):
workflow_json['6']['inputs']['text'] = self.prompt
@@ -341,7 +343,7 @@ class ComfyUIServerFLF2V:
# 1. 提交任务
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
if not prompt_response:
return
return None
prompt_id = prompt_response.get("prompt_id")
logger.info(f" 任务已提交Prompt ID: {prompt_id}")
@@ -361,6 +363,7 @@ class ComfyUIServerFLF2V:
}
logger.info(file_list)
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
return None
def download_from_minio_in_memory(self, image_url):
bucket = image_url.split('/')[0]
@@ -391,8 +394,9 @@ class ComfyUIServerFLF2V:
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
return None, None
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
@staticmethod
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
data = {
"overwrite": "true",
@@ -430,7 +434,7 @@ class ComfyUIServerFLF2V:
# 1. 从 ComfyUI 获取视频二进制数据
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
if not mp4_bytes:
return
return None
# 2. 准备进行视频处理
# moviepy 不支持直接使用 bytes需要将 bytes 写入一个 BytesIO 或临时文件
@@ -518,7 +522,7 @@ class ComfyUIServerFLF2V:
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
# 推送消息
if not DEBUG:
if not settings.DEBUG:
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
@@ -530,13 +534,14 @@ class ComfyUIServerFLF2V:
return None
# --- 辅助函数:提交任务到队列 ---
def queue_prompt(self, prompt, client_id):
@staticmethod
def queue_prompt(prompt, client_id):
"""向 ComfyUI 提交工作流提示。"""
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
data = json.dumps(p).encode('utf-8')
# 提交任务到 /prompt 端点
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
# print(f"-------------{response.text}")
# print(f"------------{client_id}")
@@ -547,9 +552,10 @@ class ComfyUIServerFLF2V:
logger.warning(response.text)
return None
def poll_history(self, prompt_id, interval_seconds=5):
@staticmethod
def poll_history(prompt_id, interval_seconds=5):
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
@@ -574,7 +580,8 @@ class ComfyUIServerFLF2V:
logger.info(f"⚠️ 轮询时发生错误: {e}")
pass
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
@staticmethod
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
"""
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
@@ -586,7 +593,7 @@ class ComfyUIServerFLF2V:
返回:
- 视频文件的二进制内容 (bytes) 或 None。
"""
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
params = {
"filename": filename,
"subfolder": subfolder,

View File

@@ -12,8 +12,8 @@ from PIL import Image
from minio import Minio, S3Error
from moviepy.video.io.VideoFileClip import VideoFileClip
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
from app.schemas.comfyui_i2v import ComfyuiPose2VModel, ComfyuiI2VModel
from app.core.config import PS_RABBITMQ_QUEUES, settings
from app.schemas.comfyui_i2v import ComfyuiI2VModel
from app.service.generate_image.utils.mq import publish_status
logger = logging.getLogger()
@@ -293,13 +293,14 @@ workflow_json = {
class ComfyUIServerI2V:
def __init__(self, request_data):
self.pose_transform_data = None
self.image_url = request_data.image_url
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def get_result(self):
workflow_json['93']['inputs']['text'] = self.prompt
@@ -319,7 +320,7 @@ class ComfyUIServerI2V:
# 1. 提交任务
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
if not prompt_response:
return
return None
prompt_id = prompt_response.get("prompt_id")
logger.info(f" 任务已提交Prompt ID: {prompt_id}")
outputs = self.poll_history(prompt_id)
@@ -339,6 +340,7 @@ class ComfyUIServerI2V:
}
logger.info(file_list)
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
return None
def download_from_minio_in_memory(self, image_url):
bucket = image_url.split('/')[0]
@@ -369,8 +371,9 @@ class ComfyUIServerI2V:
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
return None, None
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
@staticmethod
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
data = {
"overwrite": "true",
@@ -408,7 +411,7 @@ class ComfyUIServerI2V:
# 1. 从 ComfyUI 获取视频二进制数据
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
if not mp4_bytes:
return
return None
# 2. 准备进行视频处理
# moviepy 不支持直接使用 bytes需要将 bytes 写入一个 BytesIO 或临时文件
@@ -496,7 +499,7 @@ class ComfyUIServerI2V:
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
# 推送消息
if not DEBUG:
if not settings.DEBUG:
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
@@ -508,13 +511,14 @@ class ComfyUIServerI2V:
return None
# --- 辅助函数:提交任务到队列 ---
def queue_prompt(self, prompt, client_id):
@staticmethod
def queue_prompt(prompt, client_id):
"""向 ComfyUI 提交工作流提示。"""
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
data = json.dumps(p).encode('utf-8')
# 提交任务到 /prompt 端点
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
# print(f"-------------{response.text}")
# print(f"------------{client_id}")
@@ -525,9 +529,10 @@ class ComfyUIServerI2V:
logger.warning(response.text)
return None
def poll_history(self, prompt_id, interval_seconds=5):
@staticmethod
def poll_history(prompt_id, interval_seconds=5):
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
@@ -552,7 +557,8 @@ class ComfyUIServerI2V:
logger.info(f"⚠️ 轮询时发生错误: {e}")
pass
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
@staticmethod
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
"""
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
@@ -564,7 +570,7 @@ class ComfyUIServerI2V:
返回:
- 视频文件的二进制内容 (bytes) 或 None。
"""
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
params = {
"filename": filename,
"subfolder": subfolder,

View File

@@ -13,7 +13,7 @@ from PIL import Image
from minio import Minio, S3Error
from moviepy.video.io.VideoFileClip import VideoFileClip
from app.core.config import REDIS_HOST, REDIS_PORT, REDIS_DB, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
from app.core.config import settings
from app.schemas.comfyui_i2v import ComfyuiPose2VModel
from app.service.generate_image.utils.mq import publish_status
@@ -371,11 +371,11 @@ class ComfyUIServerPose2V:
self.pose_num = request_data.pose_id
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True)
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
self.redis_client.expire(self.tasks_id, 600)
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def get_result(self):
workflow_json['174']['inputs']['file'] = video_map[self.pose_num]
@@ -389,7 +389,7 @@ class ComfyUIServerPose2V:
# 1. 提交任务
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
if not prompt_response:
return
return None
prompt_id = prompt_response.get("prompt_id")
logger.info(f" 任务已提交Prompt ID: {prompt_id}")
@@ -411,6 +411,7 @@ class ComfyUIServerPose2V:
}
logger.info(file_list)
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
return None
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
@@ -492,8 +493,9 @@ class ComfyUIServerPose2V:
except Exception as e:
logger.error(f"❌ 发生未知错误: {e}")
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
@staticmethod
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
data = {
"overwrite": "true",
@@ -531,7 +533,7 @@ class ComfyUIServerPose2V:
# 1. 从 ComfyUI 获取视频二进制数据
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
if not mp4_bytes:
return
return None
# 2. 准备进行视频处理
# moviepy 不支持直接使用 bytes需要将 bytes 写入一个 BytesIO 或临时文件
@@ -619,10 +621,10 @@ class ComfyUIServerPose2V:
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
# 推送消息
if not DEBUG:
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
if not settings.DEBUG:
publish_status(json.dumps(self.pose_transform_data), settings.COMFYUI_SERVER_ADDRESS)
logger.info(
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
f" [x] Sent to {settings.COMFYUI_SERVER_ADDRESS} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
return "\n🎉 所有任务完成!"
@@ -631,13 +633,15 @@ class ComfyUIServerPose2V:
return None
# --- 辅助函数:提交任务到队列 ---
def queue_prompt(self, prompt, client_id):
@staticmethod
def queue_prompt(prompt, client_id):
"""向 ComfyUI 提交工作流提示。"""
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
data = json.dumps(p).encode('utf-8')
# 提交任务到 /prompt 端点
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
# noinspection HttpUrlsUsage
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
# print(f"-------------{response.text}")
# print(f"------------{client_id}")
@@ -648,9 +652,10 @@ class ComfyUIServerPose2V:
logger.warning(response.text)
return None
def poll_history(self, prompt_id, interval_seconds=5):
@staticmethod
def poll_history(prompt_id, interval_seconds=5):
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
@@ -675,7 +680,8 @@ class ComfyUIServerPose2V:
logger.info(f"⚠️ 轮询时发生错误: {e}")
pass
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
@staticmethod
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
"""
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
@@ -687,7 +693,7 @@ class ComfyUIServerPose2V:
返回:
- 视频文件的二进制内容 (bytes) 或 None。
"""
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
params = {
"filename": filename,
"subfolder": subfolder,

View File

@@ -1,116 +0,0 @@
import logging
import numpy as np
import cv2
from matplotlib import pyplot as plt
from PIL import Image
def show(img, win_name="temp"):
cv2.imshow(win_name, img)
cv2.waitKey(0)
def crop(img):
mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2)
img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340]
return img_roi
class Layer(object):
def __init__(self):
self._layer = []
@property
def layer(self):
return self._layer
def insert(self, layer_instance):
if layer_instance['name'] == 'body':
self._body = layer_instance
self._layer.append(layer_instance)
def sort(self, priority):
self._layer.sort(key=lambda x: priority[x['name']])
# def merge(self, cfg):
# """
# opencv shape order (height, width, channel)
# image coordinate system:
# |------------->x (width)
# |
# |
# |
# y (height)
# Returns:
#
#
# """
# base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0))
# for layer in self._layer:
# y, x = layer['position']
# base_image.paste(layer['image'], (x, y), layer['image'])
# # base_image.show()
#
# for x in self._layer:
# if np.all(x['mask'] == 0):
# continue
# # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask)
# roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x)
# temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2)
# tmp1 = (roi * (temp_bg / 255)).astype(np.uint8)
# temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2)
# tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8)
#
# roi[:] = cv2.add(tmp1, tmp2)
# # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA),
# # win_name=x.get('name'))
# # crop image and get the central part
# if cfg.get('basic')['self_template'] == False:
# dst_roi = crop(dst)
# else:
# dst_roi = dst
# return dst_roi, signal
#
# @staticmethod
# def get_roi(dst, image):
# signal = False
# dst_y, dst_x = dst.shape[:2]
# roi_height, roi_width = image['mask'].shape
# roi_y0, roi_x0 = image['position']
#
# if roi_y0 < 0:
# roi_yin = 0
# mask_yin = -roi_y0
# signal = True
# else:
# roi_yin = roi_y0
# mask_yin = 0
# if roi_y0 + roi_height > dst_y:
# roi_yout = dst_y
# mask_yout = dst_y - roi_y0
# signal = True
# else:
# roi_yout = roi_height + roi_y0
# mask_yout = roi_height
# # x part
# if roi_x0 < 0:
# roi_xin = 0
# mask_xin = -roi_x0
# signal = True
# else:
# roi_xin = roi_x0
# mask_xin = 0
# if roi_x0 + roi_width > dst_x:
# roi_xout = dst_x
# mask_xout = dst_x - roi_x0
# signal = True
# else:
# roi_xout = roi_width + roi_x0
# mask_xout = roi_width
#
# roi = dst[roi_yin: roi_yout, roi_xin: roi_xout]
# roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout]
# roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout]
# return roi, roi_mask, roi_image, signal

View File

@@ -1,45 +0,0 @@
class Priority(object):
"""Item layer priority levels.
"""
def __init__(self, item_list):
self._priority = dict(
earring_front=99,
bag_front=98,
hairstyle_front=97,
outwear_front=20,
bottoms_front=19,
dress_front=18,
blouse_front=17,
skirt_front=16,
trousers_front=15,
tops_front=14,
shoes_right=1,
shoes_left=1,
body=0,
tops_back=-14,
trousers_back=-15,
skirt_back=-16,
blouse_back=-17,
dress_back=-18,
bottoms_back=-19,
outwear_back=-20,
hairstyle_back=-97,
bag_back=-98,
earring_back=-99,
)
self.clothing_start_num = 10
if not isinstance(item_list, list):
raise ValueError('item_list must be a list!')
for cate in item_list:
cate = cate.lower()
if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
raise ValueError(f'Item type error. Cannot recognize {cate}')
for i, cate in enumerate(item_list):
cate = cate.lower()
self._priority[f'{cate}_front'] = self.clothing_start_num - i
self._priority[f'{cate}_back'] = -(self.clothing_start_num - i)
@property
def priority(self):
return self._priority

View File

@@ -1,16 +0,0 @@
from .builder import ITEMS, build_item
from .clothing import Clothing # 4.0 sec
from .body import Body
from .top import Top, Blouse, Outwear, Dress
from .bottom import Bottom, Trousers, Skirt
from .shoes import Shoes
from .bag import Bag
from .others import Hairstyle, Earring
__all__ = [
'ITEMS', 'build_item',
'Clothing', 'Body',
'Top', 'Blouse', 'Outwear', 'Dress',
'Bottom', 'Trousers', 'Skirt',
'Shoes', 'Bag', 'Hairstyle', 'Earring'
]

View File

@@ -1,45 +0,0 @@
import random
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Bag(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Bag, self).__init__(**kwargs)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
"""
align left
Args:
keypoint_type: string, "hand_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (y', x')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
location = random.choice(seq=['left', 'right'])
if location == 'left':
side_indicator = f'{keypoint_type}_left'
else:
side_indicator = f'{keypoint_type}_right'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale),
body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale))
return start_point

View File

@@ -1,36 +0,0 @@
import cv2
from .builder import ITEMS
from .pipelines import Compose
@ITEMS.register_module()
class Body(object):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']),
# dict(type='ImageShow', key=['body_image', "body_mask"])
]
self.pipeline = Compose(pipeline)
self.result = dict()
def process(self):
self.pipeline(self.result)
pass
def organize(self, layer):
body_layer = dict(priority=0,
name=type(self).__name__.lower(),
image=self.result['body_image'],
image_url=self.result['image_url'],
mask_image=None,
mask_url=None,
sacle=1,
# mask=self.result['body_mask'],
position=(0, 0))
layer.insert(body_layer)
@staticmethod
def show(img):
cv2.imshow('', img)
cv2.waitKey(0)

View File

@@ -1,39 +0,0 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Bottom(Clothing):
def __init__(self, pipeline, **kwargs):
if pipeline is None:
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
# dict(type='Segmentation'),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']),
]
kwargs.update(pipeline=pipeline)
super(Bottom, self).__init__(**kwargs)
@ITEMS.register_module()
class Trousers(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Trousers, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Skirt(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Skirt, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Bottoms(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Bottoms, self).__init__(pipeline, **kwargs)

View File

@@ -1,9 +0,0 @@
from mmcv.utils import Registry, build_from_cfg
ITEMS = Registry('item')
PIPELINES = Registry('pipeline')
def build_item(cfg, default_args=None):
item = build_from_cfg(cfg, ITEMS, default_args)
return item

View File

@@ -1,100 +0,0 @@
import cv2
from app.core.config import PRIORITY_DICT
from .builder import ITEMS
from .pipelines import Compose
@ITEMS.register_module()
class Clothing(object):
def __init__(self, pipeline, **kwargs):
self.pipeline = Compose(pipeline)
self.result = dict(name=type(self).__name__.lower(), **kwargs)
def process(self):
self.pipeline(self.result)
def apply_scale(self, img):
scale = self.result['scale']
height, width = img.shape[0: 2]
if len(img.shape) > 2:
height, width = img.shape[0: 2]
scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
return scaled_img
def organize(self, layer):
start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"])
front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None),
name=f'{type(self).__name__.lower()}_front',
image=self.result["front_image"],
# mask_image=self.result['front_mask_image'],
image_url=self.result['front_image_url'],
mask_url=self.result['mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url'],
pattern_image=self.result['pattern_image']
)
layer.insert(front_layer)
back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None),
name=f'{type(self).__name__.lower()}_back',
image=self.result["back_image"],
# mask_image=self.result['back_mask_image'],
image_url=self.result['back_image_url'],
mask_url=self.result['mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url'],
)
layer.insert(back_layer)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
"""
Align left
Args:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale + offset
y' = x_body - x1 * scale + offset
"""
side_indicator = f'{keypoint_type}_left'
# if keypoint_type == "ear_point":
# start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
# body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
# else:
# start_point = (
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
# )
# milvus_DB_keypoint_cache:
start_point = (
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
)
# start_point = (
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
# )
return start_point

View File

@@ -1,59 +0,0 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Hairstyle(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Hairstyle, self).__init__(**kwargs)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
"""
align up
Args:
keypoint_type: string, "head_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
side_indicator = f'{keypoint_type}_up'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
# logging.info(clothes_point[side_indicator])
start_point = (
int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)),
int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale))
)
return start_point
@ITEMS.register_module()
class Earring(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Earring, self).__init__(**kwargs)

View File

@@ -1,19 +0,0 @@
from .compose import Compose
from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow
from .keypoints import KeypointDetection
from .segmentation import Segmentation
from .painting import Painting, PrintPainting
from .scale import Scaling
from .contour_detection import ContourDetection
from .split import Split
__all__ = [
'Compose',
'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow',
'KeypointDetection',
'Segmentation',
'Painting', 'PrintPainting',
'Scaling',
'ContourDetection',
'split',
]

View File

@@ -1,36 +0,0 @@
import collections
from mmcv.utils import build_from_cfg
from ..builder import PIPELINES
@PIPELINES.register_module()
class Compose(object):
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, data):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data

View File

@@ -1,59 +0,0 @@
import cv2
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class ContourDetection(object):
def __init__(self):
# logging.info("ContourDetection run ")
pass
# @ RunTime
def __call__(self, result):
# shoe diff
if result['name'] == 'shoes':
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
for i in range(2):
Max_contour = Contour[i]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
else:
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
# TODO 修复部分图片出现透明的情况 下版本上线
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
# Mask = cv2.bitwise_not(Mask)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
result['front_mask'] = result['mask']
result['back_mask'] = result['mask']
return result
@staticmethod
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour

View File

@@ -1,140 +0,0 @@
import logging
import time
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.utils.decorator import RunTime, ClassCallRunTime
from ..builder import PIPELINES
from ...utils.design_ensemble import get_keypoint_result
@PIPELINES.register_module()
class KeypointDetection(object):
"""
path here: abstract path
"""
# def __init__(self):
# self.client = MilvusClient(
# uri="http://10.1.1.240:19530",
# token="root:Milvus",
# db_name=MILVUS_ALIAS
# )
# def __del__(self):
# start_time = time.time()
# self.client.close()
# print(f"client close time : {time.time() - start_time}")
# @ClassCallRunTime
def __call__(self, result):
# logging.info("KeypointDetection run ")
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
result['clothes_keypoint'] = keypoint_cache
return result
@staticmethod
def infer_keypoint_result(result):
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
start_time = time.time()
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
# logging.info(f"infer keypoint time : {time.time() - start_time}")
return keypoint_infer_result, site
@staticmethod
# @ RunTime
def save_keypoint_cache(keypoint_id, cache, site):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([cache.flatten(), zeros])
# 取消向量保存 直接拿结果
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": site,
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# start_time = time.time()
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
if site == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False

View File

@@ -1,134 +0,0 @@
import cv2
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile(object):
def __init__(self, path, color=None, print_dict=None):
self.path = path
self.color = color
self.print_dict = print_dict
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# @ClassCallRunTime
def __call__(self, result):
result['image'], result['pre_mask'] = self.read_image(self.path)
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['keypoint'] = self.get_keypoint(result['name'])
result['path'] = self.path
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
result['color'] = self.color if self.color is not None else None
result['print_dict'] = self.print_dict
return result
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
keypoint = 'waistband'
elif name == 'bag':
keypoint = 'hand_point'
elif name == 'shoes':
keypoint = 'toe'
elif name == 'hairstyle':
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")
return keypoint
@staticmethod
def read_image(image_path):
image_mask = None
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
if image.shape[:2] <= (50, 50):
# 计算新尺寸
new_size = (image.shape[1] * 2, image.shape[0] * 2)
# 调整大小
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
return image, image_mask
@PIPELINES.register_module()
class LoadBodyImageFromFile(object):
def __init__(self, body_path):
self.body_path = body_path
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
# @ RunTime
def __call__(self, result):
result["image_url"] = result['body_path'] = self.body_path
result["name"] = "mannequin"
# if not result['image_url'].lower().endswith(".png"):
# bucket = self.body_path.split("/", 1)[0]
# object_name = self.body_path.split("/", 1)[1]
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
# image = self.minioClient.get_object(bucket, object_name)
# image = Image.open(io.BytesIO(image.data))
# image = image.convert("RGBA")
# data = image.getdata()
# #
# new_data = []
# for item in data:
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
# new_data.append((255, 255, 255, 0))
# else:
# new_data.append(item)
# image.putdata(new_data)
# image_data = io.BytesIO()
# image.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# self.body_path = image_path
# result["image_url"] = result['body_path'] = self.body_path
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
# put_image_time = time.time()
# result['body_image'] = Image.open(io.BytesIO(response.read()))
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
return result
@PIPELINES.register_module()
class ImageShow(object):
def __init__(self, key):
self.key = key
# @ RunTime
def __call__(self, result):
import matplotlib.pyplot as plt
if isinstance(self.key, list):
for key in self.key:
plt.imshow(result[key])
plt.title(key)
plt.show()
elif isinstance(self.key, str):
img = self._resize_img(result[self.key])
cv2.imshow(self.key, img)
cv2.waitKey(0)
else:
raise TypeError(f'key should be string but got type {type(self.key)}.')
return result
@staticmethod
def _resize_img(img):
shape = img.shape
if shape[0] > 400 or shape[1] > 400:
ratio = min(400 / shape[0], 400 / shape[1])
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
return img

View File

@@ -1,605 +0,0 @@
import logging
import random
import cv2
import numpy as np
from PIL import Image
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
logger = logging.getLogger()
@PIPELINES.register_module()
class Painting(object):
def __init__(self, painting_flag=True):
self.painting_flag = painting_flag
# @ClassCallRunTime
def __call__(self, result):
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
dim_image_h, dim_image_w = result['image'].shape[0:2]
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
result['alpha'] = 100 / 255.0
else:
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
get_image_fir = result['image'] * (closed_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
return result
@staticmethod
def get_gradient(bucket_name, object_name):
# image_data = minio_client.get_object(bucket_name, object_name)
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
# 从数据流中读取图像
# image_bytes = image_data.read()
# 将图像数据转换为numpy数组
# image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
# 使用OpenCV解码图像数组
# image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
return image
@staticmethod
def crop_image(image, image_size_h, image_size_w):
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
return pattern
@PIPELINES.register_module()
class PrintPainting(object):
def __init__(self, print_flag=True):
self.print_flag = print_flag
# @ClassCallRunTime
def __call__(self, result):
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
result['single_image'] = None
result['print_image'] = None
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
if single_print['print_path_list']:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if element_print['element_path_list']:
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])):
image, image_mode = self.read_image(element_print['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask_)
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
print_background = img1_bg + img2_fg
return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
if print_trigger:
print_ = self.get_print(print_dict)
painting_dict['Trigger'] = not is_single
painting_dict['location'] = print_['location']
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not is_single:
self.random_seed = random.randint(0, 1000)
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
tile = None
if not trigger:
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
else:
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
if len(pattern.shape) == 2:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
if len(pattern.shape) == 3:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
return tile
def get_mask_inv(self, print_):
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
mask_inv = cv2.inRange(print_tile, lower, upper)
return mask_inv
else:
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
return mask_inv
@staticmethod
def printpaint(result, painting_dict, print_=False):
if print_ and painting_dict['Trigger']:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else:
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:
assert f'there must be parameter of location if choose IfSingle'
for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
print_image = cv2.add(img_bg, img_fg)
return print_image
@staticmethod
def get_print(print_dict):
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
random.seed(self.random_seed)
# logging.info(f'overall print location : {location}')
# x_offset = random.randint(0, image.shape[0] - image_size_h)
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
x_offset = print_w - int(location[0][1] % print_w)
y_offset = print_w - int(location[0][0] % print_h)
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_low_high_lab(Lab_value, L=False):
if L:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
else:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
return high, low
@staticmethod
def img_rotate(image, angel, scale):
"""顺时针旋转图像任意角度
Args:
image (np.array): [原始图像]
angel (float): [逆时针旋转的角度]
Returns:
[array]: [旋转后的图像]
"""
h, w = image.shape[:2]
center = (w // 2, h // 2)
# if type(angel) is not int:
# angel = 0
M = cv2.getRotationMatrix2D(center, -angel, scale)
# 调整旋转后的图像长宽
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
M[0, 2] += (rotated_w - w) // 2
M[1, 2] += (rotated_h - h) // 2
# 旋转图像
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
# return rotated_img, (0, 0)
@staticmethod
def rotate_crop_image(img, angle, crop):
"""
angle: 旋转的角度
crop: 是否需要进行裁剪,布尔向量
"""
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
w, h = img.shape[:2]
# 旋转角度的周期是360°
angle %= 360
# 计算仿射变换矩阵
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
# 得到旋转后的图像
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
# 如果需要去除黑边
if crop:
# 裁剪角度的等效周期是180°
angle_crop = angle % 180
if angle > 90:
angle_crop = 180 - angle_crop
# 转化角度为弧度
theta = angle_crop * np.pi / 180
# 计算高宽比
hw_ratio = float(h) / float(w)
# 计算裁剪边长系数的分子项
tan_theta = np.tan(theta)
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
# 计算分母中和高宽比相关的项
r = hw_ratio if h > w else 1 / hw_ratio
# 计算分母项
denominator = r * tan_theta + 1
# 最终的边长系数
crop_mult = numerator / denominator
# 得到裁剪区域
w_crop = int(crop_mult * w)
h_crop = int(crop_mult * h)
x0 = int((w - w_crop) / 2)
y0 = int((h - h_crop) / 2)
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
return img_rotated
@staticmethod
def read_image(image_url):
image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
if image.shape[2] == 4:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image = Image.fromarray(image_rgb)
image_mode = "RGBA"
else:
image_mode = "RGB"
return image, image_mode
@staticmethod
def resize_and_crop(img, target_width, target_height):
# 获取原始图像的尺寸
original_height, original_width = img.shape[:2]
# 计算目标尺寸的宽高比
target_ratio = target_width / target_height
# 计算原始图像的宽高比
original_ratio = original_width / original_height
# 调整尺寸
if original_ratio > target_ratio:
# 原始图像更宽按高度resize然后裁剪宽度
new_height = target_height
new_width = int(original_width * (target_height / original_height))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪宽度
start_x = (new_width - target_width) // 2
cropped_img = resized_img[:, start_x:start_x + target_width]
else:
# 原始图像更高按宽度resize然后裁剪高度
new_width = target_width
new_height = int(original_height * (target_width / original_width))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪高度
start_y = (new_height - target_height) // 2
cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img

View File

@@ -1,57 +0,0 @@
import math
import cv2
from app.service.utils.decorator import ClassCallRunTime
from ..builder import PIPELINES
@PIPELINES.register_module()
class Scaling(object):
def __init__(self):
pass
# @ClassCallRunTime
def __call__(self, result):
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
# milvus_db_keypoint_cache
distance_clo = math.sqrt(
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
+
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2)
distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
# distance_clo = math.sqrt(
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2
# +
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2)
#
# distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'toe':
distance_bdy = math.sqrt(
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
+
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
)
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
Edge = cv2.Canny(Blur, 10, 200)
Edge = cv2.dilate(Edge, None)
Edge = cv2.erode(Edge, None)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
Max_contour = Contours[0]
x, y, w, h = cv2.boundingRect(Max_contour)
width = w
distance_clo = width
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'hand_point':
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
return result

View File

@@ -1,71 +0,0 @@
import logging
import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
from ...utils.design_ensemble import get_seg_result
logger = logging.getLogger()
@PIPELINES.register_module()
class Segmentation(object):
@ClassCallRunTime
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
red_mask = r > g
green_mask = g > r
# 创建红色和绿色掩码
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
result['seg_result'] = seg_result
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
# 处理前片后片
temp_front = seg_result == 1.0
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
# logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")
return False, None

View File

@@ -1,79 +0,0 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
from ..builder import PIPELINES
from ...utils.conversion_image import rgb_to_rgba
from ...utils.upload_image import upload_png_mask
@PIPELINES.register_module()
class Split(object):
"""
Split image into front and back layer according to the segmentation result
"""
# @ClassCallRunTime
# KNet
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
front_mask = result['front_mask']
back_mask = result['back_mask']
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
rgba_image = cv2.resize(rgba_image, new_size)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
result['front_image'], result["front_image_url"], _ = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
# result["back_mask_url"] = None
# result['back_mask_image'] = None
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -1,121 +0,0 @@
import cv2
import numpy as np
from PIL import Image
from .builder import ITEMS
from .clothing import Clothing
from ..utils.conversion_image import rgb_to_rgba
from ..utils.upload_image import upload_png_mask
from ...utils.generate_uuid import generate_uuid
@ITEMS.register_module()
class Shoes(Clothing):
# TODO location of shoes has little mismatch
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Shoes, self).__init__(**kwargs)
def organize(self, layer):
left_shoe_mask, right_shoe_mask = self.cut()
left_layer = dict(name=f'{type(self).__name__.lower()}_left',
image=self.result['shoes_left'],
image_url=self.result['left_image_url'],
mask_url=self.result['left_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=self.calculate_start_point(self.result['keypoint'],
self.result['scale'],
self.result['clothes_keypoint'],
self.result['body_point'],
'left'))
layer.insert(left_layer)
right_layer = dict(name=f'{type(self).__name__.lower()}_right',
image=self.result['shoes_right'],
image_url=self.result['right_image_url'],
mask_url=self.result['right_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=self.calculate_start_point(self.result['keypoint'],
self.result['scale'],
self.result['clothes_keypoint'],
self.result['body_point'],
'right'))
layer.insert(right_layer)
def cut(self):
"""
Cut shoes mask into two pieces
Returns:
"""
contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = sorted(contour, key=cv2.contourArea, reverse=True)
bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]]
(contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False))
epsilon_left = 0.001 * cv2.arcLength(contours[0], True)
approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True)
mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8)
cv2.drawContours(mask_left, [approx_left], -1, 255, -1)
item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0)
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left)
result_image = np.zeros_like(rgba_image)
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
result_left_image_pil = Image.fromarray(result_image, 'RGBA')
result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS)
self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}")
epsilon_right = 0.001 * cv2.arcLength(contours[1], True)
approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True)
mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8)
cv2.drawContours(mask_right, [approx_right], -1, 255, -1)
item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0)
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right)
result_image = np.zeros_like(rgba_image)
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
result_right_image_pil = Image.fromarray(result_image, 'RGBA')
result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS)
self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}")
return item_mask_left, item_mask_right
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location):
"""
left shoes align left
right shoes align right
Args:
keypoint_type: string, "toe"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
location: string, indicates whether the start point belongs to right or left shoe
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
if location not in ['left', 'right']:
raise KeyError(f'location value must be left or right but got {location}')
side_indicator = f'{keypoint_type}_{location}'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
return start_point

View File

@@ -1,46 +0,0 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Top(Clothing):
def __init__(self, pipeline, **kwargs):
if pipeline is None:
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
# dict(type='ContourDetection'),
dict(type='Segmentation'),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
dict(type='Scaling'),
dict(type='Split'),
]
kwargs.update(pipeline=pipeline)
super(Top, self).__init__(**kwargs)
@ITEMS.register_module()
class Blouse(Top):
def __init__(self, pipeline=None, **kwargs):
super(Blouse, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Outwear(Top):
def __init__(self, pipeline=None, **kwargs):
super(Outwear, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Dress(Top):
def __init__(self, pipeline=None, **kwargs):
super(Dress, self).__init__(pipeline, **kwargs)
# Men's clothing
@ITEMS.register_module()
class Tops(Top):
def __init__(self, pipeline=None, **kwargs):
super(Tops, self).__init__(pipeline, **kwargs)

View File

@@ -1,197 +0,0 @@
import concurrent.futures
import io
import cv2
from app.core.config import PRIORITY_DICT
from app.service.design.core.layer import Layer
from app.service.design.items import build_item
from app.service.design.utils.redis_utils import Redis
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
from app.service.utils.decorator import RunTime
from app.service.utils.oss_client import oss_upload_image
def process_item(item, layers):
# logging.info("process running.........")
item.process()
item.organize(layers)
if item.result['name'] == "mannequin":
return item.result['body_image'].size
def update_progress(process_id, total):
r = Redis()
progress = r.read(key=process_id)
if progress and total != 1:
if int(progress) <= 100:
r.write(key=process_id, value=int(progress) + int(100 / total))
else:
r.write(key=process_id, value=99)
return progress
elif total == 1:
r.write(key=process_id, value=100)
return progress
else:
r.write(key=process_id, value=int(100 / total))
return progress
def final_progress(process_id):
r = Redis()
progress = r.read(key=process_id)
r.write(key=process_id, value=100)
return progress
@RunTime
def generate(request_data):
return_response = {}
return_png_mask = []
request_data = request_data.dict()
assert "process_id" in request_data.keys(), "Need process_id parameters"
objects = request_data['objects']
# insert_keypoint_cache(objects)
process_id = request_data['process_id']
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交每个对象的处理任务
futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)}
# 获取处理结果
for future in concurrent.futures.as_completed(futures):
obj = futures[future]
return_response[obj] = future.result()[0]
return_png_mask.extend(future.result()[1])
# upload_results = process_images(return_png_mask)
final_progress(process_id)
return return_response
def process_object(cfg, process_id, total):
uploaded_images = []
basic_info = cfg.get('basic')
items_response = {
'layers': []
}
if cfg.get('basic')['single_overall'] == 'overall':
basic_info['debug'] = False
items = [build_item(x, default_args=basic_info) for x in cfg.get('items')]
layers = Layer()
body_size = None
futures = []
for item in items:
futures = [process_item(item, layers)]
for future in futures:
if future is not None:
body_size = future
# 是否自定义排序
if basic_info.get('layer_order', False):
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
else:
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
# 上传所有图片
# for layer in layers:
# if 'image' in layer.keys() and layer['image'] is not None:
# uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url'], 'image_type': 'image'})
# if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None:
# uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url'], 'image_type': 'pattern_image'})
# if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None:
# uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url'], 'image_type': 'mask'})
layers, new_size = update_base_size_priority(layers, body_size)
# 合成
items_response['synthesis_url'] = synthesis(layers, new_size, basic_info)
for lay in layers:
items_response['layers'].append({
'image_category': lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
# 'image': lay['image'],
# 'mask_image': lay['mask_image'],
})
elif cfg.get('basic')['single_overall'] == 'single':
assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters "
basic_info['debug'] = False
for item in cfg.get('items'):
if item['type'] == cfg.get('basic')['switch_category']:
item = build_item(item, default_args=cfg.get('basic'))
item.process()
items_response['layers'].append({
'image_category': f"{item.result['name']}_front",
'image_size': item.result['back_image'].size if item.result['back_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['front_image_url'],
'mask_url': item.result['mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
})
items_response['layers'].append({
'image_category': f"{item.result['name']}_back",
'image_size': item.result['front_image'].size if item.result['front_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['back_image_url'],
'mask_url': item.result['mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
})
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
break
update_progress(process_id, total)
return items_response, uploaded_images
@RunTime
def process_images(images):
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(upload_images, images))
# results = []
# for image in images:
# results.append(upload_images(image))
return results
# @RunTime
def upload_images(image_obj):
bucket_name = image_obj['image_url'].split("/", 1)[0]
object_name = image_obj['image_url'].split("/", 1)[1]
if image_obj['image_type'] == 'image' or image_obj['image_type'] == 'pattern_image':
image_data = io.BytesIO()
image_obj['image_obj'].save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return image_obj['image_url']
else:
mask_inverted = cv2.bitwise_not(image_obj['image_obj'])
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1])
return image_obj['image_url']
def update_base_size_priority(layers, size):
# 计算透明背景图片的宽度
min_x = min(info['position'][1] for info in layers)
x_list = []
for info in layers:
if info['image'] is not None:
x_list.append(info['position'][1] + info['image'].width)
max_x = max(x_list)
new_width = max_x - min_x
new_height = 700
# 更新坐标
for info in layers:
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
return layers, (new_width, new_height)

View File

@@ -1,31 +0,0 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File conversion_image.py
@Author :周成融
@Date 2023/8/21 10:40:29
@detail
"""
import numpy as np
# def rgb_to_rgba(rgb_size, rgb_image, mask):
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
# # 创建四通道的结果图像
# rgba_image = np.dstack((rgb_image, alpha_channel))
# alpha_channel = np.where(mask > 0, 255, 0)
# # 更新RGBA图像的透明度通道
# rgba_image[:, :, 3] = alpha_channel
# return rgba_image
def rgb_to_rgba(rgb_image, mask):
# 创建全透明的alpha通道
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
# 合并RGB图像和alpha通道
rgba_image = np.dstack((rgb_image, alpha_channel))
return rgba_image
if __name__ == '__main__':
image = open("")

View File

@@ -1,143 +0,0 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File design_ensemble.py
@Author :周成融
@Date 2023/8/16 19:36:21
@detail :发起请求 获取推理结果
"""
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import *
"""
keypoint
预处理 推理 后处理
"""
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
# @ RunTime
# 推理
def get_keypoint_result(image, site):
keypoint_result = None
try:
image, scale_factor = keypoint_preprocess(image)
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
inference_output = torch.from_numpy(results.as_numpy(f'output'))
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
except Exception as e:
logging.warning(f"get_keypoint_result : {e}")
return keypoint_result
def keypoint_postprocess(output, scale_factor):
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
segment_result = max_coords.numpy()
scale_factor = [1 / x for x in scale_factor[::-1]]
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
"""
seg
预处理 推理 后处理
"""
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# @ RunTime
def get_seg_result(image_id, image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
def key_point_show(image_path, key_point_result=None):
img = cv2.imread(image_path)
points_list = key_point_result
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for point in points_list:
cv2.circle(img, point[::-1], point_size, point_color, thickness)
cv2.imshow("0", img)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
# a = get_seg_result(1, image)
print(a)

View File

@@ -1,99 +0,0 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -1,181 +0,0 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File synthesis_item.py
@Author :周成融
@Date 2023/8/26 14:13:04
@detail
"""
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
def positioning(all_mask_shape, mask_shape, offset):
all_start = 0
all_end = 0
mask_start = 0
mask_end = 0
if offset == 0:
all_start = 0
all_end = min(all_mask_shape, mask_shape)
mask_start = 0
mask_end = min(all_mask_shape, mask_shape)
elif offset > 0:
all_start = min(offset, all_mask_shape)
all_end = min(offset + mask_shape, all_mask_shape)
mask_start = 0
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
elif offset < 0:
if abs(offset) > mask_shape:
all_start = 0
all_end = 0
else:
all_start = 0
if mask_shape - abs(offset) > all_mask_shape:
all_end = min(mask_shape - abs(offset), all_mask_shape)
else:
all_end = mask_shape - abs(offset)
if abs(offset) > mask_shape:
mask_start = mask_shape
mask_end = mask_shape
else:
mask_start = abs(offset)
if mask_shape - abs(offset) >= all_mask_shape:
mask_end = all_mask_shape + abs(offset)
else:
mask_end = mask_shape
return all_start, all_end, mask_start, mask_end
# @RunTime
def synthesis(data, size, basic_info):
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
body_mask = None
for d in data:
if d['name'] == 'body':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
result_image = base_image
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
result_image = front_image
if back_image:
result_image.paste(back_image, (0, 0), back_image)
# with io.BytesIO() as output:
# result_image.save(output, format='PNG')
# data = output.getvalue()
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# oss upload
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"

View File

@@ -4,7 +4,7 @@ import threading
from celery import Celery
from minio import Minio
from app.core.config import *
from app.core.config import settings
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, OthersItem
from app.service.design_batch.utils.MQ import publish_status
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_others
@@ -12,12 +12,12 @@ from app.service.design_batch.utils.save_json import oss_upload_json
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
id_lock = threading.Lock()
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app = Celery('tasks', broker=f'amqp://{settings.MQ_USERNAME}:{settings.MQ_PASSWORD}@{settings.MQ_HOST}:{settings.MQ_PORT}//', backend='rpc://')
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
print("start")
@@ -51,10 +51,12 @@ def process_layer(item, layers):
front_layer, back_layer = organize_others(item)
layers.append(front_layer)
layers.append(back_layer)
return None
else:
front_layer, back_layer = organize_clothing(item)
layers.append(front_layer)
layers.append(back_layer)
return None
@celery_app.task
@@ -76,12 +78,11 @@ def batch_design(objects_data, tasks_id, json_name):
for item in object['items']:
item_results.append(process_item(item, basic))
layers = []
body_size = None
for item in item_results:
body_size = process_layer(item, layers)
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
layers, new_size = update_base_size_priority(layers, body_size)
layers, new_size = update_base_size_priority(layers)
for lay in layers:
items_response['layers'].append({

View File

@@ -18,11 +18,11 @@ class BackPerspective:
result['back_perspective_url'] = file_path
return result
else:
seg_result = get_seg_result("1", result['image'])[0]
seg_result = get_seg_result(result['image'])[0]
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
seg_result = result['seg_result']
else:
seg_result = get_seg_result("1", result['image'])[0]
seg_result = get_seg_result(result['image'])[0]
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
back_sketch = result['image'].copy()
@@ -34,7 +34,8 @@ class BackPerspective:
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
return result
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
@staticmethod
def thicken_contours_and_display(mask, thickness=10, color=(0, 0, 0)):
mask = mask.astype(np.uint8) * 255
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -48,9 +49,9 @@ class BackPerspective:
# 在空白图像上绘制白色的轮廓
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
# 找到轮廓的中心(可以用重心等方法近似)
M = cv2.moments(contour)
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
m = cv2.moments(contour)
cx = int(m['m10'] / m['m00'])
cy = int(m['m01'] / m['m00'])
# 进行距离变换,离中心越近的值越小
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留

View File

@@ -79,9 +79,9 @@ class Color:
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
r, g, b = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
pattern[0, 0, 0] = int(b)
pattern[0, 0, 1] = int(g)
pattern[0, 0, 2] = int(r)
return pattern

View File

@@ -3,7 +3,7 @@ import logging
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
from app.service.utils.decorator import ClassCallRunTime, RunTime
@@ -21,12 +21,12 @@ class KeyPoint:
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
# keypoint_cache = self.keypoint_cache(result, site)
keypoint_cache = False
# 取消向量查询 直接过模型推理
if keypoint_cache is False:
if not keypoint_cache:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
@@ -55,8 +55,8 @@ class KeyPoint:
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
@@ -79,7 +79,7 @@ class KeyPoint:
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
@@ -92,7 +92,7 @@ class KeyPoint:
@RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,

View File

@@ -1,9 +1,6 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image

View File

@@ -9,6 +9,7 @@ from app.service.utils.new_oss_client import oss_get_image
class PrintPainting:
def __init__(self, minio_client):
self.random_seed = None
self.minio_client = minio_client
def __call__(self, result):
@@ -408,7 +409,7 @@ class PrintPainting:
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)

View File

@@ -4,7 +4,7 @@ import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.core.config import settings
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
@@ -36,11 +36,11 @@ class Segmentation:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
@@ -49,7 +49,7 @@ class Segmentation:
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
@@ -63,7 +63,7 @@ class Segmentation:
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
@@ -72,7 +72,7 @@ class Segmentation:
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)

View File

@@ -4,9 +4,7 @@ import logging
import cv2
import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
from app.service.design_fast.utils.upload_image import upload_png_mask
@@ -40,7 +38,7 @@ class Split(object):
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
@@ -98,21 +96,21 @@ class Split(object):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
result_pattern_image_pil = Image.fromarray(cv2.cvtColor(result_pattern_image_rgba, cv2.COLOR_BGR2RGBA))
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:

View File

@@ -2,16 +2,17 @@ import json
import pika
from app.core.config import RABBITMQ_PARAMS, BATCH_DESIGN_RABBITMQ_QUEUES
from app.core.config import settings
from app.core.rabbit_mq_config import RABBITMQ_PARAMS
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
channel.queue_declare(queue=settings.BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_DESIGN_RABBITMQ_QUEUES,
routing_key=settings.BATCH_DESIGN_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,

View File

@@ -16,7 +16,7 @@ import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import *
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
"""
keypoint
@@ -91,29 +91,29 @@ def seg_preprocess(img_path):
# @ RunTime
def get_seg_result(image_id, image):
def get_seg_result(image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
httpclient.InferInput(DESIGN_MODEL_NAME, transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
httpclient.InferRequestedOutput("seg_input__0", binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
results = client.infer(model_name=DESIGN_MODEL_NAME, inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
inference_output1 = results.as_numpy("seg_input__0")
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
def seg_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]

View File

@@ -98,6 +98,8 @@ def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offse
"""
Align left
Args:
offset:
resize_scale:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}

View File

@@ -1,6 +1,6 @@
import logging
from app.service.design_fast.utils.redis_utils import Redis
from app.service.utils.redis_utils import Redis
logger = logging.getLogger(__name__)

View File

@@ -1,99 +0,0 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -13,9 +13,12 @@ import logging
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import settings
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
from app.service.utils.new_oss_client import oss_upload_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def positioning(all_mask_shape, mask_shape, offset):
@@ -136,7 +139,7 @@ def synthesis(data, size, basic_info):
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
@@ -177,11 +180,11 @@ def synthesis_single(front_image, back_image):
# oss upload
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
def update_base_size_priority(layers, size):
def update_base_size_priority(layers):
# 计算透明背景图片的宽度
min_x = min(info['position'][1] for info in layers)
x_list = []

View File

@@ -12,7 +12,6 @@ import logging
import cv2
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
@@ -25,15 +24,15 @@ def upload_png_mask(minio_client, front_image, object_name, mask=None):
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
req = oss_upload_image(oss_client=minio_client, bucket="aida-clothing", object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
mask_url = f"aida-clothing/mask/mask_{object_name}.png"
image_data = io.BytesIO()
front_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
req = oss_upload_image(oss_client=minio_client, bucket="aida-clothing", object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
image_url = f"aida-clothing/image/image_{object_name}.png"
return front_image, image_url, mask_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -5,7 +5,7 @@ import time
import requests
from minio import Minio
from app.core.config import *
from app.core.config import settings
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
from app.service.design_fast.utils.progress import final_progress, update_progress
@@ -16,7 +16,7 @@ id_lock = threading.Lock()
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def process_item(item, basic):
@@ -48,10 +48,12 @@ def process_layer(item, layers):
front_layer, back_layer = organize_others(item)
layers.append(front_layer)
layers.append(back_layer)
return None
else:
front_layer, back_layer = organize_clothing(item)
layers.append(front_layer)
layers.append(back_layer)
return None
@RunTime
@@ -73,12 +75,11 @@ def design_generate(request_data):
for item in object['items']:
item_results.append(process_item(item, basic))
layers = []
body_size = None
for item in item_results:
body_size = process_layer(item, layers)
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
layers, new_size = update_base_size_priority(layers, body_size)
layers, new_size = update_base_size_priority(layers)
# pattern_overall_image_url 、 pattern_print_image_url
for lay in layers:
items_response['layers'].append({
@@ -149,7 +150,7 @@ def design_generate_v2(request_data):
request_id = request_data.requestId
threads = []
def process_object(step, object, callback_url):
def process_object(object, callback_url):
basic = object['basic']
items_response = {
'layers': [],
@@ -161,12 +162,11 @@ def design_generate_v2(request_data):
for item in object['items']:
item_results.append(process_item(item, basic))
layers = []
body_size = None
for item in item_results:
body_size = process_layer(item, layers)
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
layers, new_size = update_base_size_priority(layers, body_size)
layers, new_size = update_base_size_priority(layers)
for lay in layers:
items_response['layers'].append({
@@ -229,7 +229,7 @@ def design_generate_v2(request_data):
logger.info(response.text)
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object, callback_url))
t = threading.Thread(target=process_object, args=(object, callback_url))
threads.append(t)
t.start()

View File

@@ -1,13 +1,18 @@
import io
from app.service.utils.oss_client import oss_get_image, oss_upload_image
from minio import Minio
from app.core.config import settings
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def model_transpose(image_path):
bucket = image_path.split("/", 1)[0]
object_name = image_path.split("/", 1)[1]
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
image = oss_get_image(bucket=bucket, object_name=object_name, data_type="PIL")
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name, data_type="PIL")
image = image.convert("RGBA")
data = image.getdata()
#
@@ -23,6 +28,6 @@ def model_transpose(image_path):
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
oss_upload_image(bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
oss_upload_image(oss_client=minio_client, bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
image_path = f"{bucket}/{new_object_name}"
return image_path

View File

@@ -18,11 +18,11 @@ class BackPerspective:
result['back_perspective_url'] = file_path
return result
else:
seg_result = get_seg_result("1", result['image'])[0]
seg_result = get_seg_result(result['image'])[0]
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
seg_result = result['seg_result']
else:
seg_result = get_seg_result("1", result['image'])[0]
seg_result = get_seg_result(result['image'])[0]
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
back_sketch = result['image'].copy()
@@ -34,7 +34,8 @@ class BackPerspective:
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
return result
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
@staticmethod
def thicken_contours_and_display(mask, thickness=10, color=(0, 0, 0)):
mask = mask.astype(np.uint8) * 255
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -48,9 +49,9 @@ class BackPerspective:
# 在空白图像上绘制白色的轮廓
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
# 找到轮廓的中心(可以用重心等方法近似)
M = cv2.moments(contour)
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
m = cv2.moments(contour)
# cx = int(m['m10'] / m['m00'])
# cy = int(m['m01'] / m['m00'])
# 进行距离变换,离中心越近的值越小
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留

View File

@@ -81,9 +81,9 @@ class Color:
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
r, g, b = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
pattern[0, 0, 0] = int(b)
pattern[0, 0, 1] = int(g)
pattern[0, 0, 2] = int(r)
return pattern

View File

@@ -3,7 +3,7 @@ import logging
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
from app.service.utils.decorator import ClassCallRunTime, RunTime
@@ -21,12 +21,12 @@ class KeyPoint:
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
# keypoint_cache = self.keypoint_cache(result, site)
keypoint_cache = False
# 取消向量查询 直接过模型推理
if keypoint_cache is False:
if not keypoint_cache:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
@@ -55,8 +55,8 @@ class KeyPoint:
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
@@ -79,7 +79,7 @@ class KeyPoint:
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
@@ -92,7 +92,7 @@ class KeyPoint:
@RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,

View File

@@ -1,10 +1,7 @@
import io
import logging
import os
from skimage.morphology import skeletonize
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image
@@ -47,26 +44,32 @@ class LoadImage:
# else:
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), result['path'])
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY))
result['keypoint'] = self.get_keypoint(result['name'])
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
return result
def get_lines(self, img, path):
@staticmethod
def get_lines(img):
binary = cv2.adaptiveThreshold(img, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
25, 10)
binary_bool = binary > 0
skeleton = skeletonize(binary_bool, method='zhang')
mask = skeleton
result = np.ones_like(img) * 255
result[mask] = img[mask]
# 步骤2细化边缘可选让线条更干净
# kernel = np.ones((1, 1), np.uint8)
# clean = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
mask = thinned > 0
result = np.ones_like(img) * 255
result[mask] = img[mask]
# thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
# mask = thinned > 0
# result = np.ones_like(img) * 255
# result[mask] = img[mask]
# 步骤3反转回 白底黑线
# lines = cv2.bitwise_not(thinned)

View File

@@ -9,6 +9,7 @@ from app.service.utils.new_oss_client import oss_get_image
class NoSegPrintPainting:
def __init__(self, minio_client):
self.random_seed = random.randint(0, 1000)
self.minio_client = minio_client
def __call__(self, result):
@@ -174,7 +175,6 @@ class NoSegPrintPainting:
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not is_single:
self.random_seed = random.randint(0, 1000)
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
@@ -244,7 +244,7 @@ class NoSegPrintPainting:
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)

View File

@@ -9,6 +9,7 @@ from app.service.utils.new_oss_client import oss_get_image
class PrintPainting:
def __init__(self, minio_client):
self.random_seed = None
self.minio_client = minio_client
def __call__(self, result):
@@ -416,7 +417,7 @@ class PrintPainting:
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)

View File

@@ -4,7 +4,7 @@ import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.core.config import settings
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
@@ -36,11 +36,11 @@ class Segmentation:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
@@ -49,7 +49,7 @@ class Segmentation:
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
@@ -63,7 +63,7 @@ class Segmentation:
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
@@ -72,7 +72,7 @@ class Segmentation:
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)

View File

@@ -4,9 +4,7 @@ import logging
import cv2
import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
from app.service.design_fast.utils.upload_image import upload_png_mask
@@ -41,7 +39,7 @@ class Split(object):
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
@@ -106,26 +104,27 @@ class Split(object):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
mask_image[ori_back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, ori_front_mask + ori_back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
ori_front_mask, ori_back_mask = None, None
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
result_pattern_overall_image_pil = Image.fromarray(cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), COLOR_BGR2RGBA))
result_pattern_overall_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_overall_image'], result['pattern_overall_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_overall_image_pil, f'{generate_uuid()}')
result_pattern_print_image_pil = Image.fromarray(cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), COLOR_BGR2RGBA))
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
return result
except Exception as e:

View File

@@ -15,7 +15,7 @@ import numpy as np
import torch
import tritonclient.http as httpclient
from app.core.config import *
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
"""
keypoint
@@ -98,29 +98,29 @@ def seg_preprocess(img_path):
# @ RunTime
def get_seg_result(image_id, image):
def get_seg_result(image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
httpclient.InferInput("seg_input__0", transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
httpclient.InferRequestedOutput("seg_output__0", binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
results = client.infer(model_name=DESIGN_MODEL_NAME, inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
inference_output1 = results.as_numpy("seg_output__0")
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
def seg_postprocess(output, ori_shape):
seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50))
seg_logit = seg_logit[25: - 25, 25: - 25]
return seg_logit

View File

@@ -112,6 +112,8 @@ def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offse
"""
Align left
Args:
offset:
resize_scale:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}

View File

@@ -1,6 +1,6 @@
import logging
from app.service.design_fast.utils.redis_utils import Redis
from app.service.utils.redis_utils import Redis
logger = logging.getLogger(__name__)

View File

@@ -1,99 +0,0 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

Some files were not shown because too many files have changed in this diff Show More