feat(新功能): clothing seg

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-04-11 17:14:59 +08:00
parent 3593f0d431
commit f83a202b20
4 changed files with 217 additions and 2 deletions

View File

@@ -0,0 +1,51 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.response_template import ResponseModel
from app.schemas.clothing_seg import ClothingSegModel
from app.service.clothing_seg.service import ClothingSeg
router = APIRouter()
logger = logging.getLogger()
@router.post("/clothing_seg")
def clothing_seg(request_item: ClothingSegModel):
"""
创建一个具有以下参数的请求体:
- **user_id**: 用户id
- **image_data**: 图片数据
{
"image_url": "test/clothing_seg/dress.jpg",
"image_type": "product"
}
示例参数:
{
"user_id": 89,
"image_data": [
{
"image_url": "test/clothing_seg/dress.jpg",
"image_type": "sketch"
},
{
"image_url": "test/clothing_seg/skirt_559.jpg",
"image_type": "sketch"
},
{
"image_url": "test/clothing_seg/10144613.jpg",
"image_type": "product"
}
]
}
"""
try:
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}")
server = ClothingSeg(request_item)
result_url = server.get_result()
except Exception as e:
logger.warning(f"clothing_seg Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=result_url)

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter
from app.api import api_agent_generate_image, api_recommendation
from app.api import api_agent_generate_image
from app.api import api_attribute_retrieve, api_query_image
from app.api import api_brand_dna
from app.api import api_brighten
@@ -12,6 +12,7 @@ from app.api import api_image2sketch
from app.api import api_mannequins_edit
from app.api import api_pose_transform
from app.api import api_prompt_generation
from app.api import api_clothing_seg
from app.api import api_super_resolution
from app.api import api_test
@@ -29,7 +30,8 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
# router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ClothingSegModel(BaseModel):
user_id: str
image_data: list[dict]

View File

@@ -0,0 +1,156 @@
import io
import time
from pprint import pprint
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.clothing_seg import ClothingSegModel
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class ClothingSeg:
def __init__(self, request_data):
self.image_data = request_data.image_data
self.user_id = request_data.user_id
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
@RunTime
def get_result(self):
self.read_image()
self.clothing_seg()
self.upload_image()
for data in self.image_data:
del data["image"]
del data["clothing"]
return self.image_data
@RunTime
def upload_image(self):
for data in self.image_data:
data["clothing_url"] = []
for clothing in data["clothing"]:
object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png"
image_data = io.BytesIO()
clothing.save(image_data, format="PNG")
image_data.seek(0)
image_bytes = image_data.read()
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
data["clothing_url"].append(f"aida-users/{object_name}")
@RunTime
def read_image(self):
for data in self.image_data:
url = data["image_url"]
image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
data["image"] = image
@RunTime
def clothing_seg(self):
for data in self.image_data:
image_type = data["image_type"]
image = data["image"]
clothing_result = []
if image_type == "sketch":
seg_mask = get_seg_result(1, image)
temp = seg_mask != 0.0
mask = (255 * (temp + 0).astype(np.uint8))
x_min, y_min, x_max, y_max = get_bounding_box(mask)
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
h, w = cropped_image.shape[:2]
mask_pil = Image.fromarray(cropped_mask).convert("L")
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
clothing_result.append(transparent_image)
else:
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input0_data = [input_image.astype(np.float32)] * 1
input0_data = np.array(input0_data, dtype=np.float32)
inputs = [
grpcclient.InferInput(
"INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
),
]
inputs[0].set_data_from_numpy(input0_data)
outputs = [
grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
output0_data = response.as_numpy("OUTPUT0")
cv2.imwrite("output02.png", output0_data * 100)
output1_data = response.as_numpy("OUTPUT1")
for alpha in output1_data:
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
h, w = cropped_image.shape[:2]
mask_pil = Image.fromarray(cropped_mask).convert("L")
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
clothing_result.append(transparent_image)
data["clothing"] = clothing_result
@RunTime
def get_bounding_box(mask):
"""
从仅包含 0 和 1 的掩码图像中获取边界框。
:param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1
:return: 边界框坐标 (x_min, y_min, x_max, y_max)
"""
# 找到所有值不为 0 的像素的坐标
rows, cols = np.where(mask != 0)
if len(rows) == 0 or len(cols) == 0:
# 如果没有找到不为 0 的像素,返回全 0 的边界框
return (0, 0, 0, 0)
# 计算边界框的坐标
x_min = np.min(cols)
y_min = np.min(rows)
x_max = np.max(cols)
y_max = np.max(rows)
return (x_min, y_min, x_max, y_max)
if __name__ == "__main__":
request_data = ClothingSegModel(
user_id=89,
image_data=[
{
"image_url": "test/clothing_seg/dress.jpg",
"image_type": "sketch"
},
{
"image_url": "test/clothing_seg/skirt_559.jpg",
"image_type": "sketch"
},
{
"image_url": "test/clothing_seg/10144613.jpg",
"image_type": "product"
}
]
)
start_time = time.time()
server = ClothingSeg(request_data)
pprint(server.get_result())
print(time.time() - start_time)