Files
AiDA_Python/app/service/generate_image/service.py

176 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
from io import BytesIO
import cv2
import minio
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.image_processing import remove_background, stain_detection
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
if request_data.mode == "img2img":
self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else:
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.mode = request_data.mode
self.batch_size = 1
self.category = request_data.category
self.index = 0
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
# Get data of an object.
# Read data from response.
try:
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = cv2.resize(image_cv2, (1024, 1024))
except minio.error.S3Error:
image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image
def callback(self, result, error):
if error:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
image_result = result.as_numpy("generated_image")[0]
is_smudge = True
if self.category == "sketch":
# 去背景
remove_bg_image = remove_background(np.asarray(image_result))
# 污点检测
is_smudge, not_smudge_image = stain_detection(remove_bg_image)
image_result = not_smudge_image
if is_smudge: # 无污点
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else: # 有污点
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = str(GI_SYS_IMAGE_URL)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# logger.info(f"stain_detection result : {self.generate_data}")
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16")
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
ctx = self.infer(inputs)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = "failure"
self.generate_data['data'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
image_url="",
mode='txt2img',
category="test"
)
server = GenerateImage(rd)
print(server.get_result())