Files
aida_seg_anything/scripts/export_encoder.py
2026-01-08 14:35:23 +08:00

59 lines
2.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# export_encoder.py
"""
导出 SAM 的 Image Encoder 为 ONNX。
假设你使用的是 vit_h 并且想用固定输入 1024x1024。
如果要换 model_type 或 checkpoint使用命令行参数 --model-type/--checkpoint/--out
"""
import argparse
import torch
import os
from segment_anything import sam_model_registry
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-type", type=str, default="vit_h")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--out", type=str, default="encoder.onnx")
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--img-size", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=1)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
print("Loading SAM model (this may take a while)...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
sam.to(device)
sam.eval()
# SAM 的 encoder 是 sam.image_encoder
image_encoder = sam.image_encoder
image_encoder.eval()
# 构建 dummy 输入,注意 SAM 的预处理通常是 [0,1] float32, 3xHxW
bs = args.batch_size
H = W = args.img_size
dummy_input = torch.randn(bs, 3, H, W, device=device, dtype=torch.float32)
# ONNX 导出。注意根据具体实现encoder.forward 可能需要额外参数;
# 此处采用直接调用 image_encoder(dummy_input) 的方式
input_names = ["input_image"]
output_names = ["image_embeddings"]
print(f"Exporting encoder to {args.out} ...")
torch.onnx.export(
image_encoder,
dummy_input,
args.out,
export_params=True,
opset_version=args.opset,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=None # 固定尺寸,更容易用于 TensorRT / Triton
)
print("Export done.")
if __name__ == "__main__":
main()