超分新增发布rabbitmq消息
This commit is contained in:
@@ -1,14 +1,30 @@
|
|||||||
from fastapi import APIRouter
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks
|
||||||
|
|
||||||
from app.schemas.super_resolution import SuperResolutionModel
|
from app.schemas.super_resolution import SuperResolutionModel
|
||||||
from app.service.super_resolution.service import SuperResolution
|
from app.service.super_resolution.service import SuperResolution, infer_cancel
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
@router.post("super_resolution")
|
@router.post("super_resolution")
|
||||||
def super_resolution(request_item: SuperResolutionModel):
|
def super_resolution(request_item: SuperResolutionModel, background_tasks: BackgroundTasks):
|
||||||
service = SuperResolution()
|
try:
|
||||||
sr_result_url = service.sr_result(request_item.sr_image_url, request_item.sr_xn)
|
service = SuperResolution(request_item)
|
||||||
response = {"sr_result_url": sr_result_url}
|
background_tasks.add_task(service.sr_result)
|
||||||
return {"code": 200, "message": "ok", "data": response}
|
code = 200
|
||||||
|
message = "access"
|
||||||
|
except Exception as e:
|
||||||
|
code = 000
|
||||||
|
message = e
|
||||||
|
logger.warning(e)
|
||||||
|
return {"code": code, "message": message}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("sr_cancel/{tasks_id}>")
|
||||||
|
def super_resolution(tasks_id):
|
||||||
|
result = infer_cancel(tasks_id)
|
||||||
|
return {"code": 200, "message": result['message'], "data": result['data']}
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pika
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
@@ -47,7 +49,25 @@ MINIO_IP = "www.minio.aida.com.hk"
|
|||||||
MINIO_PORT = 9000
|
MINIO_PORT = 9000
|
||||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||||
|
|
||||||
|
# redis 配置
|
||||||
|
REDIS_HOST = "10.1.1.240"
|
||||||
|
REDIS_PORT = "6379"
|
||||||
|
REDIS_DB = "2"
|
||||||
|
|
||||||
MINIO_SECURE = True
|
MINIO_SECURE = True
|
||||||
# input = 'preprocess_img/input_x2' # 这个值需要被函数参数覆盖
|
# input = 'preprocess_img/input_x2' # 这个值需要被函数参数覆盖
|
||||||
# output = '/path/to/output' # 这个值将被函数参数覆盖
|
# output = '/path/to/output' # 这个值将被函数参数覆盖
|
||||||
LOGS_PATH = "app/logs/errors.log"
|
# LOGS_PATH = "app/logs/errors.log"
|
||||||
|
LOGS_PATH = "logs/errors.log"
|
||||||
|
|
||||||
|
SR_MODEL_NAME = "super_resolution"
|
||||||
|
|
||||||
|
# rabbitmq config
|
||||||
|
|
||||||
|
RABBITMQ_PARAMS = {
|
||||||
|
"host": "18.167.251.121",
|
||||||
|
"port": 5672,
|
||||||
|
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
|
||||||
|
"virtual_host": "/"
|
||||||
|
}
|
||||||
|
|||||||
2554
app/logs/errors.log
2554
app/logs/errors.log
File diff suppressed because it is too large
Load Diff
@@ -4,3 +4,4 @@ from pydantic import BaseModel
|
|||||||
class SuperResolutionModel(BaseModel):
|
class SuperResolutionModel(BaseModel):
|
||||||
sr_image_url: str
|
sr_image_url: str
|
||||||
sr_xn: int
|
sr_xn: int
|
||||||
|
sr_tasks_id: str
|
||||||
|
|||||||
@@ -1,14 +1,23 @@
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
import minio.error
|
||||||
|
import pika
|
||||||
|
import redis
|
||||||
|
import json
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tritonclient.http as httpclient
|
import tritonclient.http as httpclient
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT
|
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT, REDIS_HOST, REDIS_PORT, REDIS_DB, SR_MODEL_NAME, RABBITMQ_PARAMS
|
||||||
|
from app.schemas.super_resolution import SuperResolutionModel
|
||||||
|
|
||||||
from app.service.utils.decorator import RunTime
|
from app.service.utils.decorator import RunTime
|
||||||
from app.service.utils.generate_uuid import generate_uuid
|
from app.service.utils.generate_uuid import generate_uuid
|
||||||
@@ -17,39 +26,75 @@ logger = logging.getLogger()
|
|||||||
|
|
||||||
|
|
||||||
class SuperResolution:
|
class SuperResolution:
|
||||||
def __init__(self):
|
def __init__(self, data):
|
||||||
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
|
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
|
||||||
|
self.triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001")
|
||||||
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
self.tasks_id = data.sr_tasks_id
|
||||||
|
self.sr_image_url = data.sr_image_url
|
||||||
|
self.sr_xn = data.sr_xn
|
||||||
self.minio_client = Minio(
|
self.minio_client = Minio(
|
||||||
f"{MINIO_IP}:{MINIO_PORT}",
|
f"{MINIO_IP}:{MINIO_PORT}",
|
||||||
access_key=MINIO_ACCESS,
|
access_key=MINIO_ACCESS,
|
||||||
secret_key=MINIO_SECRET,
|
secret_key=MINIO_SECRET,
|
||||||
secure=MINIO_SECURE)
|
secure=MINIO_SECURE)
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
|
||||||
|
|
||||||
@RunTime
|
@RunTime
|
||||||
def read_image(self, image_url):
|
def read_image(self):
|
||||||
image_data = self.minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
|
try:
|
||||||
|
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
|
||||||
|
except minio.error.S3Error as e:
|
||||||
|
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
|
||||||
|
publish_message(sr_data)
|
||||||
|
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
|
||||||
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
|
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
|
||||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
|
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def read_tasks_status(self):
|
||||||
|
status_data = json.loads(self.redis_client.get(self.tasks_id))
|
||||||
|
logging.info(f"{self.tasks_id} ===> {status_data}")
|
||||||
|
return status_data
|
||||||
|
|
||||||
@RunTime
|
@RunTime
|
||||||
def sr_result(self, image_url, sr_xn):
|
def infer(self, inputs):
|
||||||
sample = self.read_image(image_url)
|
return self.triton_client.async_infer(
|
||||||
|
model_name=SR_MODEL_NAME,
|
||||||
|
inputs=inputs,
|
||||||
|
callback=self.callback
|
||||||
|
)
|
||||||
|
|
||||||
|
@RunTime
|
||||||
|
def sr_result(self):
|
||||||
|
sample = self.read_image()
|
||||||
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
||||||
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
||||||
inputs = [
|
inputs = [
|
||||||
httpclient.InferInput("input", sample.shape, datatype="FP32")
|
grpcclient.InferInput("input", sample.shape, datatype="FP32")
|
||||||
]
|
]
|
||||||
inputs[0].set_data_from_numpy(sample, binary_data=True)
|
inputs[0].set_data_from_numpy(sample
|
||||||
results = self.triton_client.infer(model_name="super_resolution", inputs=inputs)
|
# , binary_data=True
|
||||||
|
)
|
||||||
|
|
||||||
sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
ctx = self.infer(inputs)
|
||||||
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
time_out = 120
|
||||||
if output.ndim == 3:
|
while self.read_tasks_status()['status'] == "PENDING" and time_out > 0:
|
||||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
if self.read_tasks_status()['status'] == "REVOKED":
|
||||||
output = (output * 255.0).round().astype(np.uint8)
|
ctx.cancel()
|
||||||
output_url = self.upload_img_sr(output, generate_uuid())
|
time_out -= 1
|
||||||
return output_url
|
time.sleep(1)
|
||||||
|
return self.read_tasks_status()
|
||||||
|
|
||||||
|
# results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
|
||||||
|
|
||||||
|
# sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
||||||
|
# output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
# if output.ndim == 3:
|
||||||
|
# output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||||
|
# output = (output * 255.0).round().astype(np.uint8)
|
||||||
|
# output_url = self.upload_img_sr(output, generate_uuid())
|
||||||
|
# return output_url
|
||||||
|
|
||||||
def upload_img_sr(self, image, object_name):
|
def upload_img_sr(self, image, object_name):
|
||||||
try:
|
try:
|
||||||
@@ -60,8 +105,43 @@ class SuperResolution:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"upload_png_mask runtime exception : {e}")
|
logger.warning(f"upload_png_mask runtime exception : {e}")
|
||||||
|
|
||||||
|
def callback(self, result, error):
|
||||||
|
if error:
|
||||||
|
print(error)
|
||||||
|
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
|
||||||
|
self.redis_client.set(self.tasks_id, sr_info_data)
|
||||||
|
else:
|
||||||
|
sr_output = result.as_numpy("output")[0]
|
||||||
|
sr_output = torch.tensor(sr_output)
|
||||||
|
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
if output.ndim == 3:
|
||||||
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||||
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
|
output_url = self.upload_img_sr(output, generate_uuid())
|
||||||
|
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': 'success', 'data': f'{output_url}'})
|
||||||
|
publish_message(sr_data)
|
||||||
|
self.redis_client.set(self.tasks_id, sr_data)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_cancel(tasks_id):
|
||||||
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
data = {'tasks': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||||
|
sr_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
|
||||||
|
publish_message(sr_data)
|
||||||
|
redis_client.set(tasks_id, sr_data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def publish_message(sr_data):
|
||||||
|
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
|
channel = connection.channel()
|
||||||
|
# 发布消息,并设置回调函数
|
||||||
|
channel.basic_publish(exchange='', routing_key='SuperResolution-local', body=sr_data)
|
||||||
|
logger.info(f" [x] Sent {sr_data}")
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
service = SuperResolution()
|
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123")
|
||||||
result_url = service.sr_result("test/128_image/11.png", 4)
|
service = SuperResolution(request_data)
|
||||||
print(result_url)
|
result_url = service.sr_result()
|
||||||
|
|||||||
@@ -4,25 +4,28 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tritonclient.http as httpclient
|
import tritonclient.http as httpclient
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
|
triton_client = grpcclient.InferenceServerClient(url=f"10.1.1.150:7001")
|
||||||
|
|
||||||
sample = cv2.imread("comic2.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
sample = cv2.imread("1709713346.806274.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
||||||
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
||||||
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
||||||
inputs = [
|
inputs = [
|
||||||
httpclient.InferInput("input", sample.shape, datatype="FP32")
|
grpcclient.InferInput("input", sample.shape, datatype="FP32")
|
||||||
]
|
]
|
||||||
inputs[0].set_data_from_numpy(sample, binary_data=True)
|
inputs[0].set_data_from_numpy(sample
|
||||||
|
# , binary_data=True
|
||||||
|
)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = triton_client.infer(model_name="super_resolution", inputs=inputs)
|
results = triton_client.async_infer(model_name="super_resolution", inputs=inputs)
|
||||||
print(time.time() - start_time)
|
print(time.time() - start_time)
|
||||||
sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
||||||
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
if output.ndim == 3:
|
if output.ndim == 3:
|
||||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||||
output = (output * 255.0).round().astype(np.uint8)
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
# cv2.imshow("", output)
|
cv2.imshow("", output)
|
||||||
# cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
cv2.imwrite("comic3.png", output)
|
|
||||||
|
|||||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Reference in New Issue
Block a user