fitst commit

This commit is contained in:
2026-01-08 14:35:23 +08:00
commit b57e61f2f7
152 changed files with 7675 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

238
scripts/amg.py Executable file
View File

@@ -0,0 +1,238 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import cv2 # type: ignore
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import argparse
import json
import os
from typing import Any, Dict, List
parser = argparse.ArgumentParser(
description=(
"Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."
)
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to either a single input image or folder of images.",
)
parser.add_argument(
"--output",
type=str,
required=True,
help=(
"Path to the directory where masks will be output. Output will be either a folder "
"of PNGs per image or a single json with COCO-style masks."
),
)
parser.add_argument(
"--model-type",
type=str,
required=True,
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="The path to the SAM checkpoint to use for mask generation.",
)
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
parser.add_argument(
"--convert-to-rle",
action="store_true",
help=(
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
"Requires pycocotools."
),
)
amg_settings = parser.add_argument_group("AMG Settings")
amg_settings.add_argument(
"--points-per-side",
type=int,
default=None,
help="Generate masks by sampling a grid over the image with this many points to a side.",
)
amg_settings.add_argument(
"--points-per-batch",
type=int,
default=None,
help="How many input points to process simultaneously in one batch.",
)
amg_settings.add_argument(
"--pred-iou-thresh",
type=float,
default=None,
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-thresh",
type=float,
default=None,
help="Exclude masks with a stability score lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-offset",
type=float,
default=None,
help="Larger values perturb the mask more when measuring stability score.",
)
amg_settings.add_argument(
"--box-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding a duplicate mask.",
)
amg_settings.add_argument(
"--crop-n-layers",
type=int,
default=None,
help=(
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
"The value sets how many different scales to crop at."
),
)
amg_settings.add_argument(
"--crop-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding duplicate masks across different crops.",
)
amg_settings.add_argument(
"--crop-overlap-ratio",
type=int,
default=None,
help="Larger numbers mean image crops will overlap more.",
)
amg_settings.add_argument(
"--crop-n-points-downscale-factor",
type=int,
default=None,
help="The number of points-per-side in each layer of crop is reduced by this factor.",
)
amg_settings.add_argument(
"--min-mask-region-area",
type=int,
default=None,
help=(
"Disconnected mask regions or holes with area smaller than this value "
"in pixels are removed by postprocessing."
),
)
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
metadata = [header]
for i, mask_data in enumerate(masks):
mask = mask_data["segmentation"]
filename = f"{i}.png"
cv2.imwrite(os.path.join(path, filename), mask * 255)
mask_metadata = [
str(i),
str(mask_data["area"]),
*[str(x) for x in mask_data["bbox"]],
*[str(x) for x in mask_data["point_coords"][0]],
str(mask_data["predicted_iou"]),
str(mask_data["stability_score"]),
*[str(x) for x in mask_data["crop_box"]],
]
row = ",".join(mask_metadata)
metadata.append(row)
metadata_path = os.path.join(path, "metadata.csv")
with open(metadata_path, "w") as f:
f.write("\n".join(metadata))
return
def get_amg_kwargs(args):
amg_kwargs = {
"points_per_side": args.points_per_side,
"points_per_batch": args.points_per_batch,
"pred_iou_thresh": args.pred_iou_thresh,
"stability_score_thresh": args.stability_score_thresh,
"stability_score_offset": args.stability_score_offset,
"box_nms_thresh": args.box_nms_thresh,
"crop_n_layers": args.crop_n_layers,
"crop_nms_thresh": args.crop_nms_thresh,
"crop_overlap_ratio": args.crop_overlap_ratio,
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
"min_mask_region_area": args.min_mask_region_area,
}
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
return amg_kwargs
def main(args: argparse.Namespace) -> None:
print("Loading model...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
_ = sam.to(device=args.device)
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
amg_kwargs = get_amg_kwargs(args)
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
if not os.path.isdir(args.input):
targets = [args.input]
else:
targets = [
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
]
targets = [os.path.join(args.input, f) for f in targets]
os.makedirs(args.output, exist_ok=True)
for t in targets:
print(f"Processing '{t}'...")
image = cv2.imread(t)
if image is None:
print(f"Could not load '{t}' as an image, skipping...")
continue
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = generator.generate(image)
base = os.path.basename(t)
base = os.path.splitext(base)[0]
save_base = os.path.join(args.output, base)
if output_mode == "binary_mask":
os.makedirs(save_base, exist_ok=False)
write_masks_to_folder(masks, save_base)
else:
save_file = save_base + ".json"
with open(save_file, "w") as f:
json.dump(masks, f)
print("Done!")
if __name__ == "__main__":
args = parser.parse_args()
main(args)

238
scripts/amg_box.py Executable file
View File

@@ -0,0 +1,238 @@
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageTk
import tkinter as tk
from tkinter import messagebox
from segment_anything import SamPredictor, sam_model_registry
class SAMBoxSegmenter:
def __init__(self, image_path, sam_checkpoint, model_type="vit_h"):
# 加载图像
self.image = Image.open(image_path).convert("RGB")
self.original_image = self.image.copy()
self.result_image = self.image.copy()
# 初始化SAM模型
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
# 准备图像供SAM使用
self.image_np = np.array(self.image)
self.predictor.set_image(self.image_np)
# 框选相关变量
self.drawing = False
self.start_x = 0
self.start_y = 0
self.current_box = None
self.mask = None
# 创建GUI
self.root = tk.Tk()
self.root.title("SAM 手绘框分割工具")
# 创建画布
self.tk_image = ImageTk.PhotoImage(image=self.image)
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
self.canvas.pack()
# 绑定鼠标事件
self.canvas.bind("<ButtonPress-1>", self.on_mouse_down) # 鼠标按下
self.canvas.bind("<B1-Motion>", self.on_mouse_drag) # 鼠标拖动
self.canvas.bind("<ButtonRelease-1>", self.on_mouse_up) # 鼠标释放
# 创建按钮
self.controls_frame = tk.Frame(self.root)
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
self.clear_btn = tk.Button(self.controls_frame, text="清除框选", command=self.clear_box)
self.clear_btn.pack(side=tk.LEFT, padx=5)
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
self.save_btn.pack(side=tk.LEFT, padx=5)
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
self.quit_btn.pack(side=tk.RIGHT, padx=5)
# 添加说明标签
self.info_label = tk.Label(
self.root,
text="操作说明: 按住鼠标左键拖动绘制框选区域,松开后自动分割",
fg="blue"
)
self.info_label.pack(pady=5)
def on_mouse_down(self, event):
"""鼠标按下时记录起点"""
self.drawing = True
self.start_x = event.x
self.start_y = event.y
def on_mouse_drag(self, event):
"""鼠标拖动时实时绘制框选(使用兼容方式实现虚线框)"""
if not self.drawing:
return
# 临时复制原图用于绘制动态框
temp_image = self.original_image.copy()
draw = ImageDraw.Draw(temp_image)
# 计算框选坐标
x1, y1 = self.start_x, self.start_y
x2, y2 = event.x, event.y
# 绘制虚线框的替代方法兼容所有Pillow版本
# 绘制四条边,每条边用虚线模式
dash_pattern = [5, 5] # 虚线样式5像素实线5像素空白
self.draw_dashed_line(draw, x1, y1, x2, y1, dash_pattern, "red", 2) # 上边
self.draw_dashed_line(draw, x2, y1, x2, y2, dash_pattern, "red", 2) # 右边
self.draw_dashed_line(draw, x2, y2, x1, y2, dash_pattern, "red", 2) # 下边
self.draw_dashed_line(draw, x1, y2, x1, y1, dash_pattern, "red", 2) # 左边
# 更新显示
self.tk_image = ImageTk.PhotoImage(image=temp_image)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width):
"""绘制虚线的工具函数兼容所有Pillow版本"""
# 计算线段长度和方向
dx = x2 - x1
dy = y2 - y1
length = (dx ** 2 + dy ** 2) ** 0.5
# 如果线段太短,直接画实线
if length < 1:
draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
return
# 计算单位向量
ux = dx / length
uy = dy / length
# 绘制虚线
current_pos = 0
dash_on = True # 开始时是实线部分
pattern_index = 0
segment_length = dash_pattern[pattern_index]
while current_pos < length:
# 计算当前段的结束位置
end_pos = current_pos + segment_length
if end_pos > length:
end_pos = length
# 计算像素坐标
x_start = x1 + ux * current_pos
y_start = y1 + uy * current_pos
x_end = x1 + ux * end_pos
y_end = y1 + uy * end_pos
# 如果是实线部分,绘制线段
if dash_on:
draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width)
# 切换状态并更新参数
dash_on = not dash_on
pattern_index = (pattern_index + 1) % len(dash_pattern)
segment_length = dash_pattern[pattern_index]
current_pos = end_pos
def on_mouse_up(self, event):
"""鼠标释放时确定框选并生成掩码"""
if not self.drawing:
return
self.drawing = False
# 确保坐标正确(左上角到右下角)
self.current_box = [
min(self.start_x, event.x),
min(self.start_y, event.y),
max(self.start_x, event.x),
max(self.start_y, event.y)
]
# 生成分割掩码
self.update_segmentation()
def update_segmentation(self):
print('update_segmentation')
"""根据当前框选更新分割结果"""
if not self.current_box:
self.result_image = self.original_image.copy()
self.update_display()
return
# 准备框坐标
box = np.array(self.current_box)
# 调用SAM生成掩码
masks, _, _ = self.predictor.predict(
box=box,
multimask_output=False
)
self.mask = masks[0]
# 复制原图用于绘制结果
self.result_image = self.original_image.copy()
# 创建半透明的掩码叠加层
mask_array = self.mask.astype(np.uint8) * 128 # 0或128半透明
mask_image = Image.fromarray(mask_array, mode="L")
# 创建绿色的掩码图像
green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0))
green_draw = ImageDraw.Draw(green_mask)
green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明
# 叠加掩码到原图
self.result_image = Image.alpha_composite(
self.result_image.convert("RGBA"),
green_mask
).convert("RGB")
# 绘制最终框选(红色实线)
draw = ImageDraw.Draw(self.result_image)
draw.rectangle(self.current_box, outline="red", width=2)
self.update_display()
def update_display(self):
"""更新画布显示"""
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def clear_box(self):
"""清除框选和掩码"""
self.current_box = None
self.mask = None
self.result_image = self.original_image.copy()
self.update_display()
def save_result(self):
"""保存分割结果"""
try:
self.result_image.save("box_segmentation_result.jpg")
messagebox.showinfo("成功", "分割结果已保存为 box_segmentation_result.jpg")
except Exception as e:
messagebox.showerror("错误", f"保存失败: {str(e)}")
def run(self):
"""运行GUI主循环"""
self.root.mainloop()
if __name__ == "__main__":
# 配置参数
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应
# 创建并运行分割器
segmenter = SAMBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
segmenter.run()

158
scripts/amg_point.py Executable file
View File

@@ -0,0 +1,158 @@
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageTk
import tkinter as tk
from tkinter import messagebox
from segment_anything import SamPredictor, sam_model_registry
class SAMPointSegmenter:
def __init__(self, image_path, sam_checkpoint, model_type="vit_b"):
# 加载图像
self.image = Image.open(image_path).convert("RGB")
self.original_image = self.image.copy()
self.result_image = self.image.copy()
# 初始化SAM模型
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
# 准备图像供SAM使用
self.image_np = np.array(self.image)
self.predictor.set_image(self.image_np)
# 存储点击的点 [(x, y, label), ...]label=1表示目标内0表示目标外
self.points = []
self.mask = None
# 创建GUI
self.root = tk.Tk()
self.root.title("SAM 打点分割工具")
# 创建画布
self.tk_image = ImageTk.PhotoImage(image=self.image)
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
self.canvas.pack()
# 绑定鼠标事件
self.canvas.bind("<Button-1>", self.add_foreground_point) # 左键点击添加前景点
self.canvas.bind("<Button-3>", self.add_background_point) # 右键点击添加背景点
# 创建按钮
self.controls_frame = tk.Frame(self.root)
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
self.clear_btn = tk.Button(self.controls_frame, text="清除所有点", command=self.clear_points)
self.clear_btn.pack(side=tk.LEFT, padx=5)
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
self.save_btn.pack(side=tk.LEFT, padx=5)
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
self.quit_btn.pack(side=tk.RIGHT, padx=5)
# 添加说明标签
self.info_label = tk.Label(
self.root,
text="操作说明: 左键点击添加前景点(绿色), 右键点击添加背景点(红色)",
fg="blue"
)
self.info_label.pack(pady=5)
def add_foreground_point(self, event):
"""添加前景点(目标内)"""
self.points.append((event.x, event.y, 1))
self.update_segmentation()
def add_background_point(self, event):
"""添加背景点(目标外)"""
self.points.append((event.x, event.y, 0))
self.update_segmentation()
def update_segmentation(self):
"""根据当前点更新分割结果"""
# 复制原图
self.result_image = self.original_image.copy()
draw = ImageDraw.Draw(self.result_image)
if not self.points:
self.update_display()
return
# 准备点和标签
point_coords = np.array([(x, y) for x, y, label in self.points])
point_labels = np.array([label for x, y, label in self.points])
# 调用SAM生成掩码
masks, _, _ = self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False
)
self.mask = masks[0]
# 创建半透明的掩码叠加层
mask_array = self.mask.astype(np.uint8) * 128 # 0或128半透明
mask_image = Image.fromarray(mask_array, mode="L")
# 创建绿色的掩码图像
green_mask = Image.new("RGBA", self.result_image.size, (0, 255, 0, 0))
green_draw = ImageDraw.Draw(green_mask)
green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明
# 叠加掩码到原图
self.result_image = Image.alpha_composite(
self.result_image.convert("RGBA"),
green_mask
).convert("RGB")
# 绘制所有点
draw = ImageDraw.Draw(self.result_image)
for x, y, label in self.points:
# 前景点为绿色,背景点为红色
color = (0, 255, 0) if label == 1 else (255, 0, 0)
# 绘制点(内部实心,外部白色边框)
draw.ellipse([(x - 5, y - 5), (x + 5, y + 5)], fill=color)
draw.ellipse([(x - 7, y - 7), (x + 7, y + 7)], outline=(255, 255, 255), width=2)
self.update_display()
def update_display(self):
"""更新画布显示"""
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def clear_points(self):
"""清除所有点和掩码"""
self.points = []
self.mask = None
self.result_image = self.original_image.copy()
self.update_display()
def save_result(self):
"""保存分割结果"""
try:
self.result_image.save("segmentation_result.jpg")
messagebox.showinfo("成功", "分割结果已保存为 segmentation_result.jpg")
except Exception as e:
messagebox.showerror("错误", f"保存失败: {str(e)}")
def run(self):
"""运行GUI主循环"""
self.root.mainloop()
if __name__ == "__main__":
# 配置参数
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应
# 创建并运行分割器
segmenter = SAMPointSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
segmenter.run()

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

BIN
scripts/dog_and_girl.jpeg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 485 KiB

58
scripts/export_encoder.py Executable file
View File

@@ -0,0 +1,58 @@
# export_encoder.py
"""
导出 SAM 的 Image Encoder 为 ONNX。
假设你使用的是 vit_h 并且想用固定输入 1024x1024。
如果要换 model_type 或 checkpoint使用命令行参数 --model-type/--checkpoint/--out
"""
import argparse
import torch
import os
from segment_anything import sam_model_registry
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-type", type=str, default="vit_h")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--out", type=str, default="encoder.onnx")
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--img-size", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=1)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
print("Loading SAM model (this may take a while)...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
sam.to(device)
sam.eval()
# SAM 的 encoder 是 sam.image_encoder
image_encoder = sam.image_encoder
image_encoder.eval()
# 构建 dummy 输入,注意 SAM 的预处理通常是 [0,1] float32, 3xHxW
bs = args.batch_size
H = W = args.img_size
dummy_input = torch.randn(bs, 3, H, W, device=device, dtype=torch.float32)
# ONNX 导出。注意根据具体实现encoder.forward 可能需要额外参数;
# 此处采用直接调用 image_encoder(dummy_input) 的方式
input_names = ["input_image"]
output_names = ["image_embeddings"]
print(f"Exporting encoder to {args.out} ...")
torch.onnx.export(
image_encoder,
dummy_input,
args.out,
export_params=True,
opset_version=args.opset,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=None # 固定尺寸,更容易用于 TensorRT / Triton
)
print("Export done.")
if __name__ == "__main__":
main()

201
scripts/export_onnx_model.py Executable file
View File

@@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
import argparse
import warnings
try:
import onnxruntime # type: ignore
onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False
parser = argparse.ArgumentParser(
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
)
parser.add_argument(
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
)
parser.add_argument(
"--output", type=str, required=True, help="The filename to save the ONNX model to."
)
parser.add_argument(
"--model-type",
type=str,
required=True,
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
)
parser.add_argument(
"--return-single-mask",
action="store_true",
help=(
"If true, the exported ONNX model will only return the best mask, "
"instead of returning multiple masks. For high resolution images "
"this can improve runtime when upscaling masks is expensive."
),
)
parser.add_argument(
"--opset",
type=int,
default=17,
help="The ONNX opset version to use. Must be >=11",
)
parser.add_argument(
"--quantize-out",
type=str,
default=None,
help=(
"If set, will quantize the model and save it with this name. "
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
),
)
parser.add_argument(
"--gelu-approximate",
action="store_true",
help=(
"Replace GELU operations with approximations using tanh. Useful "
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
),
)
parser.add_argument(
"--use-stability-score",
action="store_true",
help=(
"Replaces the model's predicted mask quality score with the stability "
"score calculated on the low resolution masks using an offset of 1.0. "
),
)
parser.add_argument(
"--return-extra-metrics",
action="store_true",
help=(
"The model will return five results: (masks, scores, stability_scores, "
"areas, low_res_logits) instead of the usual three. This can be "
"significantly slower for high resolution outputs."
),
)
def run_export(
model_type: str,
checkpoint: str,
output: str,
opset: int,
return_single_mask: bool,
gelu_approximate: bool = False,
use_stability_score: bool = False,
return_extra_metrics=False,
):
print("Loading model...")
sam = sam_model_registry[model_type](checkpoint=checkpoint)
onnx_model = SamOnnxModel(
model=sam,
return_single_mask=return_single_mask,
use_stability_score=use_stability_score,
return_extra_metrics=return_extra_metrics,
)
if gelu_approximate:
for n, m in onnx_model.named_modules():
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
_ = onnx_model(**dummy_inputs)
output_names = ["masks", "iou_predictions", "low_res_masks"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(output, "wb") as f:
print(f"Exporting onnx model to {output}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
if onnxruntime_exists:
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
# set cpu provider default
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(output, providers=providers)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")
def to_numpy(tensor):
return tensor.cpu().numpy()
if __name__ == "__main__":
args = parser.parse_args()
run_export(
model_type=args.model_type,
checkpoint=args.checkpoint,
output=args.output,
opset=args.opset,
return_single_mask=args.return_single_mask,
gelu_approximate=args.gelu_approximate,
use_stability_score=args.use_stability_score,
return_extra_metrics=args.return_extra_metrics,
)
if args.quantize_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

BIN
scripts/segmentation_result.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

257
scripts/smg_box_point.py Executable file
View File

@@ -0,0 +1,257 @@
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageTk
import tkinter as tk
from tkinter import messagebox
from segment_anything import SamPredictor, sam_model_registry
class SAMPointBoxSegmenter:
def __init__(self, image_path, sam_checkpoint, model_type="vit_h"):
# 加载图像
self.image = Image.open(image_path).convert("RGB")
self.original_image = self.image.copy()
self.result_image = self.image.copy()
# 初始化SAM模型
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
# 准备图像供SAM使用
self.image_np = np.array(self.image)
self.predictor.set_image(self.image_np)
# 交互状态变量
self.drawing_box = False # 是否正在绘制框
self.start_x, self.start_y = 0, 0 # 框起点
self.current_box = None # 当前框坐标 [x1, y1, x2, y2]
self.points = [] # 点坐标及标签 [(x, y, label), ...]label=1前景0背景
self.mask = None # 生成的掩码
# 创建GUI
self.root = tk.Tk()
self.root.title("SAM 点框组合分割工具")
# 创建画布
self.tk_image = ImageTk.PhotoImage(image=self.image)
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
self.canvas.pack()
# 绑定鼠标事件
self.canvas.bind("<ButtonPress-1>", self.on_left_down) # 左键:开始画框或加点
self.canvas.bind("<B1-Motion>", self.on_left_drag) # 左键拖动:画框
self.canvas.bind("<ButtonRelease-1>", self.on_left_up) # 左键释放:确认框
self.canvas.bind("<Button-3>", self.add_background_point) # 右键:添加背景点
# 控制按钮
self.controls_frame = tk.Frame(self.root)
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
self.clear_box_btn = tk.Button(self.controls_frame, text="清除框", command=self.clear_box)
self.clear_box_btn.pack(side=tk.LEFT, padx=5)
self.clear_points_btn = tk.Button(self.controls_frame, text="清除点", command=self.clear_points)
self.clear_points_btn.pack(side=tk.LEFT, padx=5)
self.clear_all_btn = tk.Button(self.controls_frame, text="清除全部", command=self.clear_all)
self.clear_all_btn.pack(side=tk.LEFT, padx=5)
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
self.save_btn.pack(side=tk.LEFT, padx=5)
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
self.quit_btn.pack(side=tk.RIGHT, padx=5)
# 操作说明
self.info_label = tk.Label(
self.root,
text="操作说明: 左键拖动画框 | 左键点击加前景点(绿) | 右键点击加背景点(红) | 框和点可组合使用",
fg="blue",
wraplength=600 # 自动换行
)
self.info_label.pack(pady=5)
# ------------------------------ 鼠标事件处理 ------------------------------
def on_left_down(self, event):
"""左键按下:判断是开始画框还是添加前景点"""
# 若已有框且点击位置不在框起点附近,则视为添加前景点
if self.current_box is not None:
self.points.append((event.x, event.y, 1)) # 前景点(绿色)
self.update_segmentation()
return
# 否则视为开始画框
self.drawing_box = True
self.start_x, self.start_y = event.x, event.y
def on_left_drag(self, event):
"""左键拖动:绘制动态框(仅当正在画框时)"""
if not self.drawing_box:
return
# 临时图像上绘制虚线框
temp_img = self.original_image.copy()
draw = ImageDraw.Draw(temp_img)
# 绘制四条虚线边兼容所有Pillow版本
x1, y1 = self.start_x, self.start_y
x2, y2 = event.x, event.y
self.draw_dashed_line(draw, x1, y1, x2, y1, [5, 5], "red", 2) # 上边
self.draw_dashed_line(draw, x2, y1, x2, y2, [5, 5], "red", 2) # 右边
self.draw_dashed_line(draw, x2, y2, x1, y2, [5, 5], "red", 2) # 下边
self.draw_dashed_line(draw, x1, y2, x1, y1, [5, 5], "red", 2) # 左边
# 叠加已有点
self._draw_points(draw)
# 更新显示
self.tk_image = ImageTk.PhotoImage(image=temp_img)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def on_left_up(self, event):
"""左键释放:确认框选"""
if not self.drawing_box:
return
self.drawing_box = False
# 确定框坐标(左上角到右下角)
self.current_box = [
min(self.start_x, event.x),
min(self.start_y, event.y),
max(self.start_x, event.x),
max(self.start_y, event.y)
]
self.update_segmentation()
def add_background_point(self, event):
"""右键点击:添加背景点(红色)"""
self.points.append((event.x, event.y, 0))
self.update_segmentation()
# ------------------------------ 辅助函数 ------------------------------
def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width):
"""绘制虚线兼容所有Pillow版本"""
dx = x2 - x1
dy = y2 - y1
length = (dx ** 2 + dy ** 2) ** 0.5
if length < 1:
draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
return
ux, uy = dx / length, dy / length # 单位向量
current_pos = 0
dash_on = True
pattern_idx = 0
seg_len = dash_pattern[pattern_idx]
while current_pos < length:
end_pos = min(current_pos + seg_len, length)
x_start, y_start = x1 + ux * current_pos, y1 + uy * current_pos
x_end, y_end = x1 + ux * end_pos, y1 + uy * end_pos
if dash_on:
draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width)
dash_on = not dash_on
pattern_idx = (pattern_idx + 1) % len(dash_pattern)
seg_len = dash_pattern[pattern_idx]
current_pos = end_pos
def _draw_points(self, draw):
"""在图像上绘制所有点"""
for x, y, label in self.points:
color = (0, 255, 0) if label == 1 else (255, 0, 0) # 绿=前景,红=背景
draw.ellipse([(x - 5, y - 5), (x + 5, y + 5)], fill=color) # 内部实心
draw.ellipse([(x - 7, y - 7), (x + 7, y + 7)], outline=(255, 255, 255), width=2) # 白色边框
# ------------------------------ 分割逻辑 ------------------------------
def update_segmentation(self):
"""结合框和点生成掩码"""
self.result_image = self.original_image.copy()
draw = ImageDraw.Draw(self.result_image)
# 准备提示信息(框和点)
box = np.array(self.current_box) if self.current_box else None
point_coords = np.array([(x, y) for x, y, label in self.points]) if self.points else None
point_labels = np.array([label for x, y, label in self.points]) if self.points else None
# 若没有任何提示,直接显示原图
if box is None and point_coords is None:
self._draw_points(draw)
self.update_display()
return
# 调用SAM生成掩码同时传入框和点
masks, _, _ = self.predictor.predict(
box=box,
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False
)
self.mask = masks[0]
# 叠加掩码(半透明绿色)
mask_array = self.mask.astype(np.uint8) * 128 # 0或128半透明
mask_img = Image.fromarray(mask_array, mode="L")
green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0))
green_draw = ImageDraw.Draw(green_mask)
green_draw.bitmap((0, 0), mask_img, fill=(0, 255, 0, 128)) # 绿色半透明
self.result_image = Image.alpha_composite(
self.result_image.convert("RGBA"),
green_mask
).convert("RGB")
# 绘制框和点
if self.current_box:
draw.rectangle(self.current_box, outline="red", width=2) # 实线框
self._draw_points(draw)
self.update_display()
# ------------------------------ 控制函数 ------------------------------
def update_display(self):
"""更新画布显示"""
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def clear_box(self):
"""清除当前框"""
self.current_box = None
self.update_segmentation()
def clear_points(self):
"""清除所有点"""
self.points = []
self.update_segmentation()
def clear_all(self):
"""清除框和点"""
self.current_box = None
self.points = []
self.mask = None
self.result_image = self.original_image.copy()
self.update_display()
def save_result(self):
"""保存结果"""
try:
self.result_image.save("combined_segmentation_result.jpg")
messagebox.showinfo("成功", "结果已保存为 combined_segmentation_result.jpg")
except Exception as e:
messagebox.showerror("错误", f"保存失败: {str(e)}")
def run(self):
"""启动GUI"""
self.root.mainloop()
if __name__ == "__main__":
# 配置路径
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/home/alab/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应
# 运行工具
segmenter = SAMPointBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
segmenter.run()

42
scripts/test.py Executable file
View File

@@ -0,0 +1,42 @@
import time
import random
import cv2
import numpy as np
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
# 初始化SAM模型
sam = sam_model_registry["vit_h"](checkpoint="/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(sam)
# 读取并转换图像格式
image = cv2.imread("ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # SAM需要RGB格式
image_copy = image.copy() # 用于绘制掩码的原图副本BGR格式适配OpenCV
# 生成掩码并计时
start_time = time.time()
masks = mask_generator.generate(image_rgb)
print(f"掩码生成耗时: {time.time() - start_time:.2f}")
print(f"共生成 {len(masks)} 个掩码")
# 定义颜色生成函数随机RGB颜色适配OpenCV的BGR格式
def get_random_color():
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
# 遍历所有掩码并绘制到图像上
for mask in masks:
# 获取掩码的二进制数组True/False
mask_array = mask["segmentation"]
# 生成随机颜色
color = get_random_color()
# 将掩码区域绘制到图像副本上(半透明效果)
image_copy[mask_array] = image_copy[mask_array] * 0.5 + np.array(color) * 0.5
# 保存结果图像
cv2.imwrite("mask_visualization.jpg", image_copy)
# 显示结果(如果是桌面环境)
cv2.namedWindow("SAM Mask Visualization", cv2.WINDOW_NORMAL)
cv2.imshow("SAM Mask Visualization", image_copy)
cv2.waitKey(0)
cv2.destroyAllWindows()