2024-04-15 18:07:25 +08:00
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project : trinity_client
2024-04-16 15:51:03 +08:00
@File : service_att_recognition . py
2024-04-15 18:07:25 +08:00
@Author : 周成融
@Date : 2023 / 7 / 26 12 : 01 : 05
@detail :
"""
import json
import logging
2024-04-16 15:51:03 +08:00
import time
2024-06-25 16:58:17 +08:00
2024-04-16 16:36:17 +08:00
import cv2
import minio
2024-06-25 16:58:17 +08:00
import numpy as np
2024-04-15 18:07:25 +08:00
import redis
2024-04-16 15:51:03 +08:00
import tritonclient . grpc as grpcclient
from tritonclient . utils import np_to_triton_dtype
2024-06-25 16:58:17 +08:00
2024-04-15 18:07:25 +08:00
from app . core . config import *
from app . schemas . generate_image import GenerateImageModel
2024-06-25 16:58:17 +08:00
from app . service . generate_image . utils . image_processing import remove_background , stain_detection , generate_category_recognition , autoLevels , luminance_adjust
2025-06-24 16:58:05 +08:00
from app . service . generate_image . utils . mq import publish_status
2024-06-21 17:13:39 +08:00
from app . service . generate_image . utils . upload_sd_image import upload_png_sd
from app . service . utils . oss_client import oss_get_image
2024-04-15 18:07:25 +08:00
logger = logging . getLogger ( )
class GenerateImage :
def __init__ ( self , request_data ) :
2024-12-01 15:35:09 +08:00
self . version = request_data . version
2024-12-01 14:24:48 +08:00
if request_data . version == " fast " :
self . grpc_client = grpcclient . InferenceServerClient ( url = FAST_GI_MODEL_URL )
else :
self . grpc_client = grpcclient . InferenceServerClient ( url = GI_MODEL_URL )
2024-04-15 18:07:25 +08:00
self . redis_client = redis . StrictRedis ( host = REDIS_HOST , port = REDIS_PORT , db = REDIS_DB , decode_responses = True )
2024-04-16 16:36:17 +08:00
if request_data . mode == " img2img " :
2024-04-25 17:36:35 +08:00
# cv2 读图片是BGR PIL读图片是RGB
2024-04-16 16:36:17 +08:00
self . image = self . get_image ( request_data . image_url )
2024-04-16 15:51:03 +08:00
else :
2024-04-16 16:36:17 +08:00
self . image = np . random . randint ( 0 , 256 , ( 1024 , 1024 , 3 ) , dtype = np . uint8 )
2024-06-28 15:47:47 +08:00
self . prompt = request_data . prompt
2024-04-16 15:51:03 +08:00
self . tasks_id = request_data . tasks_id
self . user_id = self . tasks_id [ self . tasks_id . rfind ( ' - ' ) + 1 : ]
self . mode = request_data . mode
2024-04-15 18:07:25 +08:00
self . batch_size = 1
2024-04-16 15:51:03 +08:00
self . category = request_data . category
2024-06-28 15:47:47 +08:00
if self . category == " sketch " :
self . prompt = f " { self . category } , { self . prompt } "
2024-04-16 15:51:03 +08:00
self . index = 0
2024-04-24 13:25:17 +08:00
self . gender = request_data . gender
2024-04-24 11:24:38 +08:00
self . generate_data = { ' tasks_id ' : self . tasks_id , ' status ' : ' PENDING ' , ' message ' : " pending " , ' image_url ' : ' ' , ' category ' : ' ' }
2024-04-17 17:37:51 +08:00
self . redis_client . set ( self . tasks_id , json . dumps ( self . generate_data ) )
self . redis_client . expire ( self . tasks_id , 600 )
2024-04-15 18:07:25 +08:00
2024-04-16 16:36:17 +08:00
def get_image ( self , image_url ) :
# Get data of an object.
# Read data from response.
2024-04-25 17:36:35 +08:00
# read image use cv2
2024-04-16 16:36:17 +08:00
try :
2024-06-20 16:23:02 +08:00
# response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
2024-06-25 11:49:54 +08:00
image_cv2 = oss_get_image ( bucket = image_url . split ( ' / ' ) [ 0 ] , object_name = image_url [ image_url . find ( ' / ' ) + 1 : ] , data_type = " cv2 " )
2024-04-25 17:36:35 +08:00
image_rbg = cv2 . cvtColor ( image_cv2 , cv2 . COLOR_BGR2RGB )
image = cv2 . resize ( image_rbg , ( 1024 , 1024 ) )
2024-04-16 16:36:17 +08:00
except minio . error . S3Error :
2024-04-23 14:59:47 +08:00
image = np . random . randint ( 0 , 256 , ( 1024 , 1024 , 3 ) , dtype = np . uint8 )
return image
2024-04-16 16:36:17 +08:00
2024-04-15 18:07:25 +08:00
def callback ( self , result , error ) :
if error :
2024-04-17 17:37:51 +08:00
self . generate_data [ ' status ' ] = " FAILURE "
self . generate_data [ ' message ' ] = str ( error )
2024-04-24 11:20:14 +08:00
# self.generate_data['data'] = str(error)
2024-04-17 17:37:51 +08:00
self . redis_client . set ( self . tasks_id , json . dumps ( self . generate_data ) )
2024-04-15 18:07:25 +08:00
else :
2024-04-25 17:36:35 +08:00
# pil图像转成numpy数组
2024-04-25 14:11:09 +08:00
image = result . as_numpy ( " generated_image " )
2024-04-25 17:36:35 +08:00
image_result = cv2 . cvtColor ( np . squeeze ( image . astype ( np . uint8 ) ) , cv2 . COLOR_RGB2BGR )
2024-04-23 15:53:24 +08:00
is_smudge = True
2024-12-01 19:52:33 +08:00
if self . category == " sketch " :
2024-12-01 19:49:27 +08:00
if self . version == " fast " :
# 色阶调整
cutoff = 1
levels_img = autoLevels ( image_result , cutoff )
# 亮度调整
luminance = luminance_adjust ( 0.3 , levels_img )
# 去背景
remove_bg_image = remove_background ( luminance )
# 人脸检测
# if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0:
# is_smudge = False
# else:
# 污点/
is_smudge , not_smudge_image = stain_detection ( remove_bg_image , self . user_id , self . category , self . tasks_id )
# 类型识别
category , scores , not_smudge_image = generate_category_recognition ( image = remove_bg_image , gender = self . gender )
self . generate_data [ ' category ' ] = str ( category )
image_result = not_smudge_image
else :
category , scores , not_smudge_image = generate_category_recognition ( image = image_result , gender = self . gender )
self . generate_data [ ' category ' ] = str ( category )
image_result = not_smudge_image
2024-04-23 15:53:24 +08:00
if is_smudge : # 无污点
2024-04-24 13:25:17 +08:00
# image_result = adjust_contrast(image_result)
2024-06-23 16:30:18 +08:00
image_url = upload_png_sd ( image_result , user_id = self . user_id , category = f " { self . category } " , file_name = f " { self . tasks_id } .png " )
2024-04-23 15:54:42 +08:00
# logger.info(f"upload image SUCCESS : {image_url}")
2024-04-23 15:53:24 +08:00
self . generate_data [ ' status ' ] = " SUCCESS "
self . generate_data [ ' message ' ] = " success "
2024-04-24 11:24:38 +08:00
self . generate_data [ ' image_url ' ] = str ( image_url )
2024-04-23 15:53:24 +08:00
self . redis_client . set ( self . tasks_id , json . dumps ( self . generate_data ) )
2024-05-13 10:44:20 +08:00
else : # 有污点 保存图片到本地 测试用
2024-04-23 15:53:24 +08:00
self . generate_data [ ' status ' ] = " SUCCESS "
self . generate_data [ ' message ' ] = " success "
2024-04-24 11:24:38 +08:00
self . generate_data [ ' image_url ' ] = str ( GI_SYS_IMAGE_URL )
2024-04-23 15:53:24 +08:00
self . redis_client . set ( self . tasks_id , json . dumps ( self . generate_data ) )
2024-04-23 15:54:42 +08:00
# logger.info(f"stain_detection result : {self.generate_data}")
2024-04-15 18:07:25 +08:00
def read_tasks_status ( self ) :
2024-04-17 17:37:51 +08:00
status_data = self . redis_client . get ( self . tasks_id )
return json . loads ( status_data ) , status_data
2024-04-15 18:07:25 +08:00
def get_result ( self ) :
2024-04-17 17:37:51 +08:00
try :
prompts = [ self . prompt ] * self . batch_size
modes = [ self . mode ] * self . batch_size
images = [ self . image . astype ( np . float16 ) ] * self . batch_size
text_obj = np . array ( prompts , dtype = " object " ) . reshape ( ( - 1 , 1 ) )
mode_obj = np . array ( modes , dtype = " object " ) . reshape ( ( - 1 , 1 ) )
image_obj = np . array ( images , dtype = np . float16 ) . reshape ( ( - 1 , 1024 , 1024 , 3 ) )
input_text = grpcclient . InferInput ( " prompt " , text_obj . shape , np_to_triton_dtype ( text_obj . dtype ) )
2024-12-01 15:30:32 +08:00
input_image = grpcclient . InferInput ( " input_image " , image_obj . shape , np_to_triton_dtype ( image_obj . dtype ) )
input_mode = grpcclient . InferInput ( " mode " , mode_obj . shape , np_to_triton_dtype ( mode_obj . dtype ) )
2024-04-17 17:37:51 +08:00
input_text . set_data_from_numpy ( text_obj )
input_image . set_data_from_numpy ( image_obj )
input_mode . set_data_from_numpy ( mode_obj )
inputs = [ input_text , input_image , input_mode ]
2024-12-01 15:35:09 +08:00
if self . version == " fast " :
2025-04-21 10:04:40 +08:00
ctx = self . grpc_client . async_infer ( model_name = FAST_GI_MODEL_NAME , inputs = inputs , callback = self . callback , priority = 1 )
2024-12-01 15:36:24 +08:00
else :
2025-04-21 10:04:40 +08:00
ctx = self . grpc_client . async_infer ( model_name = GI_MODEL_NAME , inputs = inputs , callback = self . callback , priority = 1 )
2024-12-01 15:36:24 +08:00
2024-04-23 15:30:23 +08:00
time_out = 600
2024-04-17 17:37:51 +08:00
generate_data = None
while time_out > 0 :
generate_data , _ = self . read_tasks_status ( )
if generate_data [ ' status ' ] in [ " REVOKED " , " FAILURE " ] :
ctx . cancel ( )
break
elif generate_data [ ' status ' ] == " SUCCESS " :
break
time_out - = 1
time . sleep ( 0.1 )
return generate_data
except Exception as e :
self . generate_data [ ' status ' ] = " FAILURE "
2024-04-24 11:20:14 +08:00
self . generate_data [ ' message ' ] = str ( e )
2024-04-17 17:37:51 +08:00
self . redis_client . set ( self . tasks_id , json . dumps ( self . generate_data ) )
raise Exception ( str ( e ) )
finally :
dict_generate_data , str_generate_data = self . read_tasks_status ( )
2025-06-24 16:58:05 +08:00
if not DEBUG :
publish_status ( str_generate_data , GI_RABBITMQ_QUEUES )
2024-04-15 18:07:25 +08:00
def infer_cancel ( tasks_id ) :
redis_client = redis . StrictRedis ( host = REDIS_HOST , port = REDIS_PORT , db = REDIS_DB , decode_responses = True )
2024-04-17 17:37:51 +08:00
data = { ' tasks_id ' : tasks_id , ' status ' : ' REVOKED ' , ' message ' : " revoked " , ' data ' : ' revoked ' }
generate_data = json . dumps ( data )
2024-04-15 18:07:25 +08:00
redis_client . set ( tasks_id , generate_data )
return data
if __name__ == ' __main__ ' :
2024-04-15 18:33:20 +08:00
rd = GenerateImageModel (
2024-04-16 15:51:03 +08:00
tasks_id = " 123-89 " ,
2025-06-30 11:29:19 +08:00
prompt = " Women ' s clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view " ,
2024-06-25 16:58:17 +08:00
image_url = " aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg " ,
2024-04-16 15:51:03 +08:00
mode = ' txt2img ' ,
2024-06-20 16:23:02 +08:00
category = " test " ,
2024-12-01 14:24:48 +08:00
gender = " male " ,
2024-12-01 15:30:32 +08:00
version = " high "
2024-04-15 18:07:25 +08:00
)
2024-04-15 18:33:20 +08:00
server = GenerateImage ( rd )
2024-06-25 16:58:17 +08:00
print ( server . get_result ( ) )