feat generate 污点图保存

This commit is contained in:
zhouchengrong
2024-05-17 18:22:25 +08:00
parent 18eb4735ea
commit 91af50f863

View File

@@ -1,4 +1,6 @@
import logging
import time
import mmcv
import numpy as np
import torch
@@ -15,7 +17,7 @@ logger = logging.getLogger()
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale = (224, 224)
img_scale = ori_shape
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
@@ -61,6 +63,26 @@ def get_contours(image):
return Contour
# def seg_infer_image(image_obj):
# image, ori_shape = seg_preprocess(image_obj)
# client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
# transformed_img = image.astype(np.float32)
# # 输入集
# inputs = [
# httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
# ]
# inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# # 输出集
# outputs = [
# httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
# ]
# results = client.infer(model_name=SEGMENTATION['name'], inputs=inputs, outputs=outputs)
# # 推理
# # 取结果
# inference_output1 = torch.from_numpy(results.as_numpy(SEGMENTATION['output']))
# seg_result = seg_postprocess(inference_output1, ori_shape)
# return seg_result
def seg_infer_image(image_obj):
image, ori_shape = seg_preprocess(image_obj)
client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
@@ -74,21 +96,29 @@ def seg_infer_image(image_obj):
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['name'], inputs=inputs, outputs=outputs)
start_time = time.time()
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
print(f"KNet infer time is :{time.time() - start_time}")
# 推理
# 取结果
inference_output1 = torch.from_numpy(results.as_numpy(SEGMENTATION['output']))
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
# def seg_postprocess(output, ori_shape):
# seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
# seg_logit = F.softmax(seg_logit, dim=1)
# seg_pred = seg_logit.argmax(dim=1)
# seg_pred = seg_pred.cpu().numpy()
# return seg_pred
# KNet
def seg_postprocess(output, ori_shape):
seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_logit = F.softmax(seg_logit, dim=1)
seg_pred = seg_logit.argmax(dim=1)
seg_pred = seg_pred.cpu().numpy()
return seg_pred
# seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
# seg_logit = F.softmax(seg_logit, dim=1)
# seg_pred = seg_logit.argmax(dim=1)
# seg_pred = output.cpu().numpy()
return output[0]
def remove_background(image):
image_obj, mask = get_mask(image)