Files
AiDA_Python/app/service/super_resolution/service.py

148 lines
6.0 KiB
Python
Raw Normal View History

2024-03-20 11:44:15 +08:00
import io
import logging
2024-03-21 11:12:01 +08:00
import time
2024-03-20 11:44:15 +08:00
from io import BytesIO
2024-03-21 11:12:01 +08:00
import minio.error
import pika
import redis
import json
2024-03-20 11:44:15 +08:00
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
2024-03-21 11:12:01 +08:00
import tritonclient.grpc as grpcclient
from PIL import Image
2024-03-20 11:44:15 +08:00
from minio import Minio
2024-03-26 11:50:57 +08:00
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, RABBITMQ_QUEUES
2024-03-21 11:12:01 +08:00
from app.schemas.super_resolution import SuperResolutionModel
2024-03-20 11:44:15 +08:00
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class SuperResolution:
2024-03-21 11:12:01 +08:00
def __init__(self, data):
2024-03-20 11:44:15 +08:00
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
2024-03-21 11:12:01 +08:00
self.triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.tasks_id = data.sr_tasks_id
self.sr_image_url = data.sr_image_url
self.sr_xn = data.sr_xn
2024-03-20 11:44:15 +08:00
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
2024-03-21 11:12:01 +08:00
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
2024-03-20 11:44:15 +08:00
@RunTime
2024-03-21 11:12:01 +08:00
def read_image(self):
try:
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
except minio.error.S3Error as e:
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
publish_message(sr_data)
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
2024-03-20 11:44:15 +08:00
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
return img
2024-03-21 11:12:01 +08:00
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
2024-03-20 11:44:15 +08:00
@RunTime
2024-03-21 11:12:01 +08:00
def infer(self, inputs):
return self.triton_client.async_infer(
model_name=SR_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
@RunTime
def sr_result(self):
sample = self.read_image()
2024-03-20 11:44:15 +08:00
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
2024-03-21 11:12:01 +08:00
grpcclient.InferInput("input", sample.shape, datatype="FP32")
2024-03-20 11:44:15 +08:00
]
2024-03-21 11:12:01 +08:00
inputs[0].set_data_from_numpy(sample
# , binary_data=True
)
ctx = self.infer(inputs)
time_out = 120
while self.read_tasks_status()['status'] == "PENDING" and time_out > 0:
if self.read_tasks_status()['status'] == "REVOKED":
ctx.cancel()
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
2024-03-20 11:44:15 +08:00
2024-03-21 11:12:01 +08:00
# results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
# sr_output = torch.from_numpy(results.as_numpy(f"output"))
# output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
# if output.ndim == 3:
# output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
# output = (output * 255.0).round().astype(np.uint8)
# output_url = self.upload_img_sr(output, generate_uuid())
# return output_url
2024-03-20 11:44:15 +08:00
def upload_img_sr(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
2024-03-21 11:12:01 +08:00
def callback(self, result, error):
if error:
print(error)
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, sr_info_data)
else:
sr_output = result.as_numpy("output")[0]
sr_output = torch.tensor(sr_output)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid())
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
publish_message(sr_data)
self.redis_client.set(self.tasks_id, sr_data)
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
publish_message(sr_data)
redis_client.set(tasks_id, sr_data)
return data
def publish_message(sr_data):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
# 发布消息,并设置回调函数
2024-03-26 11:50:57 +08:00
channel.basic_publish(exchange='', routing_key=RABBITMQ_QUEUES, body=sr_data)
2024-03-21 11:12:01 +08:00
logger.info(f" [x] Sent {sr_data}")
connection.close()
2024-03-20 11:44:15 +08:00
if __name__ == '__main__':
2024-03-21 11:12:01 +08:00
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123")
service = SuperResolution(request_data)
result_url = service.sr_result()