diff --git a/scripts/dwi_gp_estimation_analysis_plot.py b/scripts/dwi_gp_estimation_analysis_plot.py new file mode 100644 index 00000000..9c8400c9 --- /dev/null +++ b/scripts/dwi_gp_estimation_analysis_plot.py @@ -0,0 +1,160 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# + +""" +Plot the RMSE (mean and std dev) and prediction surface from the predicted DWI +signal estimated using Gaussian processes k-fold cross-validation. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np +import pandas as pd +from dipy.core.gradients import gradient_table +from dipy.io import read_bvals_bvecs + +from eddymotion.viz.signals import plot_error, plot_prediction_surface + + +def _build_arg_parser() -> argparse.ArgumentParser: + """ + Build argument parser for command-line interface. + + Returns + ------- + :obj:`~argparse.ArgumentParser` + Argument parser for the script. + + """ + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "error_data_fname", + help="Filename of TSV file containing the error data to plot", + type=Path, + ) + parser.add_argument( + "dwi_gt_data_fname", + help="Filename of NIfTI file containing the ground truth DWI signal", + type=Path, + ) + parser.add_argument( + "bval_data_fname", + help="Filename of b-val file containing the diffusion-encoding gradient b-vals", + type=Path, + ) + parser.add_argument( + "bvec_data_fname", + help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs", + type=Path, + ) + parser.add_argument( + "dwi_pred_data_fname", + help="Filename of NIfTI file containing the predicted DWI signal", + type=Path, + ) + parser.add_argument( + "error_plot_fname", + help="Filename of SVG file where the error plot will be saved", + type=Path, + ) + parser.add_argument( + "signal_surface_plot_fname", + help="Filename of SVG file where the predicted signal plot will be saved", + type=Path, + ) + return parser + + +def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace: + """ + Parse command-line arguments. + + Parameters + ---------- + parser : :obj:`~argparse.ArgumentParser` + Argument parser for the script. + + Returns + ------- + :obj:`~argparse.Namespace` + Parsed arguments. + """ + return parser.parse_args() + + +def main() -> None: + """Main function for running the experiment and plotting the results.""" + parser = _build_arg_parser() + args = _parse_args(parser) + + df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a") + + # Plot the prediction error + kfolds = sorted(np.unique(df["n_folds"].values)) + snr = np.unique(df["snr"].values).item() + rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds] + axis = 1 + mean = np.mean(rmse_data, axis=axis) + std_dev = np.std(rmse_data, axis=axis) + xlabel = "k" + ylabel = "RMSE" + title = f"Gaussian process estimation\n(SNR={snr})" + fig = plot_error(kfolds, mean, std_dev, xlabel, ylabel, title) + fig.savefig(args.error_plot_fname) + plt.close(fig) + + # Plot the predicted DWI signal at a single voxel + + # Load the dMRI data + signal = nib.load(args.dwi_gt_data_fname).get_fdata() + y_pred = nib.load(args.dwi_pred_data_fname).get_fdata() + + bvals, bvecs = read_bvals_bvecs(str(args.bval_data_fname), str(args.bvec_data_fname)) + gtab = gradient_table(bvals, bvecs) + + # Pick one voxel randomly + rng = np.random.default_rng(1234) + idx = rng.integers(0, signal.shape[0], size=1).item() + + title = "GP model signal prediction" + fig, _, _ = plot_prediction_surface( + signal[idx, ~gtab.b0s_mask], + y_pred[idx], + signal[idx, gtab.b0s_mask].item(), + gtab[~gtab.b0s_mask].bvecs, + gtab[~gtab.b0s_mask].bvecs, + title, + "gray", + ) + fig.savefig(args.signal_surface_plot_fname, format="svg") + + +if __name__ == "__main__": + main() diff --git a/scripts/dwi_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py similarity index 65% rename from scripts/dwi_estimation_error_analysis.py rename to scripts/dwi_gp_estimation_error_analysis.py index d6d739e1..2d3dd483 100644 --- a/scripts/dwi_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -30,11 +30,11 @@ import argparse from collections import defaultdict +from pathlib import Path -# import nibabel as nib import numpy as np import pandas as pd -from sklearn.model_selection import RepeatedKFold, cross_val_score +from sklearn.model_selection import KFold, RepeatedKFold, cross_val_predict, cross_val_score from eddymotion.model._sklearn import ( EddyMotionGPR, @@ -47,24 +47,21 @@ def cross_validate( X: np.ndarray, y: np.ndarray, cv: int, + gpr: EddyMotionGPR, ) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: """ Perform the experiment by estimating the dMRI signal using a Gaussian process model. Parameters ---------- - gtab : :obj:`~dipy.core.gradients.gradient_table` - Gradient table. - S0 : :obj:`float` - S0 value. - evals1 : :obj:`~numpy.ndarray` - Eigenvalues of the tensor. - evecs : :obj:`~numpy.ndarray` - Eigenvectors of the tensor. - snr : :obj:`float` - Signal-to-noise ratio. + X : :obj:`~numpy.ndarray` + Diffusion-encoding gradient vectors. + y : :obj:`~numpy.ndarray` + DWI signal. cv : :obj:`int` number of folds + gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR` + The eddymotion Gaussian process regressor object. Returns ------- @@ -72,14 +69,9 @@ def cross_validate( Data for the predicted signal and its error. """ - gpm = EddyMotionGPR( - kernel=SphericalKriging(a=1.15, lambda_s=120), - alpha=100, - optimizer=None, - ) rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv) - scores = cross_val_score(gpm, X, y, scoring="neg_root_mean_squared_error", cv=rkf) + scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf) return scores @@ -103,7 +95,32 @@ def _build_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument("bval_shell", help="Shell b-value", type=float) parser.add_argument("S0", help="S0 value", type=float) - parser.add_argument("--evals1", help="Eigenvalues of the tensor", nargs="+", type=float) + parser.add_argument( + "error_data_fname", + help="Filename of TSV file containing the data to plot", + type=Path, + ) + parser.add_argument( + "dwi_gt_data_fname", + help="Filename of NIfTI file containing the generated DWI signal", + type=Path, + ) + parser.add_argument( + "bval_data_fname", + help="Filename of b-val file containing the diffusion-encoding gradient b-vals", + type=Path, + ) + parser.add_argument( + "bvec_data_fname", + help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs", + type=Path, + ) + parser.add_argument( + "dwi_pred_data_fname", + help="Filename of NIfTI file containing the predicted DWI signal", + type=Path, + ) + parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float) parser.add_argument("--snr", help="Signal to noise ratio", type=float) parser.add_argument("--repeats", help="Number of repeats", type=int, default=5) parser.add_argument( @@ -134,37 +151,60 @@ def main() -> None: parser = _build_arg_parser() args = _parse_args(parser) + n_voxels = 100 + data, gtab = testsims.simulate_voxels( args.S0, - args.evals1, args.hsph_dirs, bval_shell=args.bval_shell, snr=args.snr, - n_voxels=100, + n_voxels=n_voxels, + evals=args.evals, seed=None, ) + # Save the generated signal and gradient table + testsims.serialize_dmri( + data, gtab, args.dwi_gt_data_fname, args.bval_data_fname, args.bvec_data_fname + ) + X = gtab[~gtab.b0s_mask].bvecs y = data[:, ~gtab.b0s_mask] + snr_str = args.snr if args.snr is not None else "None" + + a = 1.15 + lambda_s = 120 + alpha = 100 + gpr = EddyMotionGPR( + kernel=SphericalKriging(a=a, lambda_s=lambda_s), + alpha=alpha, + optimizer=None, + ) + # Use Scikit-learn cross validation scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n) + cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores) + scores["snr"] += [snr_str] * len(cv_scores) print(f"Finished {n}-fold cross-validation") scores_df = pd.DataFrame(scores) - scores_df.to_csv("cv_scores.tsv", sep="\t", index=None, na_rep="n/a") + scores_df.to_csv(args.error_data_fname, sep="\t", index=None, na_rep="n/a") grouped = scores_df.groupby(["n_folds"]) print(grouped[["rmse"]].mean()) print(grouped[["rmse"]].std()) + cv = KFold(n_splits=3, shuffle=False, random_state=None) + predictions = cross_val_predict(gpr, X, y.T, cv=cv) + testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname) + if __name__ == "__main__": main() diff --git a/src/eddymotion/testing/simulations.py b/src/eddymotion/testing/simulations.py index b849d3bb..98fd8d20 100644 --- a/src/eddymotion/testing/simulations.py +++ b/src/eddymotion/testing/simulations.py @@ -24,13 +24,17 @@ from __future__ import annotations -# import nibabel as nib +import nibabel as nib import numpy as np from dipy.core.geometry import sphere2cart from dipy.core.gradients import gradient_table from dipy.core.sphere import HemiSphere, Sphere, disperse_charges from dipy.sims.voxel import all_tensor_evecs, single_tensor +# Bounds defined following Canales-Rodriguez, NIMG 184 2019, https://doi.org/10.1016/j.neuroimage.2018.08.071 +BOUNDS_LAMBDA1: tuple[float, float] = (1.4e-3, 1.8e-3) +BOUNDS_LAMBDA23: tuple[float, float] = (0.1e-3, 0.5e-3) + def add_b0(bvals: np.ndarray, bvecs: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ @@ -197,31 +201,138 @@ def get_query_vectors( return gtab.bvecs[idx], np.where(idx)[0] -def single_fiber_voxel(gtab, S0, evals, theta=0, phi=0, snr=20): +def single_fiber_voxel(gtab, S0, evals, rng, theta=0, phi=0, snr=20): # create eigenvectors for a single fiber evecs = create_single_fiber_evecs(theta=theta, phi=phi) # Generate some data - return single_tensor(gtab, S0=S0, evals=evals, evecs=evecs, snr=snr) + return single_tensor(gtab, S0=S0, evals=evals, evecs=evecs, snr=snr, rng=rng) -def simulate_voxels(S0, evals, hsph_dirs, bval_shell=1000, snr=20, n_voxels=1, seed=None): - # Create a gradient table for a single-shell - gtab = create_single_shell_gradient_table(hsph_dirs, bval_shell) +def create_random_polar_angles(size, rng): + """Create polar angles drawn from a uniform distribution.""" + + return zip( + rng.uniform(0, np.pi, size=size), + rng.uniform(0, 2.0 * np.pi, size=size), + strict=True, + ) - rng = np.random.default_rng(seed) - angles = zip( - rng.uniform(0, np.pi, size=n_voxels), - rng.uniform(0, 2.0 * np.pi, size=n_voxels), - strict=False, +def create_random_diffusivity_eigenvalues(size, rng): + r"""Create DTI model diffusion tensor eigenvalues ($\lambda_{1}, + \lambda_{2}, \lambda_{3}$) drawn from a uniform distribution.""" + + # lambda_2 = lambda_3 following Canales-Rodriguez, NIMG 184 2019, + # https://doi.org/10.1016/j.neuroimage.2018.08.071 + return zip( + rng.uniform(*BOUNDS_LAMBDA1, size=size), + *[rng.uniform(*BOUNDS_LAMBDA23, size=size)] * 2, + strict=True, ) + +def group_values(values, group_size): + return np.asarray([values[i : i + group_size] for i in range(0, len(values), group_size)]) + + +def simulate_one_fiber_multivoxel(gtab, S0, snr, n_voxels, rng, evals=None): + """Create a single-fiber multi-voxel DWI signal.""" + + angles = create_random_polar_angles(n_voxels, rng) + if evals is None: + _evals = create_random_diffusivity_eigenvalues(n_voxels, rng) + else: + _evals = group_values(evals, 3) + if _evals.shape[0] == 1 and n_voxels != 1: + _evals = np.repeat(_evals, n_voxels, axis=0) + signal = np.vstack( [ - single_fiber_voxel(gtab, S0, evals, theta=theta, phi=phi, snr=snr) - for theta, phi in angles + single_fiber_voxel(gtab, S0, _eignvls, rng, theta=theta, phi=phi, snr=snr) + for (theta, phi), _eignvls in zip(angles, _evals, strict=True) ] ) + return signal + + +def simulate_voxels(S0, hsph_dirs, bval_shell=1000, snr=20, n_voxels=1, evals=None, seed=None): + # Create a gradient table for a single-shell + gtab = create_single_shell_gradient_table(hsph_dirs, bval_shell) + + rng = np.random.default_rng(seed) + + signal = simulate_one_fiber_multivoxel(gtab, S0, snr, n_voxels, rng, evals=evals) + return signal, gtab + + +def serialize_dwi(dwi_data, dwi_data_fname, affine: np.ndarray | None = None): + """Serialize DWI data. + + Parameters + ---------- + dwi_data : :obj:`~numpy.ndarray` + DWI data. + dwi_data_fname : :obj:`str` + Filename of NIfTI file to save the DWI signal. + affine : :obj:`~numpy.ndarray`, optional + Affine matrix. If ``None`` an identity affine matrix is used. + """ + + if affine is None: + affine = np.eye(4) + + dwi_img = nib.Nifti1Image(dwi_data, affine=affine) + nib.save(dwi_img, dwi_data_fname) + + +def serialize_gtab(gtab, bval_data_fname, bvec_data_fname): + """Serialize dMRI gradient-encoding table data into a pair of b-vals and + b-vecs files. + + Parameters + ---------- + gtab : :obj:`~dipy.core.gradients.gradient_table` + Gradient table. + bval_data_fname : :obj:`str` + Filename of NIfTI file to save the diffusion-encoding gradient b-vals. + bvec_data_fname : :obj:`str` + Filename of NIfTI file to save the diffusion-encoding gradient b-vecs. + """ + + fmt = "%d" + np.savetxt(bval_data_fname, gtab.bvals, newline=" ", fmt=fmt) + fmt = "%.3f" + np.savetxt(bvec_data_fname, gtab.bvecs.T, fmt=fmt) + + +def serialize_dmri( + dwi_data, + gtab, + dwi_data_fname, + bval_data_fname, + bvec_data_fname, + affine: np.ndarray | None = None, +): + """Serialize dMRI data. + + Parameters + ---------- + dwi_data : :obj:`~numpy.ndarray` + DWI data. + gtab : :obj:`~dipy.core.gradients.gradient_table` + Gradient table. + dwi_data_fname : :obj:`str` + Filename of NIfTI file to save the DWI signal. + bval_data_fname : :obj:`str` + Filename of NIfTI file to save the diffusion-encoding gradient b-vals. + bvec_data_fname : :obj:`str` + Filename of NIfTI file to save the diffusion-encoding gradient b-vecs. + affine : :obj:`~numpy.ndarray`, optional + Affine matrix. If ``None`` an identity affine matrix is used. + """ + + serialize_dwi(dwi_data, dwi_data_fname, affine=affine) + serialize_gtab(gtab, bval_data_fname, bvec_data_fname) diff --git a/src/eddymotion/viz/signals.py b/src/eddymotion/viz/signals.py index 60c3933c..c763f1f2 100644 --- a/src/eddymotion/viz/signals.py +++ b/src/eddymotion/viz/signals.py @@ -25,11 +25,19 @@ import matplotlib.gridspec as gridspec import numpy as np from matplotlib import pyplot as plt +from scipy.spatial import ConvexHull, KDTree from scipy.stats import pearsonr def plot_error( - kfolds: list[int], mean: np.ndarray, std_dev: np.ndarray, xlabel: str, ylabel: str, title: str + kfolds: list[int], + mean: np.ndarray, + std_dev: np.ndarray, + xlabel: str, + ylabel: str, + title: str, + color: str = "orange", + figsize: tuple[int, int] = (19.2, 10.8), ) -> plt.Figure: """ Plot the error and standard deviation. @@ -48,6 +56,10 @@ def plot_error( Y-axis label. title : :obj:`str` Plot title. + color : :obj:`str`, optional + Plot color. + figsize : :obj:`tuple`, optional + Figure size. Returns ------- @@ -55,10 +67,10 @@ def plot_error( Matplotlib figure object. """ - fig, ax = plt.subplots() - ax.plot(kfolds, mean, c="orange") - ax.fill_between(kfolds, mean - std_dev, mean + std_dev, alpha=0.5, color="orange") - ax.scatter(kfolds, mean, c="orange") + fig, ax = plt.subplots(figsize=figsize) + ax.plot(kfolds, mean, c=color) + ax.fill_between(kfolds, mean - std_dev, mean + std_dev, alpha=0.5, color=color) + ax.scatter(kfolds, mean, c=color) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xticks(kfolds) @@ -112,3 +124,96 @@ def plot_correlation(x, y, title): fig.tight_layout() return fig, r + + +def calculate_sphere_pts(points, center): + """Calculate the location of each point when it is expanded out to the sphere.""" + + kdtree = KDTree(points) # tree of nearest points + # d is an array of distances, i is an array of indices + d, i = kdtree.query(center, points.shape[0]) + sphere_pts = np.zeros(points.shape, dtype=float) + + radius = np.amax(d) + for p in range(points.shape[0]): + sphere_pts[p] = points[i[p]] * radius / d[p] + # points and the indices for where they were in the original lists + return sphere_pts, i + + +def compute_dmri_convex_hull(s, dirs, mask=None): + """Compute the convex hull of the dMRI signal s.""" + + if mask is None: + mask = np.ones(len(dirs), dtype=bool) + + # Scale the original sampling directions by the corresponding signal values + scaled_bvecs = dirs[mask] * np.asarray(s)[:, np.newaxis] + + # Create the data for the convex hull: project the scaled vectors to a + # sphere + sphere_pts, sphere_idx = calculate_sphere_pts(scaled_bvecs, [0, 0, 0]) + + # Create the convex hull: find the right ordering of vertices for the + # triangles: ConvexHull finds the simplices of the points on the outside of + # the data set + hull = ConvexHull(sphere_pts) + triang_idx = hull.simplices # returns the list of indices for each triangle + + return scaled_bvecs, sphere_idx, triang_idx + + +def plot_surface(scaled_vecs, sphere_idx, triang_idx, title, cmap): + """Plot a surface.""" + + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + ax.scatter3D( + scaled_vecs[:, 0], scaled_vecs[:, 1], scaled_vecs[:, 2], s=2, c="black", alpha=1.0 + ) + + surface = ax.plot_trisurf( + scaled_vecs[sphere_idx, 0], + scaled_vecs[sphere_idx, 1], + scaled_vecs[sphere_idx, 2], + triangles=triang_idx, + cmap=cmap, + alpha=0.6, + ) + + ax.view_init(10, 45) + ax.set_aspect("equal", adjustable="box") + ax.set_title(title) + + return fig, ax, surface + + +def plot_signal_data(y, ax): + """Plot the data provided as a scatter plot""" + + ax.scatter( + y[:, 0], y[:, 1], y[:, 2], color="red", marker="*", alpha=0.8, s=5, label="Original points" + ) + + +def plot_prediction_surface(y, y_pred, S0, y_dirs, y_pred_dirs, title, cmap): + """Plot the prediction surface obtained by computing the convex hull of the + predicted signal data, and plot the true data as a scatter plot.""" + + # Scale the original sampling directions by the corresponding signal values + y_bvecs = y_dirs * np.asarray(y)[:, np.newaxis] + + # Compute the convex hull + y_pred_bvecs, sphere_idx, triang_idx = compute_dmri_convex_hull(y_pred, y_pred_dirs) + + # Plot the surface + fig, ax, surface = plot_surface(y_pred_bvecs, sphere_idx, triang_idx, title, cmap) + + # Add the underlying signal to the plot + # plot_signal_data(y_bvecs/S0, ax) + plot_signal_data(y_bvecs, ax) + + fig.tight_layout() + + return fig, ax, surface