diff --git a/clinica/pipelines/dwi_connectome/dwi_connectome_utils.py b/clinica/pipelines/dwi_connectome/dwi_connectome_utils.py index b5c1f0491..69a0bd154 100644 --- a/clinica/pipelines/dwi_connectome/dwi_connectome_utils.py +++ b/clinica/pipelines/dwi_connectome/dwi_connectome_utils.py @@ -1,115 +1,101 @@ -def get_luts(): - import os +from pathlib import Path - from clinica.utils.exceptions import ClinicaException - try: - # For aparc+aseg.mgz file: - default = os.path.join(os.environ["FREESURFER_HOME"], "FreeSurferColorLUT.txt") - # For aparc.a2009s+aseg.mgz file: - a2009s = os.path.join(os.environ["FREESURFER_HOME"], "FreeSurferColorLUT.txt") +def get_luts() -> list: + from pathlib import Path - # TODO: Add custom Lausanne2008 LUTs here. - except KeyError: - raise ClinicaException("Could not find FREESURFER_HOME environment variable.") - return [default, a2009s] + from clinica.utils.check_dependency import check_environment_variable + freesurfer_home = Path(check_environment_variable("FREESURFER_HOME", "Freesurfer")) -def get_conversion_luts_offline(): - # TODO: use this function if no internet connect found in client (need to upload files to clinica repository) - return + return [ + str(freesurfer_home / "FreeSurferColorLUT.txt"), + str(freesurfer_home / "FreeSurferColorLUT.txt"), + ] -def get_conversion_luts(): - from os import pardir - from os.path import abspath, dirname, join +def get_conversion_luts() -> list: from pathlib import Path - from clinica.utils.inputs import RemoteFileStructure, fetch_file - from clinica.utils.stream import cprint - - root = dirname(abspath(join(abspath(__file__), pardir, pardir))) - - path_to_mappings = Path(root) / "resources" / "mappings" - - url_mrtrix = "https://raw.githubusercontent.com/MRtrix3/mrtrix3/master/share/mrtrix3/labelconvert/" - - fs_default = RemoteFileStructure( - filename="fs_default.txt", - url=url_mrtrix, - checksum="6ee07088915fdbcf52b05147ddae86e5fcaf3efc63db5b0ba8f361637dfa11ef", + path_to_mappings = ( + Path(__file__).resolve().parent.parent.parent / "resources" / "mappings" ) + resulting_paths = [] + for filename in ("fs_default.txt", "fs_a2009s.txt"): + file_path = path_to_mappings / filename + if not file_path.is_file(): + file_path = _download_mrtrix3_file(filename, path_to_mappings) + resulting_paths.append(str(file_path)) + return resulting_paths - fs_a2009s = RemoteFileStructure( - filename="fs_a2009s.txt", - url=url_mrtrix, - checksum="b472f09cfe92ac0b6694fb6b00a87baf15dd269566e4a92b8a151ff1080bf170", - ) - ref_fs_default = path_to_mappings / Path(fs_default.filename) - ref_fs_a2009 = path_to_mappings / Path(fs_a2009s.filename) +def _download_mrtrix3_file(filename: str, path_to_mappings: Path) -> str: + from clinica.utils.inputs import RemoteFileStructure, fetch_file + from clinica.utils.stream import cprint - if not (ref_fs_default.is_file()): - try: - ref_fs_default = fetch_file(fs_default, path_to_mappings) - except IOError as err: - cprint( - msg=f"Unable to download required MRTRIX mapping (fs_default.txt) for processing: {err}", - lvl="error", - ) - if not (ref_fs_a2009.is_file()): - try: - ref_fs_a2009 = fetch_file(fs_a2009s, path_to_mappings) - except IOError as err: - cprint( - msg=f"Unable to download required MRTRIX mapping (fs_a2009s.txt) for processing: {err}", - lvl="error", - ) + try: + return fetch_file( + RemoteFileStructure( + filename=filename, + url="https://raw.githubusercontent.com/MRtrix3/mrtrix3/master/share/mrtrix3/labelconvert/", + checksum=_get_checksum_for_filename(filename), + ), + str(path_to_mappings), + ) + except IOError as err: + error_msg = f"Unable to download required MRTRIX mapping ({filename}) for processing: {err}" + cprint(msg=error_msg, lvl="error") + raise IOError(error_msg) - return [ref_fs_default, ref_fs_a2009] +def _get_checksum_for_filename(filename: str) -> str: + if filename == "fs_default.txt": + return "a8d561694887a1ca8d9df223aa5ef861b6c79d43ce9ed93835b9ce8aadc331b1" + if filename == "fs_a2009s.txt": + return "40b0d4d77bde7e1d265439347af5b30cc973748c1a88d203d7044cb35b3863e1" + raise ValueError(f"File name {filename} is not supported.") -def get_containers(subjects, sessions): - import os + +def get_containers(subjects: list, sessions: list) -> list: + from pathlib import Path return [ - os.path.join("subjects", subjects[i], sessions[i], "dwi") - for i in range(len(subjects)) + str(Path("subjects") / subject / session / "dwi") + for subject, session in zip(subjects, sessions) ] -def get_caps_filenames(dwi_file: str): +def get_caps_filenames(dwi_file: str) -> tuple: import re - m = re.search(r"/(sub-[a-zA-Z0-9]+_ses-[a-zA-Z0-9]+.*)_preproc", dwi_file) - if not m: - raise ValueError( - f"Input filename {dwi_file} is not in a CAPS compliant format." + error_msg = f"Input filename {dwi_file} is not in a CAPS compliant format." + if ( + m := re.search( + r"/(sub-[a-zA-Z0-9]+_ses-[a-zA-Z0-9]+.*_desc-preproc*)_dwi", dwi_file ) + ) is None: + raise ValueError(error_msg) source_file_caps = m.group(1) - - m = re.search( - r"/(sub-[a-zA-Z0-9]+_ses-[a-zA-Z0-9]+.*)_space-[a-zA-Z0-9]+_preproc", dwi_file - ) - if not m: - raise ValueError( - f"Input filename {dwi_file} is not in a CAPS compliant format." + if ( + m := re.search( + r"/(sub-[a-zA-Z0-9]+_ses-[a-zA-Z0-9]+.*)_space-[a-zA-Z0-9]+_desc-preproc_dwi", + dwi_file, ) + ) is None: + raise ValueError(error_msg) source_file_bids = m.group(1) response = f"{source_file_caps}_model-CSD_responseFunction.txt" fod = f"{source_file_caps}_model-CSD_diffmodel.nii.gz" tracts = f"{source_file_caps}_model-CSD_tractography.tck" nodes = [ - f"{source_file_caps}_atlas-desikan_parcellation.nii.gz", - f"{source_file_caps}_atlas-destrieux_parcellation.nii.gz", + f"{source_file_caps}_atlas-{atlas}_parcellation.nii.gz" + for atlas in ("desikan", "destrieux") ] - # TODO: Add custom Lausanne2008 node files here. connectomes = [ - f"{source_file_bids}_model-CSD_atlas-desikan_connectivity.tsv", - f"{source_file_bids}_model-CSD_atlas-destrieux_connectivity.tsv", + f"{source_file_bids}_model-CSD_atlas-{atlas}_connectivity.tsv" + for atlas in ("desikan", "destrieux") ] - # TODO: Add custom Lausanne2008 connectome files here. return response, fod, tracts, nodes, connectomes diff --git a/test/unittests/pipelines/dwi_connectome/test_dwi_connectome_utils.py b/test/unittests/pipelines/dwi_connectome/test_dwi_connectome_utils.py new file mode 100644 index 000000000..35c6fab2b --- /dev/null +++ b/test/unittests/pipelines/dwi_connectome/test_dwi_connectome_utils.py @@ -0,0 +1,154 @@ +import pytest + + +def test_get_luts(mocker): + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import get_luts + + mocked_freesurfer_home = "/Applications/freesurfer/7.2.0" + mocker.patch( + "clinica.utils.check_dependency.check_environment_variable", + return_value=mocked_freesurfer_home, + ) + assert get_luts() == [f"{mocked_freesurfer_home}/FreeSurferColorLUT.txt"] * 2 + + +@pytest.mark.parametrize( + "filename,expected_checksum", + [ + ( + "fs_default.txt", + "a8d561694887a1ca8d9df223aa5ef861b6c79d43ce9ed93835b9ce8aadc331b1", + ), + ( + "fs_a2009s.txt", + "40b0d4d77bde7e1d265439347af5b30cc973748c1a88d203d7044cb35b3863e1", + ), + ], +) +def test_get_checksum_for_filename(filename, expected_checksum): + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import ( + _get_checksum_for_filename, + ) + + assert _get_checksum_for_filename(filename) == expected_checksum + + +def test_get_checksum_for_filename_error(): + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import ( + _get_checksum_for_filename, + ) + + with pytest.raises(ValueError, match="File name foo.txt is not supported."): + _get_checksum_for_filename("foo.txt") + + +@pytest.mark.parametrize( + "filename,expected_length", [("fs_default.txt", 112), ("fs_a2009s.txt", 192)] +) +def test_download_mrtrix3_file(tmp_path, filename, expected_length): + """Atm this test needs an internet connection to download the files. + + TODO: Use mocking in the fetch_file function to remove this necessity. + """ + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import ( + _download_mrtrix3_file, + ) + + _download_mrtrix3_file(filename, tmp_path) + + assert [f.name for f in tmp_path.iterdir()] == [filename] + assert len((tmp_path / filename).read_text().split("\n")) == expected_length + + +def test_download_mrtrix3_file_error(tmp_path, mocker): + import re + + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import ( + _download_mrtrix3_file, + ) + + mocker.patch( + "clinica.pipelines.dwi_connectome.dwi_connectome_utils._get_checksum_for_filename", + return_value="foo", + ) + mocker.patch("clinica.utils.inputs.fetch_file", side_effect=IOError) + + with pytest.raises( + IOError, + match=re.escape( + "Unable to download required MRTRIX mapping (foo.txt) for processing" + ), + ): + _download_mrtrix3_file("foo.txt", tmp_path) + + +def test_get_conversion_luts(): + from pathlib import Path + + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import ( + get_conversion_luts, + ) + + luts = [Path(_) for _ in get_conversion_luts()] + + assert [p.name for p in luts] == ["fs_default.txt", "fs_a2009s.txt"] + assert all([p.is_file() for p in luts]) + + +@pytest.mark.parametrize( + "filename", + [ + "foo.txt", + "dwi.nii.gz", + "sub-01_ses-M000_dwi.nii.gz", + "sub-01_ses-M000_preproc.nii.gz", + "sub-01_ses-M000_space-T1w_preproc.nii.gz", + "sub-01_ses-M000_space-b0_preproc.nii.gz", + ], +) +def test_get_caps_filenames_error(tmp_path, filename): + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import get_caps_filenames + + with pytest.raises(ValueError, match="is not in a CAPS compliant format."): + get_caps_filenames(str(tmp_path / filename)) + + +def test_get_caps_filenames(tmp_path): + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import get_caps_filenames + + dwi_caps = tmp_path / "dwi" / "preprocessing" + dwi_caps.mkdir(parents=True) + + assert get_caps_filenames( + str(dwi_caps / "sub-01_ses-M000_space-b0_desc-preproc_dwi.nii.gz") + ) == ( + "sub-01_ses-M000_space-b0_desc-preproc_model-CSD_responseFunction.txt", + "sub-01_ses-M000_space-b0_desc-preproc_model-CSD_diffmodel.nii.gz", + "sub-01_ses-M000_space-b0_desc-preproc_model-CSD_tractography.tck", + [ + "sub-01_ses-M000_space-b0_desc-preproc_atlas-desikan_parcellation.nii.gz", + "sub-01_ses-M000_space-b0_desc-preproc_atlas-destrieux_parcellation.nii.gz", + ], + [ + "sub-01_ses-M000_model-CSD_atlas-desikan_connectivity.tsv", + "sub-01_ses-M000_model-CSD_atlas-destrieux_connectivity.tsv", + ], + ) + + +@pytest.mark.parametrize( + "subjects,sessions,expected", + [ + ([], [], []), + (["foo"], ["bar"], ["subjects/foo/bar/dwi"]), + ( + ["sub-01", "sub-02"], + ["ses-M000", "ses-M006"], + ["subjects/sub-01/ses-M000/dwi", "subjects/sub-02/ses-M006/dwi"], + ), + ], +) +def test_get_containers(subjects, sessions, expected): + from clinica.pipelines.dwi_connectome.dwi_connectome_utils import get_containers + + assert get_containers(subjects, sessions) == expected