diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index f1f2e733..d9dd9f8c 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -34,7 +34,7 @@ import numpy as np import pandas as pd -from sklearn.model_selection import RepeatedKFold, cross_val_score +from sklearn.model_selection import KFold, RepeatedKFold, cross_val_predict, cross_val_score from eddymotion.model._sklearn import ( EddyMotionGPR, @@ -92,22 +92,52 @@ def _build_arg_parser() -> argparse.ArgumentParser: description=__doc__, formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( - "hsph_dirs", + "--bval-shell", + help="Shell b-value", + type=float, + default=1000, + ) + parser.add_argument("--S0", help="S0 value", type=float, default=100) + parser.add_argument( + "--hsph-dirs", help="Number of diffusion gradient-encoding directions in the half sphere", type=int, + default=60, ) - parser.add_argument("bval_shell", help="Shell b-value", type=int) - parser.add_argument("S0", help="S0 value", type=float) parser.add_argument( - "error_data_fname", + "--output-scores", help="Filename of TSV file containing the data to plot", type=Path, + default=Path() / "scores.tsv", + ) + parser.add_argument( + "-n", + "--n-voxels", + help="Number of diffusion gradient-encoding directions in the half sphere", + type=int, + default=100, + ) + parser.add_argument( + "--write-inputs", + help="Filename of NIfTI file containing the generated DWI signal", + type=Path, + default=None, + ) + parser.add_argument( + "--output-predicted", + help="Filename of NIfTI file containing the predicted DWI signal", + type=Path, + default=None, ) parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float) parser.add_argument("--snr", help="Signal to noise ratio", type=float) parser.add_argument("--repeats", help="Number of repeats", type=int, default=5) parser.add_argument( - "--kfold", help="Number of directions to leave out/predict", nargs="+", type=int + "--kfold", + help="Number of folds in repeated-k-fold cross-validation", + nargs="+", + type=int, + default=None, ) return parser @@ -134,18 +164,26 @@ def main() -> None: parser = _build_arg_parser() args = _parse_args(parser) - n_voxels = 100 - data, gtab = testsims.simulate_voxels( args.S0, args.hsph_dirs, bval_shell=args.bval_shell, snr=args.snr, - n_voxels=n_voxels, + n_voxels=args.n_voxels, evals=args.evals, seed=None, ) + # Save the generated signal and gradient table + if args.write_inputs: + testsims.serialize_dmri( + data, + gtab, + args.write_inputs, + args.write_inputs.with_suffix(".bval"), + args.write_inputs.with_suffix(".bvec"), + ) + X = gtab[~gtab.b0s_mask].bvecs y = data[:, ~gtab.b0s_mask] @@ -153,9 +191,9 @@ def main() -> None: a = 1.15 lambda_s = 120 - alpha = 100 + alpha = 1 gpr = EddyMotionGPR( - kernel=SphericalKriging(a=a, lambda_s=lambda_s), + kernel=SphericalKriging(beta_a=a, beta_l=lambda_s), alpha=alpha, optimizer=None, # optimizer="Nelder-Mead", @@ -164,25 +202,35 @@ def main() -> None: # max_iter=2e5, ) - # Use Scikit-learn cross validation - scores = defaultdict(list, {}) - for n in args.kfold: - for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, np.max(args.kfold) // n, gpr) - scores["rmse"] += cv_scores.tolist() - scores["repeat"] += [i] * len(cv_scores) - scores["n_folds"] += [n] * len(cv_scores) - scores["bval"] += [args.bval_shell] * len(cv_scores) - scores["snr"] += [snr_str] * len(cv_scores) - - print(f"Finished {n}-fold cross-validation") - - scores_df = pd.DataFrame(scores) - scores_df.to_csv(args.error_data_fname, sep="\t", index=None, na_rep="n/a") - - grouped = scores_df.groupby(["n_folds"]) - print(grouped[["rmse"]].mean()) - print(grouped[["rmse"]].std()) + if args.kfold: + # Use Scikit-learn cross validation + scores = defaultdict(list, {}) + for n in args.kfold: + for i in range(args.repeats): + cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + scores["rmse"] += cv_scores.tolist() + scores["repeat"] += [i] * len(cv_scores) + scores["n_folds"] += [n] * len(cv_scores) + scores["bval"] += [args.bval-shell] * len(cv_scores) + scores["snr"] += [snr_str] * len(cv_scores) + + print(f"Finished {n}-fold cross-validation") + + scores_df = pd.DataFrame(scores) + scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a") + + grouped = scores_df.groupby(["n_folds"]) + print(grouped[["rmse"]].mean()) + print(grouped[["rmse"]].std()) + else: + gpr.fit(X, y.T) + print(gpr.kernel_) + + if args.output_predicted: + cv = KFold(n_splits=3, shuffle=False, random_state=None) + predictions = cross_val_predict(gpr, X, y.T, cv=cv) + + testsims.serialize_dwi(predictions.T, args.output_predicted) if __name__ == "__main__":