Skip to content

Commit

Permalink
Merge pull request #249 from nipreps/fix/dwi_experiment_improvements
Browse files Browse the repository at this point in the history
ENH: Improve CLI of the dwi/error experiment
  • Loading branch information
oesteban authored Oct 29, 2024
2 parents 26ef287 + 4a2d147 commit 0c2e6c0
Showing 1 changed file with 78 additions and 30 deletions.
108 changes: 78 additions & 30 deletions scripts/dwi_gp_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

import numpy as np
import pandas as pd
from sklearn.model_selection import RepeatedKFold, cross_val_score
from sklearn.model_selection import KFold, RepeatedKFold, cross_val_predict, cross_val_score

from eddymotion.model._sklearn import (
EddyMotionGPR,
Expand Down Expand Up @@ -92,22 +92,52 @@ def _build_arg_parser() -> argparse.ArgumentParser:
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"hsph_dirs",
"--bval-shell",
help="Shell b-value",
type=float,
default=1000,
)
parser.add_argument("--S0", help="S0 value", type=float, default=100)
parser.add_argument(
"--hsph-dirs",
help="Number of diffusion gradient-encoding directions in the half sphere",
type=int,
default=60,
)
parser.add_argument("bval_shell", help="Shell b-value", type=int)
parser.add_argument("S0", help="S0 value", type=float)
parser.add_argument(
"error_data_fname",
"--output-scores",
help="Filename of TSV file containing the data to plot",
type=Path,
default=Path() / "scores.tsv",
)
parser.add_argument(
"-n",
"--n-voxels",
help="Number of diffusion gradient-encoding directions in the half sphere",
type=int,
default=100,
)
parser.add_argument(
"--write-inputs",
help="Filename of NIfTI file containing the generated DWI signal",
type=Path,
default=None,
)
parser.add_argument(
"--output-predicted",
help="Filename of NIfTI file containing the predicted DWI signal",
type=Path,
default=None,
)
parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float)
parser.add_argument("--snr", help="Signal to noise ratio", type=float)
parser.add_argument("--repeats", help="Number of repeats", type=int, default=5)
parser.add_argument(
"--kfold", help="Number of directions to leave out/predict", nargs="+", type=int
"--kfold",
help="Number of folds in repeated-k-fold cross-validation",
nargs="+",
type=int,
default=None,
)
return parser

Expand All @@ -134,28 +164,36 @@ def main() -> None:
parser = _build_arg_parser()
args = _parse_args(parser)

n_voxels = 100

data, gtab = testsims.simulate_voxels(
args.S0,
args.hsph_dirs,
bval_shell=args.bval_shell,
snr=args.snr,
n_voxels=n_voxels,
n_voxels=args.n_voxels,
evals=args.evals,
seed=None,
)

# Save the generated signal and gradient table
if args.write_inputs:
testsims.serialize_dmri(
data,
gtab,
args.write_inputs,
args.write_inputs.with_suffix(".bval"),
args.write_inputs.with_suffix(".bvec"),
)

X = gtab[~gtab.b0s_mask].bvecs
y = data[:, ~gtab.b0s_mask]

snr_str = args.snr if args.snr is not None else "None"

a = 1.15
lambda_s = 120
alpha = 100
alpha = 1
gpr = EddyMotionGPR(
kernel=SphericalKriging(a=a, lambda_s=lambda_s),
kernel=SphericalKriging(beta_a=a, beta_l=lambda_s),
alpha=alpha,
optimizer=None,
# optimizer="Nelder-Mead",
Expand All @@ -164,25 +202,35 @@ def main() -> None:
# max_iter=2e5,
)

# Use Scikit-learn cross validation
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n, np.max(args.kfold) // n, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
scores["bval"] += [args.bval_shell] * len(cv_scores)
scores["snr"] += [snr_str] * len(cv_scores)

print(f"Finished {n}-fold cross-validation")

scores_df = pd.DataFrame(scores)
scores_df.to_csv(args.error_data_fname, sep="\t", index=None, na_rep="n/a")

grouped = scores_df.groupby(["n_folds"])
print(grouped[["rmse"]].mean())
print(grouped[["rmse"]].std())
if args.kfold:
# Use Scikit-learn cross validation
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
scores["bval"] += [args.bval-shell] * len(cv_scores)
scores["snr"] += [snr_str] * len(cv_scores)

print(f"Finished {n}-fold cross-validation")

scores_df = pd.DataFrame(scores)
scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a")

grouped = scores_df.groupby(["n_folds"])
print(grouped[["rmse"]].mean())
print(grouped[["rmse"]].std())
else:
gpr.fit(X, y.T)
print(gpr.kernel_)

if args.output_predicted:
cv = KFold(n_splits=3, shuffle=False, random_state=None)
predictions = cross_val_predict(gpr, X, y.T, cv=cv)

testsims.serialize_dwi(predictions.T, args.output_predicted)


if __name__ == "__main__":
Expand Down

0 comments on commit 0c2e6c0

Please sign in to comment.