-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Miscelaneous revisions to make tests execute properly #84
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -240,7 +240,7 @@ def predict(self, gradient, **kwargs): | |
class AverageDWModel: | ||
"""A trivial model that returns an average map.""" | ||
|
||
__slots__ = ("_data", "_gtab", "_th_low", "_th_high", "_bias", "_stat") | ||
__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat") | ||
|
||
def __init__(self, gtab, **kwargs): | ||
r""" | ||
|
@@ -264,18 +264,19 @@ def __init__(self, gtab, **kwargs): | |
Whether the summary statistic to apply is ``"mean"`` or ``"median"``. | ||
|
||
""" | ||
self._gtab = gtab | ||
self._th_low = kwargs.get("th_low", 50) | ||
self._th_high = kwargs.get("th_high", self._gtab[3, ...].max()) | ||
self._th_high = kwargs.get("th_high", 10000) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is just a toy model, but we might want to document the default values for these, in case they get inherited into other places. |
||
self._bias = kwargs.get("bias", True) | ||
self._stat = kwargs.get("stat", "median") | ||
|
||
def fit(self, data, **kwargs): | ||
"""Calculate the average.""" | ||
gtab = kwargs.pop("gtab", None) | ||
# Select the interval of b-values for which DWIs will be averaged | ||
b_mask = (self._gtab[3, ...] >= self._th_low) & ( | ||
self._gtab[3, ...] <= self._th_high | ||
) | ||
b_mask = ( | ||
(gtab[3] >= self._th_low) | ||
& (gtab[3] <= self._th_high) | ||
) if gtab is not None else np.ones((data.shape[-1], ), dtype=bool) | ||
shells = data[..., b_mask] | ||
|
||
# Regress out global signal differences | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,50 +47,51 @@ def test_trivial_model(): | |
def test_average_model(): | ||
"""Check the implementation of the average DW model.""" | ||
|
||
data = np.ones((100, 100, 100, 6), dtype=float) | ||
|
||
gtab = np.array( | ||
[ | ||
[0, 0, 0, 0], | ||
[-0.31, 0.933, 0.785, 25], | ||
[0.25, 0.565, 0.21, 500], | ||
[-0.861, -0.464, 0.564, 1000], | ||
[0.307, -0.766, 0.677, 1000], | ||
[0.736, 0.013, 0.774, 2000], | ||
[0.736, 0.013, 0.774, 1300], | ||
] | ||
) | ||
|
||
gtab_w25 = gtab[1:, :] | ||
gtab_1000 = gtab[2:3, :] | ||
gtab_2000 = gtab[2:, :] | ||
data *= gtab[:, -1] | ||
|
||
tmodel_mean = model.AverageDWModel(gtab=gtab, bias=False, stat="mean") | ||
tmodel_median = model.AverageDWModel(gtab=gtab, bias=False, stat="median") | ||
tmodel_1000 = model.AverageDWModel( | ||
gtab=gtab, bias=False, th_high=1000, th_low=1000 | ||
gtab=gtab, bias=False, th_high=1000, th_low=900 | ||
) | ||
tmodel_2000 = model.AverageDWModel( | ||
gtab=gtab, bias=False, th_high=2000, th_low=1000 | ||
gtab=gtab, bias=False, th_high=2000, th_low=900, stat="mean", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha! I just noticed that the default statistic for the "mean" model is the median. Isn't that a bit confusing? |
||
) | ||
|
||
# Verify that fit function returns nothing | ||
assert tmodel_mean.fit() is None | ||
assert tmodel_mean.fit(data[..., 1:], gtab=gtab[1:].T) is None | ||
|
||
tmodel_median.fit() | ||
tmodel_1000.fit() | ||
tmodel_2000.fit() | ||
tmodel_median.fit(data[..., 1:], gtab=gtab[1:].T) | ||
tmodel_1000.fit(data[..., 1:], gtab=gtab[1:].T) | ||
tmodel_2000.fit(data[..., 1:], gtab=gtab[1:].T) | ||
|
||
# Verify that the right statistics is applied and that the model discard b-values < 50 | ||
assert np.all(tmodel_mean.predict() == np.mean(gtab_w25[:, :2], axis=0)) | ||
assert np.all(tmodel_median.predict() == np.median(gtab_w25[:, :2], axis=0)) | ||
assert np.all(tmodel_mean.predict([0, 0, 0]) == 950) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the input to |
||
assert np.all(tmodel_median.predict([0, 0, 0]) == 1000) | ||
|
||
# Verify that the threshold for b-value selection works as expected | ||
assert np.all(tmodel_1000.predict() == np.median(gtab_1000[:, :2], axis=0)) | ||
assert np.all(tmodel_2000.predict() == np.median(gtab_2000[:, :2], axis=0)) | ||
assert np.all(tmodel_1000.predict([0, 0, 0]) == 1000) | ||
assert np.all(tmodel_2000.predict([0, 0, 0]) == 1100) | ||
|
||
|
||
def test_two_initialisations(pkg_datadir): | ||
def test_two_initialisations(datadir): | ||
"""Check that the two different initialisations result in the same models""" | ||
|
||
# Load test data | ||
dmri_dataset = DWI.from_filename(pkg_datadir / "dwi.h5") | ||
dmri_dataset = DWI.from_filename(datadir / "dwi.h5") | ||
|
||
# Split data into test and train set | ||
data_train, data_test = dmri_dataset.logo_split(10) | ||
|
@@ -101,10 +102,10 @@ def test_two_initialisations(pkg_datadir): | |
S0=dmri_dataset.bzero, | ||
th_low=100, | ||
th_high=1000, | ||
bias=True, | ||
bias=False, | ||
stat="mean", | ||
) | ||
model1.fit(data_train[0]) | ||
model1.fit(data_train[0], gtab=data_train[1]) | ||
predicted1 = model1.predict(data_test[1]) | ||
|
||
# Initialisation via ModelFactory | ||
|
@@ -114,10 +115,10 @@ def test_two_initialisations(pkg_datadir): | |
S0=dmri_dataset.bzero, | ||
th_low=100, | ||
th_high=1000, | ||
bias=True, | ||
bias=False, | ||
stat="mean", | ||
) | ||
model2.fit(data_train[0]) | ||
model2.fit(data_train[0], gtab=data_train[1]) | ||
predicted2 = model2.predict(data_test[1]) | ||
|
||
assert predicted1 == predicted2 | ||
assert np.all(predicted1 == predicted2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to also remove the
gtab
input here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If yes, don't forget to remove it in the docstring as well.