import io import logging from io import BytesIO import cv2 import numpy as np import torch import tritonclient.http as httpclient from minio import Minio from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT from app.service.utils.decorator import RunTime from app.service.utils.generate_uuid import generate_uuid logger = logging.getLogger() class SuperResolution: def __init__(self): self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000") self.minio_client = Minio( f"{MINIO_IP}:{MINIO_PORT}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) @RunTime def read_image(self, image_url): image_data = self.minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1]) img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 return img @RunTime def sr_result(self, image_url, sr_xn): sample = self.read_image(image_url) sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ httpclient.InferInput("input", sample.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(sample, binary_data=True) results = self.triton_client.infer(model_name="super_resolution", inputs=inputs) sr_output = torch.from_numpy(results.as_numpy(f"output")) output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) output_url = self.upload_img_sr(output, generate_uuid()) return output_url def upload_img_sr(self, image, object_name): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") if __name__ == '__main__': service = SuperResolution() result_url = service.sr_result("test/128_image/11.png", 4) print(result_url)