Skip to content

Commit

Permalink
Merge pull request nipreps#51 from oesteban/enh/avg-model-improvements
Browse files Browse the repository at this point in the history
ENH: Sophistication of ``AverageDWModel``
  • Loading branch information
oesteban authored May 27, 2021
2 parents 71e33c1 + e0b4582 commit 096d14a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
18 changes: 13 additions & 5 deletions eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def fit(
"""
align_kwargs = align_kwargs or {}
reg_target_type = "dwi" if model.lower() not in ("b0", "s0") else "b0"
reg_target_type = (
"dwi"
if model.lower() not in ("b0", "s0", "avg", "average", "mean")
else "b0"
)

if seed or seed == 0:
np.random.seed(20210324 if seed is True else seed)
Expand Down Expand Up @@ -99,7 +103,13 @@ def fit(
moving = tmpdir / "moving.nii.gz"
fixed = tmpdir / "fixed.nii.gz"
_to_nifti(data_test[0], dwdata.affine, moving)
_to_nifti(predicted, dwdata.affine, fixed, clip=reg_target_type == "dwi")
_to_nifti(
predicted,
dwdata.affine,
fixed,
clip=reg_target_type == "dwi",
)

registration = Registration(
terminal_output="file",
from_file=pkg_fn(
Expand Down Expand Up @@ -174,9 +184,7 @@ def _advanced_clip(
return data


def _to_nifti(
data, affine, filename, clip=True
):
def _to_nifti(data, affine, filename, clip=True):
data = np.squeeze(data)
if clip:
data = _advanced_clip(data)
Expand Down
50 changes: 47 additions & 3 deletions eddymotion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def init(gtab, model="DTI", **kwargs):
if model.lower() in ("s0", "b0"):
return TrivialB0Model(gtab=gtab, S0=kwargs.pop("S0"))

if model.lower() in ("avg", "average", "mean"):
return AverageDWModel(gtab=gtab, **kwargs)

# Generate a GradientTable object for DIPY
gtab = _rasb2dipy(gtab)
param = {}
Expand Down Expand Up @@ -101,15 +104,56 @@ def predict(self, gradient, **kwargs):
class AverageDWModel:
"""A trivial model that returns an average map."""

__slots__ = ("_data", "_gtab")
__slots__ = ("_data", "_gtab", "_th_low", "_th_high", "_bias", "_stat")

def __init__(self, gtab, **kwargs):
"""Implement object initialization."""
r"""
Implement object initialization.
Parameters
----------
gtab : :obj:`~numpy.ndarray`
An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and
columns are b-vector components and corresponding b-value, respectively.
th_low : :obj:`~numbers.Number`
A lower bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
th_high : :obj:`~numbers.Number`
An upper bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
bias : :obj:`bool`
Whether the overall distribution of each diffusion weighted image will be
standardized and centered around the global 75th percentile.
stat : :obj:`str`
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.
"""
self._gtab = gtab
self._th_low = kwargs.get("th_low", 50)
self._th_high = kwargs.get("th_high", self._gtab[3, ...].max())
self._bias = kwargs.get("bias", True)
self._stat = kwargs.get("stat", "median")

def fit(self, data, **kwargs):
"""Calculate the average."""
self._data = data[..., self._gtab[..., 3] > 50].mean(-1)
# Select the interval of b-values for which DWIs will be averaged
b_mask = (self._gtab[3, ...] >= self._th_low) & (
self._gtab[3, ...] <= self._th_high
)
shells = data[..., b_mask]

# Regress out global signal differences
if self._bias:
centers = np.median(shells, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], 75)
centers[centers < 1.0] = reference
drift = reference / centers
shells = shells * drift

# Select the summary statistic
avg_func = np.median if self._stat == "median" else np.mean
# Calculate the average
self._data = avg_func(shells, axis=-1)

def predict(self, gradient, **kwargs):
"""Return the average map."""
Expand Down

0 comments on commit 096d14a

Please sign in to comment.