diff --git a/nwbinspector/checks/time_series.py b/nwbinspector/checks/time_series.py index 3e99cb71..0d44f38a 100644 --- a/nwbinspector/checks/time_series.py +++ b/nwbinspector/checks/time_series.py @@ -92,3 +92,35 @@ def check_resolution(time_series: TimeSeries): return InspectorMessage( message=f"'resolution' should use -1.0 or NaN for unknown instead of {time_series.resolution}." ) + + +@register_check(importance=Importance.BEST_PRACTICE_SUGGESTION, neurodata_type=TimeSeries) +def check_rows_not_nan(time_series: TimeSeries, nelems=200): + """Check that each row of a TimeSeries has at least one non-NaN piece of data.""" + n_dims = len(time_series.data.shape) + if n_dims > 2: + yield + + spanning_by = np.ceil(time_series.shape[0] / nelems).astype(int) if nelems else None + if n_dims == 1: + if nelems is not None and not all(np.isnan(time_series.data[:nelems])): + yield + if all(np.isnan(time_series[slice(0, None, spanning_by)]).flatten()): + yield InspectorMessage( + message=( + "This TimeSeries appears to contain NaN data at each frame. " + "Consider removing this object from the file." + ) + ) + elif n_dims == 2: + for col in range(time_series.data.shape[1]): + if nelems is not None and not all(np.isnan(time_series.data[:nelems, col]).flatten()): + continue + + if all(np.isnan(time_series[slice(0, None, spanning_by), col]).flatten()): + yield InspectorMessage( + message=( + f"Column index {col} of this TimeSeries appears to contain NaN data at each frame. " + "Consider removing this column from the TimeSeries." + ) + ) diff --git a/tests/unit_tests/test_time_series.py b/tests/unit_tests/test_time_series.py index 3428d566..6caf3263 100644 --- a/tests/unit_tests/test_time_series.py +++ b/tests/unit_tests/test_time_series.py @@ -4,6 +4,7 @@ import numpy as np import pynwb import pytest +from hdmf.testing import TestCase from nwbinspector import ( InspectorMessage, @@ -14,6 +15,7 @@ check_timestamps_ascending, check_missing_unit, check_resolution, + check_rows_not_nan, ) from nwbinspector.utils import get_package_version, robust_s3_read @@ -214,3 +216,75 @@ def test_check_resolution_fail(): object_name="test", location="/", ) + + +def test_check_rows_not_nan_1d_pass(): + data = np.zeros(shape=400) + data[0] = np.nan + time_series = pynwb.TimeSeries(name="test_time_series", unit="", data=data, rate=1.0) + assert check_rows_not_nan(time_series) is None + + +def test_check_rows_not_nan_1d_fail(): + data = np.zeros(shape=400) + data[:] = np.nan + time_series = pynwb.TimeSeries(name="test", unit="test", data=data, rate=1.0) + assert check_rows_not_nan(time_series) == [ + InspectorMessage( + message=( + "This TimeSeries appears to contain NaN data at each frame. " + "Consider removing this object from the file." + ), + importance=Importance.BEST_PRACTICE_SUGGESTION, + check_function_name="check_rows_not_nan", + object_type="TimeSeries", + object_name="test", + location="/", + ) + ] + + +def test_check_rows_not_nan_2d_pass(): + data = np.zeros(shape=(2, 3)) + data[0, 0] = np.nan + time_series = pynwb.TimeSeries(name="test_time_series", unit="", data=data, rate=1.0) + assert check_rows_not_nan(time_series) is None + + +def test_check_rows_not_nan_2d_fail(): + data = np.zeros(shape=(2, 5)) + data[:, [1, 4]] = np.nan + time_series = pynwb.TimeSeries(name="test", unit="test", data=data, rate=1.0) + assert check_rows_not_nan(time_series) == [ + InspectorMessage( + message=( + "Column index 1 of this TimeSeries appears to contain NaN data at each frame. " + "Consider removing this column from the TimeSeries." + ), + importance=Importance.BEST_PRACTICE_SUGGESTION, + check_function_name="check_rows_not_nan", + object_type="TimeSeries", + object_name="test", + location="/", + file_path=None, + ), + InspectorMessage( + message=( + "Column index 4 of this TimeSeries appears to contain NaN data at each frame. " + "Consider removing this column from the TimeSeries." + ), + importance=Importance.BEST_PRACTICE_SUGGESTION, + check_function_name="check_rows_not_nan", + object_type="TimeSeries", + object_name="test", + location="/", + file_path=None, + ), + ] + + +def test_check_rows_not_nan_higher_dim_skip(): + data = np.empty(shape=(2, 3, 4)) + data[:, 0, 0] = np.nan + time_series = pynwb.TimeSeries(name="test", unit="test", data=data, rate=1.0) + assert check_rows_not_nan(time_series) is None