1
This commit is contained in:
90
main.py
Normal file
90
main.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import io
|
||||
import os
|
||||
import urllib.request # 必须这样写,不能只 import urllib
|
||||
import cv2
|
||||
import litserve as ls
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
from minio import Minio
|
||||
from pydantic import BaseModel
|
||||
from fastapi import Response # 导入 FastAPI 的 Response
|
||||
|
||||
from config import settings
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
from utils.minio_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
class SAMRequest(BaseModel):
|
||||
image_path: str
|
||||
points: list[list[float]]
|
||||
labels: list[int]
|
||||
|
||||
|
||||
class SimpleLitAPI(ls.LitAPI):
|
||||
|
||||
# class SimpleLitAPI():
|
||||
def setup(self, device):
|
||||
# def __init__(self, device, sam_checkpoint, model_type="vit_h"):
|
||||
# 初始化SAM模型
|
||||
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
|
||||
model_type = "vit_h"
|
||||
|
||||
# 自动化下载检查
|
||||
if not os.path.exists(sam_checkpoint):
|
||||
os.makedirs(os.path.dirname(sam_checkpoint))
|
||||
|
||||
if not os.path.isfile(sam_checkpoint):
|
||||
print("正在下载权重文件,请稍候...")
|
||||
urllib.request.urlretrieve(model_url, sam_checkpoint)
|
||||
print("下载完成。")
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(self.device)
|
||||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
self.sam.to(device=self.device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
|
||||
def decode_request(self, request: SAMRequest):
|
||||
return request
|
||||
|
||||
def predict(self, request):
|
||||
# 加载图像
|
||||
image = oss_get_image(
|
||||
oss_client=minio_client,
|
||||
path=request.image_path,
|
||||
data_type="cv2")
|
||||
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
self.predictor.set_image(image_rgb)
|
||||
|
||||
input_points = np.array(request.points)
|
||||
input_labels = np.array(request.labels)
|
||||
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=input_points,
|
||||
point_labels=input_labels,
|
||||
multimask_output=False
|
||||
)
|
||||
mask = masks[0] # 获取第一个掩码
|
||||
|
||||
image = Image.fromarray(image)
|
||||
rgba_image = image.convert("RGBA")
|
||||
rgba_np = np.array(rgba_image)
|
||||
rgba_np[:, :, 3] = mask.astype(np.uint8) * 255
|
||||
req = oss_upload_image(
|
||||
oss_client=minio_client,
|
||||
bucket="test",
|
||||
object_name=f"test.png",
|
||||
image_bytes=cv2.imencode('.png', rgba_np)[1]
|
||||
)
|
||||
return {"output": f"{req.bucket_name}/{req.object_name}"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = SimpleLitAPI()
|
||||
server = ls.LitServer(api, accelerator="cuda")
|
||||
server.run(port=8777)
|
||||
Reference in New Issue
Block a user