238 lines
9.5 KiB
Python
238 lines
9.5 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: UTF-8 -*-
|
||
"""
|
||
@Project :trinity_client
|
||
@File :service_att_recognition.py
|
||
@Author :周成融
|
||
@Date :2023/7/26 12:01:05
|
||
@detail :
|
||
"""
|
||
import json
|
||
import logging
|
||
|
||
import minio
|
||
import numpy as np
|
||
import random
|
||
import redis
|
||
import tritonclient
|
||
import tritonclient.grpc as grpc_client
|
||
from minio import Minio
|
||
import cv2
|
||
from PIL import Image
|
||
import time
|
||
from app.core.config import *
|
||
from app.schemas.generate_image import GenerateImageModel
|
||
from app.service.generate_image.utils.remove_background import remove_background
|
||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||
from app.service.utils.decorator import RunTime
|
||
from app.service.utils.generate_uuid import generate_uuid
|
||
|
||
logger = logging.getLogger()
|
||
|
||
|
||
class GenerateImage:
|
||
def __init__(self, request_data):
|
||
self.tasks_id = request_data.tasks_id
|
||
self.model = request_data.model
|
||
self.request_count = request_data.request_count
|
||
self.prompt = request_data.prompt
|
||
self.image = request_data.image
|
||
self.mode = request_data.mode
|
||
self.batch_size = request_data.batch_size
|
||
|
||
self.image_url = request_data.image_url
|
||
self.user_id = request_data.user_id
|
||
self.content = request_data.content
|
||
self.category = request_data.category
|
||
self.model_name = f"{self.category}{GI_MODEL_NAME}"
|
||
self.mode = request_data.mode
|
||
self.version = request_data.version
|
||
self.triton_client = grpc_client.InferenceServerClient(url="1")
|
||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||
self.channel = self.connection.channel()
|
||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||
self.samples = 4 # no.of images to generate
|
||
self.steps = 24
|
||
self.guidance_scale = 7
|
||
self.seed = random.randint(0, 2000000000)
|
||
self.batch_size = 1
|
||
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})
|
||
self.redis_client.set(self.tasks_id, self.generate_data)
|
||
|
||
def get_result(self):
|
||
|
||
pass
|
||
|
||
def __del__(self):
|
||
self.redis_client.close()
|
||
self.triton_client.close()
|
||
self.connection.close()
|
||
|
||
@staticmethod
|
||
def image_grid(imgs, rows, cols):
|
||
assert len(imgs) == rows * cols
|
||
|
||
w, h = imgs[0].size
|
||
grid = Image.new('RGB', size=(cols * w, rows * h))
|
||
|
||
for i, img in enumerate(imgs):
|
||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||
return grid
|
||
|
||
@staticmethod
|
||
def preprocess_image(image, category):
|
||
height, width, _ = image.shape
|
||
|
||
if category == "print" or category == "moodboard":
|
||
square_size = min(height, width)
|
||
start_x = (width - square_size) // 2
|
||
start_y = (height - square_size) // 2
|
||
cropped = image[start_y: start_y + square_size, start_x: start_x + square_size]
|
||
resized_image = cv2.resize(cropped, (512, 512))
|
||
|
||
elif category == "sketch":
|
||
# below is the way that get "bigger" square image.
|
||
max_dimension = max(height, width)
|
||
square_image = np.ones((max_dimension, max_dimension, 3), dtype=np.uint8) * 255
|
||
start_h = (max_dimension - height) // 2
|
||
start_w = (max_dimension - width) // 2
|
||
square_image[start_h:start_h + height, start_w:start_w + width] = image
|
||
resized_image = cv2.resize(square_image, (512, 512))
|
||
|
||
else:
|
||
raise ValueError(f"wrong category {category}, only in moodboard, print and sketch!")
|
||
|
||
return resized_image
|
||
|
||
def get_image(self):
|
||
# Get data of an object.
|
||
# Read data from response.
|
||
try:
|
||
response = self.minio_client.get_object(self.image_url.split('/')[0], self.image_url[self.image_url.find('/') + 1:])
|
||
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
||
img = self.preprocess_image(img, self.category)
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
except minio.error.S3Error:
|
||
img = np.random.randn(512, 512, 3)
|
||
return img
|
||
|
||
def callback(self, result, error):
|
||
if error:
|
||
generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
|
||
self.redis_client.set(self.tasks_id, generate_data)
|
||
else:
|
||
images = result.as_numpy("IMAGES")
|
||
if images.ndim == 3:
|
||
images = images[None, ...]
|
||
images = (images * 255).round().astype("uint8")
|
||
pil_images = [Image.fromarray(image) for image in images]
|
||
|
||
# for i in range(len(pil_images)):
|
||
# pil = pil_images[i]
|
||
# pil.save(f'./temp_i2_{i}.png')
|
||
# self.image_grid(pil_images, rows, cols)
|
||
url_list = []
|
||
for i, image in enumerate(pil_images):
|
||
|
||
if self.category == "sketch":
|
||
image = remove_background(np.asarray(image))
|
||
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", )
|
||
url_list.append(image_url)
|
||
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'})
|
||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data)
|
||
logger.info(f" [x] Sent {generate_data}")
|
||
self.redis_client.set(self.tasks_id, generate_data)
|
||
|
||
def read_tasks_status(self):
|
||
status_data = json.loads(self.redis_client.get(self.tasks_id))
|
||
logging.info(f"{self.tasks_id} ===> {status_data}")
|
||
return status_data
|
||
|
||
# @RunTime
|
||
def get_result(self):
|
||
self.triton_client.get_model_metadata(model_name=self.model_name, model_version=self.version)
|
||
self.triton_client.get_model_config(model_name=self.model_name, model_version=self.version)
|
||
|
||
image = self.get_image()
|
||
|
||
# Input placeholder
|
||
prompt_in = tritonclient.grpc.InferInput(name="PROMPT", shape=(self.batch_size,), datatype="BYTES")
|
||
samples_in = tritonclient.grpc.InferInput("SAMPLES", (self.batch_size,), "INT32")
|
||
steps_in = tritonclient.grpc.InferInput("STEPS", (self.batch_size,), "INT32")
|
||
guidance_scale_in = tritonclient.grpc.InferInput("GUIDANCE_SCALE", (self.batch_size,), "FP32")
|
||
seed_in = tritonclient.grpc.InferInput("SEED", (self.batch_size,), "INT64")
|
||
input_images_in = tritonclient.grpc.InferInput("INPUT_IMAGES", image.shape, "FP16")
|
||
images = tritonclient.grpc.InferRequestedOutput(name="IMAGES",
|
||
# binary_data=False
|
||
)
|
||
mode_in = tritonclient.grpc.InferInput("MODE", (self.batch_size,), "INT32")
|
||
|
||
# Setting inputs
|
||
prompt_in.set_data_from_numpy(np.asarray([self.content] * self.batch_size, dtype=object))
|
||
samples_in.set_data_from_numpy(np.asarray([self.samples], dtype=np.int32))
|
||
steps_in.set_data_from_numpy(np.asarray([self.steps], dtype=np.int32))
|
||
guidance_scale_in.set_data_from_numpy(np.asarray([self.guidance_scale], dtype=np.float32))
|
||
seed_in.set_data_from_numpy(np.asarray([self.seed], dtype=np.int64))
|
||
input_images_in.set_data_from_numpy(image.astype(np.float16))
|
||
mode_in.set_data_from_numpy(np.asarray([self.mode], dtype=np.int32))
|
||
|
||
# inference
|
||
# @RunTime
|
||
def infer():
|
||
return self.triton_client.async_infer(
|
||
model_name=self.model_name,
|
||
model_version=self.version,
|
||
inputs=[prompt_in, samples_in, steps_in, guidance_scale_in, seed_in, input_images_in, mode_in],
|
||
outputs=[images],
|
||
callback=self.callback
|
||
)
|
||
|
||
ctx = infer()
|
||
time_out = 60
|
||
while time_out > 0:
|
||
generate_data = self.read_tasks_status()
|
||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||
ctx.cancel()
|
||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data))
|
||
logger.info(f" [x] Sent {generate_data}")
|
||
break
|
||
elif generate_data['status'] == "SUCCESS":
|
||
break
|
||
time_out -= 1
|
||
time.sleep(1)
|
||
return self.read_tasks_status()
|
||
|
||
|
||
def infer_cancel(tasks_id):
|
||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||
data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||
generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
|
||
redis_client.set(tasks_id, generate_data)
|
||
return data
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# request_data = {
|
||
# "user_id": 78,
|
||
# "image_url": "123_123.png",
|
||
# "category": "print",
|
||
# "mode": 1,
|
||
# "str": "a simple print",
|
||
# "version": "1"
|
||
# }
|
||
rd = GenerateImageModel(
|
||
mode=1,
|
||
content='a blouse',
|
||
gender='',
|
||
user_id=89,
|
||
image_url='test/微信图片_20231206133428.jpg',
|
||
category='sketch',
|
||
version='1',
|
||
tasks_id='123456'
|
||
)
|
||
server = GenerateImage(rd)
|
||
server.get_result()
|
||
# print(infer_cancel(123456))
|