fitst commit
BIN
scripts/076ec5a6-0d99-4867-b38d-94c026000691-3-89.jpeg
Executable file
|
After Width: | Height: | Size: 104 KiB |
BIN
scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg
Executable file
|
After Width: | Height: | Size: 108 KiB |
238
scripts/amg.py
Executable file
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import cv2 # type: ignore
|
||||
|
||||
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Runs automatic mask generation on an input image or directory of images, "
|
||||
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
|
||||
"as well as pycocotools if saving in RLE format."
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to either a single input image or folder of images.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Path to the directory where masks will be output. Output will be either a folder "
|
||||
"of PNGs per image or a single json with COCO-style masks."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The path to the SAM checkpoint to use for mask generation.",
|
||||
)
|
||||
|
||||
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
|
||||
|
||||
parser.add_argument(
|
||||
"--convert-to-rle",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
|
||||
"Requires pycocotools."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings = parser.add_argument_group("AMG Settings")
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-side",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Generate masks by sampling a grid over the image with this many points to a side.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-batch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many input points to process simultaneously in one batch.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--pred-iou-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a stability score lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-offset",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Larger values perturb the mask more when measuring stability score.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--box-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding a duplicate mask.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-layers",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
|
||||
"The value sets how many different scales to crop at."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding duplicate masks across different crops.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-overlap-ratio",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Larger numbers mean image crops will overlap more.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-points-downscale-factor",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of points-per-side in each layer of crop is reduced by this factor.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--min-mask-region-area",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Disconnected mask regions or holes with area smaller than this value "
|
||||
"in pixels are removed by postprocessing."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
|
||||
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
|
||||
metadata = [header]
|
||||
for i, mask_data in enumerate(masks):
|
||||
mask = mask_data["segmentation"]
|
||||
filename = f"{i}.png"
|
||||
cv2.imwrite(os.path.join(path, filename), mask * 255)
|
||||
mask_metadata = [
|
||||
str(i),
|
||||
str(mask_data["area"]),
|
||||
*[str(x) for x in mask_data["bbox"]],
|
||||
*[str(x) for x in mask_data["point_coords"][0]],
|
||||
str(mask_data["predicted_iou"]),
|
||||
str(mask_data["stability_score"]),
|
||||
*[str(x) for x in mask_data["crop_box"]],
|
||||
]
|
||||
row = ",".join(mask_metadata)
|
||||
metadata.append(row)
|
||||
metadata_path = os.path.join(path, "metadata.csv")
|
||||
with open(metadata_path, "w") as f:
|
||||
f.write("\n".join(metadata))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_amg_kwargs(args):
|
||||
amg_kwargs = {
|
||||
"points_per_side": args.points_per_side,
|
||||
"points_per_batch": args.points_per_batch,
|
||||
"pred_iou_thresh": args.pred_iou_thresh,
|
||||
"stability_score_thresh": args.stability_score_thresh,
|
||||
"stability_score_offset": args.stability_score_offset,
|
||||
"box_nms_thresh": args.box_nms_thresh,
|
||||
"crop_n_layers": args.crop_n_layers,
|
||||
"crop_nms_thresh": args.crop_nms_thresh,
|
||||
"crop_overlap_ratio": args.crop_overlap_ratio,
|
||||
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
|
||||
"min_mask_region_area": args.min_mask_region_area,
|
||||
}
|
||||
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
|
||||
return amg_kwargs
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
print("Loading model...")
|
||||
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
||||
_ = sam.to(device=args.device)
|
||||
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
|
||||
amg_kwargs = get_amg_kwargs(args)
|
||||
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
|
||||
|
||||
if not os.path.isdir(args.input):
|
||||
targets = [args.input]
|
||||
else:
|
||||
targets = [
|
||||
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
|
||||
]
|
||||
targets = [os.path.join(args.input, f) for f in targets]
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
for t in targets:
|
||||
print(f"Processing '{t}'...")
|
||||
image = cv2.imread(t)
|
||||
if image is None:
|
||||
print(f"Could not load '{t}' as an image, skipping...")
|
||||
continue
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
masks = generator.generate(image)
|
||||
|
||||
base = os.path.basename(t)
|
||||
base = os.path.splitext(base)[0]
|
||||
save_base = os.path.join(args.output, base)
|
||||
if output_mode == "binary_mask":
|
||||
os.makedirs(save_base, exist_ok=False)
|
||||
write_masks_to_folder(masks, save_base)
|
||||
else:
|
||||
save_file = save_base + ".json"
|
||||
with open(save_file, "w") as f:
|
||||
json.dump(masks, f)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
238
scripts/amg_box.py
Executable file
@@ -0,0 +1,238 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageTk
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
|
||||
|
||||
class SAMBoxSegmenter:
|
||||
def __init__(self, image_path, sam_checkpoint, model_type="vit_h"):
|
||||
# 加载图像
|
||||
self.image = Image.open(image_path).convert("RGB")
|
||||
self.original_image = self.image.copy()
|
||||
self.result_image = self.image.copy()
|
||||
|
||||
# 初始化SAM模型
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
self.sam.to(device=self.device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
|
||||
# 准备图像供SAM使用
|
||||
self.image_np = np.array(self.image)
|
||||
self.predictor.set_image(self.image_np)
|
||||
|
||||
# 框选相关变量
|
||||
self.drawing = False
|
||||
self.start_x = 0
|
||||
self.start_y = 0
|
||||
self.current_box = None
|
||||
self.mask = None
|
||||
|
||||
# 创建GUI
|
||||
self.root = tk.Tk()
|
||||
self.root.title("SAM 手绘框分割工具")
|
||||
|
||||
# 创建画布
|
||||
self.tk_image = ImageTk.PhotoImage(image=self.image)
|
||||
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
self.canvas.pack()
|
||||
|
||||
# 绑定鼠标事件
|
||||
self.canvas.bind("<ButtonPress-1>", self.on_mouse_down) # 鼠标按下
|
||||
self.canvas.bind("<B1-Motion>", self.on_mouse_drag) # 鼠标拖动
|
||||
self.canvas.bind("<ButtonRelease-1>", self.on_mouse_up) # 鼠标释放
|
||||
|
||||
# 创建按钮
|
||||
self.controls_frame = tk.Frame(self.root)
|
||||
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
|
||||
|
||||
self.clear_btn = tk.Button(self.controls_frame, text="清除框选", command=self.clear_box)
|
||||
self.clear_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
|
||||
self.save_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
|
||||
self.quit_btn.pack(side=tk.RIGHT, padx=5)
|
||||
|
||||
# 添加说明标签
|
||||
self.info_label = tk.Label(
|
||||
self.root,
|
||||
text="操作说明: 按住鼠标左键拖动绘制框选区域,松开后自动分割",
|
||||
fg="blue"
|
||||
)
|
||||
self.info_label.pack(pady=5)
|
||||
|
||||
def on_mouse_down(self, event):
|
||||
"""鼠标按下时记录起点"""
|
||||
self.drawing = True
|
||||
self.start_x = event.x
|
||||
self.start_y = event.y
|
||||
|
||||
def on_mouse_drag(self, event):
|
||||
"""鼠标拖动时实时绘制框选(使用兼容方式实现虚线框)"""
|
||||
if not self.drawing:
|
||||
return
|
||||
|
||||
# 临时复制原图用于绘制动态框
|
||||
temp_image = self.original_image.copy()
|
||||
draw = ImageDraw.Draw(temp_image)
|
||||
|
||||
# 计算框选坐标
|
||||
x1, y1 = self.start_x, self.start_y
|
||||
x2, y2 = event.x, event.y
|
||||
|
||||
# 绘制虚线框的替代方法(兼容所有Pillow版本)
|
||||
# 绘制四条边,每条边用虚线模式
|
||||
dash_pattern = [5, 5] # 虚线样式:5像素实线,5像素空白
|
||||
self.draw_dashed_line(draw, x1, y1, x2, y1, dash_pattern, "red", 2) # 上边
|
||||
self.draw_dashed_line(draw, x2, y1, x2, y2, dash_pattern, "red", 2) # 右边
|
||||
self.draw_dashed_line(draw, x2, y2, x1, y2, dash_pattern, "red", 2) # 下边
|
||||
self.draw_dashed_line(draw, x1, y2, x1, y1, dash_pattern, "red", 2) # 左边
|
||||
|
||||
# 更新显示
|
||||
self.tk_image = ImageTk.PhotoImage(image=temp_image)
|
||||
self.canvas.delete("all")
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
|
||||
def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width):
|
||||
"""绘制虚线的工具函数,兼容所有Pillow版本"""
|
||||
# 计算线段长度和方向
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
length = (dx ** 2 + dy ** 2) ** 0.5
|
||||
|
||||
# 如果线段太短,直接画实线
|
||||
if length < 1:
|
||||
draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
|
||||
return
|
||||
|
||||
# 计算单位向量
|
||||
ux = dx / length
|
||||
uy = dy / length
|
||||
|
||||
# 绘制虚线
|
||||
current_pos = 0
|
||||
dash_on = True # 开始时是实线部分
|
||||
pattern_index = 0
|
||||
segment_length = dash_pattern[pattern_index]
|
||||
|
||||
while current_pos < length:
|
||||
# 计算当前段的结束位置
|
||||
end_pos = current_pos + segment_length
|
||||
if end_pos > length:
|
||||
end_pos = length
|
||||
|
||||
# 计算像素坐标
|
||||
x_start = x1 + ux * current_pos
|
||||
y_start = y1 + uy * current_pos
|
||||
x_end = x1 + ux * end_pos
|
||||
y_end = y1 + uy * end_pos
|
||||
|
||||
# 如果是实线部分,绘制线段
|
||||
if dash_on:
|
||||
draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width)
|
||||
|
||||
# 切换状态并更新参数
|
||||
dash_on = not dash_on
|
||||
pattern_index = (pattern_index + 1) % len(dash_pattern)
|
||||
segment_length = dash_pattern[pattern_index]
|
||||
current_pos = end_pos
|
||||
|
||||
def on_mouse_up(self, event):
|
||||
"""鼠标释放时确定框选并生成掩码"""
|
||||
if not self.drawing:
|
||||
return
|
||||
|
||||
self.drawing = False
|
||||
# 确保坐标正确(左上角到右下角)
|
||||
self.current_box = [
|
||||
min(self.start_x, event.x),
|
||||
min(self.start_y, event.y),
|
||||
max(self.start_x, event.x),
|
||||
max(self.start_y, event.y)
|
||||
]
|
||||
|
||||
# 生成分割掩码
|
||||
self.update_segmentation()
|
||||
|
||||
def update_segmentation(self):
|
||||
print('update_segmentation')
|
||||
"""根据当前框选更新分割结果"""
|
||||
if not self.current_box:
|
||||
self.result_image = self.original_image.copy()
|
||||
self.update_display()
|
||||
return
|
||||
|
||||
# 准备框坐标
|
||||
box = np.array(self.current_box)
|
||||
|
||||
# 调用SAM生成掩码
|
||||
masks, _, _ = self.predictor.predict(
|
||||
box=box,
|
||||
multimask_output=False
|
||||
)
|
||||
self.mask = masks[0]
|
||||
|
||||
# 复制原图用于绘制结果
|
||||
self.result_image = self.original_image.copy()
|
||||
|
||||
# 创建半透明的掩码叠加层
|
||||
mask_array = self.mask.astype(np.uint8) * 128 # 0或128(半透明)
|
||||
mask_image = Image.fromarray(mask_array, mode="L")
|
||||
|
||||
# 创建绿色的掩码图像
|
||||
green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0))
|
||||
green_draw = ImageDraw.Draw(green_mask)
|
||||
green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明
|
||||
|
||||
# 叠加掩码到原图
|
||||
self.result_image = Image.alpha_composite(
|
||||
self.result_image.convert("RGBA"),
|
||||
green_mask
|
||||
).convert("RGB")
|
||||
|
||||
# 绘制最终框选(红色实线)
|
||||
draw = ImageDraw.Draw(self.result_image)
|
||||
draw.rectangle(self.current_box, outline="red", width=2)
|
||||
|
||||
self.update_display()
|
||||
|
||||
def update_display(self):
|
||||
"""更新画布显示"""
|
||||
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
|
||||
self.canvas.delete("all")
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
|
||||
def clear_box(self):
|
||||
"""清除框选和掩码"""
|
||||
self.current_box = None
|
||||
self.mask = None
|
||||
self.result_image = self.original_image.copy()
|
||||
self.update_display()
|
||||
|
||||
def save_result(self):
|
||||
"""保存分割结果"""
|
||||
try:
|
||||
self.result_image.save("box_segmentation_result.jpg")
|
||||
messagebox.showinfo("成功", "分割结果已保存为 box_segmentation_result.jpg")
|
||||
except Exception as e:
|
||||
messagebox.showerror("错误", f"保存失败: {str(e)}")
|
||||
|
||||
def run(self):
|
||||
"""运行GUI主循环"""
|
||||
self.root.mainloop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||
|
||||
# 创建并运行分割器
|
||||
segmenter = SAMBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
|
||||
segmenter.run()
|
||||
158
scripts/amg_point.py
Executable file
@@ -0,0 +1,158 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageTk
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
|
||||
|
||||
class SAMPointSegmenter:
|
||||
def __init__(self, image_path, sam_checkpoint, model_type="vit_b"):
|
||||
# 加载图像
|
||||
self.image = Image.open(image_path).convert("RGB")
|
||||
self.original_image = self.image.copy()
|
||||
self.result_image = self.image.copy()
|
||||
|
||||
# 初始化SAM模型
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
self.sam.to(device=self.device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
|
||||
# 准备图像供SAM使用
|
||||
self.image_np = np.array(self.image)
|
||||
self.predictor.set_image(self.image_np)
|
||||
|
||||
# 存储点击的点 [(x, y, label), ...],label=1表示目标内,0表示目标外
|
||||
self.points = []
|
||||
self.mask = None
|
||||
|
||||
# 创建GUI
|
||||
self.root = tk.Tk()
|
||||
self.root.title("SAM 打点分割工具")
|
||||
|
||||
# 创建画布
|
||||
self.tk_image = ImageTk.PhotoImage(image=self.image)
|
||||
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
self.canvas.pack()
|
||||
|
||||
# 绑定鼠标事件
|
||||
self.canvas.bind("<Button-1>", self.add_foreground_point) # 左键点击添加前景点
|
||||
self.canvas.bind("<Button-3>", self.add_background_point) # 右键点击添加背景点
|
||||
|
||||
# 创建按钮
|
||||
self.controls_frame = tk.Frame(self.root)
|
||||
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
|
||||
|
||||
self.clear_btn = tk.Button(self.controls_frame, text="清除所有点", command=self.clear_points)
|
||||
self.clear_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
|
||||
self.save_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
|
||||
self.quit_btn.pack(side=tk.RIGHT, padx=5)
|
||||
|
||||
# 添加说明标签
|
||||
self.info_label = tk.Label(
|
||||
self.root,
|
||||
text="操作说明: 左键点击添加前景点(绿色), 右键点击添加背景点(红色)",
|
||||
fg="blue"
|
||||
)
|
||||
self.info_label.pack(pady=5)
|
||||
|
||||
def add_foreground_point(self, event):
|
||||
"""添加前景点(目标内)"""
|
||||
self.points.append((event.x, event.y, 1))
|
||||
self.update_segmentation()
|
||||
|
||||
def add_background_point(self, event):
|
||||
"""添加背景点(目标外)"""
|
||||
self.points.append((event.x, event.y, 0))
|
||||
self.update_segmentation()
|
||||
|
||||
def update_segmentation(self):
|
||||
"""根据当前点更新分割结果"""
|
||||
# 复制原图
|
||||
self.result_image = self.original_image.copy()
|
||||
draw = ImageDraw.Draw(self.result_image)
|
||||
|
||||
if not self.points:
|
||||
self.update_display()
|
||||
return
|
||||
|
||||
# 准备点和标签
|
||||
point_coords = np.array([(x, y) for x, y, label in self.points])
|
||||
point_labels = np.array([label for x, y, label in self.points])
|
||||
|
||||
# 调用SAM生成掩码
|
||||
masks, _, _ = self.predictor.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels,
|
||||
multimask_output=False
|
||||
)
|
||||
self.mask = masks[0]
|
||||
|
||||
# 创建半透明的掩码叠加层
|
||||
mask_array = self.mask.astype(np.uint8) * 128 # 0或128(半透明)
|
||||
mask_image = Image.fromarray(mask_array, mode="L")
|
||||
|
||||
# 创建绿色的掩码图像
|
||||
green_mask = Image.new("RGBA", self.result_image.size, (0, 255, 0, 0))
|
||||
green_draw = ImageDraw.Draw(green_mask)
|
||||
green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明
|
||||
|
||||
# 叠加掩码到原图
|
||||
self.result_image = Image.alpha_composite(
|
||||
self.result_image.convert("RGBA"),
|
||||
green_mask
|
||||
).convert("RGB")
|
||||
|
||||
# 绘制所有点
|
||||
draw = ImageDraw.Draw(self.result_image)
|
||||
for x, y, label in self.points:
|
||||
# 前景点为绿色,背景点为红色
|
||||
color = (0, 255, 0) if label == 1 else (255, 0, 0)
|
||||
# 绘制点(内部实心,外部白色边框)
|
||||
draw.ellipse([(x - 5, y - 5), (x + 5, y + 5)], fill=color)
|
||||
draw.ellipse([(x - 7, y - 7), (x + 7, y + 7)], outline=(255, 255, 255), width=2)
|
||||
|
||||
self.update_display()
|
||||
|
||||
def update_display(self):
|
||||
"""更新画布显示"""
|
||||
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
|
||||
self.canvas.delete("all")
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
|
||||
def clear_points(self):
|
||||
"""清除所有点和掩码"""
|
||||
self.points = []
|
||||
self.mask = None
|
||||
self.result_image = self.original_image.copy()
|
||||
self.update_display()
|
||||
|
||||
def save_result(self):
|
||||
"""保存分割结果"""
|
||||
try:
|
||||
self.result_image.save("segmentation_result.jpg")
|
||||
messagebox.showinfo("成功", "分割结果已保存为 segmentation_result.jpg")
|
||||
except Exception as e:
|
||||
messagebox.showerror("错误", f"保存失败: {str(e)}")
|
||||
|
||||
def run(self):
|
||||
"""运行GUI主循环"""
|
||||
self.root.mainloop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
|
||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||
|
||||
# 创建并运行分割器
|
||||
segmenter = SAMPointSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
|
||||
segmenter.run()
|
||||
BIN
scripts/box_segmentation_result.jpg
Executable file
|
After Width: | Height: | Size: 54 KiB |
BIN
scripts/combined_segmentation_result.jpg
Executable file
|
After Width: | Height: | Size: 52 KiB |
BIN
scripts/dog_and_girl.jpeg
Executable file
|
After Width: | Height: | Size: 485 KiB |
58
scripts/export_encoder.py
Executable file
@@ -0,0 +1,58 @@
|
||||
# export_encoder.py
|
||||
"""
|
||||
导出 SAM 的 Image Encoder 为 ONNX。
|
||||
假设你使用的是 vit_h 并且想用固定输入 1024x1024。
|
||||
如果要换 model_type 或 checkpoint,使用命令行参数 --model-type/--checkpoint/--out
|
||||
"""
|
||||
import argparse
|
||||
import torch
|
||||
import os
|
||||
from segment_anything import sam_model_registry
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-type", type=str, default="vit_h")
|
||||
parser.add_argument("--checkpoint", type=str, required=True)
|
||||
parser.add_argument("--out", type=str, default="encoder.onnx")
|
||||
parser.add_argument("--opset", type=int, default=17)
|
||||
parser.add_argument("--img-size", type=int, default=1024)
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"device: {device}")
|
||||
print("Loading SAM model (this may take a while)...")
|
||||
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
||||
sam.to(device)
|
||||
sam.eval()
|
||||
|
||||
# SAM 的 encoder 是 sam.image_encoder
|
||||
image_encoder = sam.image_encoder
|
||||
image_encoder.eval()
|
||||
|
||||
# 构建 dummy 输入,注意 SAM 的预处理通常是 [0,1] float32, 3xHxW
|
||||
bs = args.batch_size
|
||||
H = W = args.img_size
|
||||
dummy_input = torch.randn(bs, 3, H, W, device=device, dtype=torch.float32)
|
||||
|
||||
# ONNX 导出。注意:根据具体实现,encoder.forward 可能需要额外参数;
|
||||
# 此处采用直接调用 image_encoder(dummy_input) 的方式
|
||||
input_names = ["input_image"]
|
||||
output_names = ["image_embeddings"]
|
||||
|
||||
print(f"Exporting encoder to {args.out} ...")
|
||||
torch.onnx.export(
|
||||
image_encoder,
|
||||
dummy_input,
|
||||
args.out,
|
||||
export_params=True,
|
||||
opset_version=args.opset,
|
||||
do_constant_folding=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=None # 固定尺寸,更容易用于 TensorRT / Triton
|
||||
)
|
||||
print("Export done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
201
scripts/export_onnx_model.py
Executable file
@@ -0,0 +1,201 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from segment_anything import sam_model_registry
|
||||
from segment_anything.utils.onnx import SamOnnxModel
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import onnxruntime # type: ignore
|
||||
|
||||
onnxruntime_exists = True
|
||||
except ImportError:
|
||||
onnxruntime_exists = False
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output", type=str, required=True, help="The filename to save the ONNX model to."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
required=True,
|
||||
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-single-mask",
|
||||
action="store_true",
|
||||
help=(
|
||||
"If true, the exported ONNX model will only return the best mask, "
|
||||
"instead of returning multiple masks. For high resolution images "
|
||||
"this can improve runtime when upscaling masks is expensive."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
type=int,
|
||||
default=17,
|
||||
help="The ONNX opset version to use. Must be >=11",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--quantize-out",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"If set, will quantize the model and save it with this name. "
|
||||
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gelu-approximate",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Replace GELU operations with approximations using tanh. Useful "
|
||||
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-stability-score",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Replaces the model's predicted mask quality score with the stability "
|
||||
"score calculated on the low resolution masks using an offset of 1.0. "
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-extra-metrics",
|
||||
action="store_true",
|
||||
help=(
|
||||
"The model will return five results: (masks, scores, stability_scores, "
|
||||
"areas, low_res_logits) instead of the usual three. This can be "
|
||||
"significantly slower for high resolution outputs."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_export(
|
||||
model_type: str,
|
||||
checkpoint: str,
|
||||
output: str,
|
||||
opset: int,
|
||||
return_single_mask: bool,
|
||||
gelu_approximate: bool = False,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics=False,
|
||||
):
|
||||
print("Loading model...")
|
||||
sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
|
||||
onnx_model = SamOnnxModel(
|
||||
model=sam,
|
||||
return_single_mask=return_single_mask,
|
||||
use_stability_score=use_stability_score,
|
||||
return_extra_metrics=return_extra_metrics,
|
||||
)
|
||||
|
||||
if gelu_approximate:
|
||||
for n, m in onnx_model.named_modules():
|
||||
if isinstance(m, torch.nn.GELU):
|
||||
m.approximate = "tanh"
|
||||
|
||||
dynamic_axes = {
|
||||
"point_coords": {1: "num_points"},
|
||||
"point_labels": {1: "num_points"},
|
||||
}
|
||||
|
||||
embed_dim = sam.prompt_encoder.embed_dim
|
||||
embed_size = sam.prompt_encoder.image_embedding_size
|
||||
mask_input_size = [4 * x for x in embed_size]
|
||||
dummy_inputs = {
|
||||
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
|
||||
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
|
||||
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
|
||||
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
|
||||
"has_mask_input": torch.tensor([1], dtype=torch.float),
|
||||
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
|
||||
}
|
||||
|
||||
_ = onnx_model(**dummy_inputs)
|
||||
|
||||
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
with open(output, "wb") as f:
|
||||
print(f"Exporting onnx model to {output}...")
|
||||
torch.onnx.export(
|
||||
onnx_model,
|
||||
tuple(dummy_inputs.values()),
|
||||
f,
|
||||
export_params=True,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
do_constant_folding=True,
|
||||
input_names=list(dummy_inputs.keys()),
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
|
||||
if onnxruntime_exists:
|
||||
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
|
||||
# set cpu provider default
|
||||
providers = ["CPUExecutionProvider"]
|
||||
ort_session = onnxruntime.InferenceSession(output, providers=providers)
|
||||
_ = ort_session.run(None, ort_inputs)
|
||||
print("Model has successfully been run with ONNXRuntime.")
|
||||
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
run_export(
|
||||
model_type=args.model_type,
|
||||
checkpoint=args.checkpoint,
|
||||
output=args.output,
|
||||
opset=args.opset,
|
||||
return_single_mask=args.return_single_mask,
|
||||
gelu_approximate=args.gelu_approximate,
|
||||
use_stability_score=args.use_stability_score,
|
||||
return_extra_metrics=args.return_extra_metrics,
|
||||
)
|
||||
|
||||
if args.quantize_out is not None:
|
||||
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
|
||||
from onnxruntime.quantization import QuantType # type: ignore
|
||||
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
|
||||
|
||||
print(f"Quantizing model and writing to {args.quantize_out}...")
|
||||
quantize_dynamic(
|
||||
model_input=args.output,
|
||||
model_output=args.quantize_out,
|
||||
optimize_model=True,
|
||||
per_channel=False,
|
||||
reduce_range=False,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
print("Done!")
|
||||
BIN
scripts/mask_visualization.jpg
Normal file
|
After Width: | Height: | Size: 92 KiB |
BIN
scripts/segmentation_result.jpg
Executable file
|
After Width: | Height: | Size: 52 KiB |
257
scripts/smg_box_point.py
Executable file
@@ -0,0 +1,257 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageTk
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
|
||||
|
||||
class SAMPointBoxSegmenter:
|
||||
def __init__(self, image_path, sam_checkpoint, model_type="vit_h"):
|
||||
# 加载图像
|
||||
self.image = Image.open(image_path).convert("RGB")
|
||||
self.original_image = self.image.copy()
|
||||
self.result_image = self.image.copy()
|
||||
|
||||
# 初始化SAM模型
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
self.sam.to(device=self.device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
|
||||
# 准备图像供SAM使用
|
||||
self.image_np = np.array(self.image)
|
||||
self.predictor.set_image(self.image_np)
|
||||
|
||||
# 交互状态变量
|
||||
self.drawing_box = False # 是否正在绘制框
|
||||
self.start_x, self.start_y = 0, 0 # 框起点
|
||||
self.current_box = None # 当前框坐标 [x1, y1, x2, y2]
|
||||
self.points = [] # 点坐标及标签 [(x, y, label), ...],label=1前景,0背景
|
||||
self.mask = None # 生成的掩码
|
||||
|
||||
# 创建GUI
|
||||
self.root = tk.Tk()
|
||||
self.root.title("SAM 点框组合分割工具")
|
||||
|
||||
# 创建画布
|
||||
self.tk_image = ImageTk.PhotoImage(image=self.image)
|
||||
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
self.canvas.pack()
|
||||
|
||||
# 绑定鼠标事件
|
||||
self.canvas.bind("<ButtonPress-1>", self.on_left_down) # 左键:开始画框或加点
|
||||
self.canvas.bind("<B1-Motion>", self.on_left_drag) # 左键拖动:画框
|
||||
self.canvas.bind("<ButtonRelease-1>", self.on_left_up) # 左键释放:确认框
|
||||
self.canvas.bind("<Button-3>", self.add_background_point) # 右键:添加背景点
|
||||
|
||||
# 控制按钮
|
||||
self.controls_frame = tk.Frame(self.root)
|
||||
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
|
||||
|
||||
self.clear_box_btn = tk.Button(self.controls_frame, text="清除框", command=self.clear_box)
|
||||
self.clear_box_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.clear_points_btn = tk.Button(self.controls_frame, text="清除点", command=self.clear_points)
|
||||
self.clear_points_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.clear_all_btn = tk.Button(self.controls_frame, text="清除全部", command=self.clear_all)
|
||||
self.clear_all_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
|
||||
self.save_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
|
||||
self.quit_btn.pack(side=tk.RIGHT, padx=5)
|
||||
|
||||
# 操作说明
|
||||
self.info_label = tk.Label(
|
||||
self.root,
|
||||
text="操作说明: 左键拖动画框 | 左键点击加前景点(绿) | 右键点击加背景点(红) | 框和点可组合使用",
|
||||
fg="blue",
|
||||
wraplength=600 # 自动换行
|
||||
)
|
||||
self.info_label.pack(pady=5)
|
||||
|
||||
# ------------------------------ 鼠标事件处理 ------------------------------
|
||||
def on_left_down(self, event):
|
||||
"""左键按下:判断是开始画框还是添加前景点"""
|
||||
# 若已有框且点击位置不在框起点附近,则视为添加前景点
|
||||
if self.current_box is not None:
|
||||
self.points.append((event.x, event.y, 1)) # 前景点(绿色)
|
||||
self.update_segmentation()
|
||||
return
|
||||
|
||||
# 否则视为开始画框
|
||||
self.drawing_box = True
|
||||
self.start_x, self.start_y = event.x, event.y
|
||||
|
||||
def on_left_drag(self, event):
|
||||
"""左键拖动:绘制动态框(仅当正在画框时)"""
|
||||
if not self.drawing_box:
|
||||
return
|
||||
|
||||
# 临时图像上绘制虚线框
|
||||
temp_img = self.original_image.copy()
|
||||
draw = ImageDraw.Draw(temp_img)
|
||||
# 绘制四条虚线边(兼容所有Pillow版本)
|
||||
x1, y1 = self.start_x, self.start_y
|
||||
x2, y2 = event.x, event.y
|
||||
self.draw_dashed_line(draw, x1, y1, x2, y1, [5, 5], "red", 2) # 上边
|
||||
self.draw_dashed_line(draw, x2, y1, x2, y2, [5, 5], "red", 2) # 右边
|
||||
self.draw_dashed_line(draw, x2, y2, x1, y2, [5, 5], "red", 2) # 下边
|
||||
self.draw_dashed_line(draw, x1, y2, x1, y1, [5, 5], "red", 2) # 左边
|
||||
|
||||
# 叠加已有点
|
||||
self._draw_points(draw)
|
||||
|
||||
# 更新显示
|
||||
self.tk_image = ImageTk.PhotoImage(image=temp_img)
|
||||
self.canvas.delete("all")
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
|
||||
def on_left_up(self, event):
|
||||
"""左键释放:确认框选"""
|
||||
if not self.drawing_box:
|
||||
return
|
||||
|
||||
self.drawing_box = False
|
||||
# 确定框坐标(左上角到右下角)
|
||||
self.current_box = [
|
||||
min(self.start_x, event.x),
|
||||
min(self.start_y, event.y),
|
||||
max(self.start_x, event.x),
|
||||
max(self.start_y, event.y)
|
||||
]
|
||||
self.update_segmentation()
|
||||
|
||||
def add_background_point(self, event):
|
||||
"""右键点击:添加背景点(红色)"""
|
||||
self.points.append((event.x, event.y, 0))
|
||||
self.update_segmentation()
|
||||
|
||||
# ------------------------------ 辅助函数 ------------------------------
|
||||
def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width):
|
||||
"""绘制虚线(兼容所有Pillow版本)"""
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
length = (dx ** 2 + dy ** 2) ** 0.5
|
||||
if length < 1:
|
||||
draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
|
||||
return
|
||||
ux, uy = dx / length, dy / length # 单位向量
|
||||
current_pos = 0
|
||||
dash_on = True
|
||||
pattern_idx = 0
|
||||
seg_len = dash_pattern[pattern_idx]
|
||||
|
||||
while current_pos < length:
|
||||
end_pos = min(current_pos + seg_len, length)
|
||||
x_start, y_start = x1 + ux * current_pos, y1 + uy * current_pos
|
||||
x_end, y_end = x1 + ux * end_pos, y1 + uy * end_pos
|
||||
if dash_on:
|
||||
draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width)
|
||||
dash_on = not dash_on
|
||||
pattern_idx = (pattern_idx + 1) % len(dash_pattern)
|
||||
seg_len = dash_pattern[pattern_idx]
|
||||
current_pos = end_pos
|
||||
|
||||
def _draw_points(self, draw):
|
||||
"""在图像上绘制所有点"""
|
||||
for x, y, label in self.points:
|
||||
color = (0, 255, 0) if label == 1 else (255, 0, 0) # 绿=前景,红=背景
|
||||
draw.ellipse([(x - 5, y - 5), (x + 5, y + 5)], fill=color) # 内部实心
|
||||
draw.ellipse([(x - 7, y - 7), (x + 7, y + 7)], outline=(255, 255, 255), width=2) # 白色边框
|
||||
|
||||
# ------------------------------ 分割逻辑 ------------------------------
|
||||
def update_segmentation(self):
|
||||
"""结合框和点生成掩码"""
|
||||
self.result_image = self.original_image.copy()
|
||||
draw = ImageDraw.Draw(self.result_image)
|
||||
|
||||
# 准备提示信息(框和点)
|
||||
box = np.array(self.current_box) if self.current_box else None
|
||||
point_coords = np.array([(x, y) for x, y, label in self.points]) if self.points else None
|
||||
point_labels = np.array([label for x, y, label in self.points]) if self.points else None
|
||||
|
||||
# 若没有任何提示,直接显示原图
|
||||
if box is None and point_coords is None:
|
||||
self._draw_points(draw)
|
||||
self.update_display()
|
||||
return
|
||||
|
||||
# 调用SAM生成掩码(同时传入框和点)
|
||||
masks, _, _ = self.predictor.predict(
|
||||
box=box,
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels,
|
||||
multimask_output=False
|
||||
)
|
||||
self.mask = masks[0]
|
||||
|
||||
# 叠加掩码(半透明绿色)
|
||||
mask_array = self.mask.astype(np.uint8) * 128 # 0或128(半透明)
|
||||
mask_img = Image.fromarray(mask_array, mode="L")
|
||||
green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0))
|
||||
green_draw = ImageDraw.Draw(green_mask)
|
||||
green_draw.bitmap((0, 0), mask_img, fill=(0, 255, 0, 128)) # 绿色半透明
|
||||
self.result_image = Image.alpha_composite(
|
||||
self.result_image.convert("RGBA"),
|
||||
green_mask
|
||||
).convert("RGB")
|
||||
|
||||
# 绘制框和点
|
||||
if self.current_box:
|
||||
draw.rectangle(self.current_box, outline="red", width=2) # 实线框
|
||||
self._draw_points(draw)
|
||||
|
||||
self.update_display()
|
||||
|
||||
# ------------------------------ 控制函数 ------------------------------
|
||||
def update_display(self):
|
||||
"""更新画布显示"""
|
||||
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
|
||||
self.canvas.delete("all")
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
|
||||
def clear_box(self):
|
||||
"""清除当前框"""
|
||||
self.current_box = None
|
||||
self.update_segmentation()
|
||||
|
||||
def clear_points(self):
|
||||
"""清除所有点"""
|
||||
self.points = []
|
||||
self.update_segmentation()
|
||||
|
||||
def clear_all(self):
|
||||
"""清除框和点"""
|
||||
self.current_box = None
|
||||
self.points = []
|
||||
self.mask = None
|
||||
self.result_image = self.original_image.copy()
|
||||
self.update_display()
|
||||
|
||||
def save_result(self):
|
||||
"""保存结果"""
|
||||
try:
|
||||
self.result_image.save("combined_segmentation_result.jpg")
|
||||
messagebox.showinfo("成功", "结果已保存为 combined_segmentation_result.jpg")
|
||||
except Exception as e:
|
||||
messagebox.showerror("错误", f"保存失败: {str(e)}")
|
||||
|
||||
def run(self):
|
||||
"""启动GUI"""
|
||||
self.root.mainloop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置路径
|
||||
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/home/alab/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||
|
||||
# 运行工具
|
||||
segmenter = SAMPointBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
|
||||
segmenter.run()
|
||||
42
scripts/test.py
Executable file
@@ -0,0 +1,42 @@
|
||||
import time
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
||||
|
||||
# 初始化SAM模型
|
||||
sam = sam_model_registry["vit_h"](checkpoint="/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth")
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
|
||||
# 读取并转换图像格式
|
||||
image = cv2.imread("ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg")
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # SAM需要RGB格式
|
||||
image_copy = image.copy() # 用于绘制掩码的原图副本(BGR格式,适配OpenCV)
|
||||
|
||||
# 生成掩码并计时
|
||||
start_time = time.time()
|
||||
masks = mask_generator.generate(image_rgb)
|
||||
print(f"掩码生成耗时: {time.time() - start_time:.2f} 秒")
|
||||
print(f"共生成 {len(masks)} 个掩码")
|
||||
|
||||
# 定义颜色生成函数(随机RGB颜色,适配OpenCV的BGR格式)
|
||||
def get_random_color():
|
||||
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||
|
||||
# 遍历所有掩码并绘制到图像上
|
||||
for mask in masks:
|
||||
# 获取掩码的二进制数组(True/False)
|
||||
mask_array = mask["segmentation"]
|
||||
# 生成随机颜色
|
||||
color = get_random_color()
|
||||
# 将掩码区域绘制到图像副本上(半透明效果)
|
||||
image_copy[mask_array] = image_copy[mask_array] * 0.5 + np.array(color) * 0.5
|
||||
|
||||
# 保存结果图像
|
||||
cv2.imwrite("mask_visualization.jpg", image_copy)
|
||||
|
||||
# 显示结果(如果是桌面环境)
|
||||
cv2.namedWindow("SAM Mask Visualization", cv2.WINDOW_NORMAL)
|
||||
cv2.imshow("SAM Mask Visualization", image_copy)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||