fitst commit
This commit is contained in:
58
scripts/export_encoder.py
Executable file
58
scripts/export_encoder.py
Executable file
@@ -0,0 +1,58 @@
|
||||
# export_encoder.py
|
||||
"""
|
||||
导出 SAM 的 Image Encoder 为 ONNX。
|
||||
假设你使用的是 vit_h 并且想用固定输入 1024x1024。
|
||||
如果要换 model_type 或 checkpoint,使用命令行参数 --model-type/--checkpoint/--out
|
||||
"""
|
||||
import argparse
|
||||
import torch
|
||||
import os
|
||||
from segment_anything import sam_model_registry
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-type", type=str, default="vit_h")
|
||||
parser.add_argument("--checkpoint", type=str, required=True)
|
||||
parser.add_argument("--out", type=str, default="encoder.onnx")
|
||||
parser.add_argument("--opset", type=int, default=17)
|
||||
parser.add_argument("--img-size", type=int, default=1024)
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"device: {device}")
|
||||
print("Loading SAM model (this may take a while)...")
|
||||
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
||||
sam.to(device)
|
||||
sam.eval()
|
||||
|
||||
# SAM 的 encoder 是 sam.image_encoder
|
||||
image_encoder = sam.image_encoder
|
||||
image_encoder.eval()
|
||||
|
||||
# 构建 dummy 输入,注意 SAM 的预处理通常是 [0,1] float32, 3xHxW
|
||||
bs = args.batch_size
|
||||
H = W = args.img_size
|
||||
dummy_input = torch.randn(bs, 3, H, W, device=device, dtype=torch.float32)
|
||||
|
||||
# ONNX 导出。注意:根据具体实现,encoder.forward 可能需要额外参数;
|
||||
# 此处采用直接调用 image_encoder(dummy_input) 的方式
|
||||
input_names = ["input_image"]
|
||||
output_names = ["image_embeddings"]
|
||||
|
||||
print(f"Exporting encoder to {args.out} ...")
|
||||
torch.onnx.export(
|
||||
image_encoder,
|
||||
dummy_input,
|
||||
args.out,
|
||||
export_params=True,
|
||||
opset_version=args.opset,
|
||||
do_constant_folding=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=None # 固定尺寸,更容易用于 TensorRT / Triton
|
||||
)
|
||||
print("Export done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user