import io import os import urllib.request # 必须这样写,不能只 import urllib import uuid from typing import Optional, List import cv2 import litserve as ls import numpy as np import torch from PIL import Image, ImageDraw from minio import Minio from pydantic import BaseModel, Field from fastapi import Response # 导入 FastAPI 的 Response from config import settings from segment_anything import SamPredictor, sam_model_registry from utils.minio_client import oss_get_image, oss_upload_image minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) class SAMRequest(BaseModel): user_id: int image_path: str = Field(..., description="图片路径,必填字段") type: str = Field(..., description="推理类型,必填字段") points: Optional[List[List[float]]] | None = None labels: Optional[List[int]] | None = None box: Optional[List[int]] | None = None class SimpleLitAPI(ls.LitAPI): # class SimpleLitAPI(): def setup(self, device): # def __init__(self, device, sam_checkpoint, model_type="vit_h"): # 初始化SAM模型 model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth" model_type = "vit_h" # 自动化下载检查 if not os.path.exists(sam_checkpoint): os.makedirs(os.path.dirname(sam_checkpoint)) if not os.path.isfile(sam_checkpoint): print("正在下载权重文件,请稍候...") urllib.request.urlretrieve(model_url, sam_checkpoint) print("下载完成。") self.device = "cuda" if torch.cuda.is_available() else "cpu" print(self.device) self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) self.sam.to(device=self.device) self.predictor = SamPredictor(self.sam) def decode_request(self, request: SAMRequest): return request def predict(self, request): # 加载图像 image = oss_get_image( oss_client=minio_client, path=request.image_path, data_type="cv2") image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.predictor.set_image(image_rgb) input_points = np.array(request.points) input_labels = np.array(request.labels) if request.type == "point": masks, scores, logits = self.predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False ) elif request.type == "box": box = np.array(request.box) masks, scores, logits = self.predictor.predict( box=box, multimask_output=False ) else: raise ValueError( f"request.type 参数错误!合法值为 'point' 或 'box',实际传入:{request.type}" ) mask = masks[0] # 获取第一个掩码 image = Image.fromarray(image) rgba_image = image.convert("RGBA") rgba_np = np.array(rgba_image) rgba_np[:, :, 3] = mask.astype(np.uint8) * 255 object_name = f"{request.user_id}/seg_anything/{uuid.uuid4()}" req = oss_upload_image( oss_client=minio_client, bucket="aida-users", object_name=f"{object_name}.png", image_bytes=cv2.imencode('.png', rgba_np)[1] ) return {"output": f"{req.bucket_name}/{req.object_name}"} if __name__ == "__main__": api = SimpleLitAPI() server = ls.LitServer(api, accelerator="cuda") server.run(port=8777)