#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_att_recognition.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import json import logging import time from io import BytesIO import cv2 import minio import redis import tritonclient.grpc as grpcclient import numpy as np from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.upload_sd_image import upload_png_sd logger = logging.getLogger() class GenerateImage: def __init__(self, request_data): self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() if request_data.mode == "img2img": self.image = self.get_image(request_data.image_url) self.prompt = request_data.prompt else: self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) self.prompt = request_data.prompt self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category self.index = 0 def __del__(self): self.redis_client.close() self.grpc_client.close() self.connection.close() def get_image(self, image_url): # Get data of an object. # Read data from response. try: response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) image_file = BytesIO(response.data) image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) except minio.error.S3Error: image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) return image_cv2 def __call__(self, *args, **kwargs): self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) self.redis_client.set(self.tasks_id, self.generate_data) self.redis_client.expire(self.tasks_id, 600) def callback(self, result, error): if error: generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, generate_data) else: image_result = result.as_numpy("generated_image")[0] image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{image_url}'}) self.redis_client.set(self.tasks_id, generate_data) def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) return status_data def infer(self, inputs): return self.grpc_client.async_infer( model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback ) def get_result(self): prompts = [self.prompt] * self.batch_size modes = [self.mode] * self.batch_size images = [self.image.astype(np.float16)] * self.batch_size text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] ctx = self.infer(inputs) time_out = 60 while time_out > 0: generate_data = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {json.dumps(generate_data, indent=4)}") break elif generate_data['status'] == "SUCCESS": self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {json.dumps(generate_data, indent=4)}") break time_out -= 1 time.sleep(0.1) return self.read_tasks_status() def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) redis_client.set(tasks_id, generate_data) return data if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', image_url="", mode='txt2img', category="test" ) server = GenerateImage(rd) print(server.get_result())