diff --git a/main.py b/main.py index 06b627f..0df5e42 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,15 @@ import io import os import urllib.request # 必须这样写,不能只 import urllib +from typing import Optional, List + import cv2 import litserve as ls import numpy as np import torch from PIL import Image, ImageDraw from minio import Minio -from pydantic import BaseModel +from pydantic import BaseModel, Field from fastapi import Response # 导入 FastAPI 的 Response from config import settings @@ -18,9 +20,11 @@ minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secre class SAMRequest(BaseModel): - image_path: str - points: list[list[float]] - labels: list[int] + image_path: str = Field(..., description="图片路径,必填字段") + type: str = Field(..., description="推理类型,必填字段") + points: Optional[List[List[float]]] = None + labels: Optional[List[int]] = None + box: Optional[List[int]] = None class SimpleLitAPI(ls.LitAPI): @@ -64,11 +68,23 @@ class SimpleLitAPI(ls.LitAPI): input_points = np.array(request.points) input_labels = np.array(request.labels) - masks, scores, logits = self.predictor.predict( - point_coords=input_points, - point_labels=input_labels, - multimask_output=False - ) + if request.type == "point": + masks, scores, logits = self.predictor.predict( + point_coords=input_points, + point_labels=input_labels, + multimask_output=False + ) + elif request.type == "box": + box = np.array(request.box) + masks, scores, logits = self.predictor.predict( + box=box, + multimask_output=False + ) + else: + raise ValueError( + f"request.type 参数错误!合法值为 'point' 或 'box',实际传入:{request.type}" + ) + mask = masks[0] # 获取第一个掩码 image = Image.fromarray(image) diff --git a/scripts/amg_box.py b/scripts/amg_box.py index f9a654c..20609ac 100755 --- a/scripts/amg_box.py +++ b/scripts/amg_box.py @@ -170,7 +170,7 @@ class SAMBoxSegmenter: # 准备框坐标 box = np.array(self.current_box) - + print(self.current_box) # 调用SAM生成掩码 masks, _, _ = self.predictor.predict( box=box,