Merge branch 'local' into develop
# Conflicts: # app/core/config.py
This commit is contained in:
@@ -2,9 +2,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from app.core.config import DEBUG
|
||||||
from app.schemas.attribute_retrieve import *
|
from app.schemas.attribute_retrieve import *
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.attribute.config import const
|
from app.service.attribute.config import const, local_debug_const
|
||||||
from app.service.attribute.service_att_recognition import AttributeRecognition
|
from app.service.attribute.service_att_recognition import AttributeRecognition
|
||||||
from app.service.attribute.service_category_recognition import CategoryRecognition
|
from app.service.attribute.service_category_recognition import CategoryRecognition
|
||||||
|
|
||||||
@@ -17,13 +18,16 @@ logger = logging.getLogger()
|
|||||||
def attribute_recognition(request_item: list[AttributeRecognitionModel]):
|
def attribute_recognition(request_item: list[AttributeRecognitionModel]):
|
||||||
try:
|
try:
|
||||||
logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}")
|
logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}")
|
||||||
service = AttributeRecognition(const=const, request_data=request_item)
|
if DEBUG:
|
||||||
|
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
|
||||||
|
else:
|
||||||
|
service = AttributeRecognition(const=const, request_data=request_item)
|
||||||
data = service.get_result()
|
data = service.get_result()
|
||||||
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
|
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
|
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
return ResponseModel(data=data)
|
return ResponseModel(data={"list": data})
|
||||||
|
|
||||||
|
|
||||||
# 类别识别
|
# 类别识别
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ import time
|
|||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
from app.schemas.design import DesignModel
|
from app.schemas.design import DesignModel, DesignProgressModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.design.service import generate
|
from app.service.design.service import generate
|
||||||
|
from app.service.design.utils.redis_utils import Redis
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -22,3 +23,19 @@ def design(request_data: DesignModel):
|
|||||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
return ResponseModel(data=data)
|
return ResponseModel(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/get_progress')
|
||||||
|
def get_progress(request_data: DesignProgressModel):
|
||||||
|
try:
|
||||||
|
logger.info(f"get_progress request item is : @@@@@@:{request_data}")
|
||||||
|
process_id = request_data.process_id
|
||||||
|
r = Redis()
|
||||||
|
data = r.read(key=process_id)
|
||||||
|
if data is None:
|
||||||
|
raise ValueError("The progress must be numbers ")
|
||||||
|
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=data)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES
|
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
|
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
@@ -15,6 +15,8 @@ def test(id: int):
|
|||||||
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
|
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
|
||||||
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||||
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||||
|
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||||
|
"local_oss_server": OSS
|
||||||
}
|
}
|
||||||
logger.info(data)
|
logger.info(data)
|
||||||
if id == 1:
|
if id == 1:
|
||||||
|
|||||||
@@ -19,15 +19,16 @@ class Settings(BaseSettings):
|
|||||||
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
|
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
|
||||||
|
|
||||||
|
|
||||||
|
OSS = "minio"
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
LOGS_PATH = "logs/"
|
LOGS_PATH = "logs/"
|
||||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||||
FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
|
# FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
|
||||||
else:
|
else:
|
||||||
LOGS_PATH = "app/logs/"
|
LOGS_PATH = "app/logs/"
|
||||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||||
FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
|
# FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
|
||||||
|
|
||||||
# RABBITMQ_ENV = "" # 生产环境
|
# RABBITMQ_ENV = "" # 生产环境
|
||||||
# RABBITMQ_ENV = "-dev" # 开发环境
|
# RABBITMQ_ENV = "-dev" # 开发环境
|
||||||
@@ -47,7 +48,7 @@ S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8"
|
|||||||
S3_REGION_NAME = "ap-east-1"
|
S3_REGION_NAME = "ap-east-1"
|
||||||
|
|
||||||
# redis 配置
|
# redis 配置
|
||||||
REDIS_HOST = "10.1.1.240"
|
REDIS_HOST = "10.1.1.150"
|
||||||
REDIS_PORT = "6379"
|
REDIS_PORT = "6379"
|
||||||
REDIS_DB = "2"
|
REDIS_DB = "2"
|
||||||
|
|
||||||
@@ -60,9 +61,9 @@ RABBITMQ_PARAMS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# milvus 配置
|
# milvus 配置
|
||||||
MILVUS_DB_HOST = "10.1.1.240"
|
MILVUS_URL = "http://10.1.1.240:19530"
|
||||||
|
MILVUS_TOKEN = "root:Milvus"
|
||||||
MILVUS_ALIAS = "default"
|
MILVUS_ALIAS = "default"
|
||||||
MILVUS_PORT = "19530"
|
|
||||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
|
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
|
||||||
MILVUS_TABLE_SEG = "seg_cache"
|
MILVUS_TABLE_SEG = "seg_cache"
|
||||||
|
|
||||||
@@ -123,8 +124,8 @@ GPI_MODEL_NAME = 'diffusion_ensemble_all'
|
|||||||
GPI_MODEL_URL = '10.1.1.240:10041'
|
GPI_MODEL_URL = '10.1.1.240:10041'
|
||||||
|
|
||||||
# Generate Single Logo service config
|
# Generate Single Logo service config
|
||||||
GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
|
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
|
||||||
GRI_MODEL_NAME = 'stable_diffusion_1_5'
|
GRI_MODEL_NAME = 'diffusion_relight_ensemble'
|
||||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||||
|
|
||||||
# SEG service config
|
# SEG service config
|
||||||
|
|||||||
@@ -48,3 +48,7 @@ from pydantic import BaseModel
|
|||||||
class DesignModel(BaseModel):
|
class DesignModel(BaseModel):
|
||||||
objects: list[dict]
|
objects: list[dict]
|
||||||
process_id: str
|
process_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class DesignProgressModel(BaseModel):
|
||||||
|
process_id: str
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
top_description_list = ['service/attribute/config/descriptor/top/length.csv',
|
top_description_list = ['app/service/attribute/config/descriptor/top/length.csv',
|
||||||
'service/attribute/config/descriptor/top/type.csv',
|
'app/service/attribute/config/descriptor/top/type.csv',
|
||||||
'service/attribute/config/descriptor/top/sleeve_length.csv',
|
'app/service/attribute/config/descriptor/top/sleeve_length.csv',
|
||||||
'service/attribute/config/descriptor/top/sleeve_shape.csv',
|
'app/service/attribute/config/descriptor/top/sleeve_shape.csv',
|
||||||
'service/attribute/config/descriptor/top/sleeve_shoulder.csv',
|
'app/service/attribute/config/descriptor/top/sleeve_shoulder.csv',
|
||||||
'service/attribute/config/descriptor/top/neckline.csv',
|
'app/service/attribute/config/descriptor/top/neckline.csv',
|
||||||
'service/attribute/config/descriptor/top/design.csv',
|
'app/service/attribute/config/descriptor/top/design.csv',
|
||||||
'service/attribute/config/descriptor/top/opening_type.csv',
|
'app/service/attribute/config/descriptor/top/opening_type.csv',
|
||||||
'service/attribute/config/descriptor/top/silhouette.csv',
|
'app/service/attribute/config/descriptor/top/silhouette.csv',
|
||||||
'service/attribute/config/descriptor/top/collar.csv']
|
'app/service/attribute/config/descriptor/top/collar.csv']
|
||||||
|
|
||||||
top_model_list = ['attr_retrieve_T_length',
|
top_model_list = ['attr_retrieve_T_length',
|
||||||
'attr_retrieve_T_type',
|
'attr_retrieve_T_type',
|
||||||
@@ -22,11 +22,11 @@ top_model_list = ['attr_retrieve_T_length',
|
|||||||
]
|
]
|
||||||
|
|
||||||
bottom_description_list = [
|
bottom_description_list = [
|
||||||
'service/attribute/config/descriptor/bottom/subtype.csv',
|
'app/service/attribute/config/descriptor/bottom/subtype.csv',
|
||||||
'service/attribute/config/descriptor/bottom/length.csv',
|
'app/service/attribute/config/descriptor/bottom/length.csv',
|
||||||
'service/attribute/config/descriptor/bottom/silhouette.csv',
|
'app/service/attribute/config/descriptor/bottom/silhouette.csv',
|
||||||
'service/attribute/config/descriptor/bottom/opening_type.csv',
|
'app/service/attribute/config/descriptor/bottom/opening_type.csv',
|
||||||
'service/attribute/config/descriptor/bottom/design.csv']
|
'app/service/attribute/config/descriptor/bottom/design.csv']
|
||||||
|
|
||||||
bottom_model_list = [
|
bottom_model_list = [
|
||||||
'attr_retrieve_B_subtype',
|
'attr_retrieve_B_subtype',
|
||||||
@@ -35,14 +35,14 @@ bottom_model_list = [
|
|||||||
'attr_recong_B_optype',
|
'attr_recong_B_optype',
|
||||||
'attr_retrieve_B_design']
|
'attr_retrieve_B_design']
|
||||||
|
|
||||||
outwear_description_list = ['service/attribute/config/descriptor/outwear/length.csv',
|
outwear_description_list = ['app/service/attribute/config/descriptor/outwear/length.csv',
|
||||||
'service/attribute/config/descriptor/outwear/sleeve_length.csv',
|
'app/service/attribute/config/descriptor/outwear/sleeve_length.csv',
|
||||||
'service/attribute/config/descriptor/outwear/sleeve_shape.csv',
|
'app/service/attribute/config/descriptor/outwear/sleeve_shape.csv',
|
||||||
'service/attribute/config/descriptor/outwear/sleeve_shoulder.csv',
|
'app/service/attribute/config/descriptor/outwear/sleeve_shoulder.csv',
|
||||||
'service/attribute/config/descriptor/outwear/collar.csv',
|
'app/service/attribute/config/descriptor/outwear/collar.csv',
|
||||||
'service/attribute/config/descriptor/outwear/design.csv',
|
'app/service/attribute/config/descriptor/outwear/design.csv',
|
||||||
'service/attribute/config/descriptor/outwear/opening_type.csv',
|
'app/service/attribute/config/descriptor/outwear/opening_type.csv',
|
||||||
'service/attribute/config/descriptor/outwear/silhouette.csv', ]
|
'app/service/attribute/config/descriptor/outwear/silhouette.csv', ]
|
||||||
|
|
||||||
outwear_model_list = ['attr_recong_O_length',
|
outwear_model_list = ['attr_recong_O_length',
|
||||||
'attr_retrieve_O_sleeve_length',
|
'attr_retrieve_O_sleeve_length',
|
||||||
@@ -53,15 +53,15 @@ outwear_model_list = ['attr_recong_O_length',
|
|||||||
'attr_recong_O_optype',
|
'attr_recong_O_optype',
|
||||||
'attr_retrieve_O_silhouette']
|
'attr_retrieve_O_silhouette']
|
||||||
|
|
||||||
dress_description_list = [ # 'service/attribute/config/descriptor/dress/D_length.csv',
|
dress_description_list = [ # 'app/service/attribute/config/descriptor/dress/D_length.csv',
|
||||||
'service/attribute/config/descriptor/dress/sleeve_length.csv',
|
'app/service/attribute/config/descriptor/dress/sleeve_length.csv',
|
||||||
'service/attribute/config/descriptor/dress/sleeve_shape.csv',
|
'app/service/attribute/config/descriptor/dress/sleeve_shape.csv',
|
||||||
# 'service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv',
|
# 'app/service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv',
|
||||||
'service/attribute/config/descriptor/dress/neckline.csv',
|
'app/service/attribute/config/descriptor/dress/neckline.csv',
|
||||||
'service/attribute/config/descriptor/dress/collar.csv',
|
'app/service/attribute/config/descriptor/dress/collar.csv',
|
||||||
'service/attribute/config/descriptor/dress/design.csv',
|
'app/service/attribute/config/descriptor/dress/design.csv',
|
||||||
'service/attribute/config/descriptor/dress/silhouette.csv',
|
'app/service/attribute/config/descriptor/dress/silhouette.csv',
|
||||||
'service/attribute/config/descriptor/dress/type.csv']
|
'app/service/attribute/config/descriptor/dress/type.csv']
|
||||||
|
|
||||||
dress_model_list = [ # 'attr_recong_D_length',
|
dress_model_list = [ # 'attr_recong_D_length',
|
||||||
'attr_retrieve_D_sleeve_length',
|
'attr_retrieve_D_sleeve_length',
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ from minio import Minio
|
|||||||
import tritonclient.http as httpclient
|
import tritonclient.http as httpclient
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.attribute_retrieve import AttributeRecognitionModel
|
from app.schemas.attribute_retrieve import AttributeRecognitionModel
|
||||||
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
|
|
||||||
class AttributeRecognition:
|
class AttributeRecognition:
|
||||||
def __init__(self, const, request_data):
|
def __init__(self, const, request_data):
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
logging.info("实例化完成")
|
|
||||||
self.request_data = []
|
self.request_data = []
|
||||||
for i, sketch in enumerate(request_data):
|
for i, sketch in enumerate(request_data):
|
||||||
self.request_data.append(
|
self.request_data.append(
|
||||||
@@ -97,9 +97,10 @@ class AttributeRecognition:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def get_image(self, url):
|
def get_image(self, url):
|
||||||
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||||
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
|
||||||
|
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|||||||
@@ -18,12 +18,13 @@ import torch
|
|||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.attribute_retrieve import CategoryRecognitionModel
|
from app.schemas.attribute_retrieve import CategoryRecognitionModel
|
||||||
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
|
|
||||||
class CategoryRecognition:
|
class CategoryRecognition:
|
||||||
def __init__(self, request_data):
|
def __init__(self, request_data):
|
||||||
self.attr_type = pd.read_csv(CATEGORY_PATH)
|
self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
self.request_data = []
|
self.request_data = []
|
||||||
self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL)
|
self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL)
|
||||||
for sketch in request_data:
|
for sketch in request_data:
|
||||||
@@ -51,9 +52,10 @@ class CategoryRecognition:
|
|||||||
def get_image(self, url):
|
def get_image(self, url):
|
||||||
# Get data of an object.
|
# Get data of an object.
|
||||||
# Read data from response.
|
# Read data from response.
|
||||||
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||||
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
||||||
|
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pymilvus import MilvusClient
|
from pymilvus import MilvusClient
|
||||||
|
|
||||||
@@ -14,17 +15,17 @@ class KeypointDetection(object):
|
|||||||
path here: abstract path
|
path here: abstract path
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
self.client = MilvusClient(
|
# self.client = MilvusClient(
|
||||||
uri="http://10.1.1.240:19530",
|
# uri="http://10.1.1.240:19530",
|
||||||
token="root:Milvus",
|
# token="root:Milvus",
|
||||||
db_name=MILVUS_ALIAS
|
# db_name=MILVUS_ALIAS
|
||||||
)
|
# )
|
||||||
|
|
||||||
def __del__(self):
|
# def __del__(self):
|
||||||
# start_time = time.time()
|
# start_time = time.time()
|
||||||
self.client.close()
|
# self.client.close()
|
||||||
# print(f"client close time : {time.time() - start_time}")
|
# print(f"client close time : {time.time() - start_time}")
|
||||||
|
|
||||||
# @ RunTime
|
# @ RunTime
|
||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
@@ -55,7 +56,7 @@ class KeypointDetection(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# @ RunTime
|
# @ RunTime
|
||||||
def save_keypoint_cache(keypoint_id, cache, site, KEYPOINT_RESULT_TABLE_FIELD_SET=None):
|
def save_keypoint_cache(keypoint_id, cache, site):
|
||||||
if site == "down":
|
if site == "down":
|
||||||
zeros = np.zeros(20, dtype=int)
|
zeros = np.zeros(20, dtype=int)
|
||||||
result = np.concatenate([zeros, cache.flatten()])
|
result = np.concatenate([zeros, cache.flatten()])
|
||||||
@@ -69,24 +70,16 @@ class KeypointDetection(object):
|
|||||||
"keypoint_vector": result.tolist()
|
"keypoint_vector": result.tolist()
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
client = MilvusClient(
|
|
||||||
uri="http://10.1.1.240:19530",
|
|
||||||
token="root:Milvus",
|
|
||||||
db_name=MILVUS_ALIAS
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||||
res = client.upsert(
|
# start_time = time.time()
|
||||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||||
|
client.close()
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"save keypoint cache milvus error : {e}")
|
logging.info(f"save keypoint cache milvus error : {e}")
|
||||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
finally:
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
||||||
@@ -102,12 +95,9 @@ class KeypointDetection(object):
|
|||||||
"keypoint_vector": result.tolist()
|
"keypoint_vector": result.tolist()
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
client = MilvusClient(
|
|
||||||
uri="http://10.1.1.240:19530",
|
|
||||||
token="root:Milvus",
|
|
||||||
db_name=MILVUS_ALIAS
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
|
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||||
@@ -125,8 +115,9 @@ class KeypointDetection(object):
|
|||||||
# @ RunTime
|
# @ RunTime
|
||||||
def keypoint_cache(self, result, site):
|
def keypoint_cache(self, result, site):
|
||||||
try:
|
try:
|
||||||
|
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||||
keypoint_id = result['image_id']
|
keypoint_id = result['image_id']
|
||||||
res = self.client.query(
|
res = client.query(
|
||||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||||
# ids=[keypoint_id],
|
# ids=[keypoint_id],
|
||||||
filter=f"keypoint_id == {keypoint_id}",
|
filter=f"keypoint_id == {keypoint_id}",
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -8,6 +7,7 @@ from PIL import Image
|
|||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
|
from app.service.utils.oss_client import oss_get_image
|
||||||
from ..builder import PIPELINES
|
from ..builder import PIPELINES
|
||||||
|
|
||||||
|
|
||||||
@@ -70,11 +70,7 @@ class LoadImageFromFile(object):
|
|||||||
class LoadBodyImageFromFile(object):
|
class LoadBodyImageFromFile(object):
|
||||||
def __init__(self, body_path):
|
def __init__(self, body_path):
|
||||||
self.body_path = body_path
|
self.body_path = body_path
|
||||||
self.minioClient = Minio(
|
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
f"{MINIO_URL}",
|
|
||||||
access_key=MINIO_ACCESS,
|
|
||||||
secret_key=MINIO_SECRET,
|
|
||||||
secure=MINIO_SECURE)
|
|
||||||
|
|
||||||
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
|
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
|
||||||
|
|
||||||
@@ -82,33 +78,33 @@ class LoadBodyImageFromFile(object):
|
|||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
result["image_url"] = result['body_path'] = self.body_path
|
result["image_url"] = result['body_path'] = self.body_path
|
||||||
result["name"] = "mannequin"
|
result["name"] = "mannequin"
|
||||||
if not result['image_url'].lower().endswith(".png"):
|
# if not result['image_url'].lower().endswith(".png"):
|
||||||
logging.info(1)
|
# bucket = self.body_path.split("/", 1)[0]
|
||||||
bucket = self.body_path.split("/", 1)[0]
|
# object_name = self.body_path.split("/", 1)[1]
|
||||||
object_name = self.body_path.split("/", 1)[1]
|
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
||||||
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
# image = self.minioClient.get_object(bucket, object_name)
|
||||||
image = self.minioClient.get_object(bucket, object_name)
|
# image = Image.open(io.BytesIO(image.data))
|
||||||
image = Image.open(io.BytesIO(image.data))
|
# image = image.convert("RGBA")
|
||||||
image = image.convert("RGBA")
|
# data = image.getdata()
|
||||||
data = image.getdata()
|
# #
|
||||||
#
|
# new_data = []
|
||||||
new_data = []
|
# for item in data:
|
||||||
for item in data:
|
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
|
||||||
if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
|
# new_data.append((255, 255, 255, 0))
|
||||||
new_data.append((255, 255, 255, 0))
|
# else:
|
||||||
else:
|
# new_data.append(item)
|
||||||
new_data.append(item)
|
# image.putdata(new_data)
|
||||||
image.putdata(new_data)
|
# image_data = io.BytesIO()
|
||||||
image_data = io.BytesIO()
|
# image.save(image_data, format='PNG')
|
||||||
image.save(image_data, format='PNG')
|
# image_data.seek(0)
|
||||||
image_data.seek(0)
|
# image_bytes = image_data.read()
|
||||||
image_bytes = image_data.read()
|
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||||
image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
# self.body_path = image_path
|
||||||
self.body_path = image_path
|
# result["image_url"] = result['body_path'] = self.body_path
|
||||||
result["image_url"] = result['body_path'] = self.body_path
|
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
|
||||||
response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
|
|
||||||
# put_image_time = time.time()
|
# put_image_time = time.time()
|
||||||
result['body_image'] = Image.open(io.BytesIO(response.read()))
|
# result['body_image'] = Image.open(io.BytesIO(response.read()))
|
||||||
|
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
|
||||||
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
|
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
import random
|
import random
|
||||||
from io import BytesIO
|
|
||||||
# import boto3
|
# import boto3
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from minio import Minio
|
|
||||||
|
|
||||||
from app.core.config import *
|
from app.service.utils.oss_client import oss_get_image
|
||||||
from ..builder import PIPELINES
|
from ..builder import PIPELINES
|
||||||
|
|
||||||
minio_client = Minio(
|
|
||||||
MINIO_URL,
|
# minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
access_key=MINIO_ACCESS,
|
|
||||||
secret_key=MINIO_SECRET,
|
|
||||||
secure=MINIO_SECURE)
|
|
||||||
|
|
||||||
# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||||
|
|
||||||
@@ -56,17 +53,18 @@ class Painting(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_gradient(bucket_name, object_name):
|
def get_gradient(bucket_name, object_name):
|
||||||
image_data = minio_client.get_object(bucket_name, object_name)
|
# image_data = minio_client.get_object(bucket_name, object_name)
|
||||||
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
|
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
|
||||||
|
|
||||||
# 从数据流中读取图像
|
# 从数据流中读取图像
|
||||||
image_bytes = image_data.read()
|
# image_bytes = image_data.read()
|
||||||
|
|
||||||
# 将图像数据转换为numpy数组
|
# 将图像数据转换为numpy数组
|
||||||
image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
|
# image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
|
||||||
|
|
||||||
# 使用OpenCV解码图像数组
|
# 使用OpenCV解码图像数组
|
||||||
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
# image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||||
|
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -494,16 +492,20 @@ class PrintPainting(object):
|
|||||||
if not 'IfSingle' in print_dict.keys():
|
if not 'IfSingle' in print_dict.keys():
|
||||||
print_dict['IfSingle'] = False
|
print_dict['IfSingle'] = False
|
||||||
|
|
||||||
data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1])
|
# data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1])
|
||||||
# data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body']
|
# data_bytes = BytesIO(data.read())
|
||||||
|
# image = Image.open(data_bytes)
|
||||||
|
# image_mode = image.mode
|
||||||
|
|
||||||
data_bytes = BytesIO(data.read())
|
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
|
||||||
image = Image.open(data_bytes)
|
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
|
||||||
image_mode = image.mode
|
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||||
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
||||||
if image_mode == "RGBA":
|
if image.shape[2] == 4:
|
||||||
new_background = Image.new('RGB', image.size, (255, 255, 255))
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||||
new_background.paste(image, mask=image.split()[3])
|
image_pil = Image.fromarray(image_rgb)
|
||||||
|
new_background = Image.new('RGB', image_pil.size, (255, 255, 255))
|
||||||
|
new_background.paste(image_pil, mask=image.split()[3])
|
||||||
image = new_background
|
image = new_background
|
||||||
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
@@ -577,21 +579,30 @@ class PrintPainting(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_image(image_url):
|
def read_image(image_url):
|
||||||
data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
|
image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
|
||||||
# data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body']
|
if image.shape[2] == 4:
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||||
data_bytes = BytesIO(data.read())
|
image = Image.fromarray(image_rgb)
|
||||||
image = Image.open(data_bytes)
|
image_mode = "RGBA"
|
||||||
image_mode = image.mode
|
else:
|
||||||
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
image_mode = "RGB"
|
||||||
if image_mode == "RGBA":
|
|
||||||
# new_background = Image.new('RGB', image.size, (255, 255, 255))
|
|
||||||
# new_background.paste(image, mask=image.split()[3])
|
|
||||||
# image = new_background
|
|
||||||
return image, image_mode
|
|
||||||
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
|
||||||
return image, image_mode
|
return image, image_mode
|
||||||
|
|
||||||
|
# data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
|
||||||
|
# # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body']
|
||||||
|
#
|
||||||
|
# data_bytes = BytesIO(data.read())
|
||||||
|
# image = Image.open(data_bytes)
|
||||||
|
# image_mode = image.mode
|
||||||
|
# # 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
||||||
|
# if image_mode == "RGBA":
|
||||||
|
# # new_background = Image.new('RGB', image.size, (255, 255, 255))
|
||||||
|
# # new_background.paste(image, mask=image.split()[3])
|
||||||
|
# # image = new_background
|
||||||
|
# return image, image_mode
|
||||||
|
# image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||||
|
# return image, "RGB"
|
||||||
|
|
||||||
# @staticmethod
|
# @staticmethod
|
||||||
# def read_image(image_url):
|
# def read_image(image_url):
|
||||||
# response = requests.get(image_url)
|
# response = requests.get(image_url)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class Split(object):
|
|||||||
else:
|
else:
|
||||||
back_mask = result['back_mask']
|
back_mask = result['back_mask']
|
||||||
|
|
||||||
rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask'])
|
rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), re4sult['final_image'], result['mask'])
|
||||||
result_front_image = np.zeros_like(rgba_image)
|
result_front_image = np.zeros_like(rgba_image)
|
||||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||||
|
|
||||||
|
|||||||
@@ -13,15 +13,12 @@ import io
|
|||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.service.design.utils.design_ensemble import get_keypoint_result
|
from app.service.design.utils.design_ensemble import get_keypoint_result
|
||||||
|
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||||
|
|
||||||
|
|
||||||
class DesignPreprocessing:
|
class DesignPreprocessing:
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
self.minio_client = Minio(
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
MINIO_URL,
|
|
||||||
access_key=MINIO_ACCESS,
|
|
||||||
secret_key=MINIO_SECRET,
|
|
||||||
secure=MINIO_SECURE)
|
|
||||||
|
|
||||||
# @ RunTime
|
# @ RunTime
|
||||||
def pipeline(self, image_list):
|
def pipeline(self, image_list):
|
||||||
@@ -51,8 +48,9 @@ class DesignPreprocessing:
|
|||||||
|
|
||||||
def read_image(self, image_list):
|
def read_image(self, image_list):
|
||||||
for obj in image_list:
|
for obj in image_list:
|
||||||
file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
|
# file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
|
||||||
image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
|
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
|
||||||
|
image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2")
|
||||||
if len(image.shape) == 2:
|
if len(image.shape) == 2:
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||||
elif image.shape[2] == 4: # 如果是四通道 mask
|
elif image.shape[2] == 4: # 如果是四通道 mask
|
||||||
@@ -125,7 +123,10 @@ class DesignPreprocessing:
|
|||||||
try:
|
try:
|
||||||
# 覆盖到minio
|
# 覆盖到minio
|
||||||
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
|
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
|
||||||
self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
|
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
|
||||||
|
bucket_name = item['image_url'].split("/", 1)[0]
|
||||||
|
object_name = item['image_url'].split("/", 1)[1]
|
||||||
|
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||||
except ResponseError as err:
|
except ResponseError as err:
|
||||||
print(f"Error: {err}")
|
print(f"Error: {err}")
|
||||||
@@ -165,36 +166,76 @@ class DesignPreprocessing:
|
|||||||
# @ RunTime
|
# @ RunTime
|
||||||
def composing_image(self, image_list):
|
def composing_image(self, image_list):
|
||||||
for image in image_list:
|
for image in image_list:
|
||||||
if image['site'] == 'down':
|
''' 比例相同 整合上下装代码'''
|
||||||
image_width = image['obj'].shape[1]
|
image_width = image['obj'].shape[1]
|
||||||
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
|
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
|
||||||
scale = 0.4
|
scale = 0.4
|
||||||
if waist_width / scale >= image['obj'].shape[1]:
|
if waist_width / scale >= image_width:
|
||||||
add_width = int((waist_width / scale - image_width) / 2)
|
add_width = int((waist_width / scale - image_width) / 2)
|
||||||
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
||||||
if IF_DEBUG_SHOW:
|
if IF_DEBUG_SHOW:
|
||||||
cv2.imshow("composing_image", ret)
|
cv2.imshow("composing_image", ret)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
||||||
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
else:
|
bucket_name = image['image_url'].split('/', 1)[0]
|
||||||
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
object_name = image['image_url'].split('/', 1)[1]
|
||||||
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
|
image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||||
else:
|
else:
|
||||||
scale = 0.4
|
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
||||||
image_width = image['obj'].shape[1]
|
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1]
|
bucket_name = image['image_url'].split('/', 1)[0]
|
||||||
if waist_width / scale >= image_width:
|
object_name = image['image_url'].split('/', 1)[1]
|
||||||
add_width = int((waist_width / scale - image_width) / 2)
|
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||||
if IF_DEBUG_SHOW:
|
|
||||||
cv2.imshow("composing_image", ret)
|
# if image['site'] == 'down':
|
||||||
cv2.waitKey(0)
|
# image_width = image['obj'].shape[1]
|
||||||
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
|
||||||
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
# scale = 0.4
|
||||||
else:
|
# if waist_width / scale >= image_width:
|
||||||
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
# add_width = int((waist_width / scale - image_width) / 2)
|
||||||
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
||||||
|
# if IF_DEBUG_SHOW:
|
||||||
|
# cv2.imshow("composing_image", ret)
|
||||||
|
# cv2.waitKey(0)
|
||||||
|
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
||||||
|
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
|
# bucket_name = image['image_url'].split('/', 1)[0]
|
||||||
|
# object_name = image['image_url'].split('/', 1)[1]
|
||||||
|
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
|
# image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||||
|
# else:
|
||||||
|
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
||||||
|
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
|
# bucket_name = image['image_url'].split('/', 1)[0]
|
||||||
|
# object_name = image['image_url'].split('/', 1)[1]
|
||||||
|
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
|
# image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||||
|
# else:
|
||||||
|
# image_width = image['obj'].shape[1]
|
||||||
|
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
|
||||||
|
# scale = 0.4
|
||||||
|
# if waist_width / scale >= image_width:
|
||||||
|
# add_width = int((waist_width / scale - image_width) / 2)
|
||||||
|
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
||||||
|
# if IF_DEBUG_SHOW:
|
||||||
|
# cv2.imshow("composing_image", ret)
|
||||||
|
# cv2.waitKey(0)
|
||||||
|
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
||||||
|
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
|
# bucket_name = image['image_url'].split('/', 1)[0]
|
||||||
|
# object_name = image['image_url'].split('/', 1)[1]
|
||||||
|
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
|
# image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||||
|
# else:
|
||||||
|
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
||||||
|
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
|
# bucket_name = image['image_url'].split('/', 1)[0]
|
||||||
|
# object_name = image['image_url'].split('/', 1)[1]
|
||||||
|
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
|
# image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||||
return image_list
|
return image_list
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -10,21 +10,17 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import minio
|
import minio
|
||||||
import redis
|
import redis
|
||||||
import tritonclient.grpc as grpcclient
|
import tritonclient.grpc as grpcclient
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from minio import Minio
|
|
||||||
from tritonclient.utils import np_to_triton_dtype
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateImageModel
|
from app.schemas.generate_image import GenerateImageModel
|
||||||
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
|
|
||||||
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
|
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd
|
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||||
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
@@ -36,7 +32,7 @@ class GenerateImage:
|
|||||||
self.channel = self.connection.channel()
|
self.channel = self.connection.channel()
|
||||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
# self.channel = self.connection.channel()
|
# self.channel = self.connection.channel()
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
if request_data.mode == "img2img":
|
if request_data.mode == "img2img":
|
||||||
@@ -63,10 +59,13 @@ class GenerateImage:
|
|||||||
# Read data from response.
|
# Read data from response.
|
||||||
# read image use cv2
|
# read image use cv2
|
||||||
try:
|
try:
|
||||||
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
|
# response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
|
||||||
image_file = BytesIO(response.data)
|
# image_file = BytesIO(response.data)
|
||||||
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
||||||
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||||
|
# image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url, data_type="cv2")
|
||||||
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
||||||
image = cv2.resize(image_rbg, (1024, 1024))
|
image = cv2.resize(image_rbg, (1024, 1024))
|
||||||
except minio.error.S3Error:
|
except minio.error.S3Error:
|
||||||
@@ -189,7 +188,8 @@ if __name__ == '__main__':
|
|||||||
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
|
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
|
||||||
image_url="",
|
image_url="",
|
||||||
mode='txt2img',
|
mode='txt2img',
|
||||||
category="test"
|
category="test",
|
||||||
|
gender="male"
|
||||||
)
|
)
|
||||||
server = GenerateImage(rd)
|
server = GenerateImage(rd)
|
||||||
print(server.get_result())
|
print(server.get_result())
|
||||||
@@ -18,10 +18,10 @@ import numpy as np
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
from tritonclient.utils import np_to_triton_dtype
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateImageModel
|
from app.schemas.generate_image import GenerateProductImageModel
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||||
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
@@ -33,69 +33,29 @@ class GenerateProductImage:
|
|||||||
self.channel = self.connection.channel()
|
self.channel = self.connection.channel()
|
||||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
# self.channel = self.connection.channel()
|
# self.channel = self.connection.channel()
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.category = "product_image"
|
self.category = "product_image"
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.prompt = request_data.prompt
|
self.prompt = request_data.prompt
|
||||||
# TODO aida design 结果图背景改为白色
|
self.image, self.image_size = pre_processing_image(request_data.image_url)
|
||||||
self.image, self.image_size = self.get_image(request_data.image_url)
|
|
||||||
# TODO image 填充并resize成512*768
|
|
||||||
|
|
||||||
self.tasks_id = request_data.tasks_id
|
self.tasks_id = request_data.tasks_id
|
||||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
self.redis_client.expire(self.tasks_id, 600)
|
self.redis_client.expire(self.tasks_id, 600)
|
||||||
|
|
||||||
def get_image(self, image_url):
|
|
||||||
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
|
|
||||||
image_bytes = io.BytesIO(response.read())
|
|
||||||
|
|
||||||
# 转换为PIL图像对象
|
|
||||||
image = Image.open(image_bytes)
|
|
||||||
target_height = 768
|
|
||||||
target_width = 512
|
|
||||||
|
|
||||||
aspect_ratio = image.width / image.height
|
|
||||||
new_width = int(target_height * aspect_ratio)
|
|
||||||
|
|
||||||
resized_image = image.resize((new_width, target_height))
|
|
||||||
left = (target_width - resized_image.width) // 2
|
|
||||||
top = (target_height - resized_image.height) // 2
|
|
||||||
right = target_width - resized_image.width - left
|
|
||||||
bottom = target_height - resized_image.height - top
|
|
||||||
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
|
|
||||||
image_size = image.size
|
|
||||||
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
|
|
||||||
# 创建白色背景
|
|
||||||
background = Image.new("RGB", image.size, (255, 255, 255))
|
|
||||||
# 将图片粘贴到白色背景上
|
|
||||||
background.paste(image, mask=image.split()[3])
|
|
||||||
image = np.array(background)
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# image_file = BytesIO(response.data)
|
|
||||||
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
|
||||||
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
|
||||||
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
|
||||||
# image = cv2.resize(image_rbg, (1024, 1024))
|
|
||||||
return image, image_size
|
|
||||||
|
|
||||||
def callback(self, result, error):
|
def callback(self, result, error):
|
||||||
if error:
|
if error:
|
||||||
self.gen_product_data['status'] = "FAILURE"
|
self.gen_product_data['status'] = "FAILURE"
|
||||||
self.gen_product_data['message'] = str(error)
|
self.gen_product_data['message'] = str(error)
|
||||||
# self.gen_product_data['data'] = str(error)
|
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
else:
|
else:
|
||||||
# pil图像转成numpy数组
|
# pil图像转成numpy数组
|
||||||
image = result.as_numpy("generated_inpaint_image")
|
image = result.as_numpy("generated_inpaint_image")
|
||||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
||||||
|
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||||
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
|
|
||||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
|
||||||
self.gen_product_data['status'] = "SUCCESS"
|
self.gen_product_data['status'] = "SUCCESS"
|
||||||
self.gen_product_data['message'] = "success"
|
self.gen_product_data['message'] = "success"
|
||||||
self.gen_product_data['image_url'] = str(image_url)
|
self.gen_product_data['image_url'] = str(image_url)
|
||||||
@@ -105,13 +65,6 @@ class GenerateProductImage:
|
|||||||
status_data = self.redis_client.get(self.tasks_id)
|
status_data = self.redis_client.get(self.tasks_id)
|
||||||
return json.loads(status_data), status_data
|
return json.loads(status_data), status_data
|
||||||
|
|
||||||
def infer(self, inputs):
|
|
||||||
return self.grpc_client.async_infer(
|
|
||||||
model_name=GPI_MODEL_NAME,
|
|
||||||
inputs=inputs,
|
|
||||||
callback=self.callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_result(self):
|
def get_result(self):
|
||||||
try:
|
try:
|
||||||
prompts = [self.prompt] * self.batch_size
|
prompts = [self.prompt] * self.batch_size
|
||||||
@@ -129,11 +82,10 @@ class GenerateProductImage:
|
|||||||
input_image.set_data_from_numpy(image_obj)
|
input_image.set_data_from_numpy(image_obj)
|
||||||
inputs = [input_text, input_image]
|
inputs = [input_text, input_image]
|
||||||
|
|
||||||
ctx = self.infer(inputs)
|
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||||
time_out = 600
|
time_out = 600
|
||||||
while time_out > 0:
|
while time_out > 0:
|
||||||
gen_product_data, _ = self.read_tasks_status()
|
gen_product_data, _ = self.read_tasks_status()
|
||||||
# logger.info(gen_product_data)
|
|
||||||
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
|
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
|
||||||
ctx.cancel()
|
ctx.cancel()
|
||||||
break
|
break
|
||||||
@@ -141,7 +93,6 @@ class GenerateProductImage:
|
|||||||
break
|
break
|
||||||
time_out -= 1
|
time_out -= 1
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
# logger.info(time_out, gen_product_data)
|
|
||||||
gen_product_data, _ = self.read_tasks_status()
|
gen_product_data, _ = self.read_tasks_status()
|
||||||
return gen_product_data
|
return gen_product_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -153,7 +104,6 @@ class GenerateProductImage:
|
|||||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if DEBUG is False:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||||
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
|
|
||||||
logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||||
|
|
||||||
|
|
||||||
@@ -165,11 +115,37 @@ def infer_cancel(tasks_id):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def pre_processing_image(image_url):
|
||||||
|
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||||
|
|
||||||
|
# resize 图片内尺寸 并贴到768-512的纯白图像上
|
||||||
|
target_height = 768
|
||||||
|
target_width = 512
|
||||||
|
aspect_ratio = image.width / image.height
|
||||||
|
new_width = int(target_height * aspect_ratio)
|
||||||
|
resized_image = image.resize((new_width, target_height))
|
||||||
|
left = (target_width - resized_image.width) // 2
|
||||||
|
top = (target_height - resized_image.height) // 2
|
||||||
|
right = target_width - resized_image.width - left
|
||||||
|
bottom = target_height - resized_image.height - top
|
||||||
|
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
|
||||||
|
image_size = image.size
|
||||||
|
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
|
||||||
|
# 创建白色背景
|
||||||
|
background = Image.new("RGB", image.size, (255, 255, 255))
|
||||||
|
# 将图片粘贴到白色背景上
|
||||||
|
background.paste(image, mask=image.split()[3])
|
||||||
|
image = np.array(background)
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
return image, image_size
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
rd = GenerateImageModel(
|
rd = GenerateProductImageModel(
|
||||||
tasks_id="123-89",
|
tasks_id="123-89",
|
||||||
prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
prompt="",
|
||||||
image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png",
|
# prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||||
|
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||||
)
|
)
|
||||||
server = GenerateProductImage(rd)
|
server = GenerateProductImage(rd)
|
||||||
print(server.get_result())
|
print(server.get_result())
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from tritonclient.utils import np_to_triton_dtype
|
|||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateRelightImageModel
|
from app.schemas.generate_image import GenerateRelightImageModel
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||||
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
@@ -31,71 +32,34 @@ class GenerateRelightImage:
|
|||||||
if DEBUG is False:
|
if DEBUG is False:
|
||||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
self.channel = self.connection.channel()
|
self.channel = self.connection.channel()
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.category = "relight_image"
|
self.category = "relight_image"
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.prompt = request_data.prompt
|
self.prompt = request_data.prompt
|
||||||
self.seed = "12345"
|
self.seed = "1"
|
||||||
# TODO aida design 结果图背景改为白色
|
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
||||||
# self.image, self.image_size = self.get_image(request_data.image_url)
|
self.direction = "Right Light"
|
||||||
self.image = request_data.image_url
|
self.image_url = request_data.image_url
|
||||||
# TODO image 填充并resize成512*768
|
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
||||||
|
|
||||||
self.tasks_id = request_data.tasks_id
|
self.tasks_id = request_data.tasks_id
|
||||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
self.redis_client.expire(self.tasks_id, 600)
|
self.redis_client.expire(self.tasks_id, 600)
|
||||||
|
|
||||||
def get_image(self, image_url):
|
|
||||||
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
|
|
||||||
image_bytes = io.BytesIO(response.read())
|
|
||||||
|
|
||||||
# 转换为PIL图像对象
|
|
||||||
image = Image.open(image_bytes)
|
|
||||||
target_height = 768
|
|
||||||
target_width = 512
|
|
||||||
|
|
||||||
aspect_ratio = image.width / image.height
|
|
||||||
new_width = int(target_height * aspect_ratio)
|
|
||||||
|
|
||||||
resized_image = image.resize((new_width, target_height))
|
|
||||||
left = (target_width - resized_image.width) // 2
|
|
||||||
top = (target_height - resized_image.height) // 2
|
|
||||||
right = target_width - resized_image.width - left
|
|
||||||
bottom = target_height - resized_image.height - top
|
|
||||||
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
|
|
||||||
image_size = image.size
|
|
||||||
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
|
|
||||||
# 创建白色背景
|
|
||||||
background = Image.new("RGB", image.size, (255, 255, 255))
|
|
||||||
# 将图片粘贴到白色背景上
|
|
||||||
background.paste(image, mask=image.split()[3])
|
|
||||||
image = np.array(background)
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# image_file = BytesIO(response.data)
|
|
||||||
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
|
||||||
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
|
||||||
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
|
||||||
# image = cv2.resize(image_rbg, (1024, 1024))
|
|
||||||
return image, image_size
|
|
||||||
|
|
||||||
def callback(self, result, error):
|
def callback(self, result, error):
|
||||||
if error:
|
if error:
|
||||||
self.gen_product_data['status'] = "FAILURE"
|
self.gen_product_data['status'] = "FAILURE"
|
||||||
self.gen_product_data['message'] = str(error)
|
self.gen_product_data['message'] = str(error)
|
||||||
# self.gen_product_data['data'] = str(error)
|
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
else:
|
else:
|
||||||
# pil图像转成numpy数组
|
# pil图像转成numpy数组
|
||||||
image = result.as_numpy("generated_inpaint_image")
|
image = result.as_numpy("generated_inpaint_image")
|
||||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||||
|
|
||||||
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
|
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
|
||||||
self.gen_product_data['status'] = "SUCCESS"
|
self.gen_product_data['status'] = "SUCCESS"
|
||||||
self.gen_product_data['message'] = "success"
|
self.gen_product_data['message'] = "success"
|
||||||
self.gen_product_data['image_url'] = str(image_url)
|
self.gen_product_data['image_url'] = str(image_url)
|
||||||
@@ -105,62 +69,40 @@ class GenerateRelightImage:
|
|||||||
status_data = self.redis_client.get(self.tasks_id)
|
status_data = self.redis_client.get(self.tasks_id)
|
||||||
return json.loads(status_data), status_data
|
return json.loads(status_data), status_data
|
||||||
|
|
||||||
def infer(self, inputs):
|
|
||||||
return self.grpc_client.async_infer(
|
|
||||||
model_name=GRI_MODEL_NAME,
|
|
||||||
inputs=inputs,
|
|
||||||
callback=self.callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_result(self):
|
def get_result(self):
|
||||||
try:
|
try:
|
||||||
direction = "Right Light"
|
|
||||||
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
|
||||||
self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere'
|
|
||||||
prompts = [self.prompt] * self.batch_size
|
prompts = [self.prompt] * self.batch_size
|
||||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
||||||
input_text = grpcclient.InferInput(
|
image = cv2.resize(image, (512, 768))
|
||||||
"prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)
|
images = [image.astype(np.uint8)] * self.batch_size
|
||||||
)
|
seeds = [self.seed] * self.batch_size
|
||||||
|
nagetive_prompts = [self.negative_prompt] * self.batch_size
|
||||||
|
directions = [self.direction] * self.batch_size
|
||||||
|
|
||||||
|
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||||
|
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||||
|
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
|
||||||
|
seed_obj = np.array(seeds, dtype="object").reshape((1))
|
||||||
|
direction_obj = np.array(directions, dtype="object").reshape((1))
|
||||||
|
|
||||||
|
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||||
|
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
|
||||||
|
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
|
||||||
|
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
|
||||||
|
|
||||||
input_text.set_data_from_numpy(text_obj)
|
input_text.set_data_from_numpy(text_obj)
|
||||||
|
input_image.set_data_from_numpy(image_obj)
|
||||||
|
input_natext.set_data_from_numpy(na_text_obj)
|
||||||
|
input_seed.set_data_from_numpy(seed_obj)
|
||||||
|
input_direction.set_data_from_numpy(direction_obj)
|
||||||
|
|
||||||
negative_prompts = [negative_prompt] * self.batch_size
|
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
|
||||||
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
|
|
||||||
input_text_neg = grpcclient.InferInput(
|
|
||||||
"negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)
|
|
||||||
)
|
|
||||||
input_text_neg.set_data_from_numpy(text_obj_neg)
|
|
||||||
|
|
||||||
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
|
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||||
input_seed = grpcclient.InferInput(
|
|
||||||
"seed", seed.shape, np_to_triton_dtype(seed.dtype)
|
|
||||||
)
|
|
||||||
input_seed.set_data_from_numpy(seed)
|
|
||||||
|
|
||||||
input_images = [self.image] * self.batch_size
|
|
||||||
text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1))
|
|
||||||
input_input_images = grpcclient.InferInput(
|
|
||||||
"input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype)
|
|
||||||
)
|
|
||||||
input_input_images.set_data_from_numpy(text_obj_images)
|
|
||||||
|
|
||||||
directions = [direction] * self.batch_size
|
|
||||||
text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1))
|
|
||||||
input_directions = grpcclient.InferInput(
|
|
||||||
"direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype)
|
|
||||||
)
|
|
||||||
input_directions.set_data_from_numpy(text_obj_directions)
|
|
||||||
|
|
||||||
output_img = grpcclient.InferRequestedOutput("generated_image")
|
|
||||||
request_start = time.time()
|
|
||||||
|
|
||||||
inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions]
|
|
||||||
|
|
||||||
ctx = self.infer(inputs)
|
|
||||||
time_out = 600
|
time_out = 600
|
||||||
while time_out > 0:
|
while time_out > 0:
|
||||||
gen_product_data, _ = self.read_tasks_status()
|
gen_product_data, _ = self.read_tasks_status()
|
||||||
# logger.info(gen_product_data)
|
|
||||||
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
|
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
|
||||||
ctx.cancel()
|
ctx.cancel()
|
||||||
break
|
break
|
||||||
@@ -168,7 +110,6 @@ class GenerateRelightImage:
|
|||||||
break
|
break
|
||||||
time_out -= 1
|
time_out -= 1
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
# logger.info(time_out, gen_product_data)
|
|
||||||
gen_product_data, _ = self.read_tasks_status()
|
gen_product_data, _ = self.read_tasks_status()
|
||||||
return gen_product_data
|
return gen_product_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -179,9 +120,8 @@ class GenerateRelightImage:
|
|||||||
finally:
|
finally:
|
||||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if DEBUG is False:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||||
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
|
logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||||
logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
|
||||||
|
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
@@ -195,8 +135,9 @@ def infer_cancel(tasks_id):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
rd = GenerateRelightImageModel(
|
rd = GenerateRelightImageModel(
|
||||||
tasks_id="123-89",
|
tasks_id="123-89",
|
||||||
prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||||
image_url="/workspace/i3.png",
|
prompt="Colorful black",
|
||||||
|
image_url='aida-users/89/product_image/123-89.png'
|
||||||
)
|
)
|
||||||
server = GenerateRelightImage(rd)
|
server = GenerateRelightImage(rd)
|
||||||
print(server.get_result())
|
print(server.get_result())
|
||||||
|
|||||||
@@ -31,8 +31,6 @@ class GenerateSingleLogoImage:
|
|||||||
if DEBUG is False:
|
if DEBUG is False:
|
||||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
self.channel = self.connection.channel()
|
self.channel = self.connection.channel()
|
||||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
|
||||||
# self.channel = self.connection.channel()
|
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
@@ -51,23 +49,15 @@ class GenerateSingleLogoImage:
|
|||||||
status_data = self.redis_client.get(self.tasks_id)
|
status_data = self.redis_client.get(self.tasks_id)
|
||||||
return json.loads(status_data), status_data
|
return json.loads(status_data), status_data
|
||||||
|
|
||||||
def infer(self, inputs):
|
|
||||||
return self.grpc_client.async_infer(
|
|
||||||
model_name=GSL_MODEL_NAME,
|
|
||||||
inputs=inputs,
|
|
||||||
callback=self.callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def callback(self, result, error):
|
def callback(self, result, error):
|
||||||
if error:
|
if error:
|
||||||
self.gen_single_logo_data['status'] = "FAILURE"
|
self.gen_single_logo_data['status'] = "FAILURE"
|
||||||
self.gen_single_logo_data['message'] = str(error)
|
self.gen_single_logo_data['message'] = str(error)
|
||||||
# self.generate_data['data'] = str(error)
|
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
|
||||||
else:
|
else:
|
||||||
image = result.as_numpy("generated_image")
|
image = result.as_numpy("generated_image")
|
||||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||||
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
|
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||||
self.gen_single_logo_data['status'] = "SUCCESS"
|
self.gen_single_logo_data['status'] = "SUCCESS"
|
||||||
self.gen_single_logo_data['message'] = "success"
|
self.gen_single_logo_data['message'] = "success"
|
||||||
self.gen_single_logo_data['image_url'] = str(image_url)
|
self.gen_single_logo_data['image_url'] = str(image_url)
|
||||||
@@ -81,25 +71,19 @@ class GenerateSingleLogoImage:
|
|||||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
input_text.set_data_from_numpy(text_obj)
|
input_text.set_data_from_numpy(text_obj)
|
||||||
|
|
||||||
# negative_prompts
|
|
||||||
text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1))
|
text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1))
|
||||||
# print('text obj neg: ', text_obj_neg)
|
|
||||||
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
|
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
|
||||||
input_text_neg.set_data_from_numpy(text_obj_neg)
|
input_text_neg.set_data_from_numpy(text_obj_neg)
|
||||||
|
|
||||||
# seed
|
|
||||||
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
|
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
|
||||||
input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype))
|
input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype))
|
||||||
input_seed.set_data_from_numpy(seed)
|
input_seed.set_data_from_numpy(seed)
|
||||||
|
|
||||||
inputs = [input_text, input_text_neg, input_seed]
|
inputs = [input_text, input_text_neg, input_seed]
|
||||||
|
ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||||
ctx = self.infer(inputs)
|
|
||||||
time_out = 600
|
time_out = 600
|
||||||
generate_data = None
|
generate_data = None
|
||||||
while time_out > 0:
|
while time_out > 0:
|
||||||
generate_data, _ = self.read_tasks_status()
|
generate_data, _ = self.read_tasks_status()
|
||||||
# logger.info(generate_data)
|
|
||||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||||
ctx.cancel()
|
ctx.cancel()
|
||||||
break
|
break
|
||||||
@@ -107,7 +91,6 @@ class GenerateSingleLogoImage:
|
|||||||
break
|
break
|
||||||
time_out -= 1
|
time_out -= 1
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
# logger.info(time_out, generate_data)
|
|
||||||
return generate_data
|
return generate_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(str(e))
|
raise Exception(str(e))
|
||||||
@@ -115,7 +98,6 @@ class GenerateSingleLogoImage:
|
|||||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if DEBUG is False:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||||
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
|
||||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,8 +16,11 @@ from PIL import Image
|
|||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
|
from app.service.utils.oss_client import oss_upload_image
|
||||||
|
|
||||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||||
|
|
||||||
|
|
||||||
@@ -34,36 +37,34 @@ minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET
|
|||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# print(f'上传到 S3 失败: {e}')
|
# print(f'上传到 S3 失败: {e}')
|
||||||
|
|
||||||
def upload_SDXL_image(image, user_id, category, object_name):
|
def upload_SDXL_image(image, user_id, category, file_name):
|
||||||
try:
|
try:
|
||||||
image_data = io.BytesIO()
|
image_data = io.BytesIO()
|
||||||
image.save(image_data, format='PNG')
|
image.save(image_data, format='PNG')
|
||||||
image_data.seek(0)
|
image_data.seek(0)
|
||||||
image_bytes = image_data.read()
|
image_bytes = image_data.read()
|
||||||
minio_req = minio_client.put_object(
|
|
||||||
GI_MINIO_BUCKET,
|
# minio_req = minio_client.put_object(
|
||||||
f'{user_id}/{category}/{object_name}',
|
# GI_MINIO_BUCKET,
|
||||||
io.BytesIO(image_bytes),
|
# f'{user_id}/{category}/{file_name}',
|
||||||
len(image_bytes),
|
# io.BytesIO(image_bytes),
|
||||||
content_type='image/jpeg'
|
# len(image_bytes),
|
||||||
)
|
# content_type='image/jpeg'
|
||||||
image_url = f"aida-users/{minio_req.object_name}"
|
# )
|
||||||
|
object_name = f'{user_id}/{category}/{file_name}'
|
||||||
|
req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
|
||||||
|
image_url = f"aida-users/{object_name}"
|
||||||
return image_url
|
return image_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||||
|
|
||||||
|
|
||||||
def upload_png_sd(image, user_id, category, object_name):
|
def upload_png_sd(image, user_id, category, file_name):
|
||||||
try:
|
try:
|
||||||
_, img_byte_array = cv2.imencode('.jpg', image)
|
_, img_byte_array = cv2.imencode('.jpg', image)
|
||||||
minio_req = minio_client.put_object(
|
object_name = f'{user_id}/{category}/{file_name}'
|
||||||
GI_MINIO_BUCKET,
|
req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=img_byte_array)
|
||||||
f'{user_id}/{category}/{object_name}',
|
image_url = f"aida-users/{object_name}"
|
||||||
io.BytesIO(img_byte_array),
|
|
||||||
len(img_byte_array),
|
|
||||||
content_type='image/jpeg'
|
|
||||||
)
|
|
||||||
image_url = f"aida-users/{minio_req.object_name}"
|
|
||||||
return image_url
|
return image_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||||
|
|||||||
@@ -1,17 +1,15 @@
|
|||||||
import io
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import minio.error
|
|
||||||
import redis
|
|
||||||
import json
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import minio.error
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import redis
|
||||||
import torch
|
import torch
|
||||||
import tritonclient.grpc as grpcclient
|
import tritonclient.grpc as grpcclient
|
||||||
from minio import Minio
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.super_resolution import SuperResolutionModel
|
from app.schemas.super_resolution import SuperResolutionModel
|
||||||
from app.service.utils.decorator import RunTime
|
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
@@ -24,7 +22,7 @@ class SuperResolution:
|
|||||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
self.sr_image_url = data.sr_image_url
|
self.sr_image_url = data.sr_image_url
|
||||||
self.sr_xn = data.sr_xn
|
self.sr_xn = data.sr_xn
|
||||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
|
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
|
||||||
self.redis_client.expire(self.tasks_id, 600)
|
self.redis_client.expire(self.tasks_id, 600)
|
||||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
@@ -33,16 +31,25 @@ class SuperResolution:
|
|||||||
# @RunTime
|
# @RunTime
|
||||||
def read_image(self):
|
def read_image(self):
|
||||||
try:
|
try:
|
||||||
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
|
img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2")
|
||||||
except minio.error.S3Error as e:
|
except minio.error.S3Error as e:
|
||||||
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
|
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
|
||||||
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
||||||
logger.info(f" [x] Sent {sr_data}")
|
logger.info(f" [x] Sent {sr_data}")
|
||||||
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
|
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
|
||||||
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
|
|
||||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
|
||||||
|
# except minio.error.S3Error as e:
|
||||||
|
# sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
|
||||||
|
# self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
||||||
|
# logger.info(f" [x] Sent {sr_data}")
|
||||||
|
# raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
|
||||||
|
# img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
|
||||||
|
# img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
|
||||||
|
# return img
|
||||||
|
|
||||||
def read_tasks_status(self):
|
def read_tasks_status(self):
|
||||||
status_data = json.loads(self.redis_client.get(self.tasks_id))
|
status_data = json.loads(self.redis_client.get(self.tasks_id))
|
||||||
logging.info(f"{self.tasks_id} ===> {status_data}")
|
logging.info(f"{self.tasks_id} ===> {status_data}")
|
||||||
@@ -101,8 +108,10 @@ class SuperResolution:
|
|||||||
def upload_img_sr(self, image):
|
def upload_img_sr(self, image):
|
||||||
try:
|
try:
|
||||||
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||||
res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
|
# res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
|
||||||
image_url = f"aida-users/{res.object_name}"
|
object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg'
|
||||||
|
oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
|
||||||
|
image_url = f"{SR_MINIO_BUCKET}/{object_name}"
|
||||||
return image_url
|
return image_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"upload_png_mask runtime exception : {e}")
|
logger.warning(f"upload_png_mask runtime exception : {e}")
|
||||||
|
|||||||
70
app/service/utils/oss_client.py
Normal file
70
app/service/utils/oss_client.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
from app.core.config import *
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
# 获取图片
|
||||||
|
def oss_get_image(bucket, object_name, data_type):
|
||||||
|
# cv2 默认全通道读取
|
||||||
|
image_object = None
|
||||||
|
try:
|
||||||
|
if OSS == "minio":
|
||||||
|
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||||
|
else:
|
||||||
|
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||||
|
image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body']
|
||||||
|
if data_type == "cv2":
|
||||||
|
image_bytes = image_data.read()
|
||||||
|
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
|
||||||
|
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
|
||||||
|
else:
|
||||||
|
data_bytes = BytesIO(image_data.read())
|
||||||
|
image_object = Image.open(data_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}")
|
||||||
|
return image_object
|
||||||
|
|
||||||
|
|
||||||
|
def oss_upload_image(bucket, object_name, image_bytes):
|
||||||
|
req = None
|
||||||
|
try:
|
||||||
|
if OSS == "minio":
|
||||||
|
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||||
|
else:
|
||||||
|
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||||
|
req = oss_client.put_object(Bucket=AIDA_CLOTHING, Key=object_name, Body=image_bytes, ContentType='image/png')
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}")
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png"
|
||||||
|
# url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg"
|
||||||
|
# url = "aida-sys-image/images/female/outwear/0628000054.jpg"
|
||||||
|
# url = "aida-users/89/product_image/string-89.png"
|
||||||
|
url = "aida-users/89/single_logo/123-89.png"
|
||||||
|
# url = 'aida-users/89/relight_image/123-89.png'
|
||||||
|
# url = 'aida-users/89/relight_image/123-89.png'
|
||||||
|
# url = 'aida-users/89/relight_image/123-89.png'
|
||||||
|
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||||
|
read_type = "cv2"
|
||||||
|
if read_type == "cv2":
|
||||||
|
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||||
|
cv2.imshow("", img)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
else:
|
||||||
|
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||||
|
img.show()
|
||||||
Reference in New Issue
Block a user