#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :design_ensemble.py @Author :周成融 @Date :2023/8/16 19:36:21 @detail :发起请求 获取推理结果 """ import logging import cv2 import mmcv import numpy as np import tritonclient.http as httpclient import torch import torch.nn.functional as F from app.core.config import * """ keypoint 预处理 推理 后处理 """ def keypoint_preprocess(img_path): img = mmcv.imread(img_path) img_scale = (256, 256) img, w_scale, h_scale = mmcv.imresize(img, img_scale, return_scale=True) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, (w_scale, h_scale) # @ RunTime # 推理 def get_keypoint_result(image, site): keypoint_result = None try: image, scale_factor = keypoint_preprocess(image) client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) transformed_img = image.astype(np.float32) inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")] inputs[0].set_data_from_numpy(transformed_img, binary_data=True) outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)] results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs) inference_output = torch.from_numpy(results.as_numpy(f'output')) keypoint_result = keypoint_postprocess(inference_output, scale_factor) except Exception as e: logging.warning(f"get_keypoint_result : {e}") return keypoint_result def keypoint_postprocess(output, scale_factor): max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2) max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2) segment_result = max_coords.numpy() scale_factor = [1 / x for x in scale_factor[::-1]] scale_matrix = np.diag(scale_factor) nan = np.isinf(scale_matrix) scale_matrix[nan] = 0 return np.ceil(np.dot(segment_result, scale_matrix) * 4) """ seg 预处理 推理 后处理 """ # KNet def seg_preprocess(img_path): img = mmcv.imread(img_path) ori_shape = img.shape[:2] img_scale_w, img_scale_h = ori_shape if ori_shape[0] > 1024: img_scale_w = 1024 if ori_shape[1] > 1024: img_scale_h = 1024 scale_factor = [] img, x, y = mmcv.imresize(img, (img_scale_w, img_scale_h), return_scale=True) scale_factor.append(x) scale_factor.append(y) img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape # @ RunTime def get_seg_result(image_id, image): image, ori_shape = seg_preprocess(image) client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") transformed_img = image.astype(np.float32) # 输入集 inputs = [ httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32") ] inputs[0].set_data_from_numpy(transformed_img, binary_data=True) # 输出集 outputs = [ httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True), ] results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs) # 推理 # 取结果 inference_output1 = results.as_numpy(SEGMENTATION['output']) seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape) return seg_result # no cache def seg_postprocess(image_id, output, ori_shape): seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) seg_pred = seg_logit.cpu().numpy() return seg_pred[0] def key_point_show(image_path, key_point_result=None): img = cv2.imread(image_path) points_list = key_point_result point_size = 1 point_color = (0, 0, 255) # BGR thickness = 4 # 可以为 0 、4、8 for point in points_list: cv2.circle(img, point[::-1], point_size, point_color, thickness) cv2.imshow("0", img) cv2.waitKey(0) if __name__ == '__main__': image = cv2.imread("./14162b58-f259-4833-98cb-89b9b496b251.jfif") a = get_keypoint_result(image, "up") new_list = [] print(list) for i in a[0]: new_list.append((int(i[0]), int(i[1]))) key_point_show("./14162b58-f259-4833-98cb-89b9b496b251.jfif", new_list) # a = get_seg_result(1, image) print(a)