2024-09-12 10:05:38 +08:00
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
# -*- coding: UTF-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
@Project :trinity_client
|
|
|
|
|
|
@File :design_ensemble.py
|
|
|
|
|
|
@Author :周成融
|
|
|
|
|
|
@Date :2023/8/16 19:36:21
|
|
|
|
|
|
@detail :发起请求 获取推理结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
import tritonclient.http as httpclient
|
|
|
|
|
|
|
2025-12-30 16:49:08 +08:00
|
|
|
|
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
|
2026-02-10 11:17:31 +08:00
|
|
|
|
from app.service.utils.image_normalize import my_imnormalize
|
2024-09-12 10:05:38 +08:00
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
keypoint
|
|
|
|
|
|
预处理 推理 后处理
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def keypoint_preprocess(img_path):
|
2026-02-10 11:17:31 +08:00
|
|
|
|
img = img_path
|
2024-09-12 10:05:38 +08:00
|
|
|
|
img_scale = (256, 256)
|
|
|
|
|
|
h, w = img.shape[:2]
|
|
|
|
|
|
img = cv2.resize(img, img_scale)
|
|
|
|
|
|
w_scale = img_scale[0] / w
|
|
|
|
|
|
h_scale = img_scale[1] / h
|
2026-02-10 11:17:31 +08:00
|
|
|
|
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
2024-09-12 10:05:38 +08:00
|
|
|
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
|
|
|
|
|
return preprocessed_img, (w_scale, h_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @ RunTime
|
|
|
|
|
|
# 推理
|
|
|
|
|
|
def get_keypoint_result(image, site):
|
|
|
|
|
|
keypoint_result = None
|
|
|
|
|
|
try:
|
|
|
|
|
|
image, scale_factor = keypoint_preprocess(image)
|
|
|
|
|
|
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
|
|
|
|
|
transformed_img = image.astype(np.float32)
|
|
|
|
|
|
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
|
|
|
|
|
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
|
|
|
|
|
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
|
|
|
|
|
|
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
|
|
|
|
|
|
inference_output = torch.from_numpy(results.as_numpy(f'output'))
|
|
|
|
|
|
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.warning(f"get_keypoint_result : {e}")
|
|
|
|
|
|
return keypoint_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def keypoint_postprocess(output, scale_factor):
|
|
|
|
|
|
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
|
|
|
|
|
|
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
|
|
|
|
|
|
segment_result = max_coords.numpy()
|
|
|
|
|
|
scale_factor = [1 / x for x in scale_factor[::-1]]
|
|
|
|
|
|
scale_matrix = np.diag(scale_factor)
|
|
|
|
|
|
nan = np.isinf(scale_matrix)
|
|
|
|
|
|
scale_matrix[nan] = 0
|
|
|
|
|
|
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
seg
|
|
|
|
|
|
预处理 推理 后处理
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# KNet
|
|
|
|
|
|
def seg_preprocess(img_path):
|
2026-02-10 11:17:31 +08:00
|
|
|
|
img = img_path
|
2024-09-12 10:05:38 +08:00
|
|
|
|
ori_shape = img.shape[:2]
|
|
|
|
|
|
img_scale_w, img_scale_h = ori_shape
|
|
|
|
|
|
if ori_shape[0] > 1024:
|
|
|
|
|
|
img_scale_w = 1024
|
|
|
|
|
|
if ori_shape[1] > 1024:
|
|
|
|
|
|
img_scale_h = 1024
|
|
|
|
|
|
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
|
|
|
|
|
if ori_shape != (img_scale_w, img_scale_h):
|
2026-02-10 11:17:31 +08:00
|
|
|
|
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
2024-09-12 10:05:38 +08:00
|
|
|
|
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
2026-02-10 11:17:31 +08:00
|
|
|
|
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
2024-09-12 10:05:38 +08:00
|
|
|
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
|
|
|
|
|
return preprocessed_img, ori_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @ RunTime
|
2025-12-30 16:49:08 +08:00
|
|
|
|
def get_seg_result(image):
|
2024-09-12 10:05:38 +08:00
|
|
|
|
image, ori_shape = seg_preprocess(image)
|
|
|
|
|
|
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
|
|
|
|
|
transformed_img = image.astype(np.float32)
|
|
|
|
|
|
# 输入集
|
|
|
|
|
|
inputs = [
|
2025-12-30 16:49:08 +08:00
|
|
|
|
httpclient.InferInput(DESIGN_MODEL_NAME, transformed_img.shape, datatype="FP32")
|
2024-09-12 10:05:38 +08:00
|
|
|
|
]
|
|
|
|
|
|
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
|
|
|
|
|
# 输出集
|
|
|
|
|
|
outputs = [
|
2025-12-30 16:49:08 +08:00
|
|
|
|
httpclient.InferRequestedOutput("seg_input__0", binary_data=True),
|
2024-09-12 10:05:38 +08:00
|
|
|
|
]
|
2025-12-30 16:49:08 +08:00
|
|
|
|
results = client.infer(model_name=DESIGN_MODEL_NAME, inputs=inputs, outputs=outputs)
|
2024-09-12 10:05:38 +08:00
|
|
|
|
# 推理
|
|
|
|
|
|
# 取结果
|
2025-12-30 16:49:08 +08:00
|
|
|
|
inference_output1 = results.as_numpy("seg_input__0")
|
|
|
|
|
|
seg_result = seg_postprocess(inference_output1, ori_shape)
|
2024-09-12 10:05:38 +08:00
|
|
|
|
return seg_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# no cache
|
2025-12-30 16:49:08 +08:00
|
|
|
|
def seg_postprocess(output, ori_shape):
|
2024-09-12 10:05:38 +08:00
|
|
|
|
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
|
|
|
|
|
seg_pred = seg_logit.cpu().numpy()
|
|
|
|
|
|
return seg_pred[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def key_point_show(image_path, key_point_result=None):
|
|
|
|
|
|
img = cv2.imread(image_path)
|
|
|
|
|
|
points_list = key_point_result
|
|
|
|
|
|
point_size = 1
|
|
|
|
|
|
point_color = (0, 0, 255) # BGR
|
|
|
|
|
|
thickness = 4 # 可以为 0 、4、8
|
|
|
|
|
|
for point in points_list:
|
|
|
|
|
|
cv2.circle(img, point[::-1], point_size, point_color, thickness)
|
|
|
|
|
|
cv2.imshow("0", img)
|
|
|
|
|
|
cv2.waitKey(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
|
|
|
|
|
|
a = get_keypoint_result(image, "up")
|
|
|
|
|
|
new_list = []
|
|
|
|
|
|
print(list)
|
|
|
|
|
|
for i in a[0]:
|
|
|
|
|
|
new_list.append((int(i[0]), int(i[1])))
|
|
|
|
|
|
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
|
|
|
|
|
|
# a = get_seg_result(1, image)
|
|
|
|
|
|
print(a)
|