Files
aida_seg_anything/scripts/amg_box.py
2026-01-08 17:02:33 +08:00

240 lines
8.4 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageTk
import tkinter as tk
from tkinter import messagebox
from segment_anything import SamPredictor, sam_model_registry
class SAMBoxSegmenter:
def __init__(self, image_path, sam_checkpoint, model_type="vit_h"):
# 加载图像
self.image = Image.open(image_path).convert("RGB")
self.original_image = self.image.copy()
self.result_image = self.image.copy()
# 初始化SAM模型
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(self.device)
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
# 准备图像供SAM使用
self.image_np = np.array(self.image)
self.predictor.set_image(self.image_np)
# 框选相关变量
self.drawing = False
self.start_x = 0
self.start_y = 0
self.current_box = None
self.mask = None
# 创建GUI
self.root = tk.Tk()
self.root.title("SAM 手绘框分割工具")
# 创建画布
self.tk_image = ImageTk.PhotoImage(image=self.image)
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
self.canvas.pack()
# 绑定鼠标事件
self.canvas.bind("<ButtonPress-1>", self.on_mouse_down) # 鼠标按下
self.canvas.bind("<B1-Motion>", self.on_mouse_drag) # 鼠标拖动
self.canvas.bind("<ButtonRelease-1>", self.on_mouse_up) # 鼠标释放
# 创建按钮
self.controls_frame = tk.Frame(self.root)
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
self.clear_btn = tk.Button(self.controls_frame, text="清除框选", command=self.clear_box)
self.clear_btn.pack(side=tk.LEFT, padx=5)
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
self.save_btn.pack(side=tk.LEFT, padx=5)
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
self.quit_btn.pack(side=tk.RIGHT, padx=5)
# 添加说明标签
self.info_label = tk.Label(
self.root,
text="操作说明: 按住鼠标左键拖动绘制框选区域,松开后自动分割",
fg="blue"
)
self.info_label.pack(pady=5)
def on_mouse_down(self, event):
"""鼠标按下时记录起点"""
self.drawing = True
self.start_x = event.x
self.start_y = event.y
def on_mouse_drag(self, event):
"""鼠标拖动时实时绘制框选(使用兼容方式实现虚线框)"""
if not self.drawing:
return
# 临时复制原图用于绘制动态框
temp_image = self.original_image.copy()
draw = ImageDraw.Draw(temp_image)
# 计算框选坐标
x1, y1 = self.start_x, self.start_y
x2, y2 = event.x, event.y
# 绘制虚线框的替代方法兼容所有Pillow版本
# 绘制四条边,每条边用虚线模式
dash_pattern = [5, 5] # 虚线样式5像素实线5像素空白
self.draw_dashed_line(draw, x1, y1, x2, y1, dash_pattern, "red", 2) # 上边
self.draw_dashed_line(draw, x2, y1, x2, y2, dash_pattern, "red", 2) # 右边
self.draw_dashed_line(draw, x2, y2, x1, y2, dash_pattern, "red", 2) # 下边
self.draw_dashed_line(draw, x1, y2, x1, y1, dash_pattern, "red", 2) # 左边
# 更新显示
self.tk_image = ImageTk.PhotoImage(image=temp_image)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width):
"""绘制虚线的工具函数兼容所有Pillow版本"""
# 计算线段长度和方向
dx = x2 - x1
dy = y2 - y1
length = (dx ** 2 + dy ** 2) ** 0.5
# 如果线段太短,直接画实线
if length < 1:
draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
return
# 计算单位向量
ux = dx / length
uy = dy / length
# 绘制虚线
current_pos = 0
dash_on = True # 开始时是实线部分
pattern_index = 0
segment_length = dash_pattern[pattern_index]
while current_pos < length:
# 计算当前段的结束位置
end_pos = current_pos + segment_length
if end_pos > length:
end_pos = length
# 计算像素坐标
x_start = x1 + ux * current_pos
y_start = y1 + uy * current_pos
x_end = x1 + ux * end_pos
y_end = y1 + uy * end_pos
# 如果是实线部分,绘制线段
if dash_on:
draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width)
# 切换状态并更新参数
dash_on = not dash_on
pattern_index = (pattern_index + 1) % len(dash_pattern)
segment_length = dash_pattern[pattern_index]
current_pos = end_pos
def on_mouse_up(self, event):
"""鼠标释放时确定框选并生成掩码"""
if not self.drawing:
return
self.drawing = False
# 确保坐标正确(左上角到右下角)
self.current_box = [
min(self.start_x, event.x),
min(self.start_y, event.y),
max(self.start_x, event.x),
max(self.start_y, event.y)
]
# 生成分割掩码
self.update_segmentation()
def update_segmentation(self):
print('update_segmentation')
"""根据当前框选更新分割结果"""
if not self.current_box:
self.result_image = self.original_image.copy()
self.update_display()
return
# 准备框坐标
box = np.array(self.current_box)
# 调用SAM生成掩码
masks, _, _ = self.predictor.predict(
box=box,
multimask_output=False
)
self.mask = masks[0]
# 复制原图用于绘制结果
self.result_image = self.original_image.copy()
# 创建半透明的掩码叠加层
mask_array = self.mask.astype(np.uint8) * 128 # 0或128半透明
mask_image = Image.fromarray(mask_array, mode="L")
# 创建绿色的掩码图像
green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0))
green_draw = ImageDraw.Draw(green_mask)
green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明
# 叠加掩码到原图
self.result_image = Image.alpha_composite(
self.result_image.convert("RGBA"),
green_mask
).convert("RGB")
# 绘制最终框选(红色实线)
draw = ImageDraw.Draw(self.result_image)
draw.rectangle(self.current_box, outline="red", width=2)
self.update_display()
def update_display(self):
"""更新画布显示"""
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def clear_box(self):
"""清除框选和掩码"""
self.current_box = None
self.mask = None
self.result_image = self.original_image.copy()
self.update_display()
def save_result(self):
"""保存分割结果"""
try:
self.result_image.save("box_segmentation_result.jpg")
messagebox.showinfo("成功", "分割结果已保存为 box_segmentation_result.jpg")
except Exception as e:
messagebox.showerror("错误", f"保存失败: {str(e)}")
def run(self):
"""运行GUI主循环"""
self.root.mainloop()
if __name__ == "__main__":
# 配置参数
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应
# 创建并运行分割器
segmenter = SAMBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
segmenter.run()