Files
AiDA_Python/app/service/design/items/pipelines/segmentation.py
2024-07-19 15:17:54 +08:00

47 lines
1.4 KiB
Python

import os
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.utils.decorator import ClassCallRunTime
from ..builder import PIPELINES
from ...utils.design_ensemble import get_seg_result
@PIPELINES.register_module()
class Segmentation(object):
def __init__(self, device='cpu', show=False, debug=None):
self.show = show
self.device = device
self.debug = debug
# @ClassCallRunTime
def __call__(self, result):
_, seg_result = self.load_seg_result(result["image_id"])
if not _:
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
self.save_seg_result(result['seg_result'][0], result['image_id'])
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
print("保存成功", os.path.abspath(file_path))
except Exception as e:
print(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
print("文件不存在")
return False, None
except Exception as e:
print(f"加载失败: {e}")
return False, None