import io import logging import time import minio.error import redis import json import cv2 import numpy as np import torch import tritonclient.grpc as grpcclient from minio import Minio from app.core.config import * from app.schemas.super_resolution import SuperResolutionModel from app.service.utils.decorator import RunTime logger = logging.getLogger() class SuperResolution: def __init__(self, data): self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.tasks_id = data.sr_tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.expire(self.tasks_id, 600) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() # @RunTime def read_image(self): try: image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 return img def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") return status_data # @RunTime def infer(self, inputs): return self.triton_client.async_infer( model_name=SR_MODEL_NAME, inputs=inputs, callback=self.callback ) # @RunTime def sr_result(self): sample = self.read_image() if self.sr_xn == 2: new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn) sample = cv2.resize(sample, new_shape) print(new_shape) sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ grpcclient.InferInput("input", sample.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(sample # , binary_data=True ) ctx = self.infer(inputs) time_out = 60 while time_out > 0: generate_data = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() # noinspection PyTypeChecker self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {generate_data}") break elif generate_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(1) return self.read_tasks_status() # results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) # sr_output = torch.from_numpy(results.as_numpy(f"output")) # output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # if output.ndim == 3: # output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR # output = (output * 255.0).round().astype(np.uint8) # output_url = self.upload_img_sr(output, generate_uuid()) # return output_url def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') image_url = f"aida-users/{res.object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") def callback(self, result, error): if error: print(error) sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, sr_info_data) else: sr_output = result.as_numpy("output")[0] sr_output = torch.tensor(sr_output) output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) output_url = self.upload_img_sr(output) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") self.redis_client.set(self.tasks_id, sr_data) def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) redis_client.set(tasks_id, sr_data) return data if __name__ == '__main__': request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123") service = SuperResolution(request_data) result_url = service.sr_result()