158 lines
6.1 KiB
Python
158 lines
6.1 KiB
Python
import io
|
|
import time
|
|
from pprint import pprint
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import tritonclient.grpc as grpcclient
|
|
from PIL import Image
|
|
from minio import Minio
|
|
from tritonclient.utils import np_to_triton_dtype
|
|
|
|
from app.core.config import *
|
|
from app.schemas.clothing_seg import ClothingSegModel
|
|
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
|
from app.service.utils.decorator import RunTime
|
|
from app.service.utils.generate_uuid import generate_uuid
|
|
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
|
|
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
|
|
|
|
|
class ClothingSeg:
|
|
def __init__(self, request_data):
|
|
self.image_data = request_data.image_data
|
|
self.user_id = request_data.user_id
|
|
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
|
|
|
|
@RunTime
|
|
def get_result(self):
|
|
self.read_image()
|
|
self.clothing_seg()
|
|
self.upload_image()
|
|
for data in self.image_data:
|
|
del data["image"]
|
|
del data["clothing"]
|
|
return self.image_data
|
|
|
|
@RunTime
|
|
def upload_image(self):
|
|
for data in self.image_data:
|
|
data["clothing_url"] = []
|
|
for clothing in data["clothing"]:
|
|
object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png"
|
|
image_data = io.BytesIO()
|
|
clothing.save(image_data, format="PNG")
|
|
image_data.seek(0)
|
|
image_bytes = image_data.read()
|
|
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
|
|
data["clothing_url"].append(f"aida-users/{object_name}")
|
|
|
|
@RunTime
|
|
def read_image(self):
|
|
for data in self.image_data:
|
|
url = data["image_url"]
|
|
image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
|
data["image"] = image
|
|
|
|
@RunTime
|
|
def clothing_seg(self):
|
|
for data in self.image_data:
|
|
image_type = data["image_type"]
|
|
image = data["image"]
|
|
clothing_result = []
|
|
if image_type == "sketch":
|
|
seg_mask = get_seg_result(1, image[:, :, :3])
|
|
temp = seg_mask != 0.0
|
|
mask = (255 * (temp + 0).astype(np.uint8))
|
|
x_min, y_min, x_max, y_max = get_bounding_box(mask)
|
|
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
|
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
|
h, w = cropped_image.shape[:2]
|
|
mask_pil = Image.fromarray(cropped_mask).convert("L")
|
|
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
|
|
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
|
|
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
|
|
clothing_result.append(transparent_image)
|
|
else:
|
|
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
input0_data = [input_image.astype(np.float32)] * 1
|
|
input0_data = np.array(input0_data, dtype=np.float32)
|
|
inputs = [
|
|
grpcclient.InferInput(
|
|
"INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
|
|
),
|
|
]
|
|
|
|
inputs[0].set_data_from_numpy(input0_data)
|
|
|
|
outputs = [
|
|
# grpcclient.InferRequestedOutput("OUTPUT0"),
|
|
grpcclient.InferRequestedOutput("OUTPUT1"),
|
|
]
|
|
|
|
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
|
|
# output0_data = response.as_numpy("OUTPUT0")
|
|
# cv2.imwrite("output02.png", output0_data * 100)
|
|
output1_data = response.as_numpy("OUTPUT1")
|
|
for alpha in output1_data:
|
|
alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
|
|
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
|
|
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
|
|
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
|
h, w = cropped_image.shape[:2]
|
|
mask_pil = Image.fromarray(cropped_mask).convert("L")
|
|
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
|
|
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
|
|
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
|
|
clothing_result.append(transparent_image)
|
|
data["clothing"] = clothing_result
|
|
|
|
|
|
@RunTime
|
|
def get_bounding_box(mask):
|
|
"""
|
|
从仅包含 0 和 1 的掩码图像中获取边界框。
|
|
|
|
:param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1
|
|
:return: 边界框坐标 (x_min, y_min, x_max, y_max)
|
|
"""
|
|
# 找到所有值不为 0 的像素的坐标
|
|
rows, cols = np.where(mask != 0)
|
|
|
|
if len(rows) == 0 or len(cols) == 0:
|
|
# 如果没有找到不为 0 的像素,返回全 0 的边界框
|
|
return 0, 0, 0, 0
|
|
|
|
# 计算边界框的坐标
|
|
x_min = np.min(cols)
|
|
y_min = np.min(rows)
|
|
x_max = np.max(cols)
|
|
y_max = np.max(rows)
|
|
|
|
return x_min, y_min, x_max, y_max
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_data = ClothingSegModel(
|
|
user_id=89,
|
|
image_data=[
|
|
# {
|
|
# "image_url": "test/clothing_seg/dress.jpg",
|
|
# "image_type": "sketch"
|
|
# },
|
|
# {
|
|
# "image_url": "test/clothing_seg/skirt_559.jpg",
|
|
# "image_type": "sketch"
|
|
# },
|
|
{
|
|
"image_url": "test/clothing_seg/10144613.jpg",
|
|
"image_type": "product"
|
|
}
|
|
]
|
|
)
|
|
start_time = time.time()
|
|
server = ClothingSeg(test_data)
|
|
pprint(server.get_result())
|
|
print(time.time() - start_time)
|