# export_encoder.py """ 导出 SAM 的 Image Encoder 为 ONNX。 假设你使用的是 vit_h 并且想用固定输入 1024x1024。 如果要换 model_type 或 checkpoint,使用命令行参数 --model-type/--checkpoint/--out """ import argparse import torch import os from segment_anything import sam_model_registry def main(): parser = argparse.ArgumentParser() parser.add_argument("--model-type", type=str, default="vit_h") parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--out", type=str, default="encoder.onnx") parser.add_argument("--opset", type=int, default=17) parser.add_argument("--img-size", type=int, default=1024) parser.add_argument("--batch-size", type=int, default=1) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" print(f"device: {device}") print("Loading SAM model (this may take a while)...") sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) sam.to(device) sam.eval() # SAM 的 encoder 是 sam.image_encoder image_encoder = sam.image_encoder image_encoder.eval() # 构建 dummy 输入,注意 SAM 的预处理通常是 [0,1] float32, 3xHxW bs = args.batch_size H = W = args.img_size dummy_input = torch.randn(bs, 3, H, W, device=device, dtype=torch.float32) # ONNX 导出。注意:根据具体实现,encoder.forward 可能需要额外参数; # 此处采用直接调用 image_encoder(dummy_input) 的方式 input_names = ["input_image"] output_names = ["image_embeddings"] print(f"Exporting encoder to {args.out} ...") torch.onnx.export( image_encoder, dummy_input, args.out, export_params=True, opset_version=args.opset, do_constant_folding=True, input_names=input_names, output_names=output_names, dynamic_axes=None # 固定尺寸,更容易用于 TensorRT / Triton ) print("Export done.") if __name__ == "__main__": main()