#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_att_recognition.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import io import json import logging import time import cv2 import redis import tritonclient.grpc as grpcclient import numpy as np from PIL import Image, ImageOps from minio import Minio from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateRelightImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() class GenerateRelightImage: def __init__(self, request_data): if DEBUG is False: self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "relight_image" self.batch_size = 1 self.prompt = request_data.prompt self.seed = "1" self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' self.direction = "Right Light" self.image_url = request_data.image_url self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.expire(self.tasks_id, 600) def callback(self, result, error): if error: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data def get_result(self): try: prompts = [self.prompt] * self.batch_size image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (512, 768)) images = [image.astype(np.uint8)] * self.batch_size seeds = [self.seed] * self.batch_size nagetive_prompts = [self.negative_prompt] * self.batch_size directions = [self.direction] * self.batch_size text_obj = np.array(prompts, dtype="object").reshape((1)) image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1)) seed_obj = np.array(seeds, dtype="object").reshape((1)) direction_obj = np.array(directions, dtype="object").reshape((1)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype)) input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype)) input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) input_natext.set_data_from_numpy(na_text_obj) input_seed.set_data_from_numpy(seed_obj) input_direction.set_data_from_numpy(direction_obj) inputs = [input_text, input_natext, input_image, input_seed, input_direction] ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: gen_product_data, _ = self.read_tasks_status() if gen_product_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break elif gen_product_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(0.1) gen_product_data, _ = self.read_tasks_status() return gen_product_data except Exception as e: self.gen_product_data['status'] = "FAILURE" self.gen_product_data['message'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) raise Exception(str(e)) finally: dict_gen_product_data, str_gen_product_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} gen_product_data = json.dumps(data) redis_client.set(tasks_id, gen_product_data) return data if __name__ == '__main__': rd = GenerateRelightImageModel( tasks_id="123-89", # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", prompt="Colorful black", image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png' ) server = GenerateRelightImage(rd) print(server.get_result())