1
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -39,4 +39,6 @@ _darcs
|
|||||||
# demo
|
# demo
|
||||||
**/node_modules
|
**/node_modules
|
||||||
yarn.lock
|
yarn.lock
|
||||||
package-lock.json
|
package-lock.json
|
||||||
|
*.env
|
||||||
|
*.png
|
||||||
40
Dockerfile
Normal file
40
Dockerfile
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
FROM ghcr.io/astral-sh/uv:latest AS uv_bin
|
||||||
|
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||||
|
|
||||||
|
# 1. 基础环境配置
|
||||||
|
ENV UV_LINK_MODE=copy \
|
||||||
|
UV_COMPILE_BYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
UV_PROJECT_ENVIRONMENT=/app/.venv
|
||||||
|
|
||||||
|
COPY --from=uv_bin /uv /uvx /bin/
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
wget \
|
||||||
|
libcurl4-openssl-dev \
|
||||||
|
build-essential \
|
||||||
|
libgl1 \
|
||||||
|
libglib2.0-0 \
|
||||||
|
ca-certificates \
|
||||||
|
git \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=0
|
||||||
|
|
||||||
|
RUN uv sync --frozen --no-dev --no-install-project --python 3.10
|
||||||
|
|
||||||
|
# 4. 拷贝项目文件并安装项目本身
|
||||||
|
COPY . .
|
||||||
|
RUN uv sync --frozen --no-dev --python 3.10
|
||||||
|
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
CMD ["uv", "run","-m","main"]
|
||||||
|
#CMD ["tail", "-f","/dev/null"]
|
||||||
|
|
||||||
7
client.py
Normal file
7
client.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# This file is auto-generated by LitServe.
|
||||||
|
# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post("http://127.0.0.1:8777/predict", json={"input": 4.0})
|
||||||
|
print(f"Status: {response.status_code}\nResponse:\n {response.text}")
|
||||||
30
config.py
Normal file
30
config.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
|
# ⚠️ 注意: 您需要安装 pydantic-settings: pip install pydantic-settings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""
|
||||||
|
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
|
||||||
|
"""
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file='.env',
|
||||||
|
env_file_encoding='utf-8',
|
||||||
|
extra='ignore' # 忽略环境变量中多余的键
|
||||||
|
)
|
||||||
|
# 启动端口
|
||||||
|
SERVE_PROD: int = Field(default=8000, description='')
|
||||||
|
# minio配置
|
||||||
|
MINIO_URL: str = Field(default="", description="URL")
|
||||||
|
MINIO_ACCESS: str = Field(default="", description="ACCESS")
|
||||||
|
MINIO_SECRET: str = Field(default="", description="SECRET")
|
||||||
|
MINIO_SECURE: bool = Field(default=True, description="SECRET")
|
||||||
|
MINIO_LC_DATA_PATH: str = Field(default="", description="图片数据路径")
|
||||||
|
|
||||||
|
|
||||||
|
# 创建配置实例,供应用其他部分使用
|
||||||
|
settings = Settings()
|
||||||
19
docker-compose.yml
Normal file
19
docker-compose.yml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
services:
|
||||||
|
aida_seg_anything:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
working_dir: /app
|
||||||
|
volumes:
|
||||||
|
- ./:/app
|
||||||
|
- /etc/localtime:/etc/localtime:ro
|
||||||
|
ports:
|
||||||
|
- "10070:8777"
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
# 告诉 Docker 使用所有可用的 NVIDIA GPU
|
||||||
|
- driver: nvidia
|
||||||
|
device_ids: [ '0' ]
|
||||||
|
capabilities: [ gpu ]
|
||||||
90
main.py
Normal file
90
main.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
import urllib.request # 必须这样写,不能只 import urllib
|
||||||
|
import cv2
|
||||||
|
import litserve as ls
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
from minio import Minio
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi import Response # 导入 FastAPI 的 Response
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
from segment_anything import SamPredictor, sam_model_registry
|
||||||
|
from utils.minio_client import oss_get_image, oss_upload_image
|
||||||
|
|
||||||
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
class SAMRequest(BaseModel):
|
||||||
|
image_path: str
|
||||||
|
points: list[list[float]]
|
||||||
|
labels: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLitAPI(ls.LitAPI):
|
||||||
|
|
||||||
|
# class SimpleLitAPI():
|
||||||
|
def setup(self, device):
|
||||||
|
# def __init__(self, device, sam_checkpoint, model_type="vit_h"):
|
||||||
|
# 初始化SAM模型
|
||||||
|
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||||
|
sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
|
||||||
|
model_type = "vit_h"
|
||||||
|
|
||||||
|
# 自动化下载检查
|
||||||
|
if not os.path.exists(sam_checkpoint):
|
||||||
|
os.makedirs(os.path.dirname(sam_checkpoint))
|
||||||
|
|
||||||
|
if not os.path.isfile(sam_checkpoint):
|
||||||
|
print("正在下载权重文件,请稍候...")
|
||||||
|
urllib.request.urlretrieve(model_url, sam_checkpoint)
|
||||||
|
print("下载完成。")
|
||||||
|
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(self.device)
|
||||||
|
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||||
|
self.sam.to(device=self.device)
|
||||||
|
self.predictor = SamPredictor(self.sam)
|
||||||
|
|
||||||
|
def decode_request(self, request: SAMRequest):
|
||||||
|
return request
|
||||||
|
|
||||||
|
def predict(self, request):
|
||||||
|
# 加载图像
|
||||||
|
image = oss_get_image(
|
||||||
|
oss_client=minio_client,
|
||||||
|
path=request.image_path,
|
||||||
|
data_type="cv2")
|
||||||
|
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
self.predictor.set_image(image_rgb)
|
||||||
|
|
||||||
|
input_points = np.array(request.points)
|
||||||
|
input_labels = np.array(request.labels)
|
||||||
|
|
||||||
|
masks, scores, logits = self.predictor.predict(
|
||||||
|
point_coords=input_points,
|
||||||
|
point_labels=input_labels,
|
||||||
|
multimask_output=False
|
||||||
|
)
|
||||||
|
mask = masks[0] # 获取第一个掩码
|
||||||
|
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
rgba_image = image.convert("RGBA")
|
||||||
|
rgba_np = np.array(rgba_image)
|
||||||
|
rgba_np[:, :, 3] = mask.astype(np.uint8) * 255
|
||||||
|
req = oss_upload_image(
|
||||||
|
oss_client=minio_client,
|
||||||
|
bucket="test",
|
||||||
|
object_name=f"test.png",
|
||||||
|
image_bytes=cv2.imencode('.png', rgba_np)[1]
|
||||||
|
)
|
||||||
|
return {"output": f"{req.bucket_name}/{req.object_name}"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
api = SimpleLitAPI()
|
||||||
|
server = ls.LitServer(api, accelerator="cuda")
|
||||||
|
server.run(port=8777)
|
||||||
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[project]
|
||||||
|
name = "aida-seg-anything"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"litserve>=0.2.17",
|
||||||
|
"minio>=7.2.20",
|
||||||
|
"numpy>=2.2.6",
|
||||||
|
"opencv-python>=4.11.0.86",
|
||||||
|
"pydantic-settings==2.11.0",
|
||||||
|
"requests>=2.32.5",
|
||||||
|
"segment-anything",
|
||||||
|
"torch>=2.9.1",
|
||||||
|
"torchvision>=0.24.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
segment-anything = { git = "https://github.com/facebookresearch/segment-anything.git" }
|
||||||
@@ -15,6 +15,7 @@ class SAMBoxSegmenter:
|
|||||||
|
|
||||||
# 初始化SAM模型
|
# 初始化SAM模型
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(self.device)
|
||||||
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||||
self.sam.to(device=self.device)
|
self.sam.to(device=self.device)
|
||||||
self.predictor = SamPredictor(self.sam)
|
self.predictor = SamPredictor(self.sam)
|
||||||
@@ -229,8 +230,8 @@ class SAMBoxSegmenter:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 配置参数
|
# 配置参数
|
||||||
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||||
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||||
|
|
||||||
# 创建并运行分割器
|
# 创建并运行分割器
|
||||||
|
|||||||
@@ -148,8 +148,8 @@ class SAMPointSegmenter:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 配置参数
|
# 配置参数
|
||||||
IMAGE_PATH = "/workspace/PycharmProjects/segment-anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
IMAGE_PATH = "/mnt/data/workspace/Code/aida_seg_anything/scripts/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||||
SAM_CHECKPOINT = "/workspace/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||||
|
|
||||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||||
|
|
||||||
|
|||||||
@@ -249,7 +249,7 @@ class SAMPointBoxSegmenter:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 配置路径
|
# 配置路径
|
||||||
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
IMAGE_PATH = "ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg" # 替换为你的图像路径
|
||||||
SAM_CHECKPOINT = "/home/alab/PycharmProjects/segment-anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
SAM_CHECKPOINT = "/mnt/data/workspace/Code/aida_seg_anything/checkpoint/sam_vit_h_4b8939.pth" # 替换为你的SAM模型路径
|
||||||
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
MODEL_TYPE = "vit_h" # 模型类型,与checkpoint对应
|
||||||
|
|
||||||
# 运行工具
|
# 运行工具
|
||||||
|
|||||||
32
test/click_get_point.py
Normal file
32
test/click_get_point.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import tkinter as tk
|
||||||
|
from PIL import Image, ImageTk
|
||||||
|
|
||||||
|
|
||||||
|
def on_click(event):
|
||||||
|
# 打印格式直接就是你代码里需要的 [x, y]
|
||||||
|
print(f"[{event.x}, {event.y}]")
|
||||||
|
|
||||||
|
|
||||||
|
def start_picker(image_path):
|
||||||
|
root = tk.Tk()
|
||||||
|
root.title("点击获取坐标")
|
||||||
|
|
||||||
|
|
||||||
|
img = Image.open(image_path)
|
||||||
|
tk_img = ImageTk.PhotoImage(img)
|
||||||
|
|
||||||
|
canvas = tk.Canvas(root, width=img.width, height=img.height)
|
||||||
|
canvas.pack()
|
||||||
|
canvas.create_image(0, 0, anchor=tk.NW, image=tk_img)
|
||||||
|
|
||||||
|
# 绑定左键点击
|
||||||
|
canvas.bind("<Button-1>", on_click)
|
||||||
|
|
||||||
|
print("程序已启动。点击图片,坐标会显示在下方控制台:")
|
||||||
|
root.mainloop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 填入你的图片路径
|
||||||
|
IMAGE_PATH = "/mnt/data/workspace/Code/aida_seg_anything/utils/1767859286.231138.png"
|
||||||
|
start_picker(IMAGE_PATH)
|
||||||
88
utils/minio_client.py
Normal file
88
utils/minio_client.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import urllib3
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义 Retry 类
|
||||||
|
class CustomRetry(urllib3.Retry):
|
||||||
|
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
|
||||||
|
# 调用父类的 increment 方法
|
||||||
|
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
|
||||||
|
# 打印重试信息
|
||||||
|
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
|
||||||
|
return new_retry
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
|
||||||
|
http_client = urllib3.PoolManager(
|
||||||
|
num_pools=10, # 设置连接池大小
|
||||||
|
maxsize=10,
|
||||||
|
timeout=timeout,
|
||||||
|
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||||
|
retries=CustomRetry(
|
||||||
|
total=5,
|
||||||
|
backoff_factor=0.2,
|
||||||
|
status_forcelist=[500, 502, 503, 504],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 获取图片
|
||||||
|
def oss_get_image(oss_client, path, data_type):
|
||||||
|
# cv2 默认全通道读取
|
||||||
|
bucket = path.split("/", 1)[0]
|
||||||
|
object_name = path.split("/", 1)[1]
|
||||||
|
image_object = None
|
||||||
|
try:
|
||||||
|
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||||
|
if data_type == "cv2":
|
||||||
|
image_bytes = image_data.read()
|
||||||
|
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
|
||||||
|
image_object = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||||
|
if image_object.dtype == np.uint16:
|
||||||
|
image_object = (image_object / 256).astype('uint8')
|
||||||
|
else:
|
||||||
|
data_bytes = BytesIO(image_data.read())
|
||||||
|
image_object = Image.open(data_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取图片出现异常 ######: {e}")
|
||||||
|
return image_object
|
||||||
|
|
||||||
|
|
||||||
|
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
|
||||||
|
req = None
|
||||||
|
try:
|
||||||
|
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"上传图片出现异常 ######: {e}")
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
|
||||||
|
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
|
||||||
|
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
|
||||||
|
url = ["aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png", ]
|
||||||
|
read_type = "1"
|
||||||
|
for id, i in enumerate(url):
|
||||||
|
img = oss_get_image(minio_client, i, read_type)
|
||||||
|
img = oss_get_image(oss_client=minio_client, path=i, data_type=read_type)
|
||||||
|
if read_type == "cv2":
|
||||||
|
cv2.imshow("", img)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
else:
|
||||||
|
img.show()
|
||||||
|
img.save(f"{time.time()}.png")
|
||||||
Reference in New Issue
Block a user