This commit is contained in:
zcr
2026-01-13 12:01:20 +08:00
parent 8fa66f327f
commit d1d15f7980
2 changed files with 26 additions and 10 deletions

34
main.py
View File

@@ -1,13 +1,15 @@
import io
import os
import urllib.request # 必须这样写,不能只 import urllib
from typing import Optional, List
import cv2
import litserve as ls
import numpy as np
import torch
from PIL import Image, ImageDraw
from minio import Minio
from pydantic import BaseModel
from pydantic import BaseModel, Field
from fastapi import Response # 导入 FastAPI 的 Response
from config import settings
@@ -18,9 +20,11 @@ minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secre
class SAMRequest(BaseModel):
image_path: str
points: list[list[float]]
labels: list[int]
image_path: str = Field(..., description="图片路径,必填字段")
type: str = Field(..., description="推理类型,必填字段")
points: Optional[List[List[float]]] = None
labels: Optional[List[int]] = None
box: Optional[List[int]] = None
class SimpleLitAPI(ls.LitAPI):
@@ -64,11 +68,23 @@ class SimpleLitAPI(ls.LitAPI):
input_points = np.array(request.points)
input_labels = np.array(request.labels)
masks, scores, logits = self.predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
if request.type == "point":
masks, scores, logits = self.predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
elif request.type == "box":
box = np.array(request.box)
masks, scores, logits = self.predictor.predict(
box=box,
multimask_output=False
)
else:
raise ValueError(
f"request.type 参数错误!合法值为 'point''box',实际传入:{request.type}"
)
mask = masks[0] # 获取第一个掩码
image = Image.fromarray(image)

View File

@@ -170,7 +170,7 @@ class SAMBoxSegmenter:
# 准备框坐标
box = np.array(self.current_box)
print(self.current_box)
# 调用SAM生成掩码
masks, _, _ = self.predictor.predict(
box=box,