Files
AiDA_Python/app/service/design_fast/utils/design_ensemble.py
zhouchengrong 6a12dcba57 feat(新功能):
fix(修复bug): design 分割预处理新增25padding,后处理取消插值处理
docs(文档变更):
refactor(重构):
test(增加测试):
2025-01-13 19:44:51 +08:00

146 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File design_ensemble.py
@Author :周成融
@Date 2023/8/16 19:36:21
@detail :发起请求 获取推理结果
"""
import logging
import cv2
import mmcv
import numpy as np
import torch
import tritonclient.http as httpclient
from app.core.config import *
"""
keypoint
预处理 推理 后处理
"""
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
# @ RunTime
# 推理
def get_keypoint_result(image, site):
keypoint_result = None
try:
image, scale_factor = keypoint_preprocess(image)
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
inference_output = torch.from_numpy(results.as_numpy(f'output'))
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
except Exception as e:
logging.warning(f"get_keypoint_result : {e}")
return keypoint_result
def keypoint_postprocess(output, scale_factor):
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
segment_result = max_coords.numpy()
scale_factor = [1 / x for x in scale_factor[::-1]]
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
"""
seg
预处理 推理 后处理
"""
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
# 扩充25的白边
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
# img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# @ RunTime
def get_seg_result(image_id, image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50))
seg_logit = seg_logit[25: - 25, 25: - 25]
return seg_logit
def key_point_show(image_path, key_point_result=None):
img = cv2.imread(image_path)
points_list = key_point_result
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for point in points_list:
cv2.circle(img, point[::-1], point_size, point_color, thickness)
cv2.imshow("0", img)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
# a = get_seg_result(1, image)
print(a)