Files
AiDA_Python/app/service/super_resolution/service.py
zhouchengrong 69132570aa 1
2024-03-20 11:44:15 +08:00

68 lines
2.5 KiB
Python

import io
import logging
from io import BytesIO
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class SuperResolution:
def __init__(self):
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
@RunTime
def read_image(self, image_url):
image_data = self.minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
return img
@RunTime
def sr_result(self, image_url, sr_xn):
sample = self.read_image(image_url)
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
httpclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample, binary_data=True)
results = self.triton_client.infer(model_name="super_resolution", inputs=inputs)
sr_output = torch.from_numpy(results.as_numpy(f"output"))
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid())
return output_url
def upload_img_sr(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
if __name__ == '__main__':
service = SuperResolution()
result_url = service.sr_result("test/128_image/11.png", 4)
print(result_url)