146 lines
4.9 KiB
Python
146 lines
4.9 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: UTF-8 -*-
|
||
"""
|
||
@Project :trinity_client
|
||
@File :design_ensemble.py
|
||
@Author :周成融
|
||
@Date :2023/8/16 19:36:21
|
||
@detail :发起请求 获取推理结果
|
||
"""
|
||
import logging
|
||
|
||
import cv2
|
||
import mmcv
|
||
import numpy as np
|
||
import torch
|
||
import tritonclient.http as httpclient
|
||
|
||
from app.core.config import *
|
||
|
||
"""
|
||
keypoint
|
||
预处理 推理 后处理
|
||
"""
|
||
|
||
|
||
def keypoint_preprocess(img_path):
|
||
img = mmcv.imread(img_path)
|
||
img_scale = (256, 256)
|
||
h, w = img.shape[:2]
|
||
img = cv2.resize(img, img_scale)
|
||
w_scale = img_scale[0] / w
|
||
h_scale = img_scale[1] / h
|
||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||
return preprocessed_img, (w_scale, h_scale)
|
||
|
||
|
||
# @ RunTime
|
||
# 推理
|
||
def get_keypoint_result(image, site):
|
||
keypoint_result = None
|
||
try:
|
||
image, scale_factor = keypoint_preprocess(image)
|
||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||
transformed_img = image.astype(np.float32)
|
||
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
|
||
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
|
||
inference_output = torch.from_numpy(results.as_numpy(f'output'))
|
||
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
|
||
except Exception as e:
|
||
logging.warning(f"get_keypoint_result : {e}")
|
||
return keypoint_result
|
||
|
||
|
||
def keypoint_postprocess(output, scale_factor):
|
||
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
|
||
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
|
||
segment_result = max_coords.numpy()
|
||
scale_factor = [1 / x for x in scale_factor[::-1]]
|
||
scale_matrix = np.diag(scale_factor)
|
||
nan = np.isinf(scale_matrix)
|
||
scale_matrix[nan] = 0
|
||
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||
|
||
|
||
"""
|
||
seg
|
||
预处理 推理 后处理
|
||
"""
|
||
|
||
|
||
# KNet
|
||
def seg_preprocess(img_path):
|
||
img = mmcv.imread(img_path)
|
||
ori_shape = img.shape[:2]
|
||
img_scale_w, img_scale_h = ori_shape
|
||
if ori_shape[0] > 1024:
|
||
img_scale_w = 1024
|
||
if ori_shape[1] > 1024:
|
||
img_scale_h = 1024
|
||
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||
if ori_shape != (img_scale_w, img_scale_h):
|
||
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||
|
||
# 扩充25的白边
|
||
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
|
||
# img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||
return preprocessed_img, ori_shape
|
||
|
||
|
||
# @ RunTime
|
||
def get_seg_result(image_id, image):
|
||
image, ori_shape = seg_preprocess(image)
|
||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||
transformed_img = image.astype(np.float32)
|
||
# 输入集
|
||
inputs = [
|
||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||
]
|
||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||
# 输出集
|
||
outputs = [
|
||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||
]
|
||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||
# 推理
|
||
# 取结果
|
||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||
return seg_result
|
||
|
||
|
||
# no cache
|
||
def seg_postprocess(image_id, output, ori_shape):
|
||
seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50))
|
||
seg_logit = seg_logit[25: - 25, 25: - 25]
|
||
return seg_logit
|
||
|
||
|
||
def key_point_show(image_path, key_point_result=None):
|
||
img = cv2.imread(image_path)
|
||
points_list = key_point_result
|
||
point_size = 1
|
||
point_color = (0, 0, 255) # BGR
|
||
thickness = 4 # 可以为 0 、4、8
|
||
for point in points_list:
|
||
cv2.circle(img, point[::-1], point_size, point_color, thickness)
|
||
cv2.imshow("0", img)
|
||
cv2.waitKey(0)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
|
||
a = get_keypoint_result(image, "up")
|
||
new_list = []
|
||
print(list)
|
||
for i in a[0]:
|
||
new_list.append((int(i[0]), int(i[1])))
|
||
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
|
||
# a = get_seg_result(1, image)
|
||
print(a)
|