Files
aida_seg_anything/main.py

111 lines
3.7 KiB
Python
Raw Normal View History

2026-01-08 17:02:33 +08:00
import io
import os
import urllib.request # 必须这样写,不能只 import urllib
2026-01-13 12:39:51 +08:00
import uuid
2026-01-13 12:01:20 +08:00
from typing import Optional, List
2026-01-08 17:02:33 +08:00
import cv2
import litserve as ls
import numpy as np
import torch
from PIL import Image, ImageDraw
from minio import Minio
2026-01-13 12:01:20 +08:00
from pydantic import BaseModel, Field
2026-01-08 17:02:33 +08:00
from fastapi import Response # 导入 FastAPI 的 Response
from config import settings
from segment_anything import SamPredictor, sam_model_registry
from utils.minio_client import oss_get_image, oss_upload_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class SAMRequest(BaseModel):
2026-01-13 12:39:51 +08:00
user_id: int
2026-01-13 12:01:20 +08:00
image_path: str = Field(..., description="图片路径,必填字段")
type: str = Field(..., description="推理类型,必填字段")
points: Optional[List[List[float]]] = None
labels: Optional[List[int]] = None
box: Optional[List[int]] = None
2026-01-08 17:02:33 +08:00
class SimpleLitAPI(ls.LitAPI):
# class SimpleLitAPI():
def setup(self, device):
# def __init__(self, device, sam_checkpoint, model_type="vit_h"):
# 初始化SAM模型
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
model_type = "vit_h"
# 自动化下载检查
if not os.path.exists(sam_checkpoint):
os.makedirs(os.path.dirname(sam_checkpoint))
if not os.path.isfile(sam_checkpoint):
print("正在下载权重文件,请稍候...")
urllib.request.urlretrieve(model_url, sam_checkpoint)
print("下载完成。")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(self.device)
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
def decode_request(self, request: SAMRequest):
return request
def predict(self, request):
# 加载图像
image = oss_get_image(
oss_client=minio_client,
path=request.image_path,
data_type="cv2")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image_rgb)
input_points = np.array(request.points)
input_labels = np.array(request.labels)
2026-01-13 12:01:20 +08:00
if request.type == "point":
masks, scores, logits = self.predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
elif request.type == "box":
box = np.array(request.box)
masks, scores, logits = self.predictor.predict(
box=box,
multimask_output=False
)
else:
raise ValueError(
f"request.type 参数错误!合法值为 'point''box',实际传入:{request.type}"
)
2026-01-08 17:02:33 +08:00
mask = masks[0] # 获取第一个掩码
image = Image.fromarray(image)
rgba_image = image.convert("RGBA")
rgba_np = np.array(rgba_image)
rgba_np[:, :, 3] = mask.astype(np.uint8) * 255
2026-01-13 12:39:51 +08:00
object_name = f"{request.user_id}/seg_anything/{uuid.uuid4()}"
2026-01-08 17:02:33 +08:00
req = oss_upload_image(
oss_client=minio_client,
2026-01-13 12:39:51 +08:00
bucket="aida-users",
object_name=f"{object_name}.png",
2026-01-08 17:02:33 +08:00
image_bytes=cv2.imencode('.png', rgba_np)[1]
)
return {"output": f"{req.bucket_name}/{req.object_name}"}
if __name__ == "__main__":
api = SimpleLitAPI()
server = ls.LitServer(api, accelerator="cuda")
server.run(port=8777)