#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_att_recognition.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import json import logging import time from io import BytesIO import cv2 import minio import redis import tritonclient.grpc as grpcclient import numpy as np from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.adjust_contrast import adjust_contrast from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition from app.service.generate_image.utils.upload_sd_image import upload_png_sd logger = logging.getLogger() class GenerateImage: def __init__(self, request_data): if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": self.image = self.get_image(request_data.image_url) self.prompt = request_data.prompt else: self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) self.prompt = request_data.prompt self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category self.index = 0 self.gender = request_data.gender self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): # Get data of an object. # Read data from response. try: response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) image_file = BytesIO(response.data) image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) image = cv2.resize(image_cv2, (1024, 1024)) except minio.error.S3Error: image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) return image def callback(self, result, error): if error: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(error) # self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: image = result.as_numpy("generated_image") image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_BGR2RGB) is_smudge = True if self.category == "sketch": # 去背景 remove_bg_image = remove_background(np.asarray(image_result)) # 污点检测 is_smudge, not_smudge_image = stain_detection(remove_bg_image) # 类型识别 category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) self.generate_data['category'] = str(category) image_result = not_smudge_image if is_smudge: # 无污点 # image_result = adjust_contrast(image_result) image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" self.generate_data['image_url'] = str(image_url) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: # 有污点 self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) # logger.info(f"stain_detection result : {self.generate_data}") def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data def infer(self, inputs): return self.grpc_client.async_infer( model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback ) def get_result(self): try: prompts = [self.prompt] * self.batch_size modes = [self.mode] * self.batch_size images = [self.image.astype(np.float16)] * self.batch_size text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] ctx = self.infer(inputs) time_out = 600 generate_data = None while time_out > 0: generate_data, _ = self.read_tasks_status() # logger.info(generate_data) if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break elif generate_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(0.1) # logger.info(time_out, generate_data) return generate_data except Exception as e: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) raise Exception(str(e)) finally: dict_generate_data, str_generate_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} generate_data = json.dumps(data) redis_client.set(tasks_id, generate_data) return data if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', image_url="", mode='txt2img', category="test" ) server = GenerateImage(rd) print(server.get_result())