超分新增发布rabbitmq消息

This commit is contained in:
zhouchengrong
2024-03-21 11:12:01 +08:00
parent 124fd5cdeb
commit fa52ea9102
7 changed files with 2709 additions and 35 deletions

View File

@@ -1,14 +1,30 @@
from fastapi import APIRouter
import json
import logging
from fastapi import APIRouter, BackgroundTasks
from app.schemas.super_resolution import SuperResolutionModel
from app.service.super_resolution.service import SuperResolution
from app.service.super_resolution.service import SuperResolution, infer_cancel
router = APIRouter()
logger = logging.getLogger()
@router.post("super_resolution")
def super_resolution(request_item: SuperResolutionModel):
service = SuperResolution()
sr_result_url = service.sr_result(request_item.sr_image_url, request_item.sr_xn)
response = {"sr_result_url": sr_result_url}
return {"code": 200, "message": "ok", "data": response}
def super_resolution(request_item: SuperResolutionModel, background_tasks: BackgroundTasks):
try:
service = SuperResolution(request_item)
background_tasks.add_task(service.sr_result)
code = 200
message = "access"
except Exception as e:
code = 000
message = e
logger.warning(e)
return {"code": code, "message": message}
@router.get("sr_cancel/{tasks_id}>")
def super_resolution(tasks_id):
result = infer_cancel(tasks_id)
return {"code": 200, "message": result['message'], "data": result['data']}

View File

@@ -1,4 +1,6 @@
import os
import pika
from dotenv import load_dotenv
from pydantic import BaseSettings
@@ -47,7 +49,25 @@ MINIO_IP = "www.minio.aida.com.hk"
MINIO_PORT = 9000
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
REDIS_DB = "2"
MINIO_SECURE = True
# input = 'preprocess_img/input_x2' # 这个值需要被函数参数覆盖
# output = '/path/to/output' # 这个值将被函数参数覆盖
LOGS_PATH = "app/logs/errors.log"
# LOGS_PATH = "app/logs/errors.log"
LOGS_PATH = "logs/errors.log"
SR_MODEL_NAME = "super_resolution"
# rabbitmq config
RABBITMQ_PARAMS = {
"host": "18.167.251.121",
"port": 5672,
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
"virtual_host": "/"
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,3 +4,4 @@ from pydantic import BaseModel
class SuperResolutionModel(BaseModel):
sr_image_url: str
sr_xn: int
sr_tasks_id: str

View File

@@ -1,14 +1,23 @@
import io
import logging
import time
from io import BytesIO
import minio.error
import pika
import redis
import json
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient
from PIL import Image
from minio import Minio
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS
from app.schemas.super_resolution import SuperResolutionModel
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
@@ -17,39 +26,75 @@ logger = logging.getLogger()
class SuperResolution:
def __init__(self):
def __init__(self, data):
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
self.triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.tasks_id = data.sr_tasks_id
self.sr_image_url = data.sr_image_url
self.sr_xn = data.sr_xn
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
@RunTime
def read_image(self, image_url):
image_data = self.minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
def read_image(self):
try:
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
except minio.error.S3Error as e:
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
publish_message(sr_data)
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
return img
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
@RunTime
def sr_result(self, image_url, sr_xn):
sample = self.read_image(image_url)
def infer(self, inputs):
return self.triton_client.async_infer(
model_name=SR_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
@RunTime
def sr_result(self):
sample = self.read_image()
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
httpclient.InferInput("input", sample.shape, datatype="FP32")
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample, binary_data=True)
results = self.triton_client.infer(model_name="super_resolution", inputs=inputs)
inputs[0].set_data_from_numpy(sample
# , binary_data=True
)
sr_output = torch.from_numpy(results.as_numpy(f"output"))
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid())
return output_url
ctx = self.infer(inputs)
time_out = 120
while self.read_tasks_status()['status'] == "PENDING" and time_out > 0:
if self.read_tasks_status()['status'] == "REVOKED":
ctx.cancel()
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
# results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
# sr_output = torch.from_numpy(results.as_numpy(f"output"))
# output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
# if output.ndim == 3:
# output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
# output = (output * 255.0).round().astype(np.uint8)
# output_url = self.upload_img_sr(output, generate_uuid())
# return output_url
def upload_img_sr(self, image, object_name):
try:
@@ -60,8 +105,43 @@ class SuperResolution:
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
def callback(self, result, error):
if error:
print(error)
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, sr_info_data)
else:
sr_output = result.as_numpy("output")[0]
sr_output = torch.tensor(sr_output)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid())
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
publish_message(sr_data)
self.redis_client.set(self.tasks_id, sr_data)
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
publish_message(sr_data)
redis_client.set(tasks_id, sr_data)
return data
def publish_message(sr_data):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
# 发布消息,并设置回调函数
channel.basic_publish(exchange='', routing_key='SuperResolution-local', body=sr_data)
logger.info(f" [x] Sent {sr_data}")
connection.close()
if __name__ == '__main__':
service = SuperResolution()
result_url = service.sr_result("test/128_image/11.png", 4)
print(result_url)
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123")
service = SuperResolution(request_data)
result_url = service.sr_result()

View File

@@ -4,25 +4,28 @@ import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient
from PIL import Image
triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001")
sample = cv2.imread("comic2.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
sample = cv2.imread("1709713346.806274.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
httpclient.InferInput("input", sample.shape, datatype="FP32")
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample, binary_data=True)
inputs[0].set_data_from_numpy(sample
# , binary_data=True
)
start_time = time.time()
results = triton_client.infer(model_name="super_resolution", inputs=inputs)
results = triton_client.async_infer(model_name="super_resolution", inputs=inputs)
print(time.time() - start_time)
sr_output = torch.from_numpy(results.as_numpy(f"output"))
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
# cv2.imshow("", output)
# cv2.waitKey(0)
cv2.imwrite("comic3.png", output)
cv2.imshow("", output)
cv2.waitKey(0)