feat(新功能): 优化clothing seg

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-04-14 15:11:29 +08:00
parent 5a93673b52
commit fed9d27bf5

View File

@@ -33,7 +33,6 @@ class ClothingSeg:
for data in self.image_data:
del data["image"]
del data["clothing"]
return self.image_data
@RunTime
@@ -88,14 +87,16 @@ class ClothingSeg:
inputs[0].set_data_from_numpy(input0_data)
outputs = [
grpcclient.InferRequestedOutput("OUTPUT0"),
# grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
output0_data = response.as_numpy("OUTPUT0")
cv2.imwrite("output02.png", output0_data * 100)
# output0_data = response.as_numpy("OUTPUT0")
# cv2.imwrite("output02.png", output0_data * 100)
output1_data = response.as_numpy("OUTPUT1")
for alpha in output1_data:
alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
@@ -121,7 +122,7 @@ def get_bounding_box(mask):
if len(rows) == 0 or len(cols) == 0:
# 如果没有找到不为 0 的像素,返回全 0 的边界框
return (0, 0, 0, 0)
return 0, 0, 0, 0
# 计算边界框的坐标
x_min = np.min(cols)
@@ -129,21 +130,21 @@ def get_bounding_box(mask):
x_max = np.max(cols)
y_max = np.max(rows)
return (x_min, y_min, x_max, y_max)
return x_min, y_min, x_max, y_max
if __name__ == "__main__":
request_data = ClothingSegModel(
test_data = ClothingSegModel(
user_id=89,
image_data=[
{
"image_url": "test/clothing_seg/dress.jpg",
"image_type": "sketch"
},
{
"image_url": "test/clothing_seg/skirt_559.jpg",
"image_type": "sketch"
},
# {
# "image_url": "test/clothing_seg/dress.jpg",
# "image_type": "sketch"
# },
# {
# "image_url": "test/clothing_seg/skirt_559.jpg",
# "image_type": "sketch"
# },
{
"image_url": "test/clothing_seg/10144613.jpg",
"image_type": "product"
@@ -151,6 +152,6 @@ if __name__ == "__main__":
]
)
start_time = time.time()
server = ClothingSeg(request_data)
server = ClothingSeg(test_data)
pprint(server.get_result())
print(time.time() - start_time)