Files
AiDA_Python/app/service/super_resolution/service.py
zcr 18024a2d70
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
2025-12-30 16:49:08 +08:00

141 lines
6.8 KiB
Python

import json
import logging
import time
import cv2
import minio.error
import numpy as np
import pika
import redis
import torch
import tritonclient.grpc as grpcclient
from minio import Minio
from app.core.config import settings, SR_TRITON_URL, SR_RABBITMQ_QUEUES, SR_MODEL_NAME
from app.core.rabbit_mq_config import RABBITMQ_PARAMS
from app.schemas.super_resolution import SuperResolutionModel
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class SuperResolution:
def __init__(self, data):
self.triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
self.redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True)
self.tasks_id = data.sr_tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.sr_image_url = data.sr_image_url
self.sr_xn = data.sr_xn
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
self.redis_client.expire(self.tasks_id, 600)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# @RunTime
def read_image(self):
try:
img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2")
img = img.astype(np.float32) / 255. # 解码
except minio.error.S3Error as e:
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
return img
# try:
# image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
# except minio.error.S3Error as e:
# sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
# self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
# logger.info(f" [x] Sent {sr_data}")
# raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
# img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
# return img
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
# @RunTime
def sr_result(self):
sample = self.read_image()
if self.sr_xn == 2:
new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn)
sample = cv2.resize(sample, new_shape)
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample
# , binary_data=True
)
ctx = self.triton_client.async_infer(
model_name=SR_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
time_out = 60
while time_out > 0:
generate_data = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data))
logger.info(f" [x] Sent {generate_data}")
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
def upload_img_sr(self, image):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
# res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg'
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
def callback(self, result, error):
if error:
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, sr_info_data)
else:
sr_output = result.as_numpy("output")[0]
sr_output = torch.tensor(sr_output)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output)
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
self.redis_client.set(self.tasks_id, sr_data)
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True)
data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
redis_client.set(tasks_id, sr_data)
return data
if __name__ == '__main__':
request_data = SuperResolutionModel(sr_image_url="aida-users/83/print/b77bf4ca-6ca2-44a1-9040-505f359a974c-3-83.png", sr_xn=2, sr_tasks_id="12341556")
service = SuperResolution(request_data)
result_url = service.sr_result()