import io import logging import time from io import BytesIO import minio.error import pika import redis import json import cv2 import numpy as np import torch import tritonclient.http as httpclient import tritonclient.grpc as grpcclient from PIL import Image from minio import Minio from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS, RABBITMQ_QUEUES from app.schemas.super_resolution import SuperResolutionModel from app.service.utils.decorator import RunTime from app.service.utils.generate_uuid import generate_uuid logger = logging.getLogger() class SuperResolution: def __init__(self, data): self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000") self.triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001") self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.tasks_id = data.sr_tasks_id self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn self.minio_client = Minio( f"{MINIO_IP}:{MINIO_PORT}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) @RunTime def read_image(self): try: image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) publish_message(sr_data) raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 return img def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") return status_data @RunTime def infer(self, inputs): return self.triton_client.async_infer( model_name=SR_MODEL_NAME, inputs=inputs, callback=self.callback ) @RunTime def sr_result(self): sample = self.read_image() sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ grpcclient.InferInput("input", sample.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(sample # , binary_data=True ) ctx = self.infer(inputs) time_out = 120 while self.read_tasks_status()['status'] == "PENDING" and time_out > 0: if self.read_tasks_status()['status'] == "REVOKED": ctx.cancel() time_out -= 1 time.sleep(1) return self.read_tasks_status() # results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) # sr_output = torch.from_numpy(results.as_numpy(f"output")) # output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # if output.ndim == 3: # output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR # output = (output * 255.0).round().astype(np.uint8) # output_url = self.upload_img_sr(output, generate_uuid()) # return output_url def upload_img_sr(self, image, object_name): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") def callback(self, result, error): if error: print(error) sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, sr_info_data) else: sr_output = result.as_numpy("output")[0] sr_output = torch.tensor(sr_output) output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) output_url = self.upload_img_sr(output, generate_uuid()) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'}) publish_message(sr_data) self.redis_client.set(self.tasks_id, sr_data) def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) publish_message(sr_data) redis_client.set(tasks_id, sr_data) return data def publish_message(sr_data): connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) channel = connection.channel() # 发布消息,并设置回调函数 channel.basic_publish(exchange='', routing_key=RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") connection.close() if __name__ == '__main__': request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123") service = SuperResolution(request_data) result_url = service.sr_result()