Skip to content

Commit

Permalink
Add test for dataset with nans
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Aug 28, 2024
1 parent ffbcb53 commit 9ba34a9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ def valid_poses_dataset(valid_position_array, request):
@pytest.fixture
def valid_poses_dataset_with_nan(valid_poses_dataset):
"""Return a valid pose tracks dataset with NaN values."""
# Sets position for all keypoints in individual ind1 to NaN
# at timepoints 3, 7, 8
valid_poses_dataset.position.loc[
{"individuals": "ind1", "time": [3, 7, 8]}
] = np.nan
Expand Down
44 changes: 44 additions & 0 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,50 @@ def test_kinematics_uniform_linear_motion(
)


@pytest.mark.parametrize(
"valid_dataset_with_nan",
[
"valid_poses_dataset_with_nan",
"valid_bboxes_dataset_with_nan",
],
)
@pytest.mark.parametrize(
"kinematic_variable, expected_nans_per_individual",
[
("displacement", {0: 5, 1: 0}),
("velocity", {0: 6, 1: 0}),
("acceleration", {0: 7, 1: 0}),
],
)
def test_kinematics_with_dataset_with_nans(
valid_dataset_with_nan,
kinematic_variable,
expected_nans_per_individual,
helpers,
request,
):
# compute kinematic array
valid_dataset = request.getfixturevalue(valid_dataset_with_nan)
position = valid_dataset.position
kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")(
position
)

# compute n nans in kinematic array per individual
n_nans_kinematics_per_indiv = {
i: helpers.count_nans(kinematic_array.isel(individuals=i))
for i in range(valid_dataset.dims["individuals"])
}

# check number of nans per indiv is as expected in kinematic array
for i in range(valid_dataset.dims["individuals"]):
assert n_nans_kinematics_per_indiv[i] == (
expected_nans_per_individual[i]
* valid_dataset.dims["space"]
* valid_dataset.dims.get("keypoints", 1)
)


@pytest.mark.parametrize(
"invalid_dataset, expected_exception",
[
Expand Down

0 comments on commit 9ba34a9

Please sign in to comment.