Files
aida_seg_anything/scripts/test.py
2026-01-08 14:35:23 +08:00

42 lines
1.6 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
import random
import cv2
import numpy as np
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
# 初始化SAM模型
sam = sam_model_registry["vit_h"](checkpoint="/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(sam)
# 读取并转换图像格式
image = cv2.imread("ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # SAM需要RGB格式
image_copy = image.copy() # 用于绘制掩码的原图副本BGR格式适配OpenCV
# 生成掩码并计时
start_time = time.time()
masks = mask_generator.generate(image_rgb)
print(f"掩码生成耗时: {time.time() - start_time:.2f}")
print(f"共生成 {len(masks)} 个掩码")
# 定义颜色生成函数随机RGB颜色适配OpenCV的BGR格式
def get_random_color():
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
# 遍历所有掩码并绘制到图像上
for mask in masks:
# 获取掩码的二进制数组True/False
mask_array = mask["segmentation"]
# 生成随机颜色
color = get_random_color()
# 将掩码区域绘制到图像副本上(半透明效果)
image_copy[mask_array] = image_copy[mask_array] * 0.5 + np.array(color) * 0.5
# 保存结果图像
cv2.imwrite("mask_visualization.jpg", image_copy)
# 显示结果(如果是桌面环境)
cv2.namedWindow("SAM Mask Visualization", cv2.WINDOW_NORMAL)
cv2.imshow("SAM Mask Visualization", image_copy)
cv2.waitKey(0)
cv2.destroyAllWindows()