feat
fix design pipeline 剔除边缘检测任务,直接用分割
This commit is contained in:
@@ -9,18 +9,22 @@ from ...utils.design_ensemble import get_seg_result
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Segmentation(object):
|
||||
def __init__(self, device='cpu', show=False, debug=None):
|
||||
self.show = show
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
result['seg_result'] = seg_result
|
||||
if not _:
|
||||
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
|
||||
self.save_seg_result(result['seg_result'][0], result['image_id'])
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# 处理前片后片
|
||||
temp_front = seg_result == 1.0
|
||||
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = seg_result == 2.0
|
||||
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user