From 270af4b78243e8b84b24145ebe0f72cd3ffc1913 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 11 May 2021 18:03:06 +0200 Subject: [PATCH 1/3] ENH: Sophistication of ``AverageDWModel`` This PR completes #50 in the implementation of the ideas stated within #41. The PR also enables the new ``AverageDWModel`` to be actually used by the estimator. References: #41. Follows-up: #50. --- eddymotion/model.py | 50 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/eddymotion/model.py b/eddymotion/model.py index 053c9516..9cfdbbe8 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -36,6 +36,9 @@ def init(gtab, model="DTI", **kwargs): if model.lower() in ("s0", "b0"): return TrivialB0Model(gtab=gtab, S0=kwargs.pop("S0")) + if model.lower() in ("avg", "average", "mean"): + return AverageDWModel(gtab=gtab, **kwargs) + # Generate a GradientTable object for DIPY gtab = _rasb2dipy(gtab) param = {} @@ -101,15 +104,56 @@ def predict(self, gradient, **kwargs): class AverageDWModel: """A trivial model that returns an average map.""" - __slots__ = ("_data", "_gtab") + __slots__ = ("_data", "_gtab", "_th_low", "_th_high", "_bias", "_stat") def __init__(self, gtab, **kwargs): - """Implement object initialization.""" + r""" + Implement object initialization. + + Parameters + ---------- + gtab : :obj:`~numpy.ndarray` + An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and + columns are b-vector components and corresponding b-value, respectively. + b_low : :obj:`~numbers.Number` + A lower bound for the b-value corresponding to the diffusion weighted images + that will be averaged. + b_high : :obj:`~numbers.Number` + A lower bound for the b-value corresponding to the diffusion weighted images + that will be averaged. + bias : :obj:`bool` + Whether the overall distribution of each diffusion weighted image will be + standardized and centered around the global 75th percentile. + stat : :obj:`str` + Whether the summary statistic to apply is ``"mean"`` or ``"median"``. + + """ self._gtab = gtab + self._th_low = kwargs.get("b_low", 50) + self._th_high = kwargs.get("b_high", self._gtab[3, ...].max() + 1) + self._bias = kwargs.get("bias", True) + self._stat = kwargs.get("stat", "median") def fit(self, data, **kwargs): """Calculate the average.""" - self._data = data[..., self._gtab[..., 3] > 50].mean(-1) + # Select the interval of b-values for which DWIs will be averages + b_mask = (self._gtab[3, ...] > self._th_low) & ( + self._gtab[3, ...] < self._th_high + ) + shells = data[..., b_mask] + + # Regress out global signal differences + if self._bias: + centers = np.median(shells, axis=(0, 1, 2)) + reference = np.percentile(centers[centers >= 1.0], 75) + centers[centers < 1.0] = reference + drift = reference / centers + shells = shells * drift + + # Select the summary statistic + avg_func = np.mean if self._stat == "median" else np.mean + # Calculate the average + self._data = avg_func(shells, axis=-1) def predict(self, gradient, **kwargs): """Return the average map.""" From e30e9f43aa1c0c6faa17d1b0649822d332a55ab0 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 11 May 2021 18:17:24 +0200 Subject: [PATCH 2/3] enh+sty: use faster registration params from b0 and run black --- eddymotion/estimator.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/eddymotion/estimator.py b/eddymotion/estimator.py index f12631f1..f3dad8d2 100644 --- a/eddymotion/estimator.py +++ b/eddymotion/estimator.py @@ -54,7 +54,11 @@ def fit( """ align_kwargs = align_kwargs or {} - reg_target_type = "dwi" if model.lower() not in ("b0", "s0") else "b0" + reg_target_type = ( + "dwi" + if model.lower() not in ("b0", "s0", "avg", "average", "mean") + else "b0" + ) if seed or seed == 0: np.random.seed(20210324 if seed is True else seed) @@ -99,7 +103,13 @@ def fit( moving = tmpdir / "moving.nii.gz" fixed = tmpdir / "fixed.nii.gz" _to_nifti(data_test[0], dwdata.affine, moving) - _to_nifti(predicted, dwdata.affine, fixed, clip=reg_target_type == "dwi") + _to_nifti( + predicted, + dwdata.affine, + fixed, + clip=reg_target_type == "dwi", + ) + registration = Registration( terminal_output="file", from_file=pkg_fn( @@ -174,9 +184,7 @@ def _advanced_clip( return data -def _to_nifti( - data, affine, filename, clip=True -): +def _to_nifti(data, affine, filename, clip=True): data = np.squeeze(data) if clip: data = _advanced_clip(data) From 0e43407a920d305ac4cb744f1a98352031f04da0 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 27 May 2021 15:25:53 +0200 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: celprov <77437752+celprov@users.noreply.github.com> --- eddymotion/model.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/eddymotion/model.py b/eddymotion/model.py index 9cfdbbe8..b3ebe36a 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -115,11 +115,11 @@ def __init__(self, gtab, **kwargs): gtab : :obj:`~numpy.ndarray` An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and columns are b-vector components and corresponding b-value, respectively. - b_low : :obj:`~numbers.Number` + th_low : :obj:`~numbers.Number` A lower bound for the b-value corresponding to the diffusion weighted images that will be averaged. - b_high : :obj:`~numbers.Number` - A lower bound for the b-value corresponding to the diffusion weighted images + th_high : :obj:`~numbers.Number` + An upper bound for the b-value corresponding to the diffusion weighted images that will be averaged. bias : :obj:`bool` Whether the overall distribution of each diffusion weighted image will be @@ -129,16 +129,16 @@ def __init__(self, gtab, **kwargs): """ self._gtab = gtab - self._th_low = kwargs.get("b_low", 50) - self._th_high = kwargs.get("b_high", self._gtab[3, ...].max() + 1) + self._th_low = kwargs.get("th_low", 50) + self._th_high = kwargs.get("th_high", self._gtab[3, ...].max()) self._bias = kwargs.get("bias", True) self._stat = kwargs.get("stat", "median") def fit(self, data, **kwargs): """Calculate the average.""" - # Select the interval of b-values for which DWIs will be averages - b_mask = (self._gtab[3, ...] > self._th_low) & ( - self._gtab[3, ...] < self._th_high + # Select the interval of b-values for which DWIs will be averaged + b_mask = (self._gtab[3, ...] >= self._th_low) & ( + self._gtab[3, ...] <= self._th_high ) shells = data[..., b_mask] @@ -151,7 +151,7 @@ def fit(self, data, **kwargs): shells = shells * drift # Select the summary statistic - avg_func = np.mean if self._stat == "median" else np.mean + avg_func = np.median if self._stat == "median" else np.mean # Calculate the average self._data = avg_func(shells, axis=-1)