feat(新功能): clothing seg
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
156
app/service/clothing_seg/service.py
Normal file
156
app/service/clothing_seg/service.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import io
|
||||
import time
|
||||
from pprint import pprint
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.clothing_seg import ClothingSegModel
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
class ClothingSeg:
|
||||
def __init__(self, request_data):
|
||||
self.image_data = request_data.image_data
|
||||
self.user_id = request_data.user_id
|
||||
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
|
||||
|
||||
@RunTime
|
||||
def get_result(self):
|
||||
self.read_image()
|
||||
self.clothing_seg()
|
||||
self.upload_image()
|
||||
for data in self.image_data:
|
||||
del data["image"]
|
||||
del data["clothing"]
|
||||
|
||||
return self.image_data
|
||||
|
||||
@RunTime
|
||||
def upload_image(self):
|
||||
for data in self.image_data:
|
||||
data["clothing_url"] = []
|
||||
for clothing in data["clothing"]:
|
||||
object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png"
|
||||
image_data = io.BytesIO()
|
||||
clothing.save(image_data, format="PNG")
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
|
||||
data["clothing_url"].append(f"aida-users/{object_name}")
|
||||
|
||||
@RunTime
|
||||
def read_image(self):
|
||||
for data in self.image_data:
|
||||
url = data["image_url"]
|
||||
image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
data["image"] = image
|
||||
|
||||
@RunTime
|
||||
def clothing_seg(self):
|
||||
for data in self.image_data:
|
||||
image_type = data["image_type"]
|
||||
image = data["image"]
|
||||
clothing_result = []
|
||||
if image_type == "sketch":
|
||||
seg_mask = get_seg_result(1, image)
|
||||
temp = seg_mask != 0.0
|
||||
mask = (255 * (temp + 0).astype(np.uint8))
|
||||
x_min, y_min, x_max, y_max = get_bounding_box(mask)
|
||||
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
h, w = cropped_image.shape[:2]
|
||||
mask_pil = Image.fromarray(cropped_mask).convert("L")
|
||||
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
|
||||
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
|
||||
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
|
||||
clothing_result.append(transparent_image)
|
||||
else:
|
||||
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
input0_data = [input_image.astype(np.float32)] * 1
|
||||
input0_data = np.array(input0_data, dtype=np.float32)
|
||||
inputs = [
|
||||
grpcclient.InferInput(
|
||||
"INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
|
||||
),
|
||||
]
|
||||
|
||||
inputs[0].set_data_from_numpy(input0_data)
|
||||
|
||||
outputs = [
|
||||
grpcclient.InferRequestedOutput("OUTPUT0"),
|
||||
grpcclient.InferRequestedOutput("OUTPUT1"),
|
||||
]
|
||||
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
|
||||
output0_data = response.as_numpy("OUTPUT0")
|
||||
cv2.imwrite("output02.png", output0_data * 100)
|
||||
output1_data = response.as_numpy("OUTPUT1")
|
||||
for alpha in output1_data:
|
||||
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
|
||||
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
|
||||
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
|
||||
h, w = cropped_image.shape[:2]
|
||||
mask_pil = Image.fromarray(cropped_mask).convert("L")
|
||||
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
|
||||
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
|
||||
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
|
||||
clothing_result.append(transparent_image)
|
||||
data["clothing"] = clothing_result
|
||||
|
||||
|
||||
@RunTime
|
||||
def get_bounding_box(mask):
|
||||
"""
|
||||
从仅包含 0 和 1 的掩码图像中获取边界框。
|
||||
|
||||
:param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1
|
||||
:return: 边界框坐标 (x_min, y_min, x_max, y_max)
|
||||
"""
|
||||
# 找到所有值不为 0 的像素的坐标
|
||||
rows, cols = np.where(mask != 0)
|
||||
|
||||
if len(rows) == 0 or len(cols) == 0:
|
||||
# 如果没有找到不为 0 的像素,返回全 0 的边界框
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
# 计算边界框的坐标
|
||||
x_min = np.min(cols)
|
||||
y_min = np.min(rows)
|
||||
x_max = np.max(cols)
|
||||
y_max = np.max(rows)
|
||||
|
||||
return (x_min, y_min, x_max, y_max)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
request_data = ClothingSegModel(
|
||||
user_id=89,
|
||||
image_data=[
|
||||
{
|
||||
"image_url": "test/clothing_seg/dress.jpg",
|
||||
"image_type": "sketch"
|
||||
},
|
||||
{
|
||||
"image_url": "test/clothing_seg/skirt_559.jpg",
|
||||
"image_type": "sketch"
|
||||
},
|
||||
{
|
||||
"image_url": "test/clothing_seg/10144613.jpg",
|
||||
"image_type": "product"
|
||||
}
|
||||
]
|
||||
)
|
||||
start_time = time.time()
|
||||
server = ClothingSeg(request_data)
|
||||
pprint(server.get_result())
|
||||
print(time.time() - start_time)
|
||||
Reference in New Issue
Block a user