Files
AiDA_Python/app/service/clothing_seg/service.py
zchengrong fed9d27bf5 feat(新功能): 优化clothing seg
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-04-14 15:11:29 +08:00

158 lines
6.1 KiB
Python

import io
import time
from pprint import pprint
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.clothing_seg import ClothingSegModel
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class ClothingSeg:
def __init__(self, request_data):
self.image_data = request_data.image_data
self.user_id = request_data.user_id
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
@RunTime
def get_result(self):
self.read_image()
self.clothing_seg()
self.upload_image()
for data in self.image_data:
del data["image"]
del data["clothing"]
return self.image_data
@RunTime
def upload_image(self):
for data in self.image_data:
data["clothing_url"] = []
for clothing in data["clothing"]:
object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png"
image_data = io.BytesIO()
clothing.save(image_data, format="PNG")
image_data.seek(0)
image_bytes = image_data.read()
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
data["clothing_url"].append(f"aida-users/{object_name}")
@RunTime
def read_image(self):
for data in self.image_data:
url = data["image_url"]
image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
data["image"] = image
@RunTime
def clothing_seg(self):
for data in self.image_data:
image_type = data["image_type"]
image = data["image"]
clothing_result = []
if image_type == "sketch":
seg_mask = get_seg_result(1, image)
temp = seg_mask != 0.0
mask = (255 * (temp + 0).astype(np.uint8))
x_min, y_min, x_max, y_max = get_bounding_box(mask)
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
h, w = cropped_image.shape[:2]
mask_pil = Image.fromarray(cropped_mask).convert("L")
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
clothing_result.append(transparent_image)
else:
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input0_data = [input_image.astype(np.float32)] * 1
input0_data = np.array(input0_data, dtype=np.float32)
inputs = [
grpcclient.InferInput(
"INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
),
]
inputs[0].set_data_from_numpy(input0_data)
outputs = [
# grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
# output0_data = response.as_numpy("OUTPUT0")
# cv2.imwrite("output02.png", output0_data * 100)
output1_data = response.as_numpy("OUTPUT1")
for alpha in output1_data:
alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
h, w = cropped_image.shape[:2]
mask_pil = Image.fromarray(cropped_mask).convert("L")
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
clothing_result.append(transparent_image)
data["clothing"] = clothing_result
@RunTime
def get_bounding_box(mask):
"""
从仅包含 0 和 1 的掩码图像中获取边界框。
:param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1
:return: 边界框坐标 (x_min, y_min, x_max, y_max)
"""
# 找到所有值不为 0 的像素的坐标
rows, cols = np.where(mask != 0)
if len(rows) == 0 or len(cols) == 0:
# 如果没有找到不为 0 的像素,返回全 0 的边界框
return 0, 0, 0, 0
# 计算边界框的坐标
x_min = np.min(cols)
y_min = np.min(rows)
x_max = np.max(cols)
y_max = np.max(rows)
return x_min, y_min, x_max, y_max
if __name__ == "__main__":
test_data = ClothingSegModel(
user_id=89,
image_data=[
# {
# "image_url": "test/clothing_seg/dress.jpg",
# "image_type": "sketch"
# },
# {
# "image_url": "test/clothing_seg/skirt_559.jpg",
# "image_type": "sketch"
# },
{
"image_url": "test/clothing_seg/10144613.jpg",
"image_type": "product"
}
]
)
start_time = time.time()
server = ClothingSeg(test_data)
pprint(server.get_result())
print(time.time() - start_time)