Files

136 lines
6.4 KiB
Python

import json
import logging
import time
import cv2
import minio.error
import numpy as np
import redis
import torch
import tritonclient.grpc as grpcclient
from app.core.config import *
from app.schemas.super_resolution import SuperResolutionModel
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
class SuperResolution:
def __init__(self, data):
self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.tasks_id = data.sr_tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.sr_image_url = data.sr_image_url
self.sr_xn = data.sr_xn
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
self.redis_client.expire(self.tasks_id, 600)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# @RunTime
def read_image(self):
try:
img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2")
img = img.astype(np.float32) / 255. # 解码
except minio.error.S3Error as e:
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
return img
# try:
# image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
# except minio.error.S3Error as e:
# sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
# self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
# logger.info(f" [x] Sent {sr_data}")
# raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
# img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
# return img
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
# @RunTime
def sr_result(self):
sample = self.read_image()
if self.sr_xn == 2:
new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn)
sample = cv2.resize(sample, new_shape)
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample
# , binary_data=True
)
ctx = self.triton_client.async_infer(
model_name=SR_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
time_out = 60
while time_out > 0:
generate_data = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data))
logger.info(f" [x] Sent {generate_data}")
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
def upload_img_sr(self, image):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
# res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg'
oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"{SR_MINIO_BUCKET}/{object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
def callback(self, result, error):
if error:
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, sr_info_data)
else:
sr_output = result.as_numpy("output")[0]
sr_output = torch.tensor(sr_output)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output)
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
self.redis_client.set(self.tasks_id, sr_data)
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
redis_client.set(tasks_id, sr_data)
return data
if __name__ == '__main__':
request_data = SuperResolutionModel(sr_image_url="aida-users/83/print/b77bf4ca-6ca2-44a1-9040-505f359a974c-3-83.png", sr_xn=2, sr_tasks_id="12341556")
service = SuperResolution(request_data)
result_url = service.sr_result()