1
This commit is contained in:
34
main.py
34
main.py
@@ -1,13 +1,15 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import urllib.request # 必须这样写,不能只 import urllib
|
import urllib.request # 必须这样写,不能只 import urllib
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import litserve as ls
|
import litserve as ls
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from fastapi import Response # 导入 FastAPI 的 Response
|
from fastapi import Response # 导入 FastAPI 的 Response
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
@@ -18,9 +20,11 @@ minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secre
|
|||||||
|
|
||||||
|
|
||||||
class SAMRequest(BaseModel):
|
class SAMRequest(BaseModel):
|
||||||
image_path: str
|
image_path: str = Field(..., description="图片路径,必填字段")
|
||||||
points: list[list[float]]
|
type: str = Field(..., description="推理类型,必填字段")
|
||||||
labels: list[int]
|
points: Optional[List[List[float]]] = None
|
||||||
|
labels: Optional[List[int]] = None
|
||||||
|
box: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
class SimpleLitAPI(ls.LitAPI):
|
class SimpleLitAPI(ls.LitAPI):
|
||||||
@@ -64,11 +68,23 @@ class SimpleLitAPI(ls.LitAPI):
|
|||||||
input_points = np.array(request.points)
|
input_points = np.array(request.points)
|
||||||
input_labels = np.array(request.labels)
|
input_labels = np.array(request.labels)
|
||||||
|
|
||||||
masks, scores, logits = self.predictor.predict(
|
if request.type == "point":
|
||||||
point_coords=input_points,
|
masks, scores, logits = self.predictor.predict(
|
||||||
point_labels=input_labels,
|
point_coords=input_points,
|
||||||
multimask_output=False
|
point_labels=input_labels,
|
||||||
)
|
multimask_output=False
|
||||||
|
)
|
||||||
|
elif request.type == "box":
|
||||||
|
box = np.array(request.box)
|
||||||
|
masks, scores, logits = self.predictor.predict(
|
||||||
|
box=box,
|
||||||
|
multimask_output=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"request.type 参数错误!合法值为 'point' 或 'box',实际传入:{request.type}"
|
||||||
|
)
|
||||||
|
|
||||||
mask = masks[0] # 获取第一个掩码
|
mask = masks[0] # 获取第一个掩码
|
||||||
|
|
||||||
image = Image.fromarray(image)
|
image = Image.fromarray(image)
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class SAMBoxSegmenter:
|
|||||||
|
|
||||||
# 准备框坐标
|
# 准备框坐标
|
||||||
box = np.array(self.current_box)
|
box = np.array(self.current_box)
|
||||||
|
print(self.current_box)
|
||||||
# 调用SAM生成掩码
|
# 调用SAM生成掩码
|
||||||
masks, _, _ = self.predictor.predict(
|
masks, _, _ = self.predictor.predict(
|
||||||
box=box,
|
box=box,
|
||||||
|
|||||||
Reference in New Issue
Block a user