diff --git a/smriprep/cli/run.py b/smriprep/cli/run.py index dfd55755d0..95efef1e1f 100644 --- a/smriprep/cli/run.py +++ b/smriprep/cli/run.py @@ -177,6 +177,11 @@ def get_parser(): action='store_true', help='treat dataset as longitudinal - may increase runtime', ) + g_conf.add_argument( + '--standardize-with-T2w', + action='store_true', + help='treat dataset as longitudinal - may increase runtime', + ) # ANTs options g_ants = parser.add_argument_group('Specific options for ANTs registrations') @@ -629,6 +634,7 @@ def build_workflow(opts, retval): fs_no_resume=opts.fs_no_resume, layout=layout, longitudinal=opts.longitudinal, + standardize_with_T2w=opts.standardize_with_T2w, low_mem=opts.low_mem, msm_sulc=opts.msm_sulc, omp_nthreads=omp_nthreads, diff --git a/smriprep/interfaces/templateflow.py b/smriprep/interfaces/templateflow.py index 6a466d7d09..71b4bb037d 100644 --- a/smriprep/interfaces/templateflow.py +++ b/smriprep/interfaces/templateflow.py @@ -39,17 +39,19 @@ class _TemplateFlowSelectInputSpec(BaseInterfaceInputSpec): - template = traits.Str('MNI152NLin2009cAsym', mandatory=True, desc='Template ID') + template = traits.Str(mandatory=True, desc='Template ID') atlas = InputMultiObject(traits.Str, desc='Specify an atlas') cohort = InputMultiObject(traits.Either(traits.Str, traits.Int), desc='Specify a cohort') resolution = InputMultiObject(traits.Int, desc='Specify a template resolution index') template_spec = traits.DictStrAny( {'atlas': None, 'cohort': None}, usedefault=True, desc='Template specifications' ) + get_T2w = traits.Bool(False, usedefault=True, desc='Get the T2w if available') class _TemplateFlowSelectOutputSpec(TraitedSpec): t1w_file = File(exists=True, desc='T1w template') + t2w_file = File(exists=True, desc='T2w template') brain_mask = File(exists=True, desc="Template's brain mask") @@ -57,7 +59,9 @@ class TemplateFlowSelect(SimpleInterface): """ Select TemplateFlow elements. - >>> select = TemplateFlowSelect(resolution=1) + Examples + -------- + >>> select = TemplateFlowSelect(resolution=1, get_T2w=True) >>> select.inputs.template = 'MNI152NLin2009cAsym' >>> result = select.run() >>> result.outputs.t1w_file # doctest: +ELLIPSIS @@ -66,6 +70,9 @@ class TemplateFlowSelect(SimpleInterface): >>> result.outputs.brain_mask # doctest: +ELLIPSIS '.../tpl-MNI152NLin2009cAsym_res-01_desc-brain_mask.nii.gz' + >>> result.outputs.t2w_file # doctest: +ELLIPSIS + '.../tpl-MNI152NLin2009cAsym_res-01_T2w.nii.gz' + >>> select = TemplateFlowSelect() >>> select.inputs.template = 'MNIPediatricAsym' >>> select.inputs.template_spec = {'cohort': 5, 'resolution': 1} @@ -94,6 +101,9 @@ class TemplateFlowSelect(SimpleInterface): >>> result.outputs.t1w_file # doctest: +ELLIPSIS '.../tpl-MNI305_T1w.nii.gz' + >>> bool(result.outputs.t2w_file) + False + """ input_spec = _TemplateFlowSelectInputSpec @@ -108,8 +118,14 @@ def _run_interface(self, runtime): if isdefined(self.inputs.cohort): specs['cohort'] = self.inputs.cohort - files = fetch_template_files(self.inputs.template, specs) + files = fetch_template_files( + self.inputs.template, + specs, + get_T2w=self.inputs.get_T2w, + ) self._results['t1w_file'] = files['t1w'] + if self.inputs.get_T2w and 't2w' in files: + self._results['t2w_file'] = files['t2w'] self._results['brain_mask'] = files['mask'] return runtime @@ -167,6 +183,7 @@ def fetch_template_files( template: str, specs: dict | None = None, sloppy: bool = False, + get_T2w: bool = False, ) -> dict: if specs is None: specs = {} @@ -203,6 +220,8 @@ def fetch_template_files( files = {} files['t1w'] = tf.get(name[0], desc=None, suffix='T1w', **specs) + if get_T2w and (t2w := tf.get(name[0], desc=None, suffix='T2w', **specs)): + files['t2w'] = t2w files['mask'] = tf.get(name[0], desc='brain', suffix='mask', **specs) or tf.get( name[0], label='brain', suffix='mask', **specs ) diff --git a/smriprep/workflows/anatomical.py b/smriprep/workflows/anatomical.py index 974d2c57a4..8088e33f20 100644 --- a/smriprep/workflows/anatomical.py +++ b/smriprep/workflows/anatomical.py @@ -114,6 +114,7 @@ def init_anat_preproc_wf( name: str = 'anat_preproc_wf', skull_strip_fixed_seed: bool = False, fs_no_resume: bool = False, + norm_add_T2w: bool = False, ): """ Stage the anatomical preprocessing steps of *sMRIPrep*. @@ -190,6 +191,9 @@ def init_anat_preproc_wf( EXPERT: Import pre-computed FreeSurfer reconstruction without resuming. The user is responsible for ensuring that all necessary files are present. (default: ``False``). + norm_add_T2w : :obj:`bool` + Use T2w as a moving image channel in the spatial normalization to template + space(s). Inputs ------ @@ -283,6 +287,7 @@ def init_anat_preproc_wf( omp_nthreads=omp_nthreads, skull_strip_fixed_seed=skull_strip_fixed_seed, fs_no_resume=fs_no_resume, + norm_add_T2w=norm_add_T2w, ) template_iterator_wf = init_template_iterator_wf(spaces=spaces, sloppy=sloppy) ds_std_volumes_wf = init_ds_anat_volumes_wf( @@ -461,6 +466,7 @@ def init_anat_fit_wf( name='anat_fit_wf', skull_strip_fixed_seed: bool = False, fs_no_resume: bool = False, + norm_add_T2w: bool = False, ): """ Stage the anatomical preprocessing steps of *sMRIPrep*. @@ -541,6 +547,9 @@ def init_anat_fit_wf( Do not use a random seed for skull-stripping - will ensure run-to-run replicability when used with --omp-nthreads 1 (default: ``False``). + norm_add_T2w : :obj:`bool` + Use T2w as a moving image channel in the spatial normalization to template + space(s). Inputs ------ @@ -994,6 +1003,7 @@ def init_anat_fit_wf( sloppy=sloppy, omp_nthreads=omp_nthreads, templates=templates, + use_T2w=norm_add_T2w and t2w, ) ds_template_registration_wf = init_ds_template_registration_wf( output_dir=output_dir, image_type='T1w' @@ -1002,7 +1012,6 @@ def init_anat_fit_wf( # fmt:off workflow.connect([ (inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]), - (t1w_buffer, register_template_wf, [('t1w_preproc', 'inputnode.moving_image')]), (refined_buffer, register_template_wf, [('t1w_mask', 'inputnode.moving_mask')]), (sourcefile_buffer, ds_template_registration_wf, [ ('source_files', 'inputnode.source_files') @@ -1137,6 +1146,12 @@ def init_anat_fit_wf( image_type='T2w', name='t2w_template_wf', ) + register_template_wf = init_register_template_wf( + sloppy=sloppy, + omp_nthreads=omp_nthreads, + templates=templates, + use_T2w=norm_add_T2w and t2w, + ) bbreg = pe.Node( fs.BBRegister( contrast_type='t2', @@ -1166,6 +1181,8 @@ def init_anat_fit_wf( ) ds_t2w_preproc.inputs.SkullStripped = False + merge_t2w = pe.Node(niu.Merge(2), name='merge_t2w', run_without_submitting=True) + workflow.connect([ (inputnode, t2w_template_wf, [('t2w', 'inputnode.anat_files')]), (t2w_template_wf, bbreg, [('outputnode.anat_ref', 'source_file')]), @@ -1182,10 +1199,34 @@ def init_anat_fit_wf( (inputnode, ds_t2w_preproc, [('t2w', 'source_file')]), (t2w_resample, ds_t2w_preproc, [('output_image', 'in_file')]), (ds_t2w_preproc, outputnode, [('out_file', 't2w_preproc')]), + (t1w_buffer, merge_t2w, [('t1w_preproc', 'in1')]), + (t2w_resample, merge_t2w, [('output_image', 'in2')]), + (merge_t2w, register_template_wf, [('out', 'inputnode.moving_image')]), ]) # fmt:skip elif not t2w: + register_template_wf = init_register_template_wf( + sloppy=sloppy, + omp_nthreads=omp_nthreads, + templates=templates, + use_T2w=norm_add_T2w and t2w, + ) + workflow.connect([ + (t1w_buffer, register_template_wf, [('t1w_preproc', 'inputnode.moving_image')]), + ]) LOGGER.info('ANAT No T2w images provided - skipping Stage 7') else: + register_template_wf = init_register_template_wf( + sloppy=sloppy, + omp_nthreads=omp_nthreads, + templates=templates, + use_T2w=norm_add_T2w and t2w, + ) + merge_t2w = pe.Node(niu.Merge(2), name='merge_t2w', run_without_submitting=True) + workflow.connect([ + (t1w_buffer, merge_t2w, [('t1w_preproc', 'in1')]), + (inputnode, merge_t2w, [('t2w', 'in2')]), + (merge_t2w, register_template_wf, [('out', 'inputnode.moving_image')]), + ]) # fmt:skip LOGGER.info('ANAT Found preprocessed T2w - skipping Stage 7') # Stages 8-10: Surface conversion and registration diff --git a/smriprep/workflows/base.py b/smriprep/workflows/base.py index 730fe1f656..c42aa82cb9 100644 --- a/smriprep/workflows/base.py +++ b/smriprep/workflows/base.py @@ -63,6 +63,7 @@ def init_smriprep_wf( work_dir, bids_filters, cifti_output, + standardize_with_T2w ): """ Create the execution graph of *sMRIPrep*, with a sub-workflow for each subject. @@ -156,6 +157,9 @@ def init_smriprep_wf( bids_filters : dict Provides finer specification of the pipeline input files through pybids entities filters. A dict with the following structure {:{:,...},...} + standardize_with_T2w : :obj:`bool` + Use T2w as a moving image channel in the spatial normalization to template + space(s). """ smriprep_wf = Workflow(name='smriprep_wf') @@ -196,6 +200,7 @@ def init_smriprep_wf( subject_id=subject_id, bids_filters=bids_filters, cifti_output=cifti_output, + standardize_with_T2w=standardize_with_T2w, ) single_subject_wf.config['execution']['crashdump_dir'] = os.path.join( @@ -233,6 +238,7 @@ def init_single_subject_wf( subject_id, bids_filters, cifti_output, + standardize_with_T2w, ): """ Create a single subject workflow. @@ -324,6 +330,9 @@ def init_single_subject_wf( bids_filters : dict Provides finer specification of the pipeline input files through pybids entities filters. A dict with the following structure {:{:,...},...} + standardize_with_T2w : :obj:`bool` + Use T2w as a moving image channel in the spatial normalization to template + space(s). Inputs ------ @@ -441,6 +450,7 @@ def init_single_subject_wf( skull_strip_template=skull_strip_template, spaces=spaces, cifti_output=cifti_output, + norm_add_T2w=standardize_with_T2w, ) # fmt:off diff --git a/smriprep/workflows/fit/registration.py b/smriprep/workflows/fit/registration.py index fc78575721..a216aebf89 100644 --- a/smriprep/workflows/fit/registration.py +++ b/smriprep/workflows/fit/registration.py @@ -41,6 +41,7 @@ def init_register_template_wf( sloppy, omp_nthreads, templates, + use_T2w=False, name='register_template_wf', ): """ @@ -171,8 +172,9 @@ def init_register_template_wf( ) # With the improvements from nipreps/niworkflows#342 this truncation is now necessary - trunc_mov = pe.Node( - ants.ImageMath(operation='TruncateImageIntensity', op2='0.01 0.999 256'), + trunc_mov = pe.MapNode( + ants.ImageMath(operation='TruncateImageIntensity', op2='0.01 0.999 255'), + iterfield='op1', name='trunc_mov', ) @@ -192,10 +194,19 @@ def init_register_template_wf( run_without_submitting=True, ) + include_t2w = pe.Node( + niu.Function(function=_include_t2w, output_names=['moving_image', 'get_T2w']), + name='include_t2w', + run_without_submitting=True, + ) + include_t2w.inputs.use_T2w = use_T2w + # fmt:off workflow.connect([ (inputnode, split_desc, [('template', 'template')]), - (inputnode, trunc_mov, [('moving_image', 'op1')]), + (inputnode, include_t2w, [('moving_image', 'moving_image')]), + (include_t2w, tf_select, [('get_T2w', 'get_T2w')]), + (include_t2w, trunc_mov, [('moving_image', 'op1')]), (inputnode, registration, [ ('moving_mask', 'moving_mask'), ('lesion_mask', 'lesion_mask')]), @@ -243,3 +254,12 @@ def _fmt_cohort(template, spec): if cohort is not None: template = f'{template}:cohort-{cohort}' return template, spec + + +def _include_t2w(moving_image, use_T2w=False): + islist = isinstance(moving_image, list) + if not use_T2w: + return moving_image[0] if islist else moving_image, False + + return moving_image, islist and len(moving_image) > 1 +