Files
aida_seg_anything/main.py
2026-03-27 14:47:25 +08:00

111 lines
3.6 KiB
Python

import os
import urllib.request
from typing import Optional, List
import cv2
import litserve as ls
import numpy as np
import torch
from PIL import Image
from minio import Minio
from pydantic import BaseModel, Field
from config import settings
from segment_anything import SamPredictor, sam_model_registry
from utils.minio_client import oss_get_image, oss_upload_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class SAMRequest(BaseModel):
bucket: str = Field(...)
object_name: str = Field(...)
image_path: str = Field(..., description="图片路径,必填字段")
type: str = Field(..., description="推理类型,必填字段")
points: Optional[List[List[float]]] | None = None
labels: Optional[List[int]] | None = None
box: Optional[List[int]] | None = None
class SimpleLitAPI(ls.LitAPI):
# class SimpleLitAPI():
def setup(self, device):
# def __init__(self, device, sam_checkpoint, model_type="vit_h"):
# 初始化SAM模型
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
model_type = "vit_h"
# 自动化下载检查
if not os.path.exists(sam_checkpoint):
os.makedirs(os.path.dirname(sam_checkpoint))
if not os.path.isfile(sam_checkpoint):
print("正在下载权重文件,请稍候...")
urllib.request.urlretrieve(model_url, sam_checkpoint)
print("下载完成。")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(self.device)
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
def decode_request(self, request: SAMRequest):
return request
def predict(self, request):
# 加载图像
image = oss_get_image(
oss_client=minio_client,
path=request.image_path,
data_type="cv2")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image_rgb)
input_points = np.array(request.points)
input_labels = np.array(request.labels)
if request.type == "point":
masks, scores, logits = self.predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
elif request.type == "box":
box = np.array(request.box)
masks, scores, logits = self.predictor.predict(
box=box,
multimask_output=False
)
else:
raise ValueError(
f"request.type 参数错误!合法值为 'point''box',实际传入:{request.type}"
)
mask = masks[0] # 获取第一个掩码
image = Image.fromarray(image)
rgba_image = image.convert("RGBA")
rgba_np = np.array(rgba_image)
rgba_np[:, :, 3] = mask.astype(np.uint8) * 255
bucket = request.bucket
object_name = request.object_name
req = oss_upload_image(
oss_client=minio_client,
bucket=bucket,
object_name=object_name,
image_bytes=cv2.imencode('.png', rgba_np)[1]
)
return {"output": f"{req.bucket_name}/{req.object_name}"}
if __name__ == "__main__":
api = SimpleLitAPI()
server = ls.LitServer(api, accelerator="cuda")
server.run(port=8777)