Files
aida_seg_anything/scripts/amg_point.py
2026-01-08 14:35:23 +08:00

159 lines
5.8 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageTk
import tkinter as tk
from tkinter import messagebox
from segment_anything import SamPredictor, sam_model_registry
class SAMPointSegmenter:
def __init__(self, image_path, sam_checkpoint, model_type="vit_b"):
# 加载图像
self.image = Image.open(image_path).convert("RGB")
self.original_image = self.image.copy()
self.result_image = self.image.copy()
# 初始化SAM模型
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
# 准备图像供SAM使用
self.image_np = np.array(self.image)
self.predictor.set_image(self.image_np)
# 存储点击的点 [(x, y, label), ...]label=1表示目标内0表示目标外
self.points = []
self.mask = None
# 创建GUI
self.root = tk.Tk()
self.root.title("SAM 打点分割工具")
# 创建画布
self.tk_image = ImageTk.PhotoImage(image=self.image)
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
self.canvas.pack()
# 绑定鼠标事件
self.canvas.bind("<Button-1>", self.add_foreground_point) # 左键点击添加前景点
self.canvas.bind("<Button-3>", self.add_background_point) # 右键点击添加背景点
# 创建按钮
self.controls_frame = tk.Frame(self.root)
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
self.clear_btn = tk.Button(self.controls_frame, text="清除所有点", command=self.clear_points)
self.clear_btn.pack(side=tk.LEFT, padx=5)
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
self.save_btn.pack(side=tk.LEFT, padx=5)
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
self.quit_btn.pack(side=tk.RIGHT, padx=5)
# 添加说明标签
self.info_label = tk.Label(
self.root,
text="操作说明: 左键点击添加前景点(绿色), 右键点击添加背景点(红色)",
fg="blue"
)
self.info_label.pack(pady=5)
def add_foreground_point(self, event):
"""添加前景点(目标内)"""
self.points.append((event.x, event.y, 1))
self.update_segmentation()
def add_background_point(self, event):
"""添加背景点(目标外)"""
self.points.append((event.x, event.y, 0))
self.update_segmentation()
def update_segmentation(self):
"""根据当前点更新分割结果"""
# 复制原图
self.result_image = self.original_image.copy()
draw = ImageDraw.Draw(self.result_image)
if not self.points:
self.update_display()
return
# 准备点和标签
point_coords = np.array([(x, y) for x, y, label in self.points])
point_labels = np.array([label for x, y, label in self.points])
# 调用SAM生成掩码
masks, _, _ = self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False
)
self.mask = masks[0]
# 创建半透明的掩码叠加层
mask_array = self.mask.astype(np.uint8) * 128 # 0或128半透明
mask_image = Image.fromarray(mask_array, mode="L")
# 创建绿色的掩码图像
green_mask = Image.new("RGBA", self.result_image.size, (0, 255, 0, 0))
green_draw = ImageDraw.Draw(green_mask)
green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明
# 叠加掩码到原图
self.result_image = Image.alpha_composite(
self.result_image.convert("RGBA"),
green_mask
).convert("RGB")
# 绘制所有点
draw = ImageDraw.Draw(self.result_image)
for x, y, label in self.points:
# 前景点为绿色,背景点为红色
color = (0, 255, 0) if label == 1 else (255, 0, 0)
# 绘制点(内部实心,外部白色边框)
draw.ellipse([(x - 5, y - 5), (x + 5, y + 5)], fill=color)
draw.ellipse([(x - 7, y - 7), (x + 7, y + 7)], outline=(255, 255, 255), width=2)
self.update_display()
def update_display(self):
"""更新画布显示"""
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
self.canvas.delete("all")
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
def clear_points(self):
"""清除所有点和掩码"""
self.points = []
self.mask = None
self.result_image = self.original_image.copy()
self.update_display()
def save_result(self):
"""保存分割结果"""
try:
self.result_image.save("segmentation_result.jpg")
messagebox.showinfo("成功", "分割结果已保存为 segmentation_result.jpg")
except Exception as e:
messagebox.showerror("错误", f"保存失败: {str(e)}")
def run(self):
"""运行GUI主循环"""
self.root.mainloop()
if __name__ == "__main__":
# 配置参数
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应
# 创建并运行分割器
segmenter = SAMPointSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
segmenter.run()