diff --git a/eddymotion/estimator.py b/eddymotion/estimator.py index f12631f1..f3dad8d2 100644 --- a/eddymotion/estimator.py +++ b/eddymotion/estimator.py @@ -54,7 +54,11 @@ def fit( """ align_kwargs = align_kwargs or {} - reg_target_type = "dwi" if model.lower() not in ("b0", "s0") else "b0" + reg_target_type = ( + "dwi" + if model.lower() not in ("b0", "s0", "avg", "average", "mean") + else "b0" + ) if seed or seed == 0: np.random.seed(20210324 if seed is True else seed) @@ -99,7 +103,13 @@ def fit( moving = tmpdir / "moving.nii.gz" fixed = tmpdir / "fixed.nii.gz" _to_nifti(data_test[0], dwdata.affine, moving) - _to_nifti(predicted, dwdata.affine, fixed, clip=reg_target_type == "dwi") + _to_nifti( + predicted, + dwdata.affine, + fixed, + clip=reg_target_type == "dwi", + ) + registration = Registration( terminal_output="file", from_file=pkg_fn( @@ -174,9 +184,7 @@ def _advanced_clip( return data -def _to_nifti( - data, affine, filename, clip=True -): +def _to_nifti(data, affine, filename, clip=True): data = np.squeeze(data) if clip: data = _advanced_clip(data) diff --git a/eddymotion/model.py b/eddymotion/model.py index 053c9516..b3ebe36a 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -36,6 +36,9 @@ def init(gtab, model="DTI", **kwargs): if model.lower() in ("s0", "b0"): return TrivialB0Model(gtab=gtab, S0=kwargs.pop("S0")) + if model.lower() in ("avg", "average", "mean"): + return AverageDWModel(gtab=gtab, **kwargs) + # Generate a GradientTable object for DIPY gtab = _rasb2dipy(gtab) param = {} @@ -101,15 +104,56 @@ def predict(self, gradient, **kwargs): class AverageDWModel: """A trivial model that returns an average map.""" - __slots__ = ("_data", "_gtab") + __slots__ = ("_data", "_gtab", "_th_low", "_th_high", "_bias", "_stat") def __init__(self, gtab, **kwargs): - """Implement object initialization.""" + r""" + Implement object initialization. + + Parameters + ---------- + gtab : :obj:`~numpy.ndarray` + An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and + columns are b-vector components and corresponding b-value, respectively. + th_low : :obj:`~numbers.Number` + A lower bound for the b-value corresponding to the diffusion weighted images + that will be averaged. + th_high : :obj:`~numbers.Number` + An upper bound for the b-value corresponding to the diffusion weighted images + that will be averaged. + bias : :obj:`bool` + Whether the overall distribution of each diffusion weighted image will be + standardized and centered around the global 75th percentile. + stat : :obj:`str` + Whether the summary statistic to apply is ``"mean"`` or ``"median"``. + + """ self._gtab = gtab + self._th_low = kwargs.get("th_low", 50) + self._th_high = kwargs.get("th_high", self._gtab[3, ...].max()) + self._bias = kwargs.get("bias", True) + self._stat = kwargs.get("stat", "median") def fit(self, data, **kwargs): """Calculate the average.""" - self._data = data[..., self._gtab[..., 3] > 50].mean(-1) + # Select the interval of b-values for which DWIs will be averaged + b_mask = (self._gtab[3, ...] >= self._th_low) & ( + self._gtab[3, ...] <= self._th_high + ) + shells = data[..., b_mask] + + # Regress out global signal differences + if self._bias: + centers = np.median(shells, axis=(0, 1, 2)) + reference = np.percentile(centers[centers >= 1.0], 75) + centers[centers < 1.0] = reference + drift = reference / centers + shells = shells * drift + + # Select the summary statistic + avg_func = np.median if self._stat == "median" else np.mean + # Calculate the average + self._data = avg_func(shells, axis=-1) def predict(self, gradient, **kwargs): """Return the average map."""