feat generate 污点图保存
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -15,7 +17,7 @@ logger = logging.getLogger()
|
|||||||
def seg_preprocess(img_path):
|
def seg_preprocess(img_path):
|
||||||
img = mmcv.imread(img_path)
|
img = mmcv.imread(img_path)
|
||||||
ori_shape = img.shape[:2]
|
ori_shape = img.shape[:2]
|
||||||
img_scale = (224, 224)
|
img_scale = ori_shape
|
||||||
scale_factor = []
|
scale_factor = []
|
||||||
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
|
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
|
||||||
scale_factor.append(x)
|
scale_factor.append(x)
|
||||||
@@ -61,6 +63,26 @@ def get_contours(image):
|
|||||||
return Contour
|
return Contour
|
||||||
|
|
||||||
|
|
||||||
|
# def seg_infer_image(image_obj):
|
||||||
|
# image, ori_shape = seg_preprocess(image_obj)
|
||||||
|
# client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
|
||||||
|
# transformed_img = image.astype(np.float32)
|
||||||
|
# # 输入集
|
||||||
|
# inputs = [
|
||||||
|
# httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||||
|
# ]
|
||||||
|
# inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||||
|
# # 输出集
|
||||||
|
# outputs = [
|
||||||
|
# httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||||
|
# ]
|
||||||
|
# results = client.infer(model_name=SEGMENTATION['name'], inputs=inputs, outputs=outputs)
|
||||||
|
# # 推理
|
||||||
|
# # 取结果
|
||||||
|
# inference_output1 = torch.from_numpy(results.as_numpy(SEGMENTATION['output']))
|
||||||
|
# seg_result = seg_postprocess(inference_output1, ori_shape)
|
||||||
|
# return seg_result
|
||||||
|
|
||||||
def seg_infer_image(image_obj):
|
def seg_infer_image(image_obj):
|
||||||
image, ori_shape = seg_preprocess(image_obj)
|
image, ori_shape = seg_preprocess(image_obj)
|
||||||
client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
|
client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
|
||||||
@@ -74,21 +96,29 @@ def seg_infer_image(image_obj):
|
|||||||
outputs = [
|
outputs = [
|
||||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||||
]
|
]
|
||||||
results = client.infer(model_name=SEGMENTATION['name'], inputs=inputs, outputs=outputs)
|
start_time = time.time()
|
||||||
|
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||||
|
print(f"KNet infer time is :{time.time() - start_time}")
|
||||||
# 推理
|
# 推理
|
||||||
# 取结果
|
# 取结果
|
||||||
inference_output1 = torch.from_numpy(results.as_numpy(SEGMENTATION['output']))
|
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||||
seg_result = seg_postprocess(inference_output1, ori_shape)
|
seg_result = seg_postprocess(inference_output1, ori_shape)
|
||||||
return seg_result
|
return seg_result
|
||||||
|
|
||||||
|
# def seg_postprocess(output, ori_shape):
|
||||||
|
# seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||||
|
# seg_logit = F.softmax(seg_logit, dim=1)
|
||||||
|
# seg_pred = seg_logit.argmax(dim=1)
|
||||||
|
# seg_pred = seg_pred.cpu().numpy()
|
||||||
|
# return seg_pred
|
||||||
|
|
||||||
|
# KNet
|
||||||
def seg_postprocess(output, ori_shape):
|
def seg_postprocess(output, ori_shape):
|
||||||
seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
# seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||||
seg_logit = F.softmax(seg_logit, dim=1)
|
# seg_logit = F.softmax(seg_logit, dim=1)
|
||||||
seg_pred = seg_logit.argmax(dim=1)
|
# seg_pred = seg_logit.argmax(dim=1)
|
||||||
seg_pred = seg_pred.cpu().numpy()
|
# seg_pred = output.cpu().numpy()
|
||||||
return seg_pred
|
return output[0]
|
||||||
|
|
||||||
|
|
||||||
def remove_background(image):
|
def remove_background(image):
|
||||||
image_obj, mask = get_mask(image)
|
image_obj, mask = get_mask(image)
|
||||||
|
|||||||
Reference in New Issue
Block a user