fitst commit
This commit is contained in:
158
scripts/amg_point.py
Executable file
158
scripts/amg_point.py
Executable file
@@ -0,0 +1,158 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageTk
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
|
||||
|
||||
class SAMPointSegmenter:
|
||||
def __init__(self, image_path, sam_checkpoint, model_type="vit_b"):
|
||||
# 加载图像
|
||||
self.image = Image.open(image_path).convert("RGB")
|
||||
self.original_image = self.image.copy()
|
||||
self.result_image = self.image.copy()
|
||||
|
||||
# 初始化SAM模型
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
self.sam.to(device=self.device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
|
||||
# 准备图像供SAM使用
|
||||
self.image_np = np.array(self.image)
|
||||
self.predictor.set_image(self.image_np)
|
||||
|
||||
# 存储点击的点 [(x, y, label), ...],label=1表示目标内,0表示目标外
|
||||
self.points = []
|
||||
self.mask = None
|
||||
|
||||
# 创建GUI
|
||||
self.root = tk.Tk()
|
||||
self.root.title("SAM 打点分割工具")
|
||||
|
||||
# 创建画布
|
||||
self.tk_image = ImageTk.PhotoImage(image=self.image)
|
||||
self.canvas = tk.Canvas(self.root, width=self.image.width, height=self.image.height)
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
self.canvas.pack()
|
||||
|
||||
# 绑定鼠标事件
|
||||
self.canvas.bind("<Button-1>", self.add_foreground_point) # 左键点击添加前景点
|
||||
self.canvas.bind("<Button-3>", self.add_background_point) # 右键点击添加背景点
|
||||
|
||||
# 创建按钮
|
||||
self.controls_frame = tk.Frame(self.root)
|
||||
self.controls_frame.pack(fill=tk.X, padx=5, pady=5)
|
||||
|
||||
self.clear_btn = tk.Button(self.controls_frame, text="清除所有点", command=self.clear_points)
|
||||
self.clear_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.save_btn = tk.Button(self.controls_frame, text="保存结果", command=self.save_result)
|
||||
self.save_btn.pack(side=tk.LEFT, padx=5)
|
||||
|
||||
self.quit_btn = tk.Button(self.controls_frame, text="退出", command=self.root.quit)
|
||||
self.quit_btn.pack(side=tk.RIGHT, padx=5)
|
||||
|
||||
# 添加说明标签
|
||||
self.info_label = tk.Label(
|
||||
self.root,
|
||||
text="操作说明: 左键点击添加前景点(绿色), 右键点击添加背景点(红色)",
|
||||
fg="blue"
|
||||
)
|
||||
self.info_label.pack(pady=5)
|
||||
|
||||
def add_foreground_point(self, event):
|
||||
"""添加前景点(目标内)"""
|
||||
self.points.append((event.x, event.y, 1))
|
||||
self.update_segmentation()
|
||||
|
||||
def add_background_point(self, event):
|
||||
"""添加背景点(目标外)"""
|
||||
self.points.append((event.x, event.y, 0))
|
||||
self.update_segmentation()
|
||||
|
||||
def update_segmentation(self):
|
||||
"""根据当前点更新分割结果"""
|
||||
# 复制原图
|
||||
self.result_image = self.original_image.copy()
|
||||
draw = ImageDraw.Draw(self.result_image)
|
||||
|
||||
if not self.points:
|
||||
self.update_display()
|
||||
return
|
||||
|
||||
# 准备点和标签
|
||||
point_coords = np.array([(x, y) for x, y, label in self.points])
|
||||
point_labels = np.array([label for x, y, label in self.points])
|
||||
|
||||
# 调用SAM生成掩码
|
||||
masks, _, _ = self.predictor.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels,
|
||||
multimask_output=False
|
||||
)
|
||||
self.mask = masks[0]
|
||||
|
||||
# 创建半透明的掩码叠加层
|
||||
mask_array = self.mask.astype(np.uint8) * 128 # 0或128(半透明)
|
||||
mask_image = Image.fromarray(mask_array, mode="L")
|
||||
|
||||
# 创建绿色的掩码图像
|
||||
green_mask = Image.new("RGBA", self.result_image.size, (0, 255, 0, 0))
|
||||
green_draw = ImageDraw.Draw(green_mask)
|
||||
green_draw.bitmap((0, 0), mask_image, fill=(0, 255, 0, 128)) # 绿色半透明
|
||||
|
||||
# 叠加掩码到原图
|
||||
self.result_image = Image.alpha_composite(
|
||||
self.result_image.convert("RGBA"),
|
||||
green_mask
|
||||
).convert("RGB")
|
||||
|
||||
# 绘制所有点
|
||||
draw = ImageDraw.Draw(self.result_image)
|
||||
for x, y, label in self.points:
|
||||
# 前景点为绿色,背景点为红色
|
||||
color = (0, 255, 0) if label == 1 else (255, 0, 0)
|
||||
# 绘制点(内部实心,外部白色边框)
|
||||
draw.ellipse([(x - 5, y - 5), (x + 5, y + 5)], fill=color)
|
||||
draw.ellipse([(x - 7, y - 7), (x + 7, y + 7)], outline=(255, 255, 255), width=2)
|
||||
|
||||
self.update_display()
|
||||
|
||||
def update_display(self):
|
||||
"""更新画布显示"""
|
||||
self.tk_image = ImageTk.PhotoImage(image=self.result_image)
|
||||
self.canvas.delete("all")
|
||||
self.canvas.create_image(0, 0, image=self.tk_image, anchor=tk.NW)
|
||||
|
||||
def clear_points(self):
|
||||
"""清除所有点和掩码"""
|
||||
self.points = []
|
||||
self.mask = None
|
||||
self.result_image = self.original_image.copy()
|
||||
self.update_display()
|
||||
|
||||
def save_result(self):
|
||||
"""保存分割结果"""
|
||||
try:
|
||||
self.result_image.save("segmentation_result.jpg")
|
||||
messagebox.showinfo("成功", "分割结果已保存为 segmentation_result.jpg")
|
||||
except Exception as e:
|
||||
messagebox.showerror("错误", f"保存失败: {str(e)}")
|
||||
|
||||
def run(self):
|
||||
"""运行GUI主循环"""
|
||||
self.root.mainloop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
|
||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||
|
||||
# 创建并运行分割器
|
||||
segmenter = SAMPointSegmenter(IMAGE_PATH, SAM_CHECKPOINT, MODEL_TYPE)
|
||||
segmenter.run()
|
||||
Reference in New Issue
Block a user