This commit is contained in:
zhouchengrong
2024-03-20 11:44:15 +08:00
commit 69132570aa
29 changed files with 815 additions and 0 deletions

View File

@@ -0,0 +1,67 @@
import io
import logging
from io import BytesIO
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class SuperResolution:
def __init__(self):
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
@RunTime
def read_image(self, image_url):
image_data = self.minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
return img
@RunTime
def sr_result(self, image_url, sr_xn):
sample = self.read_image(image_url)
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
httpclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample, binary_data=True)
results = self.triton_client.infer(model_name="super_resolution", inputs=inputs)
sr_output = torch.from_numpy(results.as_numpy(f"output"))
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid())
return output_url
def upload_img_sr(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
if __name__ == '__main__':
service = SuperResolution()
result_url = service.sr_result("test/128_image/11.png", 4)
print(result_url)

View File

@@ -0,0 +1,28 @@
import time
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
from PIL import Image
triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
sample = cv2.imread("comic2.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
httpclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample, binary_data=True)
start_time = time.time()
results = triton_client.infer(model_name="super_resolution", inputs=inputs)
print(time.time() - start_time)
sr_output = torch.from_numpy(results.as_numpy(f"output"))
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
# cv2.imshow("", output)
# cv2.waitKey(0)
cv2.imwrite("comic3.png", output)

View File

@@ -0,0 +1,14 @@
import time
import logging
def RunTime(func):
def wrapper(*args, **kwargs):
t1 = time.time()
res = func(*args, **kwargs)
t2 = time.time()
if t2 - t1 > 0.05:
logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s")
return res
return wrapper

View File

@@ -0,0 +1,10 @@
import threading
import uuid
id_lock = threading.Lock()
def generate_uuid():
with id_lock:
unique_id = str(uuid.uuid1())
return unique_id