#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_att_recognition.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import json import logging import time import redis import tritonclient.grpc as grpcclient import numpy as np from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.utils.generate_uuid import generate_uuid logger = logging.getLogger() class GenerateImage: def __init__(self, request_data): self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() if request_data.mode == "txt2img": self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) else: self.image = request_data.image self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.prompt = request_data.prompt self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category self.index = 0 def __del__(self): self.redis_client.close() self.grpc_client.close() self.connection.close() def __call__(self, *args, **kwargs): self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) self.redis_client.set(self.tasks_id, self.generate_data) def callback(self, result, error): if error: generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, generate_data) else: image_result = result.as_numpy("generated_image")[0] image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{image_url}'}) self.redis_client.set(self.tasks_id, generate_data) def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) return status_data def infer(self, inputs): return self.grpc_client.async_infer( model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback ) def get_result(self): prompts = [self.prompt] * self.batch_size modes = [self.mode] * self.batch_size images = [self.image.astype(np.float16)] * self.batch_size text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] ctx = self.infer(inputs) time_out = 60 while time_out > 0: generate_data = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {json.dumps(generate_data, indent=4)}") break elif generate_data['status'] == "SUCCESS": self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {json.dumps(generate_data, indent=4)}") break time_out -= 1 time.sleep(0.1) return self.read_tasks_status() def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) redis_client.set(tasks_id, generate_data) return data if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', image_url="", mode='txt2img', category="test" ) server = GenerateImage(rd) print(server.get_result())