Files
aida_seg_anything/scripts/smg_box_point.py
2026-01-08 17:02:33 +08:00

258 lines
10 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageTk
import tkinter as tk
from tkinter import messagebox
from segment_anything import SamPredictor, sam_model_registry
class SAMPointBoxSegmenter:
def __init__(self, image_path, sam_checkpoint, model_type="vit_h"):
# 加载图像
self.image = Image.open(image_path).convert("RGB")
self.original_image = self.image.copy()
self.result_image = self.image.copy()
# 初始化SAM模型
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
# 准备图像供SAM使用
self.image_np = np.array(self.image)
self.predictor.set_image(self.image_np)
# 交互状态变量
self.drawing_box = False # 是否正在绘制框
self.start_x, self.start_y = 0, 0 # 框起点
self.current_box = None # 当前框坐标 [x1, y1, x2, y2]
self.points = [] # 点坐标及标签 [(x, y, label), ...]label=1前景0背景
self.mask = None # 生成的掩码
# 创建GUI
self.root = tk.Tk()
self.root.title("SAM 点框组合分割工具")
# 创建画布
self.tk_image = ImageTk.PhotoImage(image=self.image)
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
self.canvas.pack()
# 绑定鼠标事件
self.canvas.bind("<ButtonPress-1>", self.on_left_down) # 左键:开始画框或加点
self.canvas.bind("<B1-Motion>", self.on_left_drag) # 左键拖动:画框
self.canvas.bind("<ButtonRelease-1>", self.on_left_up) # 左键释放:确认框
self.canvas.bind("<Button-3>", self.add_background_point) # 右键:添加背景点
# 控制按钮
self.controls_frame = tk.Frame(self.root)
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
self.clear_box_btn = tk.Button(self.controls_frame, text="清除框", command=self.clear_box)
self.clear_box_btn.pack(side=tk.LEFT, padx=5)
self.clear_points_btn = tk.Button(self.controls_frame, text="清除点", command=self.clear_points)
self.clear_points_btn.pack(side=tk.LEFT, padx=5)
self.clear_all_btn = tk.Button(self.controls_frame, text="清除全部", command=self.clear_all)
self.clear_all_btn.pack(side=tk.LEFT, padx=5)
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
self.save_btn.pack(side=tk.LEFT, padx=5)
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
self.quit_btn.pack(side=tk.RIGHT, padx=5)
# 操作说明
self.info_label = tk.Label(
self.root,
text="操作说明: 左键拖动画框 | 左键点击加前景点(绿) | 右键点击加背景点(红) | 框和点可组合使用",
fg="blue",
wraplength=600 # 自动换行
)
self.info_label.pack(pady=5)
# ------------------------------ 鼠标事件处理 ------------------------------
def on_left_down(self, event):
"""左键按下:判断是开始画框还是添加前景点"""
# 若已有框且点击位置不在框起点附近,则视为添加前景点
if self.current_box is not None:
self.points.append((event.x, event.y, 1)) # 前景点(绿色)
self.update_segmentation()
return
# 否则视为开始画框
self.drawing_box = True
self.start_x, self.start_y = event.x, event.y
def on_left_drag(self, event):
"""左键拖动:绘制动态框(仅当正在画框时)"""
if not self.drawing_box:
return
# 临时图像上绘制虚线框
temp_img = self.original_image.copy()
draw = ImageDraw.Draw(temp_img)
# 绘制四条虚线边兼容所有Pillow版本
x1, y1 = self.start_x, self.start_y
x2, y2 = event.x, event.y
self.draw_dashed_line(draw, x1, y1, x2, y1, [5, 5], "red", 2) # 上边
self.draw_dashed_line(draw, x2, y1, x2, y2, [5, 5], "red", 2) # 右边
self.draw_dashed_line(draw, x2, y2, x1, y2, [5, 5], "red", 2) # 下边
self.draw_dashed_line(draw, x1, y2, x1, y1, [5, 5], "red", 2) # 左边
# 叠加已有点
self._draw_points(draw)
# 更新显示
self.tk_image = ImageTk.PhotoImage(image=temp_img)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def on_left_up(self, event):
"""左键释放:确认框选"""
if not self.drawing_box:
return
self.drawing_box = False
# 确定框坐标(左上角到右下角)
self.current_box = [
min(self.start_x, event.x),
min(self.start_y, event.y),
max(self.start_x, event.x),
max(self.start_y, event.y)
]
self.update_segmentation()
def add_background_point(self, event):
"""右键点击:添加背景点(红色)"""
self.points.append((event.x, event.y, 0))
self.update_segmentation()
# ------------------------------ 辅助函数 ------------------------------
def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width):
"""绘制虚线兼容所有Pillow版本"""
dx = x2 - x1
dy = y2 - y1
length = (dx ** 2 + dy ** 2) ** 0.5
if length < 1:
draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
return
ux, uy = dx / length, dy / length # 单位向量
current_pos = 0
dash_on = True
pattern_idx = 0
seg_len = dash_pattern[pattern_idx]
while current_pos < length:
end_pos = min(current_pos + seg_len, length)
x_start, y_start = x1 + ux * current_pos, y1 + uy * current_pos
x_end, y_end = x1 + ux * end_pos, y1 + uy * end_pos
if dash_on:
draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width)
dash_on = not dash_on
pattern_idx = (pattern_idx + 1) % len(dash_pattern)
seg_len = dash_pattern[pattern_idx]
current_pos = end_pos
def _draw_points(self, draw):
"""在图像上绘制所有点"""
for x, y, label in self.points:
color = (0, 255, 0) if label == 1 else (255, 0, 0) # 绿=前景,红=背景
draw.ellipse([(x - 5, y - 5), (x + 5, y + 5)], fill=color) # 内部实心
draw.ellipse([(x - 7, y - 7), (x + 7, y + 7)], outline=(255, 255, 255), width=2) # 白色边框
# ------------------------------ 分割逻辑 ------------------------------
def update_segmentation(self):
"""结合框和点生成掩码"""
self.result_image = self.original_image.copy()
draw = ImageDraw.Draw(self.result_image)
# 准备提示信息(框和点)
box = np.array(self.current_box) if self.current_box else None
point_coords = np.array([(x, y) for x, y, label in self.points]) if self.points else None
point_labels = np.array([label for x, y, label in self.points]) if self.points else None
# 若没有任何提示,直接显示原图
if box is None and point_coords is None:
self._draw_points(draw)
self.update_display()
return
# 调用SAM生成掩码同时传入框和点
masks, _, _ = self.predictor.predict(
box=box,
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False
)
self.mask = masks[0]
# 叠加掩码(半透明绿色)
mask_array = self.mask.astype(np.uint8) * 128 # 0或128半透明
mask_img = Image.fromarray(mask_array, mode="L")
green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0))
green_draw = ImageDraw.Draw(green_mask)
green_draw.bitmap((0, 0), mask_img, fill=(0, 255, 0, 128)) # 绿色半透明
self.result_image = Image.alpha_composite(
self.result_image.convert("RGBA"),
green_mask
).convert("RGB")
# 绘制框和点
if self.current_box:
draw.rectangle(self.current_box, outline="red", width=2) # 实线框
self._draw_points(draw)
self.update_display()
# ------------------------------ 控制函数 ------------------------------
def update_display(self):
"""更新画布显示"""
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def clear_box(self):
"""清除当前框"""
self.current_box = None
self.update_segmentation()
def clear_points(self):
"""清除所有点"""
self.points = []
self.update_segmentation()
def clear_all(self):
"""清除框和点"""
self.current_box = None
self.points = []
self.mask = None
self.result_image = self.original_image.copy()
self.update_display()
def save_result(self):
"""保存结果"""
try:
self.result_image.save("combined_segmentation_result.jpg")
messagebox.showinfo("成功", "结果已保存为 combined_segmentation_result.jpg")
except Exception as e:
messagebox.showerror("错误", f"保存失败: {str(e)}")
def run(self):
"""启动GUI"""
self.root.mainloop()
if __name__ == "__main__":
# 配置路径
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应
# 运行工具
segmenter = SAMPointBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
segmenter.run()