Files
AiDA_Python/app/service/generate_image/service.py
2024-04-15 18:26:48 +08:00

230 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import numpy as np
import random
import redis
import tritonclient
import tritonclient.grpc as grpc_client
from minio import Minio
import cv2
from PIL import Image
import time
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.remove_background import remove_background
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
self.tasks_id = request_data.tasks_id
self.image_url = request_data.image_url
self.user_id = request_data.user_id
self.content = request_data.content
self.category = request_data.category
self.model_name = f"{self.category}{GI_MODEL_NAME}"
self.mode = request_data.mode
self.version = request_data.version
self.triton_client = grpc_client.InferenceServerClient(url=f"{GI_MODEL_URL}")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
self.samples = 4 # no.of images to generate
self.steps = 24
self.guidance_scale = 7
self.seed = random.randint(0, 2000000000)
self.batch_size = 1
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})
self.redis_client.set(self.tasks_id, self.generate_data)
def __del__(self):
self.redis_client.close()
self.triton_client.close()
self.connection.close()
@staticmethod
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
@staticmethod
def preprocess_image(image, category):
height, width, _ = image.shape
if category == "print" or category == "moodboard":
square_size = min(height, width)
start_x = (width - square_size) // 2
start_y = (height - square_size) // 2
cropped = image[start_y: start_y + square_size, start_x: start_x + square_size]
resized_image = cv2.resize(cropped, (512, 512))
elif category == "sketch":
# below is the way that get "bigger" square image.
max_dimension = max(height, width)
square_image = np.ones((max_dimension, max_dimension, 3), dtype=np.uint8) * 255
start_h = (max_dimension - height) // 2
start_w = (max_dimension - width) // 2
square_image[start_h:start_h + height, start_w:start_w + width] = image
resized_image = cv2.resize(square_image, (512, 512))
else:
raise ValueError(f"wrong category {category}, only in moodboard, print and sketch!")
return resized_image
def get_image(self):
# Get data of an object.
# Read data from response.
try:
response = self.minio_client.get_object(self.image_url.split('/')[0], self.image_url[self.image_url.find('/') + 1:])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = self.preprocess_image(img, self.category)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except:
img = np.random.randn(512, 512, 3)
return img
def callback(self, result, error):
if error:
generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, generate_data)
else:
images = result.as_numpy("IMAGES")
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
# for i in range(len(pil_images)):
# pil = pil_images[i]
# pil.save(f'./temp_i2_{i}.png')
# self.image_grid(pil_images, rows, cols)
url_list = []
for i, image in enumerate(pil_images):
if self.category == "sketch":
image = remove_background(np.asarray(image))
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", )
url_list.append(image_url)
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'})
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data)
logger.info(f" [x] Sent {generate_data}")
self.redis_client.set(self.tasks_id, generate_data)
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
@RunTime
def get_result(self):
self.triton_client.get_model_metadata(model_name=self.model_name, model_version=self.version)
self.triton_client.get_model_config(model_name=self.model_name, model_version=self.version)
image = self.get_image()
# Input placeholder
prompt_in = tritonclient.grpc.InferInput(name="PROMPT", shape=(self.batch_size,), datatype="BYTES")
samples_in = tritonclient.grpc.InferInput("SAMPLES", (self.batch_size,), "INT32")
steps_in = tritonclient.grpc.InferInput("STEPS", (self.batch_size,), "INT32")
guidance_scale_in = tritonclient.grpc.InferInput("GUIDANCE_SCALE", (self.batch_size,), "FP32")
seed_in = tritonclient.grpc.InferInput("SEED", (self.batch_size,), "INT64")
input_images_in = tritonclient.grpc.InferInput("INPUT_IMAGES", image.shape, "FP16")
images = tritonclient.grpc.InferRequestedOutput(name="IMAGES",
# binary_data=False
)
mode_in = tritonclient.grpc.InferInput("MODE", (self.batch_size,), "INT32")
# Setting inputs
prompt_in.set_data_from_numpy(np.asarray([self.content] * self.batch_size, dtype=object))
samples_in.set_data_from_numpy(np.asarray([self.samples], dtype=np.int32))
steps_in.set_data_from_numpy(np.asarray([self.steps], dtype=np.int32))
guidance_scale_in.set_data_from_numpy(np.asarray([self.guidance_scale], dtype=np.float32))
seed_in.set_data_from_numpy(np.asarray([self.seed], dtype=np.int64))
input_images_in.set_data_from_numpy(image.astype(np.float16))
mode_in.set_data_from_numpy(np.asarray([self.mode], dtype=np.int32))
# inference
@RunTime
def infer():
return self.triton_client.async_infer(
model_name=self.model_name,
model_version=self.version,
inputs=[prompt_in, samples_in, steps_in, guidance_scale_in, seed_in, input_images_in, mode_in],
outputs=[images],
callback=self.callback
)
ctx = infer()
time_out = 60
while time_out > 0:
generate_data = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data))
logger.info(f" [x] Sent {generate_data}")
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
# request_data = {
# "user_id": 78,
# "image_url": "123_123.png",
# "category": "print",
# "mode": 1,
# "str": "a simple print",
# "version": "1"
# }
request_data = GenerateImageModel(
mode=1,
content='a blouse',
gender='',
user_id=89,
image_url='test/微信图片_20231206133428.jpg',
category='sketch',
version='1',
tasks_id='123456'
)
server = GenerateImage(request_data)
server.get_result()
# print(infer_cancel(123456))