Skip to content

Commit

Permalink
Merge pull request #84 from nipreps/fix/revise-tests
Browse files Browse the repository at this point in the history
FIX: Miscelaneous revisions to make tests execute properly
  • Loading branch information
arokem authored Aug 22, 2022
2 parents 96627b2 + c896c68 commit b0ffc33
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 47 deletions.
13 changes: 7 additions & 6 deletions src/eddymotion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def predict(self, gradient, **kwargs):
class AverageDWModel:
"""A trivial model that returns an average map."""

__slots__ = ("_data", "_gtab", "_th_low", "_th_high", "_bias", "_stat")
__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat")

def __init__(self, gtab, **kwargs):
r"""
Expand All @@ -264,18 +264,19 @@ def __init__(self, gtab, **kwargs):
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.
"""
self._gtab = gtab
self._th_low = kwargs.get("th_low", 50)
self._th_high = kwargs.get("th_high", self._gtab[3, ...].max())
self._th_high = kwargs.get("th_high", 10000)
self._bias = kwargs.get("bias", True)
self._stat = kwargs.get("stat", "median")

def fit(self, data, **kwargs):
"""Calculate the average."""
gtab = kwargs.pop("gtab", None)
# Select the interval of b-values for which DWIs will be averaged
b_mask = (self._gtab[3, ...] >= self._th_low) & (
self._gtab[3, ...] <= self._th_high
)
b_mask = (
(gtab[3] >= self._th_low)
& (gtab[3] <= self._th_high)
) if gtab is not None else np.ones((data.shape[-1], ), dtype=bool)
shells = data[..., b_mask]

# Regress out global signal differences
Expand Down
11 changes: 2 additions & 9 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@
import numpy
import pytest

test_data_env = os.getenv("TEST_DATA_HOME", str(Path.home() / "eddy-tests"))
test_data_env = os.getenv("TEST_DATA_HOME", str(Path.home() / "eddymotion-tests"))
test_output_dir = os.getenv("TEST_OUTPUT_DIR")
test_workdir = os.getenv("TEST_WORK_DIR")
data_dir = Path(__file__).parent / "tests" / "data"


def pytest_report_header(config):
return f"""\
TEST_DATA_HOME={test_data_env}s
TEST_DATA_HOME={test_data_env}.
TEST_OUTPUT_DIR={test_output_dir or '<unset> (output files will be discarded)'}.
TEST_WORK_DIR={test_workdir or '<unset> (intermediate files will be discarded)'}.
"""
Expand All @@ -61,9 +60,3 @@ def outdir():
def datadir():
"""Return a data path outside the package's structure (i.e., large datasets)."""
return Path(test_data_env)


@pytest.fixture
def pkg_datadir():
"""Return a data path inside the package's structure (i.e., small, empty files)."""
return data_dir
4 changes: 2 additions & 2 deletions test/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
@pytest.mark.parametrize("t_x", [0.0, 1.0])
@pytest.mark.parametrize("t_y", [0.0, 1.0])
@pytest.mark.parametrize("t_z", [0.0, 1.0])
def test_ANTs_config_b0(pkg_datadir, tmp_path, r_x, r_y, r_z, t_x, t_y, t_z):
def test_ANTs_config_b0(datadir, tmp_path, r_x, r_y, r_z, t_x, t_y, t_z):
"""Check that the registration parameters for b=0
gives a good estimate of known affine"""

fixed = tmp_path / "b0.nii.gz"
moving = tmp_path / "moving.nii.gz"

dwdata = DWI.from_filename(pkg_datadir / "dwi.h5")
dwdata = DWI.from_filename(datadir / "dwi.h5")
b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None)
b0nii.to_filename(fixed)

Expand Down
6 changes: 3 additions & 3 deletions test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
from eddymotion.estimator import EddyMotionEstimator


def test_proximity_estimator_trivial_model(pkg_datadir, tmp_path):
def test_proximity_estimator_trivial_model(datadir, tmp_path):
"""Check the proximity of transforms estimated by the estimator with a trivial B0 model."""

dwdata = DWI.from_filename(pkg_datadir / "dwi.h5")
dwdata = DWI.from_filename(datadir / "dwi.h5")
b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None)

xfms = nt.linear.load(
pkg_datadir / "head-motion-parameters.aff12.1D",
datadir / "head-motion-parameters.aff12.1D",
fmt="afni",
)
xfms.reference = b0nii
Expand Down
43 changes: 22 additions & 21 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,50 +47,51 @@ def test_trivial_model():
def test_average_model():
"""Check the implementation of the average DW model."""

data = np.ones((100, 100, 100, 6), dtype=float)

gtab = np.array(
[
[0, 0, 0, 0],
[-0.31, 0.933, 0.785, 25],
[0.25, 0.565, 0.21, 500],
[-0.861, -0.464, 0.564, 1000],
[0.307, -0.766, 0.677, 1000],
[0.736, 0.013, 0.774, 2000],
[0.736, 0.013, 0.774, 1300],
]
)

gtab_w25 = gtab[1:, :]
gtab_1000 = gtab[2:3, :]
gtab_2000 = gtab[2:, :]
data *= gtab[:, -1]

tmodel_mean = model.AverageDWModel(gtab=gtab, bias=False, stat="mean")
tmodel_median = model.AverageDWModel(gtab=gtab, bias=False, stat="median")
tmodel_1000 = model.AverageDWModel(
gtab=gtab, bias=False, th_high=1000, th_low=1000
gtab=gtab, bias=False, th_high=1000, th_low=900
)
tmodel_2000 = model.AverageDWModel(
gtab=gtab, bias=False, th_high=2000, th_low=1000
gtab=gtab, bias=False, th_high=2000, th_low=900, stat="mean",
)

# Verify that fit function returns nothing
assert tmodel_mean.fit() is None
assert tmodel_mean.fit(data[..., 1:], gtab=gtab[1:].T) is None

tmodel_median.fit()
tmodel_1000.fit()
tmodel_2000.fit()
tmodel_median.fit(data[..., 1:], gtab=gtab[1:].T)
tmodel_1000.fit(data[..., 1:], gtab=gtab[1:].T)
tmodel_2000.fit(data[..., 1:], gtab=gtab[1:].T)

# Verify that the right statistics is applied and that the model discard b-values < 50
assert np.all(tmodel_mean.predict() == np.mean(gtab_w25[:, :2], axis=0))
assert np.all(tmodel_median.predict() == np.median(gtab_w25[:, :2], axis=0))
assert np.all(tmodel_mean.predict([0, 0, 0]) == 950)
assert np.all(tmodel_median.predict([0, 0, 0]) == 1000)

# Verify that the threshold for b-value selection works as expected
assert np.all(tmodel_1000.predict() == np.median(gtab_1000[:, :2], axis=0))
assert np.all(tmodel_2000.predict() == np.median(gtab_2000[:, :2], axis=0))
assert np.all(tmodel_1000.predict([0, 0, 0]) == 1000)
assert np.all(tmodel_2000.predict([0, 0, 0]) == 1100)


def test_two_initialisations(pkg_datadir):
def test_two_initialisations(datadir):
"""Check that the two different initialisations result in the same models"""

# Load test data
dmri_dataset = DWI.from_filename(pkg_datadir / "dwi.h5")
dmri_dataset = DWI.from_filename(datadir / "dwi.h5")

# Split data into test and train set
data_train, data_test = dmri_dataset.logo_split(10)
Expand All @@ -101,10 +102,10 @@ def test_two_initialisations(pkg_datadir):
S0=dmri_dataset.bzero,
th_low=100,
th_high=1000,
bias=True,
bias=False,
stat="mean",
)
model1.fit(data_train[0])
model1.fit(data_train[0], gtab=data_train[1])
predicted1 = model1.predict(data_test[1])

# Initialisation via ModelFactory
Expand All @@ -114,10 +115,10 @@ def test_two_initialisations(pkg_datadir):
S0=dmri_dataset.bzero,
th_low=100,
th_high=1000,
bias=True,
bias=False,
stat="mean",
)
model2.fit(data_train[0])
model2.fit(data_train[0], gtab=data_train[1])
predicted2 = model2.predict(data_test[1])

assert predicted1 == predicted2
assert np.all(predicted1 == predicted2)
6 changes: 0 additions & 6 deletions test/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""Test _version.py."""
import sys
from importlib import reload
from pkg_resources import get_distribution

import eddymotion

Expand All @@ -37,8 +36,3 @@ class _version:
monkeypatch.setitem(sys.modules, "eddymotion._version", _version)
reload(eddymotion)
assert eddymotion.__version__ == "10.0.0"


def test_version_scm1():
"""Retrieve the version via pkg_resources."""
assert get_distribution("eddymotion").version != "0.0"

0 comments on commit b0ffc33

Please sign in to comment.