#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_att_recognition.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import io import json import logging import time import cv2 import redis import tritonclient.grpc as grpcclient import numpy as np from PIL import Image, ImageOps from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateRelightImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image logger = logging.getLogger() class GenerateRelightImage: def __init__(self, request_data): if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "relight_image" self.batch_size = 1 self.prompt = request_data.prompt self.seed = "12345" # TODO aida design 结果图背景改为白色 # self.image, self.image_size = self.get_image(request_data.image_url) self.image = request_data.image_url # TODO image 填充并resize成512*768 self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) image_bytes = io.BytesIO(response.read()) # 转换为PIL图像对象 image = Image.open(image_bytes) target_height = 768 target_width = 512 aspect_ratio = image.width / image.height new_width = int(target_height * aspect_ratio) resized_image = image.resize((new_width, target_height)) left = (target_width - resized_image.width) // 2 top = (target_height - resized_image.height) // 2 right = target_width - resized_image.width - left bottom = target_height - resized_image.height - top image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white") image_size = image.size if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): # 创建白色背景 background = Image.new("RGB", image.size, (255, 255, 255)) # 将图片粘贴到白色背景上 background.paste(image, mask=image.split()[3]) image = np.array(background) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # image_file = BytesIO(response.data) # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) # image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) # image = cv2.resize(image_rbg, (1024, 1024)) return image, image_size def callback(self, result, error): if error: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(error) # self.gen_product_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data def infer(self, inputs): return self.grpc_client.async_infer( model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback ) def get_result(self): try: direction = "Right Light" negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere' prompts = [self.prompt] * self.batch_size text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) input_text = grpcclient.InferInput( "prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype) ) input_text.set_data_from_numpy(text_obj) negative_prompts = [negative_prompt] * self.batch_size text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1)) input_text_neg = grpcclient.InferInput( "negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype) ) input_text_neg.set_data_from_numpy(text_obj_neg) seed = np.array(self.seed, dtype="object").reshape((-1, 1)) input_seed = grpcclient.InferInput( "seed", seed.shape, np_to_triton_dtype(seed.dtype) ) input_seed.set_data_from_numpy(seed) input_images = [self.image] * self.batch_size text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1)) input_input_images = grpcclient.InferInput( "input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype) ) input_input_images.set_data_from_numpy(text_obj_images) directions = [direction] * self.batch_size text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1)) input_directions = grpcclient.InferInput( "direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype) ) input_directions.set_data_from_numpy(text_obj_directions) output_img = grpcclient.InferRequestedOutput("generated_image") request_start = time.time() inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions] ctx = self.infer(inputs) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() # logger.info(gen_product_data) if gen_product_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break elif gen_product_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(0.1) # logger.info(time_out, gen_product_data) gen_product_data, _ = self.read_tasks_status() return gen_product_data except Exception as e: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) raise Exception(str(e)) finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) # self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} gen_product_data = json.dumps(data) redis_client.set(tasks_id, gen_product_data) return data if __name__ == '__main__': rd = GenerateRelightImageModel( tasks_id="123-89", prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", image_url="/workspace/i3.png", ) server = GenerateRelightImage(rd) print(server.get_result())