#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_att_recognition.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import logging from datetime import timedelta import cv2 import numpy as np import tritonclient.grpc as grpcclient from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.service.utils.oss_client import oss_upload_image logger = logging.getLogger() class GenerateImage: def __init__(self): self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) self.batch_size = 1 self.mode = 'txt2img' self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) def get_result(self, prompt): prompts = [prompt] * self.batch_size modes = [self.mode] * self.batch_size images = [self.image.astype(np.float16)] * self.batch_size text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs) image = result.as_numpy("generated_image") image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) _, img_byte_array = cv2.imencode('.jpg', image_result) object_name = f'test.jpg' req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array) url = self.minio_client.get_presigned_url( "GET", "test", object_name, expires=timedelta(hours=2), ) return url if __name__ == '__main__': server = GenerateImage() print(server.get_result("rabbit"))