This commit is contained in:
zcr
2026-01-08 17:02:33 +08:00
parent b57e61f2f7
commit aab2697ea7
13 changed files with 2113 additions and 6 deletions

4
.gitignore vendored
View File

@@ -39,4 +39,6 @@ _darcs
# demo
**/node_modules
yarn.lock
package-lock.json
package-lock.json
*.env
*.png

40
Dockerfile Normal file
View File

@@ -0,0 +1,40 @@
FROM ghcr.io/astral-sh/uv:latest AS uv_bin
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
# 1. 基础环境配置
ENV UV_LINK_MODE=copy \
UV_COMPILE_BYTECODE=1 \
PYTHONUNBUFFERED=1 \
UV_PROJECT_ENVIRONMENT=/app/.venv
COPY --from=uv_bin /uv /uvx /bin/
RUN apt-get update && apt-get install -y --no-install-recommends \
wget \
libcurl4-openssl-dev \
build-essential \
libgl1 \
libglib2.0-0 \
ca-certificates \
git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY pyproject.toml uv.lock ./
ENV UV_COMPILE_BYTECODE=0
RUN uv sync --frozen --no-dev --no-install-project --python 3.10
# 4. 拷贝项目文件并安装项目本身
COPY . .
RUN uv sync --frozen --no-dev --python 3.10
ENV PATH="/app/.venv/bin:$PATH"
EXPOSE 8000
CMD ["uv", "run","-m","main"]
#CMD ["tail", "-f","/dev/null"]

7
client.py Normal file
View File

@@ -0,0 +1,7 @@
# This file is auto-generated by LitServe.
# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
import requests
response = requests.post("http://127.0.0.1:8777/predict", json={"input": 4.0})
print(f"Status: {response.status_code}\nResponse:\n {response.text}")

30
config.py Normal file
View File

@@ -0,0 +1,30 @@
import os
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
# ⚠️ 注意: 您需要安装 pydantic-settings: pip install pydantic-settings
class Settings(BaseSettings):
"""
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
"""
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
extra='ignore' # 忽略环境变量中多余的键
)
# 启动端口
SERVE_PROD: int = Field(default=8000, description='')
# minio配置
MINIO_URL: str = Field(default="", description="URL")
MINIO_ACCESS: str = Field(default="", description="ACCESS")
MINIO_SECRET: str = Field(default="", description="SECRET")
MINIO_SECURE: bool = Field(default=True, description="SECRET")
MINIO_LC_DATA_PATH: str = Field(default="", description="图片数据路径")
# 创建配置实例,供应用其他部分使用
settings = Settings()

19
docker-compose.yml Normal file
View File

@@ -0,0 +1,19 @@
services:
aida_seg_anything:
build:
context: .
dockerfile: Dockerfile
working_dir: /app
volumes:
- ./:/app
- /etc/localtime:/etc/localtime:ro
ports:
- "10070:8777"
deploy:
resources:
reservations:
devices:
# 告诉 Docker 使用所有可用的 NVIDIA GPU
- driver: nvidia
device_ids: [ '0' ]
capabilities: [ gpu ]

90
main.py Normal file
View File

@@ -0,0 +1,90 @@
import io
import os
import urllib.request # 必须这样写,不能只 import urllib
import cv2
import litserve as ls
import numpy as np
import torch
from PIL import Image, ImageDraw
from minio import Minio
from pydantic import BaseModel
from fastapi import Response # 导入 FastAPI 的 Response
from config import settings
from segment_anything import SamPredictor, sam_model_registry
from utils.minio_client import oss_get_image, oss_upload_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class SAMRequest(BaseModel):
image_path: str
points: list[list[float]]
labels: list[int]
class SimpleLitAPI(ls.LitAPI):
# class SimpleLitAPI():
def setup(self, device):
# def __init__(self, device, sam_checkpoint, model_type="vit_h"):
# 初始化SAM模型
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
model_type = "vit_h"
# 自动化下载检查
if not os.path.exists(sam_checkpoint):
os.makedirs(os.path.dirname(sam_checkpoint))
if not os.path.isfile(sam_checkpoint):
print("正在下载权重文件,请稍候...")
urllib.request.urlretrieve(model_url, sam_checkpoint)
print("下载完成。")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(self.device)
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
def decode_request(self, request: SAMRequest):
return request
def predict(self, request):
# 加载图像
image = oss_get_image(
oss_client=minio_client,
path=request.image_path,
data_type="cv2")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image_rgb)
input_points = np.array(request.points)
input_labels = np.array(request.labels)
masks, scores, logits = self.predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
mask = masks[0] # 获取第一个掩码
image = Image.fromarray(image)
rgba_image = image.convert("RGBA")
rgba_np = np.array(rgba_image)
rgba_np[:, :, 3] = mask.astype(np.uint8) * 255
req = oss_upload_image(
oss_client=minio_client,
bucket="test",
object_name=f"test.png",
image_bytes=cv2.imencode('.png', rgba_np)[1]
)
return {"output": f"{req.bucket_name}/{req.object_name}"}
if __name__ == "__main__":
api = SimpleLitAPI()
server = ls.LitServer(api, accelerator="cuda")
server.run(port=8777)

19
pyproject.toml Normal file
View File

@@ -0,0 +1,19 @@
[project]
name = "aida-seg-anything"
version = "0.1.0"
description = "Add your description here"
requires-python = ">=3.10"
dependencies = [
"litserve>=0.2.17",
"minio>=7.2.20",
"numpy>=2.2.6",
"opencv-python>=4.11.0.86",
"pydantic-settings==2.11.0",
"requests>=2.32.5",
"segment-anything",
"torch>=2.9.1",
"torchvision>=0.24.1",
]
[tool.uv.sources]
segment-anything = { git = "https://github.com/facebookresearch/segment-anything.git" }

View File

@@ -15,6 +15,7 @@ class SAMBoxSegmenter:
# 初始化SAM模型
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(self.device)
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=self.device)
self.predictor = SamPredictor(self.sam)
@@ -229,8 +230,8 @@ class SAMBoxSegmenter:
if __name__ == "__main__":
# 配置参数
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应
# 创建并运行分割器

View File

@@ -148,8 +148,8 @@ class SAMPointSegmenter:
if __name__ == "__main__":
# 配置参数
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
IMAGE_PATH = "/mnt/data/workspace/Code/aida_seg_anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应

View File

@@ -249,7 +249,7 @@ class SAMPointBoxSegmenter:
if __name__ == "__main__":
# 配置路径
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
SAM_CHECKPOINT = "/home/alab/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
MODEL_TYPE = "vit_h" # 模型类型与checkpoint对应
# 运行工具

32
test/click_get_point.py Normal file
View File

@@ -0,0 +1,32 @@
import tkinter as tk
from PIL import Image, ImageTk
def on_click(event):
# 打印格式直接就是你代码里需要的 [x, y]
print(f"[{event.x}, {event.y}]")
def start_picker(image_path):
root = tk.Tk()
root.title("点击获取坐标")
img = Image.open(image_path)
tk_img = ImageTk.PhotoImage(img)
canvas = tk.Canvas(root, width=img.width, height=img.height)
canvas.pack()
canvas.create_image(0, 0, anchor=tk.NW, image=tk_img)
# 绑定左键点击
canvas.bind("<Button-1>", on_click)
print("程序已启动。点击图片,坐标会显示在下方控制台:")
root.mainloop()
if __name__ == "__main__":
# 填入你的图片路径
IMAGE_PATH = "/mnt/data/workspace/Code/aida_seg_anything/utils/1767859286.231138.png"
start_picker(IMAGE_PATH)

88
utils/minio_client.py Normal file
View File

@@ -0,0 +1,88 @@
import io
import json
import logging
import time
from io import BytesIO
import cv2
import numpy as np
import urllib3
from PIL import Image
from minio import Minio
from config import settings
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
# 自定义 Retry 类
class CustomRetry(urllib3.Retry):
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
# 调用父类的 increment 方法
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
# 打印重试信息
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
return new_retry
logger = logging.getLogger()
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
http_client = urllib3.PoolManager(
num_pools=10, # 设置连接池大小
maxsize=10,
timeout=timeout,
cert_reqs='CERT_REQUIRED', # 需要证书验证
retries=CustomRetry(
total=5,
backoff_factor=0.2,
status_forcelist=[500, 502, 503, 504],
),
)
# 获取图片
def oss_get_image(oss_client, path, data_type):
# cv2 默认全通道读取
bucket = path.split("/", 1)[0]
object_name = path.split("/", 1)[1]
image_object = None
try:
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
if data_type == "cv2":
image_bytes = image_data.read()
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
image_object = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
if image_object.dtype == np.uint16:
image_object = (image_object / 256).astype('uint8')
else:
data_bytes = BytesIO(image_data.read())
image_object = Image.open(data_bytes)
except Exception as e:
logger.warning(f"获取图片出现异常 ######: {e}")
return image_object
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
req = None
try:
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
except Exception as e:
logger.warning(f"上传图片出现异常 ######: {e}")
return req
if __name__ == '__main__':
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
url = ["aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png", ]
read_type = "1"
for id, i in enumerate(url):
img = oss_get_image(minio_client, i, read_type)
img = oss_get_image(oss_client=minio_client, path=i, data_type=read_type)
if read_type == "cv2":
cv2.imshow("", img)
cv2.waitKey(0)
else:
img.show()
img.save(f"{time.time()}.png")

1779
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff