59 lines
2.0 KiB
Python
59 lines
2.0 KiB
Python
|
|
# export_encoder.py
|
|||
|
|
"""
|
|||
|
|
导出 SAM 的 Image Encoder 为 ONNX。
|
|||
|
|
假设你使用的是 vit_h 并且想用固定输入 1024x1024。
|
|||
|
|
如果要换 model_type 或 checkpoint,使用命令行参数 --model-type/--checkpoint/--out
|
|||
|
|
"""
|
|||
|
|
import argparse
|
|||
|
|
import torch
|
|||
|
|
import os
|
|||
|
|
from segment_anything import sam_model_registry
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser()
|
|||
|
|
parser.add_argument("--model-type", type=str, default="vit_h")
|
|||
|
|
parser.add_argument("--checkpoint", type=str, required=True)
|
|||
|
|
parser.add_argument("--out", type=str, default="encoder.onnx")
|
|||
|
|
parser.add_argument("--opset", type=int, default=17)
|
|||
|
|
parser.add_argument("--img-size", type=int, default=1024)
|
|||
|
|
parser.add_argument("--batch-size", type=int, default=1)
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
|
print(f"device: {device}")
|
|||
|
|
print("Loading SAM model (this may take a while)...")
|
|||
|
|
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
|||
|
|
sam.to(device)
|
|||
|
|
sam.eval()
|
|||
|
|
|
|||
|
|
# SAM 的 encoder 是 sam.image_encoder
|
|||
|
|
image_encoder = sam.image_encoder
|
|||
|
|
image_encoder.eval()
|
|||
|
|
|
|||
|
|
# 构建 dummy 输入,注意 SAM 的预处理通常是 [0,1] float32, 3xHxW
|
|||
|
|
bs = args.batch_size
|
|||
|
|
H = W = args.img_size
|
|||
|
|
dummy_input = torch.randn(bs, 3, H, W, device=device, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
# ONNX 导出。注意:根据具体实现,encoder.forward 可能需要额外参数;
|
|||
|
|
# 此处采用直接调用 image_encoder(dummy_input) 的方式
|
|||
|
|
input_names = ["input_image"]
|
|||
|
|
output_names = ["image_embeddings"]
|
|||
|
|
|
|||
|
|
print(f"Exporting encoder to {args.out} ...")
|
|||
|
|
torch.onnx.export(
|
|||
|
|
image_encoder,
|
|||
|
|
dummy_input,
|
|||
|
|
args.out,
|
|||
|
|
export_params=True,
|
|||
|
|
opset_version=args.opset,
|
|||
|
|
do_constant_folding=True,
|
|||
|
|
input_names=input_names,
|
|||
|
|
output_names=output_names,
|
|||
|
|
dynamic_axes=None # 固定尺寸,更容易用于 TensorRT / Triton
|
|||
|
|
)
|
|||
|
|
print("Export done.")
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|