import numpy as np import torch from PIL import Image, ImageDraw, ImageTk import tkinter as tk from tkinter import messagebox from segment_anything import SamPredictor, sam_model_registry class SAMBoxSegmenter: def __init__(self, image_path, sam_checkpoint, model_type="vit_h"): # 加载图像 self.image = Image.open(image_path).convert("RGB") self.original_image = self.image.copy() self.result_image = self.image.copy() # 初始化SAM模型 self.device = "cuda" if torch.cuda.is_available() else "cpu" print(self.device) self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) self.sam.to(device=self.device) self.predictor = SamPredictor(self.sam) # 准备图像供SAM使用 self.image_np = np.array(self.image) self.predictor.set_image(self.image_np) # 框选相关变量 self.drawing = False self.start_x = 0 self.start_y = 0 self.current_box = None self.mask = None # 创建GUI self.root = tk.Tk() self.root.title("SAM 手绘框分割工具") # 创建画布 self.tk_image = ImageTk.PhotoImage(image=self.image) self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height) self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW) self.canvas.pack() # 绑定鼠标事件 self.canvas.bind("", self.on_mouse_down) # 鼠标按下 self.canvas.bind("", self.on_mouse_drag) # 鼠标拖动 self.canvas.bind("", self.on_mouse_up) # 鼠标释放 # 创建按钮 self.controls_frame = tk.Frame(self.root) self.controls_frame.pack(fill=tk.X, padx=5, pady=5) self.clear_btn = tk.Button(self.controls_frame, text="清除框选", command=self.clear_box) self.clear_btn.pack(side=tk.LEFT, padx=5) self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result) self.save_btn.pack(side=tk.LEFT, padx=5) self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit) self.quit_btn.pack(side=tk.RIGHT, padx=5) # 添加说明标签 self.info_label = tk.Label( self.root, text="操作说明: 按住鼠标左键拖动绘制框选区域,松开后自动分割", fg="blue" ) self.info_label.pack(pady=5) def on_mouse_down(self, event): """鼠标按下时记录起点""" self.drawing = True self.start_x = event.x self.start_y = event.y def on_mouse_drag(self, event): """鼠标拖动时实时绘制框选(使用兼容方式实现虚线框)""" if not self.drawing: return # 临时复制原图用于绘制动态框 temp_image = self.original_image.copy() draw = ImageDraw.Draw(temp_image) # 计算框选坐标 x1, y1 = self.start_x, self.start_y x2, y2 = event.x, event.y # 绘制虚线框的替代方法(兼容所有Pillow版本) # 绘制四条边,每条边用虚线模式 dash_pattern = [5, 5] # 虚线样式:5像素实线,5像素空白 self.draw_dashed_line(draw, x1, y1, x2, y1, dash_pattern, "red", 2) # 上边 self.draw_dashed_line(draw, x2, y1, x2, y2, dash_pattern, "red", 2) # 右边 self.draw_dashed_line(draw, x2, y2, x1, y2, dash_pattern, "red", 2) # 下边 self.draw_dashed_line(draw, x1, y2, x1, y1, dash_pattern, "red", 2) # 左边 # 更新显示 self.tk_image = ImageTk.PhotoImage(image=temp_image) self.canvas.delete("all") self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW) def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width): """绘制虚线的工具函数,兼容所有Pillow版本""" # 计算线段长度和方向 dx = x2 - x1 dy = y2 - y1 length = (dx ** 2 + dy ** 2) ** 0.5 # 如果线段太短,直接画实线 if length < 1: draw.line([(x1, y1), (x2, y2)], fill=color, width=width) return # 计算单位向量 ux = dx / length uy = dy / length # 绘制虚线 current_pos = 0 dash_on = True # 开始时是实线部分 pattern_index = 0 segment_length = dash_pattern[pattern_index] while current_pos < length: # 计算当前段的结束位置 end_pos = current_pos + segment_length if end_pos > length: end_pos = length # 计算像素坐标 x_start = x1 + ux * current_pos y_start = y1 + uy * current_pos x_end = x1 + ux * end_pos y_end = y1 + uy * end_pos # 如果是实线部分,绘制线段 if dash_on: draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width) # 切换状态并更新参数 dash_on = not dash_on pattern_index = (pattern_index + 1) % len(dash_pattern) segment_length = dash_pattern[pattern_index] current_pos = end_pos def on_mouse_up(self, event): """鼠标释放时确定框选并生成掩码""" if not self.drawing: return self.drawing = False # 确保坐标正确(左上角到右下角) self.current_box = [ min(self.start_x, event.x), min(self.start_y, event.y), max(self.start_x, event.x), max(self.start_y, event.y) ] # 生成分割掩码 self.update_segmentation() def update_segmentation(self): print('update_segmentation') """根据当前框选更新分割结果""" if not self.current_box: self.result_image = self.original_image.copy() self.update_display() return # 准备框坐标 box = np.array(self.current_box) # 调用SAM生成掩码 masks, _, _ = self.predictor.predict( box=box, multimask_output=False ) self.mask = masks[0] # 复制原图用于绘制结果 self.result_image = self.original_image.copy() # 创建半透明的掩码叠加层 mask_array = self.mask.astype(np.uint8) * 128 # 0或128(半透明) mask_image = Image.fromarray(mask_array, mode="L") # 创建绿色的掩码图像 green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0)) green_draw = ImageDraw.Draw(green_mask) green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明 # 叠加掩码到原图 self.result_image = Image.alpha_composite( self.result_image.convert("RGBA"), green_mask ).convert("RGB") # 绘制最终框选(红色实线) draw = ImageDraw.Draw(self.result_image) draw.rectangle(self.current_box, outline="red", width=2) self.update_display() def update_display(self): """更新画布显示""" self.tk_image = ImageTk.PhotoImage(image=self.result_image) self.canvas.delete("all") self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW) def clear_box(self): """清除框选和掩码""" self.current_box = None self.mask = None self.result_image = self.original_image.copy() self.update_display() def save_result(self): """保存分割结果""" try: self.result_image.save("box_segmentation_result.jpg") messagebox.showinfo("成功", "分割结果已保存为 box_segmentation_result.jpg") except Exception as e: messagebox.showerror("错误", f"保存失败: {str(e)}") def run(self): """运行GUI主循环""" self.root.mainloop() if __name__ == "__main__": # 配置参数 IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径 SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径 MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应 # 创建并运行分割器 segmenter = SAMBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE) segmenter.run()