91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import io
|
|
import os
|
|
import urllib.request # 必须这样写,不能只 import urllib
|
|
import cv2
|
|
import litserve as ls
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image, ImageDraw
|
|
from minio import Minio
|
|
from pydantic import BaseModel
|
|
from fastapi import Response # 导入 FastAPI 的 Response
|
|
|
|
from config import settings
|
|
from segment_anything import SamPredictor, sam_model_registry
|
|
from utils.minio_client import oss_get_image, oss_upload_image
|
|
|
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
|
|
|
|
|
class SAMRequest(BaseModel):
|
|
image_path: str
|
|
points: list[list[float]]
|
|
labels: list[int]
|
|
|
|
|
|
class SimpleLitAPI(ls.LitAPI):
|
|
|
|
# class SimpleLitAPI():
|
|
def setup(self, device):
|
|
# def __init__(self, device, sam_checkpoint, model_type="vit_h"):
|
|
# 初始化SAM模型
|
|
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
|
sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
|
|
model_type = "vit_h"
|
|
|
|
# 自动化下载检查
|
|
if not os.path.exists(sam_checkpoint):
|
|
os.makedirs(os.path.dirname(sam_checkpoint))
|
|
|
|
if not os.path.isfile(sam_checkpoint):
|
|
print("正在下载权重文件,请稍候...")
|
|
urllib.request.urlretrieve(model_url, sam_checkpoint)
|
|
print("下载完成。")
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(self.device)
|
|
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
|
self.sam.to(device=self.device)
|
|
self.predictor = SamPredictor(self.sam)
|
|
|
|
def decode_request(self, request: SAMRequest):
|
|
return request
|
|
|
|
def predict(self, request):
|
|
# 加载图像
|
|
image = oss_get_image(
|
|
oss_client=minio_client,
|
|
path=request.image_path,
|
|
data_type="cv2")
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
self.predictor.set_image(image_rgb)
|
|
|
|
input_points = np.array(request.points)
|
|
input_labels = np.array(request.labels)
|
|
|
|
masks, scores, logits = self.predictor.predict(
|
|
point_coords=input_points,
|
|
point_labels=input_labels,
|
|
multimask_output=False
|
|
)
|
|
mask = masks[0] # 获取第一个掩码
|
|
|
|
image = Image.fromarray(image)
|
|
rgba_image = image.convert("RGBA")
|
|
rgba_np = np.array(rgba_image)
|
|
rgba_np[:, :, 3] = mask.astype(np.uint8) * 255
|
|
req = oss_upload_image(
|
|
oss_client=minio_client,
|
|
bucket="test",
|
|
object_name=f"test.png",
|
|
image_bytes=cv2.imencode('.png', rgba_np)[1]
|
|
)
|
|
return {"output": f"{req.bucket_name}/{req.object_name}"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
api = SimpleLitAPI()
|
|
server = ls.LitServer(api, accelerator="cuda")
|
|
server.run(port=8777)
|