Files
aida_seg_anything/scripts/export_encoder.py

59 lines
2.0 KiB
Python
Raw Normal View History

2026-01-08 14:35:23 +08:00
# export_encoder.py
"""
导出 SAM Image Encoder ONNX
假设你使用的是 vit_h 并且想用固定输入 1024x1024
如果要换 model_type checkpoint使用命令行参数 --model-type/--checkpoint/--out
"""
import argparse
import torch
import os
from segment_anything import sam_model_registry
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-type", type=str, default="vit_h")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--out", type=str, default="encoder.onnx")
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--img-size", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=1)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
print("Loading SAM model (this may take a while)...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
sam.to(device)
sam.eval()
# SAM 的 encoder 是 sam.image_encoder
image_encoder = sam.image_encoder
image_encoder.eval()
# 构建 dummy 输入,注意 SAM 的预处理通常是 [0,1] float32, 3xHxW
bs = args.batch_size
H = W = args.img_size
dummy_input = torch.randn(bs, 3, H, W, device=device, dtype=torch.float32)
# ONNX 导出。注意根据具体实现encoder.forward 可能需要额外参数;
# 此处采用直接调用 image_encoder(dummy_input) 的方式
input_names = ["input_image"]
output_names = ["image_embeddings"]
print(f"Exporting encoder to {args.out} ...")
torch.onnx.export(
image_encoder,
dummy_input,
args.out,
export_params=True,
opset_version=args.opset,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=None # 固定尺寸,更容易用于 TensorRT / Triton
)
print("Export done.")
if __name__ == "__main__":
main()