feat(新功能): clothing seg
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
51
app/api/api_clothing_seg.py
Normal file
51
app/api/api_clothing_seg.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.schemas.clothing_seg import ClothingSegModel
|
||||||
|
from app.service.clothing_seg.service import ClothingSeg
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/clothing_seg")
|
||||||
|
def clothing_seg(request_item: ClothingSegModel):
|
||||||
|
"""
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **user_id**: 用户id
|
||||||
|
- **image_data**: 图片数据
|
||||||
|
{
|
||||||
|
"image_url": "test/clothing_seg/dress.jpg",
|
||||||
|
"image_type": "product"
|
||||||
|
}
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"user_id": 89,
|
||||||
|
"image_data": [
|
||||||
|
{
|
||||||
|
"image_url": "test/clothing_seg/dress.jpg",
|
||||||
|
"image_type": "sketch"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image_url": "test/clothing_seg/skirt_559.jpg",
|
||||||
|
"image_type": "sketch"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image_url": "test/clothing_seg/10144613.jpg",
|
||||||
|
"image_type": "product"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||||
|
server = ClothingSeg(request_item)
|
||||||
|
result_url = server.get_result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"clothing_seg Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=result_url)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api import api_agent_generate_image, api_recommendation
|
from app.api import api_agent_generate_image
|
||||||
from app.api import api_attribute_retrieve, api_query_image
|
from app.api import api_attribute_retrieve, api_query_image
|
||||||
from app.api import api_brand_dna
|
from app.api import api_brand_dna
|
||||||
from app.api import api_brighten
|
from app.api import api_brighten
|
||||||
@@ -12,6 +12,7 @@ from app.api import api_image2sketch
|
|||||||
from app.api import api_mannequins_edit
|
from app.api import api_mannequins_edit
|
||||||
from app.api import api_pose_transform
|
from app.api import api_pose_transform
|
||||||
from app.api import api_prompt_generation
|
from app.api import api_prompt_generation
|
||||||
|
from app.api import api_clothing_seg
|
||||||
from app.api import api_super_resolution
|
from app.api import api_super_resolution
|
||||||
from app.api import api_test
|
from app.api import api_test
|
||||||
|
|
||||||
@@ -29,7 +30,8 @@ router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix
|
|||||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
# router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||||
router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api")
|
router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api")
|
||||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||||
|
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||||
|
|||||||
6
app/schemas/clothing_seg.py
Normal file
6
app/schemas/clothing_seg.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ClothingSegModel(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
image_data: list[dict]
|
||||||
156
app/service/clothing_seg/service.py
Normal file
156
app/service/clothing_seg/service.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import io
|
||||||
|
import time
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio
|
||||||
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
|
from app.core.config import *
|
||||||
|
from app.schemas.clothing_seg import ClothingSegModel
|
||||||
|
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||||
|
from app.service.utils.decorator import RunTime
|
||||||
|
from app.service.utils.generate_uuid import generate_uuid
|
||||||
|
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||||
|
|
||||||
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
class ClothingSeg:
|
||||||
|
def __init__(self, request_data):
|
||||||
|
self.image_data = request_data.image_data
|
||||||
|
self.user_id = request_data.user_id
|
||||||
|
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
|
||||||
|
|
||||||
|
@RunTime
|
||||||
|
def get_result(self):
|
||||||
|
self.read_image()
|
||||||
|
self.clothing_seg()
|
||||||
|
self.upload_image()
|
||||||
|
for data in self.image_data:
|
||||||
|
del data["image"]
|
||||||
|
del data["clothing"]
|
||||||
|
|
||||||
|
return self.image_data
|
||||||
|
|
||||||
|
@RunTime
|
||||||
|
def upload_image(self):
|
||||||
|
for data in self.image_data:
|
||||||
|
data["clothing_url"] = []
|
||||||
|
for clothing in data["clothing"]:
|
||||||
|
object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png"
|
||||||
|
image_data = io.BytesIO()
|
||||||
|
clothing.save(image_data, format="PNG")
|
||||||
|
image_data.seek(0)
|
||||||
|
image_bytes = image_data.read()
|
||||||
|
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
|
||||||
|
data["clothing_url"].append(f"aida-users/{object_name}")
|
||||||
|
|
||||||
|
@RunTime
|
||||||
|
def read_image(self):
|
||||||
|
for data in self.image_data:
|
||||||
|
url = data["image_url"]
|
||||||
|
image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||||
|
data["image"] = image
|
||||||
|
|
||||||
|
@RunTime
|
||||||
|
def clothing_seg(self):
|
||||||
|
for data in self.image_data:
|
||||||
|
image_type = data["image_type"]
|
||||||
|
image = data["image"]
|
||||||
|
clothing_result = []
|
||||||
|
if image_type == "sketch":
|
||||||
|
seg_mask = get_seg_result(1, image)
|
||||||
|
temp = seg_mask != 0.0
|
||||||
|
mask = (255 * (temp + 0).astype(np.uint8))
|
||||||
|
x_min, y_min, x_max, y_max = get_bounding_box(mask)
|
||||||
|
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||||
|
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||||
|
h, w = cropped_image.shape[:2]
|
||||||
|
mask_pil = Image.fromarray(cropped_mask).convert("L")
|
||||||
|
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
|
||||||
|
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
|
||||||
|
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
|
||||||
|
clothing_result.append(transparent_image)
|
||||||
|
else:
|
||||||
|
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
input0_data = [input_image.astype(np.float32)] * 1
|
||||||
|
input0_data = np.array(input0_data, dtype=np.float32)
|
||||||
|
inputs = [
|
||||||
|
grpcclient.InferInput(
|
||||||
|
"INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs[0].set_data_from_numpy(input0_data)
|
||||||
|
|
||||||
|
outputs = [
|
||||||
|
grpcclient.InferRequestedOutput("OUTPUT0"),
|
||||||
|
grpcclient.InferRequestedOutput("OUTPUT1"),
|
||||||
|
]
|
||||||
|
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
|
||||||
|
output0_data = response.as_numpy("OUTPUT0")
|
||||||
|
cv2.imwrite("output02.png", output0_data * 100)
|
||||||
|
output1_data = response.as_numpy("OUTPUT1")
|
||||||
|
for alpha in output1_data:
|
||||||
|
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
|
||||||
|
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
|
||||||
|
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||||
|
h, w = cropped_image.shape[:2]
|
||||||
|
mask_pil = Image.fromarray(cropped_mask).convert("L")
|
||||||
|
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
|
||||||
|
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
|
||||||
|
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
|
||||||
|
clothing_result.append(transparent_image)
|
||||||
|
data["clothing"] = clothing_result
|
||||||
|
|
||||||
|
|
||||||
|
@RunTime
|
||||||
|
def get_bounding_box(mask):
|
||||||
|
"""
|
||||||
|
从仅包含 0 和 1 的掩码图像中获取边界框。
|
||||||
|
|
||||||
|
:param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1
|
||||||
|
:return: 边界框坐标 (x_min, y_min, x_max, y_max)
|
||||||
|
"""
|
||||||
|
# 找到所有值不为 0 的像素的坐标
|
||||||
|
rows, cols = np.where(mask != 0)
|
||||||
|
|
||||||
|
if len(rows) == 0 or len(cols) == 0:
|
||||||
|
# 如果没有找到不为 0 的像素,返回全 0 的边界框
|
||||||
|
return (0, 0, 0, 0)
|
||||||
|
|
||||||
|
# 计算边界框的坐标
|
||||||
|
x_min = np.min(cols)
|
||||||
|
y_min = np.min(rows)
|
||||||
|
x_max = np.max(cols)
|
||||||
|
y_max = np.max(rows)
|
||||||
|
|
||||||
|
return (x_min, y_min, x_max, y_max)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
request_data = ClothingSegModel(
|
||||||
|
user_id=89,
|
||||||
|
image_data=[
|
||||||
|
{
|
||||||
|
"image_url": "test/clothing_seg/dress.jpg",
|
||||||
|
"image_type": "sketch"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image_url": "test/clothing_seg/skirt_559.jpg",
|
||||||
|
"image_type": "sketch"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image_url": "test/clothing_seg/10144613.jpg",
|
||||||
|
"image_type": "product"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
server = ClothingSeg(request_data)
|
||||||
|
pprint(server.get_result())
|
||||||
|
print(time.time() - start_time)
|
||||||
Reference in New Issue
Block a user