Files
aida_seg_anything/scripts/test.py

42 lines
1.6 KiB
Python
Raw Normal View History

2026-01-08 14:35:23 +08:00
import time
import random
import cv2
import numpy as np
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
# 初始化SAM模型
sam = sam_model_registry["vit_h"](checkpoint="/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(sam)
# 读取并转换图像格式
image = cv2.imread("ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # SAM需要RGB格式
image_copy = image.copy() # 用于绘制掩码的原图副本BGR格式适配OpenCV
# 生成掩码并计时
start_time = time.time()
masks = mask_generator.generate(image_rgb)
print(f"掩码生成耗时: {time.time() - start_time:.2f}")
print(f"共生成 {len(masks)} 个掩码")
# 定义颜色生成函数随机RGB颜色适配OpenCV的BGR格式
def get_random_color():
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
# 遍历所有掩码并绘制到图像上
for mask in masks:
# 获取掩码的二进制数组True/False
mask_array = mask["segmentation"]
# 生成随机颜色
color = get_random_color()
# 将掩码区域绘制到图像副本上(半透明效果)
image_copy[mask_array] = image_copy[mask_array] * 0.5 + np.array(color) * 0.5
# 保存结果图像
cv2.imwrite("mask_visualization.jpg", image_copy)
# 显示结果(如果是桌面环境)
cv2.namedWindow("SAM Mask Visualization", cv2.WINDOW_NORMAL)
cv2.imshow("SAM Mask Visualization", image_copy)
cv2.waitKey(0)
cv2.destroyAllWindows()