Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: celprov <[email protected]>
  • Loading branch information
oesteban and celprov authored May 27, 2021
1 parent c44f704 commit e0b4582
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions eddymotion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def __init__(self, gtab, **kwargs):
gtab : :obj:`~numpy.ndarray`
An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and
columns are b-vector components and corresponding b-value, respectively.
b_low : :obj:`~numbers.Number`
th_low : :obj:`~numbers.Number`
A lower bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
b_high : :obj:`~numbers.Number`
A lower bound for the b-value corresponding to the diffusion weighted images
th_high : :obj:`~numbers.Number`
An upper bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
bias : :obj:`bool`
Whether the overall distribution of each diffusion weighted image will be
Expand All @@ -129,16 +129,16 @@ def __init__(self, gtab, **kwargs):
"""
self._gtab = gtab
self._th_low = kwargs.get("b_low", 50)
self._th_high = kwargs.get("b_high", self._gtab[3, ...].max() + 1)
self._th_low = kwargs.get("th_low", 50)
self._th_high = kwargs.get("th_high", self._gtab[3, ...].max())
self._bias = kwargs.get("bias", True)
self._stat = kwargs.get("stat", "median")

def fit(self, data, **kwargs):
"""Calculate the average."""
# Select the interval of b-values for which DWIs will be averages
b_mask = (self._gtab[3, ...] > self._th_low) & (
self._gtab[3, ...] < self._th_high
# Select the interval of b-values for which DWIs will be averaged
b_mask = (self._gtab[3, ...] >= self._th_low) & (
self._gtab[3, ...] <= self._th_high
)
shells = data[..., b_mask]

Expand All @@ -151,7 +151,7 @@ def fit(self, data, **kwargs):
shells = shells * drift

# Select the summary statistic
avg_func = np.mean if self._stat == "median" else np.mean
avg_func = np.median if self._stat == "median" else np.mean
# Calculate the average
self._data = avg_func(shells, axis=-1)

Expand Down

0 comments on commit e0b4582

Please sign in to comment.