import json import logging import time import cv2 import minio.error import numpy as np import redis import torch import tritonclient.grpc as grpcclient from app.core.config import * from app.schemas.super_resolution import SuperResolutionModel from app.service.utils.oss_client import oss_get_image, oss_upload_image logger = logging.getLogger() class SuperResolution: def __init__(self, data): self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.tasks_id = data.sr_tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.expire(self.tasks_id, 600) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() # @RunTime def read_image(self): try: img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2") except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") return img # try: # image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) # except minio.error.S3Error as e: # sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) # self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) # logger.info(f" [x] Sent {sr_data}") # raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") # img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 # img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 # return img def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") return status_data # @RunTime def infer(self, inputs): return self.triton_client.async_infer( model_name=SR_MODEL_NAME, inputs=inputs, callback=self.callback ) # @RunTime def sr_result(self): sample = self.read_image() if self.sr_xn == 2: new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn) sample = cv2.resize(sample, new_shape) print(new_shape) sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ grpcclient.InferInput("input", sample.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(sample # , binary_data=True ) ctx = self.infer(inputs) time_out = 60 while time_out > 0: generate_data = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() # noinspection PyTypeChecker self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {generate_data}") break elif generate_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(1) return self.read_tasks_status() # results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs) # sr_output = torch.from_numpy(results.as_numpy(f"output")) # output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # if output.ndim == 3: # output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR # output = (output * 255.0).round().astype(np.uint8) # output_url = self.upload_img_sr(output, generate_uuid()) # return output_url def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() # res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg' oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes) image_url = f"{SR_MINIO_BUCKET}/{object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") def callback(self, result, error): if error: print(error) sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, sr_info_data) else: sr_output = result.as_numpy("output")[0] sr_output = torch.tensor(sr_output) output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) output_url = self.upload_img_sr(output) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") self.redis_client.set(self.tasks_id, sr_data) def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) redis_client.set(tasks_id, sr_data) return data if __name__ == '__main__': request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123") service = SuperResolution(request_data) result_url = service.sr_result()