#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import json import logging import minio import numpy as np import random import redis import tritonclient import tritonclient.grpc as grpc_client from minio import Minio import cv2 from PIL import Image import time from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.remove_background import remove_background from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.utils.decorator import RunTime from app.service.utils.generate_uuid import generate_uuid logger = logging.getLogger() class GenerateImage: def __init__(self, request_data): self.tasks_id = request_data.tasks_id self.image_url = request_data.image_url self.user_id = request_data.user_id self.content = request_data.content self.category = request_data.category self.model_name = f"{self.category}{GI_MODEL_NAME}" self.mode = request_data.mode self.version = request_data.version self.triton_client = grpc_client.InferenceServerClient(url=f"{GI_MODEL_URL}") self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() self.minio_client = Minio( f"{MINIO_IP}:{MINIO_PORT}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.samples = 4 # no.of images to generate self.steps = 24 self.guidance_scale = 7 self.seed = random.randint(0, 2000000000) self.batch_size = 1 self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) self.redis_client.set(self.tasks_id, self.generate_data) def __del__(self): self.redis_client.close() self.triton_client.close() self.connection.close() @staticmethod def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid @staticmethod def preprocess_image(image, category): height, width, _ = image.shape if category == "print" or category == "moodboard": square_size = min(height, width) start_x = (width - square_size) // 2 start_y = (height - square_size) // 2 cropped = image[start_y: start_y + square_size, start_x: start_x + square_size] resized_image = cv2.resize(cropped, (512, 512)) elif category == "sketch": # below is the way that get "bigger" square image. max_dimension = max(height, width) square_image = np.ones((max_dimension, max_dimension, 3), dtype=np.uint8) * 255 start_h = (max_dimension - height) // 2 start_w = (max_dimension - width) // 2 square_image[start_h:start_h + height, start_w:start_w + width] = image resized_image = cv2.resize(square_image, (512, 512)) else: raise ValueError(f"wrong category {category}, only in moodboard, print and sketch!") return resized_image def get_image(self): # Get data of an object. # Read data from response. try: response = self.minio_client.get_object(self.image_url.split('/')[0], self.image_url[self.image_url.find('/') + 1:]) img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 img = self.preprocess_image(img, self.category) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) except minio.error.S3Error: img = np.random.randn(512, 512, 3) return img def callback(self, result, error): if error: generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, generate_data) else: images = result.as_numpy("IMAGES") if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] # for i in range(len(pil_images)): # pil = pil_images[i] # pil.save(f'./temp_i2_{i}.png') # self.image_grid(pil_images, rows, cols) url_list = [] for i, image in enumerate(pil_images): if self.category == "sketch": image = remove_background(np.asarray(image)) image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", ) url_list.append(image_url) generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'}) self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data) logger.info(f" [x] Sent {generate_data}") self.redis_client.set(self.tasks_id, generate_data) def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") return status_data @RunTime def get_result(self): self.triton_client.get_model_metadata(model_name=self.model_name, model_version=self.version) self.triton_client.get_model_config(model_name=self.model_name, model_version=self.version) image = self.get_image() # Input placeholder prompt_in = tritonclient.grpc.InferInput(name="PROMPT", shape=(self.batch_size,), datatype="BYTES") samples_in = tritonclient.grpc.InferInput("SAMPLES", (self.batch_size,), "INT32") steps_in = tritonclient.grpc.InferInput("STEPS", (self.batch_size,), "INT32") guidance_scale_in = tritonclient.grpc.InferInput("GUIDANCE_SCALE", (self.batch_size,), "FP32") seed_in = tritonclient.grpc.InferInput("SEED", (self.batch_size,), "INT64") input_images_in = tritonclient.grpc.InferInput("INPUT_IMAGES", image.shape, "FP16") images = tritonclient.grpc.InferRequestedOutput(name="IMAGES", # binary_data=False ) mode_in = tritonclient.grpc.InferInput("MODE", (self.batch_size,), "INT32") # Setting inputs prompt_in.set_data_from_numpy(np.asarray([self.content] * self.batch_size, dtype=object)) samples_in.set_data_from_numpy(np.asarray([self.samples], dtype=np.int32)) steps_in.set_data_from_numpy(np.asarray([self.steps], dtype=np.int32)) guidance_scale_in.set_data_from_numpy(np.asarray([self.guidance_scale], dtype=np.float32)) seed_in.set_data_from_numpy(np.asarray([self.seed], dtype=np.int64)) input_images_in.set_data_from_numpy(image.astype(np.float16)) mode_in.set_data_from_numpy(np.asarray([self.mode], dtype=np.int32)) # inference @RunTime def infer(): return self.triton_client.async_infer( model_name=self.model_name, model_version=self.version, inputs=[prompt_in, samples_in, steps_in, guidance_scale_in, seed_in, input_images_in, mode_in], outputs=[images], callback=self.callback ) ctx = infer() time_out = 60 while time_out > 0: generate_data = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {generate_data}") break elif generate_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(1) return self.read_tasks_status() def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) redis_client.set(tasks_id, generate_data) return data if __name__ == '__main__': # request_data = { # "user_id": 78, # "image_url": "123_123.png", # "category": "print", # "mode": 1, # "str": "a simple print", # "version": "1" # } rd = GenerateImageModel( mode=1, content='a blouse', gender='', user_id=89, image_url='test/微信图片_20231206133428.jpg', category='sketch', version='1', tasks_id='123456' ) server = GenerateImage(rd) server.get_result() # print(infer_cancel(123456))