@@ -7,38 +7,42 @@
@Date : 2023/7/26 12:01:05
@detail :
"""
import io
import json
import logging
import time
from io import BytesIO
import cv2
import minio
import redis
import tritonclient . grpc as grpcclient
import numpy as np
from PIL import Image , ImageOps
from minio import Minio
from tritonclient . utils import np_to_triton_dtype
from app . core . config import *
from app . schemas . generate_image import GenerateImageModel
from app . service . generate_image . utils . adjust_contrast import adjust_contrast
from app . service . generate_image . utils . image_processing import remove_background , stain_detection , generate_category_recognition , autoLevels , luminance_adjust , face_detect_pic
from app . service . generate_image . utils . upload_sd_image import upload_png_sd , upload_stain_png_sd
from app . service . generate_image . utils . upload_sd_image import upload_SDXL_image
logger = logging . getLogger ( )
class GenerateProductImage :
def __init__ ( self , request_data ) :
# if DEBUG is False:
# self. connection = pika. BlockingConnection(pika. ConnectionParameters(** RABBITMQ_PARAMS) )
# self. channel = self. connection. channel( )
self. connection = pika . BlockingConnection( pika . ConnectionParameters( * * RABBITMQ_PARAMS) )
self. channel = self . connection. channel( )
if DEBUG is False :
self. connection = pika . BlockingConnection( pika . ConnectionParameters( * * RABBITMQ_PARAMS) )
self. channel = self . connection. channel( )
# self. connection = pika. BlockingConnection(pika. ConnectionParameters(** RABBITMQ_PARAMS) )
# self. channel = self. connection. channel( )
self . minio_client = Minio ( MINIO_URL , access_key = MINIO_ACCESS , secret_key = MINIO_SECRET , secure = MINIO_SECURE )
self . grpc_client = grpcclient . InferenceServerClient ( url = GI_MODEL_URL )
self . grpc_client = grpcclient . InferenceServerClient ( url = GP I_MODEL_URL )
self . redis_client = redis . StrictRedis ( host = REDIS_HOST , port = REDIS_PORT , db = REDIS_DB , decode_responses = True )
self . category = " product_image "
self . batch_size = 1
self . prompt = request_data . prompt
# TODO aida design 结果图背景改为白色
self . image , self . image_size = self . get_image ( request_data . image_url )
# TODO image 填充并resize成512*768
self . tasks_id = request_data . tasks_id
self . user_id = self . tasks_id [ self . tasks_id . rfind ( ' - ' ) + 1 : ]
self . gen_product_data = { ' tasks_id ' : self . tasks_id , ' status ' : ' PENDING ' , ' message ' : " pending " , ' image_url ' : ' ' }
@@ -46,63 +50,56 @@ class GenerateProductImage:
self . redis_client . expire ( self . tasks_id , 600 )
def get_image ( self , image_url ) :
# Get data of an object.
# Read data from response.
# read image use cv2
try :
response = self . minio_client . get_object ( image_url . split ( ' / ' ) [ 0 ] , image_url [ image_url . find ( ' / ' ) + 1 : ] )
image_file = BytesIO ( response . data )
image_array = np . asarray ( bytearray ( image_file . read ( ) ) , dtype = np . uint8 )
image_cv2 = cv2 . imdecode ( image_array , cv2 . IMREAD_COLOR )
image_rbg = cv2 . cvtColor ( image_cv2 , cv2 . COLOR_BGR2RGB )
image = cv2 . resize ( image_rbg , ( 1024 , 1024 ) )
except minio . error . S3Error :
image = np . random . randint ( 0 , 256 , ( 1024 , 1024 , 3 ) , dtype = np . uint8 )
return image
response = self . minio_client . get_object ( image_url . split ( ' / ' ) [ 0 ] , image_url [ image_url . find ( ' / ' ) + 1 : ] )
image_bytes = io . BytesIO ( response . read ( ) )
# 转换为PIL图像对象
image = Image . open ( image_bytes )
target_height = 768
target_width = 512
aspect_ratio = image . width / image . height
new_width = int ( target_height * aspect_ratio )
resized_ image = image . resize ( ( new_width , target_height ) )
left = ( target_width - resized_image . width ) / / 2
top = ( target_height - resized_image . height ) / / 2
right = target_width - resized_image . width - left
bottom = target_height - resized_image . height - top
image = ImageOps . expand ( resized_image , ( left , top , right , bottom ) , fill = " white " )
image_size = image . size
if image . mode in ( ' RGBA ' , ' LA ' ) or ( image . mode == ' P ' and ' transparency ' in image . info ) :
# 创建白色背景
background = Image . new ( " RGB " , image . size , ( 255 , 255 , 255 ) )
# 将图片粘贴到白色背景上
background . paste ( image , mask = image . split ( ) [ 3 ] )
image = np . array ( background )
image = cv2 . cvtColor ( image , cv2 . COLOR_BGR2RGB )
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image , image_size
def callback ( self , result , error ) :
if error :
self . generate _data [ ' status ' ] = " FAILURE "
self . generate _data [ ' message ' ] = str ( error )
# self.generate _data['data'] = str(error)
self . redis_client . set ( self . tasks_id , json . dumps ( self . generate _data ) )
self . gen_product _data [ ' status ' ] = " FAILURE "
self . gen_product _data [ ' message ' ] = str ( error )
# self.gen_product _data['data'] = str(error)
self . redis_client . set ( self . tasks_id , json . dumps ( self . gen_product _data ) )
else :
# pil图像转成numpy数组
image = result . as_numpy ( " generated_image " )
image_result = cv2 . cvtColor ( np . squeeze ( image . astype ( np . uint8 ) ) , cv2 . COLOR_RGB2BGR )
is_smudge = True
if self . category == " sketch " :
# 色阶调整
cutoff = 1
levels_img = autoLevels ( image_result , cutoff )
# 亮度调整
luminance = luminance_adjust ( 0.3 , levels_img )
# 去背景
remove_bg_image = remove_background ( luminance )
# 人脸检测
if face_detect_pic ( remove_bg_image , self . user_id , self . category , self . tasks_id ) > 0 :
is_smudge = False
else :
# 污点/
is_smudge , not_smudge_image = stain_detection ( remove_bg_image , self . user_id , self . category , self . tasks_id )
# 类型识别
category , scores , not_smudge_image = generate_category_recognition ( image = remove_bg_image , gender = self . gender )
self . generate_data [ ' category ' ] = str ( category )
image_result = not_smudge_image
if is_smudge : # 无污点
# image_result = adjust_contrast(image_result)
image_url = upload_png_sd ( image_result , user_id = self . user_id , category = f " { self . category } " , object_name = f " { self . tasks_id } .png " )
# logger.info(f"upload image SUCCESS : {image_url}")
self . generate_data [ ' status ' ] = " SUCCESS "
self . generate_data [ ' message ' ] = " success "
self . generate_data [ ' image_url ' ] = str ( image_url )
self . redis_client . set ( self . tasks_id , json . dumps ( self . generate_data ) )
else : # 有污点 保存图片到本地 测试用
self . generate_data [ ' status ' ] = " SUCCESS "
self . generate_data [ ' message ' ] = " success "
self . generate_data [ ' image_url ' ] = str ( GI_SYS_IMAGE_URL )
self . redis_client . set ( self . tasks_id , json . dumps ( self . generate_data ) )
# logger.info(f"stain_detection result : {self.generate_data}")
image = result . as_numpy ( " generated_inpaint_ image " )
image_result = Image . fromarray ( np . squeeze ( image . astype ( np . uint8 ) ) ) . resize ( self . image_size )
image_url = upload_SDXL_image ( image_result , user_id = self . user_id , category = f " { self . category } " , object_name = f " { self . tasks_id } .png " )
# logger.info(f"upload image SUCCESS : {image_url}")
self . gen_product_data [ ' status ' ] = " SUCCESS "
self . gen_product_data [ ' message ' ] = " success "
self . gen_product_data [ ' image_url ' ] = str ( image_url )
self . redis_client . set ( self . tasks_id , json . dumps ( self . gen_product_data ) )
def read_tasks_status ( self ) :
status_data = self . redis_client . get ( self . tasks_id )
@@ -110,46 +107,43 @@ class GenerateProductImage:
def infer ( self , inputs ) :
return self . grpc_client . async_infer (
model_name = GI_MODEL_NAME ,
model_name = GP I_MODEL_NAME ,
inputs = inputs ,
callback = self . callback
)
def get_result ( self ) :
try :
# prompts = [self.prompt] * self. batch_size
# modes = [self.mode] * self.batch_size
# images = [self.image.astype(np.float16)] * self.batch_size
#
# text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
# mode_obj = np.array(modes, dtype="object").reshape((-1, 1) )
# image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3) )
#
# input_text = grpcclient. InferInput("prompt", text_obj.shape, np_to_triton_dtype( text_obj.dtype) )
# input_image = grpcclient. InferInput(" input_image", image_obj.shape, "FP16" )
# input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype))
#
# input_text. set_data_from_numpy(text _obj)
# input_image.set_data_from_numpy(image_obj)
# input_mode.set_data_from_numpy(mode_obj)
#
# inputs = [input_text, input_image, input_mode]
# ctx = self.infer(inputs)
# time_out = 600
# generate_data = None
# while time_out > 0 :
# generate_data, _ = self.read_tasks_status( )
# # logger.info(generate_data)
# if generate_data['status'] in ["REVOKED", "FAILURE"] :
# ctx.cancel()
# break
# elif generate_data['status'] == "SUCCESS":
# break
# time_out -= 1
# time.sleep(0.1)
# # logger.info(time_out, generate_data)
generate_data , _ = self . read_tasks_status ( )
return generate_data
prompts = [ self . prompt ] * self . batch_size
self . image = cv2 . cvtColor ( self . image , cv2 . COLOR_BGR2RGB )
self . image = cv2 . resize ( self . image , ( 512 , 768 ) )
images = [ self . image . astype ( np . uint8 ) ] * self . batch_size
text_obj = np . array ( prompts , dtype = " object " ) . reshape ( 1 )
image_obj = np . array ( images , dtype = np . uint8 ) . reshape ( ( 768 , 512 , 3 ) )
input_text = grpcclient. InferInput( " prompt " , text_obj . shape , np_to_triton_dtype( text_obj. dtype ) )
input_image = grpcclient. InferInput( " input_image" , image_obj . shape , " UINT8 " )
input_text . set_data_from_numpy ( text_obj )
input_image . set_data_from_numpy( image _obj )
inputs = [ input_text , input_image ]
ctx = self . infer ( inputs )
time_out = 600
while time_out > 0 :
gen_product_data , _ = self . read_tasks_status ( )
# logger.info(gen_product_data)
if gen_product_data [ ' status ' ] in [ " REVOKED " , " FAILURE " ] :
ctx . cancel ( )
break
elif gen_product_data [ ' status ' ] == " SUCCESS " :
break
time_out - = 1
time . sleep ( 0.1 )
# logger.info(time_out, gen_product_data)
gen_product_data , _ = self . read_tasks_status ( )
return gen_product_data
except Exception as e :
self . gen_product_data [ ' status ' ] = " FAILURE "
self . gen_product_data [ ' message ' ] = str ( e )
@@ -157,25 +151,25 @@ class GenerateProductImage:
raise Exception ( str ( e ) )
finally :
dict_gen_product_data , str_gen_product_data = self . read_tasks_status ( )
# if DEBUG is False:
# self. channel. basic_publish( exchange='', routing_key= GI_RABBITMQ_QUEUES, body=str_generate _data)
self. channel. basic_publish( exchange= ' ' , routing_key= GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body = str_gen_product_data )
if DEBUG is False :
self. channel. basic_publish( exchange= ' ' , routing_key= GI_RABBITMQ_QUEUES, body = str_gen_product _data )
# self. channel. basic_publish( exchange='', routing_key= GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body= str_gen_product_data)
logger . info ( f " [x] Sent { json . dumps ( dict_gen_product_data , indent = 4 ) } " )
def infer_cancel ( tasks_id ) :
redis_client = redis . StrictRedis ( host = REDIS_HOST , port = REDIS_PORT , db = REDIS_DB , decode_responses = True )
data = { ' tasks_id ' : tasks_id , ' status ' : ' REVOKED ' , ' message ' : " revoked " , ' data ' : ' revoked ' }
generate _data = json . dumps ( data )
redis_client . set ( tasks_id , generate _data )
gen_product _data = json . dumps ( data )
redis_client . set ( tasks_id , gen_product _data )
return data
if __name__ == ' __main__ ' :
rd = GenerateImageModel (
tasks_id = " 123-89 " ,
prompt = ' skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic ' ,
image_url = " " ,
prompt = " best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting " ,
image_url = " aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png ",
)
server = GenerateImage ( rd )
server = GenerateProduct Image ( rd )
print ( server . get_result ( ) )