diff --git a/app/service/clothing_seg/service.py b/app/service/clothing_seg/service.py index a0f3640..7894bff 100644 --- a/app/service/clothing_seg/service.py +++ b/app/service/clothing_seg/service.py @@ -33,7 +33,6 @@ class ClothingSeg: for data in self.image_data: del data["image"] del data["clothing"] - return self.image_data @RunTime @@ -88,14 +87,16 @@ class ClothingSeg: inputs[0].set_data_from_numpy(input0_data) outputs = [ - grpcclient.InferRequestedOutput("OUTPUT0"), + # grpcclient.InferRequestedOutput("OUTPUT0"), grpcclient.InferRequestedOutput("OUTPUT1"), ] + response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs) - output0_data = response.as_numpy("OUTPUT0") - cv2.imwrite("output02.png", output0_data * 100) + # output0_data = response.as_numpy("OUTPUT0") + # cv2.imwrite("output02.png", output0_data * 100) output1_data = response.as_numpy("OUTPUT1") for alpha in output1_data: + alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC) x_min, y_min, x_max, y_max = get_bounding_box(alpha) cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1] cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] @@ -121,7 +122,7 @@ def get_bounding_box(mask): if len(rows) == 0 or len(cols) == 0: # 如果没有找到不为 0 的像素,返回全 0 的边界框 - return (0, 0, 0, 0) + return 0, 0, 0, 0 # 计算边界框的坐标 x_min = np.min(cols) @@ -129,21 +130,21 @@ def get_bounding_box(mask): x_max = np.max(cols) y_max = np.max(rows) - return (x_min, y_min, x_max, y_max) + return x_min, y_min, x_max, y_max if __name__ == "__main__": - request_data = ClothingSegModel( + test_data = ClothingSegModel( user_id=89, image_data=[ - { - "image_url": "test/clothing_seg/dress.jpg", - "image_type": "sketch" - }, - { - "image_url": "test/clothing_seg/skirt_559.jpg", - "image_type": "sketch" - }, + # { + # "image_url": "test/clothing_seg/dress.jpg", + # "image_type": "sketch" + # }, + # { + # "image_url": "test/clothing_seg/skirt_559.jpg", + # "image_type": "sketch" + # }, { "image_url": "test/clothing_seg/10144613.jpg", "image_type": "product" @@ -151,6 +152,6 @@ if __name__ == "__main__": ] ) start_time = time.time() - server = ClothingSeg(request_data) + server = ClothingSeg(test_data) pprint(server.get_result()) print(time.time() - start_time)