#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_att_recognition.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import json import logging import time from io import BytesIO import cv2 import minio import redis import tritonclient.grpc as grpcclient import numpy as np from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.adjust_contrast import adjust_contrast from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd from app.service.utils.oss_client import get_image logger = logging.getLogger() class GenerateImage: def __init__(self, request_data): if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": # cv2 读图片是BGR PIL读图片是RGB self.image = self.get_image(request_data.image_url) self.prompt = request_data.prompt else: self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) self.prompt = request_data.prompt self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category self.index = 0 self.gender = request_data.gender self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): # Get data of an object. # Read data from response. # read image use cv2 try: # response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) # image_file = BytesIO(response.data) # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image_cv2 = get_image(object_name=image_url, data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) return image def callback(self, result, error): if error: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(error) # self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_image") image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) is_smudge = True if self.category == "sketch": # 色阶调整 cutoff = 1 levels_img = autoLevels(image_result, cutoff) # 亮度调整 luminance = luminance_adjust(0.3, levels_img) # 去背景 remove_bg_image = remove_background(luminance) # 人脸检测 # if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: # is_smudge = False # else: # 污点/ is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) # 类型识别 category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) self.generate_data['category'] = str(category) image_result = not_smudge_image if is_smudge: # 无污点 # image_result = adjust_contrast(image_result) image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" self.generate_data['image_url'] = str(image_url) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: # 有污点 保存图片到本地 测试用 self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) # logger.info(f"stain_detection result : {self.generate_data}") def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data def infer(self, inputs): return self.grpc_client.async_infer( model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback ) def get_result(self): try: prompts = [self.prompt] * self.batch_size modes = [self.mode] * self.batch_size images = [self.image.astype(np.float16)] * self.batch_size text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] ctx = self.infer(inputs) time_out = 600 generate_data = None while time_out > 0: generate_data, _ = self.read_tasks_status() # logger.info(generate_data) if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break elif generate_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(0.1) # logger.info(time_out, generate_data) return generate_data except Exception as e: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) raise Exception(str(e)) finally: dict_generate_data, str_generate_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} generate_data = json.dumps(data) redis_client.set(tasks_id, generate_data) return data if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', image_url="", mode='txt2img', category="test", gender="male" ) server = GenerateImage(rd) print(server.get_result())