feat(新功能): 优化clothing seg

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zchengrong
2025-04-14 15:11:29 +08:00
parent 5a93673b52
commit fed9d27bf5

View File

@@ -33,7 +33,6 @@ class ClothingSeg:
for data in self.image_data: for data in self.image_data:
del data["image"] del data["image"]
del data["clothing"] del data["clothing"]
return self.image_data return self.image_data
@RunTime @RunTime
@@ -88,14 +87,16 @@ class ClothingSeg:
inputs[0].set_data_from_numpy(input0_data) inputs[0].set_data_from_numpy(input0_data)
outputs = [ outputs = [
grpcclient.InferRequestedOutput("OUTPUT0"), # grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"), grpcclient.InferRequestedOutput("OUTPUT1"),
] ]
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs) response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
output0_data = response.as_numpy("OUTPUT0") # output0_data = response.as_numpy("OUTPUT0")
cv2.imwrite("output02.png", output0_data * 100) # cv2.imwrite("output02.png", output0_data * 100)
output1_data = response.as_numpy("OUTPUT1") output1_data = response.as_numpy("OUTPUT1")
for alpha in output1_data: for alpha in output1_data:
alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
x_min, y_min, x_max, y_max = get_bounding_box(alpha) x_min, y_min, x_max, y_max = get_bounding_box(alpha)
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1] cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
@@ -121,7 +122,7 @@ def get_bounding_box(mask):
if len(rows) == 0 or len(cols) == 0: if len(rows) == 0 or len(cols) == 0:
# 如果没有找到不为 0 的像素,返回全 0 的边界框 # 如果没有找到不为 0 的像素,返回全 0 的边界框
return (0, 0, 0, 0) return 0, 0, 0, 0
# 计算边界框的坐标 # 计算边界框的坐标
x_min = np.min(cols) x_min = np.min(cols)
@@ -129,21 +130,21 @@ def get_bounding_box(mask):
x_max = np.max(cols) x_max = np.max(cols)
y_max = np.max(rows) y_max = np.max(rows)
return (x_min, y_min, x_max, y_max) return x_min, y_min, x_max, y_max
if __name__ == "__main__": if __name__ == "__main__":
request_data = ClothingSegModel( test_data = ClothingSegModel(
user_id=89, user_id=89,
image_data=[ image_data=[
{ # {
"image_url": "test/clothing_seg/dress.jpg", # "image_url": "test/clothing_seg/dress.jpg",
"image_type": "sketch" # "image_type": "sketch"
}, # },
{ # {
"image_url": "test/clothing_seg/skirt_559.jpg", # "image_url": "test/clothing_seg/skirt_559.jpg",
"image_type": "sketch" # "image_type": "sketch"
}, # },
{ {
"image_url": "test/clothing_seg/10144613.jpg", "image_url": "test/clothing_seg/10144613.jpg",
"image_type": "product" "image_type": "product"
@@ -151,6 +152,6 @@ if __name__ == "__main__":
] ]
) )
start_time = time.time() start_time = time.time()
server = ClothingSeg(request_data) server = ClothingSeg(test_data)
pprint(server.get_result()) pprint(server.get_result())
print(time.time() - start_time) print(time.time() - start_time)