Skip to content

Commit

Permalink
Fix numpy validation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samet-akcay committed Sep 19, 2024
1 parent 8116506 commit 92b96fd
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions tests/unit/data/validators/numpy/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,26 +132,40 @@ def test_validate_image_valid_single_channel(self) -> None:
def test_validate_gt_label_valid(self) -> None:
"""Test validation of valid ground truth labels."""
labels = np.array([0, 1])
validated_labels = self.validator.validate_gt_label(labels)
batch_size = 2
validated_labels = self.validator.validate_gt_label(labels, batch_size=batch_size)
assert isinstance(validated_labels, np.ndarray)
assert validated_labels.dtype == bool
assert np.array_equal(validated_labels, np.array([False, True]))

def test_validate_gt_label_none(self) -> None:
"""Test validation of None ground truth labels."""
assert self.validator.validate_gt_label(None) is None
assert self.validator.validate_gt_label(None, batch_size=2) is None

def test_validate_gt_label_invalid_type(self) -> None:
"""Test validation of ground truth labels with invalid type."""
with pytest.raises(TypeError, match="Ground truth label must be an integer or a numpy.ndarray"):
# Test with batch_size provided
# This test case no longer raises an error
validated_labels = self.validator.validate_gt_label(["0", "1"], batch_size=2)
assert validated_labels is not None
assert isinstance(validated_labels, np.ndarray)
assert validated_labels.dtype == bool
assert np.array_equal(validated_labels, np.array([False, True]))

# Test without batch_size
with pytest.raises(TypeError):
self.validator.validate_gt_label(["0", "1"])

def test_validate_gt_label_invalid_dimensions(self) -> None:
"""Test validation of ground truth labels with invalid dimensions."""
with pytest.raises(ValueError, match="Ground truth label must be 1-dimensional"):
self.validator.validate_gt_label(np.array([[0, 1], [1, 0]]))
with pytest.raises(ValueError, match="Ground truth label batch must be 1-dimensional, got shape \\(2, 2\\)"):
self.validator.validate_gt_label(np.array([[0, 1], [1, 0]]), batch_size=2)

def test_validate_gt_label_invalid_dtype(self) -> None:
"""Test validation of ground truth labels with invalid dtype."""
with pytest.raises(TypeError, match="Ground truth label must be boolean or integer"):
self.validator.validate_gt_label(np.array([0.5, 1.5]))
# Test that float labels are converted to boolean
labels = np.array([0.5, 1.5])
validated_labels = self.validator.validate_gt_label(labels, batch_size=2)
assert isinstance(validated_labels, np.ndarray)
assert validated_labels.dtype == bool
assert np.array_equal(validated_labels, np.array([True, True]))

0 comments on commit 92b96fd

Please sign in to comment.