Files
AiDA_Python/app/service/super_resolution/service.py
2024-03-21 11:12:01 +08:00

148 lines
6.0 KiB
Python

import io
import logging
import time
from io import BytesIO
import minio.error
import pika
import redis
import json
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient
from PIL import Image
from minio import Minio
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS
from app.schemas.super_resolution import SuperResolutionModel
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class SuperResolution:
def __init__(self, data):
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
self.triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.tasks_id = data.sr_tasks_id
self.sr_image_url = data.sr_image_url
self.sr_xn = data.sr_xn
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
@RunTime
def read_image(self):
try:
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
except minio.error.S3Error as e:
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
publish_message(sr_data)
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
return img
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
@RunTime
def infer(self, inputs):
return self.triton_client.async_infer(
model_name=SR_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
@RunTime
def sr_result(self):
sample = self.read_image()
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample
# , binary_data=True
)
ctx = self.infer(inputs)
time_out = 120
while self.read_tasks_status()['status'] == "PENDING" and time_out > 0:
if self.read_tasks_status()['status'] == "REVOKED":
ctx.cancel()
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
# results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
# sr_output = torch.from_numpy(results.as_numpy(f"output"))
# output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
# if output.ndim == 3:
# output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
# output = (output * 255.0).round().astype(np.uint8)
# output_url = self.upload_img_sr(output, generate_uuid())
# return output_url
def upload_img_sr(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
def callback(self, result, error):
if error:
print(error)
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, sr_info_data)
else:
sr_output = result.as_numpy("output")[0]
sr_output = torch.tensor(sr_output)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid())
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
publish_message(sr_data)
self.redis_client.set(self.tasks_id, sr_data)
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
publish_message(sr_data)
redis_client.set(tasks_id, sr_data)
return data
def publish_message(sr_data):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
# 发布消息,并设置回调函数
channel.basic_publish(exchange='', routing_key='SuperResolution-local', body=sr_data)
logger.info(f" [x] Sent {sr_data}")
connection.close()
if __name__ == '__main__':
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123")
service = SuperResolution(request_data)
result_url = service.sr_result()