29 lines
1.0 KiB
Python
29 lines
1.0 KiB
Python
|
|
import time
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
import tritonclient.http as httpclient
|
||
|
|
from PIL import Image
|
||
|
|
|
||
|
|
triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
|
||
|
|
|
||
|
|
sample = cv2.imread("comic2.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
||
|
|
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
||
|
|
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
||
|
|
inputs = [
|
||
|
|
httpclient.InferInput("input", sample.shape, datatype="FP32")
|
||
|
|
]
|
||
|
|
inputs[0].set_data_from_numpy(sample, binary_data=True)
|
||
|
|
start_time = time.time()
|
||
|
|
results = triton_client.infer(model_name="super_resolution", inputs=inputs)
|
||
|
|
print(time.time() - start_time)
|
||
|
|
sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
||
|
|
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||
|
|
if output.ndim == 3:
|
||
|
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||
|
|
output = (output * 255.0).round().astype(np.uint8)
|
||
|
|
# cv2.imshow("", output)
|
||
|
|
# cv2.waitKey(0)
|
||
|
|
cv2.imwrite("comic3.png", output)
|