diff --git a/app/service/clothing_seg/service.py b/app/service/clothing_seg/service.py index 46cc444..6d2b4d0 100644 --- a/app/service/clothing_seg/service.py +++ b/app/service/clothing_seg/service.py @@ -62,7 +62,11 @@ class ClothingSeg: image = data["image"] clothing_result = [] if image_type == "sketch": - seg_mask = get_seg_result(1, image[:, :, :3]) + if len(image.shape) == 2: + image = np.stack([image] * 3, axis=-1) + seg_mask = get_seg_result(1, image[:, :, :3]) + else: + seg_mask = get_seg_result(1, image[:, :, :3]) temp = seg_mask != 0.0 mask = (255 * (temp + 0).astype(np.uint8)) x_min, y_min, x_max, y_max = get_bounding_box(mask)