1
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -39,4 +39,6 @@ _darcs
|
||||
# demo
|
||||
**/node_modules
|
||||
yarn.lock
|
||||
package-lock.json
|
||||
package-lock.json
|
||||
*.env
|
||||
*.png
|
||||
40
Dockerfile
Normal file
40
Dockerfile
Normal file
@@ -0,0 +1,40 @@
|
||||
FROM ghcr.io/astral-sh/uv:latest AS uv_bin
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
|
||||
# 1. 基础环境配置
|
||||
ENV UV_LINK_MODE=copy \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
UV_PROJECT_ENVIRONMENT=/app/.venv
|
||||
|
||||
COPY --from=uv_bin /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
libcurl4-openssl-dev \
|
||||
build-essential \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
ca-certificates \
|
||||
git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
ENV UV_COMPILE_BYTECODE=0
|
||||
|
||||
RUN uv sync --frozen --no-dev --no-install-project --python 3.10
|
||||
|
||||
# 4. 拷贝项目文件并安装项目本身
|
||||
COPY . .
|
||||
RUN uv sync --frozen --no-dev --python 3.10
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["uv", "run","-m","main"]
|
||||
#CMD ["tail", "-f","/dev/null"]
|
||||
|
||||
7
client.py
Normal file
7
client.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# This file is auto-generated by LitServe.
|
||||
# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
|
||||
|
||||
import requests
|
||||
|
||||
response = requests.post("http://127.0.0.1:8777/predict", json={"input": 4.0})
|
||||
print(f"Status: {response.status_code}\nResponse:\n {response.text}")
|
||||
30
config.py
Normal file
30
config.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
# ⚠️ 注意: 您需要安装 pydantic-settings: pip install pydantic-settings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
|
||||
"""
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
extra='ignore' # 忽略环境变量中多余的键
|
||||
)
|
||||
# 启动端口
|
||||
SERVE_PROD: int = Field(default=8000, description='')
|
||||
# minio配置
|
||||
MINIO_URL: str = Field(default="", description="URL")
|
||||
MINIO_ACCESS: str = Field(default="", description="ACCESS")
|
||||
MINIO_SECRET: str = Field(default="", description="SECRET")
|
||||
MINIO_SECURE: bool = Field(default=True, description="SECRET")
|
||||
MINIO_LC_DATA_PATH: str = Field(default="", description="图片数据路径")
|
||||
|
||||
|
||||
# 创建配置实例,供应用其他部分使用
|
||||
settings = Settings()
|
||||
19
docker-compose.yml
Normal file
19
docker-compose.yml
Normal file
@@ -0,0 +1,19 @@
|
||||
services:
|
||||
aida_seg_anything:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- ./:/app
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
ports:
|
||||
- "10070:8777"
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
# 告诉 Docker 使用所有可用的 NVIDIA GPU
|
||||
- driver: nvidia
|
||||
device_ids: [ '0' ]
|
||||
capabilities: [ gpu ]
|
||||
90
main.py
Normal file
90
main.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import io
|
||||
import os
|
||||
import urllib.request # 必须这样写,不能只 import urllib
|
||||
import cv2
|
||||
import litserve as ls
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
from minio import Minio
|
||||
from pydantic import BaseModel
|
||||
from fastapi import Response # 导入 FastAPI 的 Response
|
||||
|
||||
from config import settings
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
from utils.minio_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
class SAMRequest(BaseModel):
|
||||
image_path: str
|
||||
points: list[list[float]]
|
||||
labels: list[int]
|
||||
|
||||
|
||||
class SimpleLitAPI(ls.LitAPI):
|
||||
|
||||
# class SimpleLitAPI():
|
||||
def setup(self, device):
|
||||
# def __init__(self, device, sam_checkpoint, model_type="vit_h"):
|
||||
# 初始化SAM模型
|
||||
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
|
||||
model_type = "vit_h"
|
||||
|
||||
# 自动化下载检查
|
||||
if not os.path.exists(sam_checkpoint):
|
||||
os.makedirs(os.path.dirname(sam_checkpoint))
|
||||
|
||||
if not os.path.isfile(sam_checkpoint):
|
||||
print("正在下载权重文件,请稍候...")
|
||||
urllib.request.urlretrieve(model_url, sam_checkpoint)
|
||||
print("下载完成。")
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(self.device)
|
||||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
self.sam.to(device=self.device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
|
||||
def decode_request(self, request: SAMRequest):
|
||||
return request
|
||||
|
||||
def predict(self, request):
|
||||
# 加载图像
|
||||
image = oss_get_image(
|
||||
oss_client=minio_client,
|
||||
path=request.image_path,
|
||||
data_type="cv2")
|
||||
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
self.predictor.set_image(image_rgb)
|
||||
|
||||
input_points = np.array(request.points)
|
||||
input_labels = np.array(request.labels)
|
||||
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=input_points,
|
||||
point_labels=input_labels,
|
||||
multimask_output=False
|
||||
)
|
||||
mask = masks[0] # 获取第一个掩码
|
||||
|
||||
image = Image.fromarray(image)
|
||||
rgba_image = image.convert("RGBA")
|
||||
rgba_np = np.array(rgba_image)
|
||||
rgba_np[:, :, 3] = mask.astype(np.uint8) * 255
|
||||
req = oss_upload_image(
|
||||
oss_client=minio_client,
|
||||
bucket="test",
|
||||
object_name=f"test.png",
|
||||
image_bytes=cv2.imencode('.png', rgba_np)[1]
|
||||
)
|
||||
return {"output": f"{req.bucket_name}/{req.object_name}"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = SimpleLitAPI()
|
||||
server = ls.LitServer(api, accelerator="cuda")
|
||||
server.run(port=8777)
|
||||
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[project]
|
||||
name = "aida-seg-anything"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"litserve>=0.2.17",
|
||||
"minio>=7.2.20",
|
||||
"numpy>=2.2.6",
|
||||
"opencv-python>=4.11.0.86",
|
||||
"pydantic-settings==2.11.0",
|
||||
"requests>=2.32.5",
|
||||
"segment-anything",
|
||||
"torch>=2.9.1",
|
||||
"torchvision>=0.24.1",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
segment-anything = { git = "https://github.com/facebookresearch/segment-anything.git" }
|
||||
@@ -15,6 +15,7 @@ class SAMBoxSegmenter:
|
||||
|
||||
# 初始化SAM模型
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(self.device)
|
||||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
self.sam.to(device=self.device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
@@ -229,8 +230,8 @@ class SAMBoxSegmenter:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||
|
||||
# 创建并运行分割器
|
||||
|
||||
@@ -148,8 +148,8 @@ class SAMPointSegmenter:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
IMAGE_PATH = "/mnt/data/workspace/Code/aida_seg_anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
|
||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ class SAMPointBoxSegmenter:
|
||||
if __name__ == "__main__":
|
||||
# 配置路径
|
||||
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||
SAM_CHECKPOINT = "/home/alab/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||
|
||||
# 运行工具
|
||||
|
||||
32
test/click_get_point.py
Normal file
32
test/click_get_point.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import tkinter as tk
|
||||
from PIL import Image, ImageTk
|
||||
|
||||
|
||||
def on_click(event):
|
||||
# 打印格式直接就是你代码里需要的 [x, y]
|
||||
print(f"[{event.x}, {event.y}]")
|
||||
|
||||
|
||||
def start_picker(image_path):
|
||||
root = tk.Tk()
|
||||
root.title("点击获取坐标")
|
||||
|
||||
|
||||
img = Image.open(image_path)
|
||||
tk_img = ImageTk.PhotoImage(img)
|
||||
|
||||
canvas = tk.Canvas(root, width=img.width, height=img.height)
|
||||
canvas.pack()
|
||||
canvas.create_image(0, 0, anchor=tk.NW, image=tk_img)
|
||||
|
||||
# 绑定左键点击
|
||||
canvas.bind("<Button-1>", on_click)
|
||||
|
||||
print("程序已启动。点击图片,坐标会显示在下方控制台:")
|
||||
root.mainloop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 填入你的图片路径
|
||||
IMAGE_PATH = "/mnt/data/workspace/Code/aida_seg_anything/utils/1767859286.231138.png"
|
||||
start_picker(IMAGE_PATH)
|
||||
88
utils/minio_client.py
Normal file
88
utils/minio_client.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from config import settings
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
# 自定义 Retry 类
|
||||
class CustomRetry(urllib3.Retry):
|
||||
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
|
||||
# 调用父类的 increment 方法
|
||||
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
|
||||
# 打印重试信息
|
||||
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
|
||||
return new_retry
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
|
||||
http_client = urllib3.PoolManager(
|
||||
num_pools=10, # 设置连接池大小
|
||||
maxsize=10,
|
||||
timeout=timeout,
|
||||
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||
retries=CustomRetry(
|
||||
total=5,
|
||||
backoff_factor=0.2,
|
||||
status_forcelist=[500, 502, 503, 504],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 获取图片
|
||||
def oss_get_image(oss_client, path, data_type):
|
||||
# cv2 默认全通道读取
|
||||
bucket = path.split("/", 1)[0]
|
||||
object_name = path.split("/", 1)[1]
|
||||
image_object = None
|
||||
try:
|
||||
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||
if data_type == "cv2":
|
||||
image_bytes = image_data.read()
|
||||
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
|
||||
image_object = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
if image_object.dtype == np.uint16:
|
||||
image_object = (image_object / 256).astype('uint8')
|
||||
else:
|
||||
data_bytes = BytesIO(image_data.read())
|
||||
image_object = Image.open(data_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取图片出现异常 ######: {e}")
|
||||
return image_object
|
||||
|
||||
|
||||
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
|
||||
req = None
|
||||
try:
|
||||
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||
except Exception as e:
|
||||
logger.warning(f"上传图片出现异常 ######: {e}")
|
||||
return req
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
|
||||
url = ["aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png", ]
|
||||
read_type = "1"
|
||||
for id, i in enumerate(url):
|
||||
img = oss_get_image(minio_client, i, read_type)
|
||||
img = oss_get_image(oss_client=minio_client, path=i, data_type=read_type)
|
||||
if read_type == "cv2":
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img.show()
|
||||
img.save(f"{time.time()}.png")
|
||||
Reference in New Issue
Block a user