diff --git a/ros2_sam/ros2_sam/sam.py b/ros2_sam/ros2_sam/sam.py index d8eaa7a..802a97c 100644 --- a/ros2_sam/ros2_sam/sam.py +++ b/ros2_sam/ros2_sam/sam.py @@ -40,6 +40,9 @@ def __del__(self): def segment(self, img, points, point_labels, boxes=None, multimask=True): self._predictor.set_image(img) + if len(points) == 0: + points = None + point_labels = None return self._predictor.predict( point_coords=points, point_labels=point_labels,