import json import logging import time import cv2 import minio.error import numpy as np import pika import redis import torch import tritonclient.grpc as grpcclient from minio import Minio from app.core.config import settings, SR_TRITON_URL, SR_RABBITMQ_QUEUES, SR_MODEL_NAME from app.core.rabbit_mq_config import RABBITMQ_PARAMS from app.schemas.super_resolution import SuperResolutionModel from app.service.utils.new_oss_client import oss_get_image, oss_upload_image logger = logging.getLogger() minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) class SuperResolution: def __init__(self, data): self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL) self.redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True) self.tasks_id = data.sr_tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.sr_image_url = data.sr_image_url self.sr_xn = data.sr_xn # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})) self.redis_client.expire(self.tasks_id, 600) self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) self.channel = self.connection.channel() # @RunTime def read_image(self): try: img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2") img = img.astype(np.float32) / 255. # 解码 except minio.error.S3Error as e: sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") return img # try: # image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1]) # except minio.error.S3Error as e: # sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'}) # self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) # logger.info(f" [x] Sent {sr_data}") # raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'") # img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型 # img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码 # return img def read_tasks_status(self): status_data = json.loads(self.redis_client.get(self.tasks_id)) logging.info(f"{self.tasks_id} ===> {status_data}") return status_data # @RunTime def sr_result(self): sample = self.read_image() if self.sr_xn == 2: new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn) sample = cv2.resize(sample, new_shape) sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1)) sample = torch.from_numpy(sample).float().unsqueeze(0).numpy() inputs = [ grpcclient.InferInput("input", sample.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(sample # , binary_data=True ) ctx = self.triton_client.async_infer( model_name=SR_MODEL_NAME, inputs=inputs, callback=self.callback ) time_out = 60 while time_out > 0: generate_data = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data)) logger.info(f" [x] Sent {generate_data}") break elif generate_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(1) return self.read_tasks_status() def upload_img_sr(self, image): try: image_bytes = cv2.imencode('.jpg', image)[1].tobytes() # res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png') object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg' oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes) image_url = f"aida-users/{object_name}" return image_url except Exception as e: logger.warning(f"upload_png_mask runtime exception : {e}") def callback(self, result, error): if error: sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.redis_client.set(self.tasks_id, sr_info_data) else: sr_output = result.as_numpy("output")[0] sr_output = torch.tensor(sr_output) output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) output_url = self.upload_img_sr(output) sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'}) self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data) logger.info(f" [x] Sent {sr_data}") self.redis_client.set(self.tasks_id, sr_data) def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True) data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) redis_client.set(tasks_id, sr_data) return data if __name__ == '__main__': request_data = SuperResolutionModel(sr_image_url="aida-users/83/print/b77bf4ca-6ca2-44a1-9040-505f359a974c-3-83.png", sr_xn=2, sr_tasks_id="12341556") service = SuperResolution(request_data) result_url = service.sr_result()