258 lines
10 KiB
Python
258 lines
10 KiB
Python
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from PIL import Image, ImageDraw, ImageTk
|
|||
|
|
import tkinter as tk
|
|||
|
|
from tkinter import messagebox
|
|||
|
|
from segment_anything import SamPredictor, sam_model_registry
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SAMPointBoxSegmenter:
|
|||
|
|
def __init__(self, image_path, sam_checkpoint, model_type="vit_h"):
|
|||
|
|
# 加载图像
|
|||
|
|
self.image = Image.open(image_path).convert("RGB")
|
|||
|
|
self.original_image = self.image.copy()
|
|||
|
|
self.result_image = self.image.copy()
|
|||
|
|
|
|||
|
|
# 初始化SAM模型
|
|||
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
|
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
|||
|
|
self.sam.to(device=self.device)
|
|||
|
|
self.predictor = SamPredictor(self.sam)
|
|||
|
|
|
|||
|
|
# 准备图像供SAM使用
|
|||
|
|
self.image_np = np.array(self.image)
|
|||
|
|
self.predictor.set_image(self.image_np)
|
|||
|
|
|
|||
|
|
# 交互状态变量
|
|||
|
|
self.drawing_box = False # 是否正在绘制框
|
|||
|
|
self.start_x, self.start_y = 0, 0 # 框起点
|
|||
|
|
self.current_box = None # 当前框坐标 [x1, y1, x2, y2]
|
|||
|
|
self.points = [] # 点坐标及标签 [(x, y, label), ...],label=1前景,0背景
|
|||
|
|
self.mask = None # 生成的掩码
|
|||
|
|
|
|||
|
|
# 创建GUI
|
|||
|
|
self.root = tk.Tk()
|
|||
|
|
self.root.title("SAM 点框组合分割工具")
|
|||
|
|
|
|||
|
|
# 创建画布
|
|||
|
|
self.tk_image = ImageTk.PhotoImage(image=self.image)
|
|||
|
|
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
|
|||
|
|
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
|||
|
|
self.canvas.pack()
|
|||
|
|
|
|||
|
|
# 绑定鼠标事件
|
|||
|
|
self.canvas.bind("<ButtonPress-1>", self.on_left_down) # 左键:开始画框或加点
|
|||
|
|
self.canvas.bind("<B1-Motion>", self.on_left_drag) # 左键拖动:画框
|
|||
|
|
self.canvas.bind("<ButtonRelease-1>", self.on_left_up) # 左键释放:确认框
|
|||
|
|
self.canvas.bind("<Button-3>", self.add_background_point) # 右键:添加背景点
|
|||
|
|
|
|||
|
|
# 控制按钮
|
|||
|
|
self.controls_frame = tk.Frame(self.root)
|
|||
|
|
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
|
|||
|
|
|
|||
|
|
self.clear_box_btn = tk.Button(self.controls_frame, text="清除框", command=self.clear_box)
|
|||
|
|
self.clear_box_btn.pack(side=tk.LEFT, padx=5)
|
|||
|
|
|
|||
|
|
self.clear_points_btn = tk.Button(self.controls_frame, text="清除点", command=self.clear_points)
|
|||
|
|
self.clear_points_btn.pack(side=tk.LEFT, padx=5)
|
|||
|
|
|
|||
|
|
self.clear_all_btn = tk.Button(self.controls_frame, text="清除全部", command=self.clear_all)
|
|||
|
|
self.clear_all_btn.pack(side=tk.LEFT, padx=5)
|
|||
|
|
|
|||
|
|
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
|
|||
|
|
self.save_btn.pack(side=tk.LEFT, padx=5)
|
|||
|
|
|
|||
|
|
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
|
|||
|
|
self.quit_btn.pack(side=tk.RIGHT, padx=5)
|
|||
|
|
|
|||
|
|
# 操作说明
|
|||
|
|
self.info_label = tk.Label(
|
|||
|
|
self.root,
|
|||
|
|
text="操作说明: 左键拖动画框 | 左键点击加前景点(绿) | 右键点击加背景点(红) | 框和点可组合使用",
|
|||
|
|
fg="blue",
|
|||
|
|
wraplength=600 # 自动换行
|
|||
|
|
)
|
|||
|
|
self.info_label.pack(pady=5)
|
|||
|
|
|
|||
|
|
# ------------------------------ 鼠标事件处理 ------------------------------
|
|||
|
|
def on_left_down(self, event):
|
|||
|
|
"""左键按下:判断是开始画框还是添加前景点"""
|
|||
|
|
# 若已有框且点击位置不在框起点附近,则视为添加前景点
|
|||
|
|
if self.current_box is not None:
|
|||
|
|
self.points.append((event.x, event.y, 1)) # 前景点(绿色)
|
|||
|
|
self.update_segmentation()
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 否则视为开始画框
|
|||
|
|
self.drawing_box = True
|
|||
|
|
self.start_x, self.start_y = event.x, event.y
|
|||
|
|
|
|||
|
|
def on_left_drag(self, event):
|
|||
|
|
"""左键拖动:绘制动态框(仅当正在画框时)"""
|
|||
|
|
if not self.drawing_box:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 临时图像上绘制虚线框
|
|||
|
|
temp_img = self.original_image.copy()
|
|||
|
|
draw = ImageDraw.Draw(temp_img)
|
|||
|
|
# 绘制四条虚线边(兼容所有Pillow版本)
|
|||
|
|
x1, y1 = self.start_x, self.start_y
|
|||
|
|
x2, y2 = event.x, event.y
|
|||
|
|
self.draw_dashed_line(draw, x1, y1, x2, y1, [5, 5], "red", 2) # 上边
|
|||
|
|
self.draw_dashed_line(draw, x2, y1, x2, y2, [5, 5], "red", 2) # 右边
|
|||
|
|
self.draw_dashed_line(draw, x2, y2, x1, y2, [5, 5], "red", 2) # 下边
|
|||
|
|
self.draw_dashed_line(draw, x1, y2, x1, y1, [5, 5], "red", 2) # 左边
|
|||
|
|
|
|||
|
|
# 叠加已有点
|
|||
|
|
self._draw_points(draw)
|
|||
|
|
|
|||
|
|
# 更新显示
|
|||
|
|
self.tk_image = ImageTk.PhotoImage(image=temp_img)
|
|||
|
|
self.canvas.delete("all")
|
|||
|
|
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
|||
|
|
|
|||
|
|
def on_left_up(self, event):
|
|||
|
|
"""左键释放:确认框选"""
|
|||
|
|
if not self.drawing_box:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
self.drawing_box = False
|
|||
|
|
# 确定框坐标(左上角到右下角)
|
|||
|
|
self.current_box = [
|
|||
|
|
min(self.start_x, event.x),
|
|||
|
|
min(self.start_y, event.y),
|
|||
|
|
max(self.start_x, event.x),
|
|||
|
|
max(self.start_y, event.y)
|
|||
|
|
]
|
|||
|
|
self.update_segmentation()
|
|||
|
|
|
|||
|
|
def add_background_point(self, event):
|
|||
|
|
"""右键点击:添加背景点(红色)"""
|
|||
|
|
self.points.append((event.x, event.y, 0))
|
|||
|
|
self.update_segmentation()
|
|||
|
|
|
|||
|
|
# ------------------------------ 辅助函数 ------------------------------
|
|||
|
|
def draw_dashed_line(self, draw, x1, y1, x2, y2, dash_pattern, color, width):
|
|||
|
|
"""绘制虚线(兼容所有Pillow版本)"""
|
|||
|
|
dx = x2 - x1
|
|||
|
|
dy = y2 - y1
|
|||
|
|
length = (dx ** 2 + dy ** 2) ** 0.5
|
|||
|
|
if length < 1:
|
|||
|
|
draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
|
|||
|
|
return
|
|||
|
|
ux, uy = dx / length, dy / length # 单位向量
|
|||
|
|
current_pos = 0
|
|||
|
|
dash_on = True
|
|||
|
|
pattern_idx = 0
|
|||
|
|
seg_len = dash_pattern[pattern_idx]
|
|||
|
|
|
|||
|
|
while current_pos < length:
|
|||
|
|
end_pos = min(current_pos + seg_len, length)
|
|||
|
|
x_start, y_start = x1 + ux * current_pos, y1 + uy * current_pos
|
|||
|
|
x_end, y_end = x1 + ux * end_pos, y1 + uy * end_pos
|
|||
|
|
if dash_on:
|
|||
|
|
draw.line([(x_start, y_start), (x_end, y_end)], fill=color, width=width)
|
|||
|
|
dash_on = not dash_on
|
|||
|
|
pattern_idx = (pattern_idx + 1) % len(dash_pattern)
|
|||
|
|
seg_len = dash_pattern[pattern_idx]
|
|||
|
|
current_pos = end_pos
|
|||
|
|
|
|||
|
|
def _draw_points(self, draw):
|
|||
|
|
"""在图像上绘制所有点"""
|
|||
|
|
for x, y, label in self.points:
|
|||
|
|
color = (0, 255, 0) if label == 1 else (255, 0, 0) # 绿=前景,红=背景
|
|||
|
|
draw.ellipse([(x - 5, y - 5), (x + 5, y + 5)], fill=color) # 内部实心
|
|||
|
|
draw.ellipse([(x - 7, y - 7), (x + 7, y + 7)], outline=(255, 255, 255), width=2) # 白色边框
|
|||
|
|
|
|||
|
|
# ------------------------------ 分割逻辑 ------------------------------
|
|||
|
|
def update_segmentation(self):
|
|||
|
|
"""结合框和点生成掩码"""
|
|||
|
|
self.result_image = self.original_image.copy()
|
|||
|
|
draw = ImageDraw.Draw(self.result_image)
|
|||
|
|
|
|||
|
|
# 准备提示信息(框和点)
|
|||
|
|
box = np.array(self.current_box) if self.current_box else None
|
|||
|
|
point_coords = np.array([(x, y) for x, y, label in self.points]) if self.points else None
|
|||
|
|
point_labels = np.array([label for x, y, label in self.points]) if self.points else None
|
|||
|
|
|
|||
|
|
# 若没有任何提示,直接显示原图
|
|||
|
|
if box is None and point_coords is None:
|
|||
|
|
self._draw_points(draw)
|
|||
|
|
self.update_display()
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 调用SAM生成掩码(同时传入框和点)
|
|||
|
|
masks, _, _ = self.predictor.predict(
|
|||
|
|
box=box,
|
|||
|
|
point_coords=point_coords,
|
|||
|
|
point_labels=point_labels,
|
|||
|
|
multimask_output=False
|
|||
|
|
)
|
|||
|
|
self.mask = masks[0]
|
|||
|
|
|
|||
|
|
# 叠加掩码(半透明绿色)
|
|||
|
|
mask_array = self.mask.astype(np.uint8) * 128 # 0或128(半透明)
|
|||
|
|
mask_img = Image.fromarray(mask_array, mode="L")
|
|||
|
|
green_mask = Image.new("RGBA", self.result_image.size, (0, 0, 0, 0))
|
|||
|
|
green_draw = ImageDraw.Draw(green_mask)
|
|||
|
|
green_draw.bitmap((0, 0), mask_img, fill=(0, 255, 0, 128)) # 绿色半透明
|
|||
|
|
self.result_image = Image.alpha_composite(
|
|||
|
|
self.result_image.convert("RGBA"),
|
|||
|
|
green_mask
|
|||
|
|
).convert("RGB")
|
|||
|
|
|
|||
|
|
# 绘制框和点
|
|||
|
|
if self.current_box:
|
|||
|
|
draw.rectangle(self.current_box, outline="red", width=2) # 实线框
|
|||
|
|
self._draw_points(draw)
|
|||
|
|
|
|||
|
|
self.update_display()
|
|||
|
|
|
|||
|
|
# ------------------------------ 控制函数 ------------------------------
|
|||
|
|
def update_display(self):
|
|||
|
|
"""更新画布显示"""
|
|||
|
|
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
|
|||
|
|
self.canvas.delete("all")
|
|||
|
|
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
|||
|
|
|
|||
|
|
def clear_box(self):
|
|||
|
|
"""清除当前框"""
|
|||
|
|
self.current_box = None
|
|||
|
|
self.update_segmentation()
|
|||
|
|
|
|||
|
|
def clear_points(self):
|
|||
|
|
"""清除所有点"""
|
|||
|
|
self.points = []
|
|||
|
|
self.update_segmentation()
|
|||
|
|
|
|||
|
|
def clear_all(self):
|
|||
|
|
"""清除框和点"""
|
|||
|
|
self.current_box = None
|
|||
|
|
self.points = []
|
|||
|
|
self.mask = None
|
|||
|
|
self.result_image = self.original_image.copy()
|
|||
|
|
self.update_display()
|
|||
|
|
|
|||
|
|
def save_result(self):
|
|||
|
|
"""保存结果"""
|
|||
|
|
try:
|
|||
|
|
self.result_image.save("combined_segmentation_result.jpg")
|
|||
|
|
messagebox.showinfo("成功", "结果已保存为 combined_segmentation_result.jpg")
|
|||
|
|
except Exception as e:
|
|||
|
|
messagebox.showerror("错误", f"保存失败: {str(e)}")
|
|||
|
|
|
|||
|
|
def run(self):
|
|||
|
|
"""启动GUI"""
|
|||
|
|
self.root.mainloop()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 配置路径
|
|||
|
|
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
|||
|
|
SAM_CHECKPOINT = "/home/alab/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
|||
|
|
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
|||
|
|
|
|||
|
|
# 运行工具
|
|||
|
|
segmenter = SAMPointBoxSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
|
|||
|
|
segmenter.run()
|