42 lines
1.6 KiB
Python
42 lines
1.6 KiB
Python
|
|
import time
|
|||
|
|
import random
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
|||
|
|
|
|||
|
|
# 初始化SAM模型
|
|||
|
|
sam = sam_model_registry["vit_h"](checkpoint="/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth")
|
|||
|
|
mask_generator = SamAutomaticMaskGenerator(sam)
|
|||
|
|
|
|||
|
|
# 读取并转换图像格式
|
|||
|
|
image = cv2.imread("ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg")
|
|||
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # SAM需要RGB格式
|
|||
|
|
image_copy = image.copy() # 用于绘制掩码的原图副本(BGR格式,适配OpenCV)
|
|||
|
|
|
|||
|
|
# 生成掩码并计时
|
|||
|
|
start_time = time.time()
|
|||
|
|
masks = mask_generator.generate(image_rgb)
|
|||
|
|
print(f"掩码生成耗时: {time.time() - start_time:.2f} 秒")
|
|||
|
|
print(f"共生成 {len(masks)} 个掩码")
|
|||
|
|
|
|||
|
|
# 定义颜色生成函数(随机RGB颜色,适配OpenCV的BGR格式)
|
|||
|
|
def get_random_color():
|
|||
|
|
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
|||
|
|
|
|||
|
|
# 遍历所有掩码并绘制到图像上
|
|||
|
|
for mask in masks:
|
|||
|
|
# 获取掩码的二进制数组(True/False)
|
|||
|
|
mask_array = mask["segmentation"]
|
|||
|
|
# 生成随机颜色
|
|||
|
|
color = get_random_color()
|
|||
|
|
# 将掩码区域绘制到图像副本上(半透明效果)
|
|||
|
|
image_copy[mask_array] = image_copy[mask_array] * 0.5 + np.array(color) * 0.5
|
|||
|
|
|
|||
|
|
# 保存结果图像
|
|||
|
|
cv2.imwrite("mask_visualization.jpg", image_copy)
|
|||
|
|
|
|||
|
|
# 显示结果(如果是桌面环境)
|
|||
|
|
cv2.namedWindow("SAM Mask Visualization", cv2.WINDOW_NORMAL)
|
|||
|
|
cv2.imshow("SAM Mask Visualization", image_copy)
|
|||
|
|
cv2.waitKey(0)
|
|||
|
|
cv2.destroyAllWindows()
|