240 lines
8.4 KiB
Python
Executable File
240 lines
8.4 KiB
Python
Executable File
import numpy as np
|
||
import torch
|
||
from PIL import Image, ImageDraw, ImageTk
|
||
import tkinter as tk
|
||
from tkinter import messagebox
|
||
from segment_anything import SamPredictor, sam_model_registry
|
||
|
||
|
||
class SAMBoxSegmenter:
|
||
def __init__(self, image_path, sam_checkpoint, model_type="vit_h"):
|
||
# 加载图像
|
||
self.image = Image.open(image_path).convert("RGB")
|
||
self.original_image = self.image.copy()
|
||
self.result_image = self.image.copy()
|
||
|
||
# 初始化SAM模型
|
||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
print(self.device)
|
||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||
self.sam.to(device=self.device)
|
||
self.predictor = SamPredictor(self.sam)
|
||
|
||
# 准备图像供SAM使用
|
||
self.image_np = np.array(self.image)
|
||
self.predictor.set_image(self.image_np)
|
||
|
||
# 框选相关变量
|
||
self.drawing = False
|
||
self.start_x = 0
|
||
self.start_y = 0
|
||
self.current_box = None
|
||
self.mask = None
|
||
|
||
# 创建GUI
|
||
self.root = tk.Tk()
|
||
self.root.title("SAM 手绘框分割工具")
|
||
|
||
# 创建画布
|
||
self.tk_image = ImageTk.PhotoImage(image=self.image)
|
||
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
|
||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||
self.canvas.pack()
|
||
|
||
# 绑定鼠标事件
|
||
self.canvas.bind("<ButtonPress-1>", self.on_mouse_down) # 鼠标按下
|
||
self.canvas.bind("<B1-Motion>", self.on_mouse_drag) # 鼠标拖动
|
||
self.canvas.bind("<ButtonRelease-1>", self.on_mouse_up) # 鼠标释放
|
||
|
||
# 创建按钮
|
||
self.controls_frame = tk.Frame(self.root)
|
||
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
|
||
|
||
self.clear_btn = tk.Button(self.controls_frame, text="清除框选", command=self.clear_box)
|
||
self.clear_btn.pack(side=tk.LEFT, padx=5)
|
||
|
||
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
|
||
self.save_btn.pack(side=tk.LEFT, padx=5)
|
||
|
||
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
|
||
self.quit_btn.pack(side=tk.RIGHT, padx=5)
|
||
|
||
# 添加说明标签
|
||
self.info_label = tk.Label(
|
||
self.root,
|
||
text="操作说明: 按住鼠标左键拖动绘制框选区域,松开后自动分割",
|
||
fg="blue"
|
||
)
|
||
self.info_label.pack(pady=5)
|
||
|
||
def on_mouse_down(self, event):
|
||
"""鼠标按下时记录起点"""
|
||
self.drawing = True
|
||
self.start_x = event.x
|
||
self.start_y = event.y
|
||
|
||
def on_mouse_drag(self, event):
|
||
"""鼠标拖动时实时绘制框选(使用兼容方式实现虚线框)"""
|
||
if not self.drawing:
|
||
return
|
||
|
||
# 临时复制原图用于绘制动态框
|
||
temp_image = self.original_image.copy()
|
||
draw = ImageDraw.Draw(temp_image)
|
||
|
||
# 计算框选坐标
|
||
x1, y1 = self.start_x, self.start_y
|
||
x2, y2 = event.x, event.y
|
||
|
||
# 绘制虚线框的替代方法(兼容所有Pillow版本)
|
||
# 绘制四条边,每条边用虚线模式
|
||
dash_pattern = [5, 5] # 虚线样式:5像素实线,5像素空白
|
||
self.draw_dashed_line(draw, x1, y1, x2, y1, dash_pattern, "red", 2) # 上边
|
||
self.draw_dashed_line(draw, x2, y1, x2, y2, dash_pattern, "red", 2) # 右边
|
||
self.draw_dashed_line(draw, x2, y2, x1, y2, dash_pattern, "red", 2) # 下边
|
||
self.draw_dashed_line(draw, x1, y2, x1, y1, dash_pattern, "red", 2) # 左边
|
||
|
||
# 更新显示
|
||
self.tk_image = ImageTk.PhotoImage(image=temp_image)
|
||
self.canvas.delete("all")
|
||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||
|
||
def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width):
|
||
"""绘制虚线的工具函数,兼容所有Pillow版本"""
|
||
# 计算线段长度和方向
|
||
dx = x2 - x1
|
||
dy = y2 - y1
|
||
length = (dx ** 2 + dy ** 2) ** 0.5
|
||
|
||
# 如果线段太短,直接画实线
|
||
if length < 1:
|
||
draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
|
||
return
|
||
|
||
# 计算单位向量
|
||
ux = dx / length
|
||
uy = dy / length
|
||
|
||
# 绘制虚线
|
||
current_pos = 0
|
||
dash_on = True # 开始时是实线部分
|
||
pattern_index = 0
|
||
segment_length = dash_pattern[pattern_index]
|
||
|
||
while current_pos < length:
|
||
# 计算当前段的结束位置
|
||
end_pos = current_pos + segment_length
|
||
if end_pos > length:
|
||
end_pos = length
|
||
|
||
# 计算像素坐标
|
||
x_start = x1 + ux * current_pos
|
||
y_start = y1 + uy * current_pos
|
||
x_end = x1 + ux * end_pos
|
||
y_end = y1 + uy * end_pos
|
||
|
||
# 如果是实线部分,绘制线段
|
||
if dash_on:
|
||
draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width)
|
||
|
||
# 切换状态并更新参数
|
||
dash_on = not dash_on
|
||
pattern_index = (pattern_index + 1) % len(dash_pattern)
|
||
segment_length = dash_pattern[pattern_index]
|
||
current_pos = end_pos
|
||
|
||
def on_mouse_up(self, event):
|
||
"""鼠标释放时确定框选并生成掩码"""
|
||
if not self.drawing:
|
||
return
|
||
|
||
self.drawing = False
|
||
# 确保坐标正确(左上角到右下角)
|
||
self.current_box = [
|
||
min(self.start_x, event.x),
|
||
min(self.start_y, event.y),
|
||
max(self.start_x, event.x),
|
||
max(self.start_y, event.y)
|
||
]
|
||
|
||
# 生成分割掩码
|
||
self.update_segmentation()
|
||
|
||
def update_segmentation(self):
|
||
print('update_segmentation')
|
||
"""根据当前框选更新分割结果"""
|
||
if not self.current_box:
|
||
self.result_image = self.original_image.copy()
|
||
self.update_display()
|
||
return
|
||
|
||
# 准备框坐标
|
||
box = np.array(self.current_box)
|
||
|
||
# 调用SAM生成掩码
|
||
masks, _, _ = self.predictor.predict(
|
||
box=box,
|
||
multimask_output=False
|
||
)
|
||
self.mask = masks[0]
|
||
|
||
# 复制原图用于绘制结果
|
||
self.result_image = self.original_image.copy()
|
||
|
||
# 创建半透明的掩码叠加层
|
||
mask_array = self.mask.astype(np.uint8) * 128 # 0或128(半透明)
|
||
mask_image = Image.fromarray(mask_array, mode="L")
|
||
|
||
# 创建绿色的掩码图像
|
||
green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0))
|
||
green_draw = ImageDraw.Draw(green_mask)
|
||
green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明
|
||
|
||
# 叠加掩码到原图
|
||
self.result_image = Image.alpha_composite(
|
||
self.result_image.convert("RGBA"),
|
||
green_mask
|
||
).convert("RGB")
|
||
|
||
# 绘制最终框选(红色实线)
|
||
draw = ImageDraw.Draw(self.result_image)
|
||
draw.rectangle(self.current_box, outline="red", width=2)
|
||
|
||
self.update_display()
|
||
|
||
def update_display(self):
|
||
"""更新画布显示"""
|
||
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
|
||
self.canvas.delete("all")
|
||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||
|
||
def clear_box(self):
|
||
"""清除框选和掩码"""
|
||
self.current_box = None
|
||
self.mask = None
|
||
self.result_image = self.original_image.copy()
|
||
self.update_display()
|
||
|
||
def save_result(self):
|
||
"""保存分割结果"""
|
||
try:
|
||
self.result_image.save("box_segmentation_result.jpg")
|
||
messagebox.showinfo("成功", "分割结果已保存为 box_segmentation_result.jpg")
|
||
except Exception as e:
|
||
messagebox.showerror("错误", f"保存失败: {str(e)}")
|
||
|
||
def run(self):
|
||
"""运行GUI主循环"""
|
||
self.root.mainloop()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 配置参数
|
||
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||
|
||
# 创建并运行分割器
|
||
segmenter = SAMBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
|
||
segmenter.run()
|