1
This commit is contained in:
34
main.py
34
main.py
@@ -1,13 +1,15 @@
|
||||
import io
|
||||
import os
|
||||
import urllib.request # 必须这样写,不能只 import urllib
|
||||
from typing import Optional, List
|
||||
|
||||
import cv2
|
||||
import litserve as ls
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
from minio import Minio
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import Response # 导入 FastAPI 的 Response
|
||||
|
||||
from config import settings
|
||||
@@ -18,9 +20,11 @@ minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secre
|
||||
|
||||
|
||||
class SAMRequest(BaseModel):
|
||||
image_path: str
|
||||
points: list[list[float]]
|
||||
labels: list[int]
|
||||
image_path: str = Field(..., description="图片路径,必填字段")
|
||||
type: str = Field(..., description="推理类型,必填字段")
|
||||
points: Optional[List[List[float]]] = None
|
||||
labels: Optional[List[int]] = None
|
||||
box: Optional[List[int]] = None
|
||||
|
||||
|
||||
class SimpleLitAPI(ls.LitAPI):
|
||||
@@ -64,11 +68,23 @@ class SimpleLitAPI(ls.LitAPI):
|
||||
input_points = np.array(request.points)
|
||||
input_labels = np.array(request.labels)
|
||||
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=input_points,
|
||||
point_labels=input_labels,
|
||||
multimask_output=False
|
||||
)
|
||||
if request.type == "point":
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=input_points,
|
||||
point_labels=input_labels,
|
||||
multimask_output=False
|
||||
)
|
||||
elif request.type == "box":
|
||||
box = np.array(request.box)
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
box=box,
|
||||
multimask_output=False
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"request.type 参数错误!合法值为 'point' 或 'box',实际传入:{request.type}"
|
||||
)
|
||||
|
||||
mask = masks[0] # 获取第一个掩码
|
||||
|
||||
image = Image.fromarray(image)
|
||||
|
||||
@@ -170,7 +170,7 @@ class SAMBoxSegmenter:
|
||||
|
||||
# 准备框坐标
|
||||
box = np.array(self.current_box)
|
||||
|
||||
print(self.current_box)
|
||||
# 调用SAM生成掩码
|
||||
masks, _, _ = self.predictor.predict(
|
||||
box=box,
|
||||
|
||||
Reference in New Issue
Block a user