Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add T2w as an additional channel in spatial normalization #452

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions smriprep/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def get_parser():
action='store_true',
help='treat dataset as longitudinal - may increase runtime',
)
g_conf.add_argument(
'--standardize-with-T2w',
action='store_true',
help='treat dataset as longitudinal - may increase runtime',
)

# ANTs options
g_ants = parser.add_argument_group('Specific options for ANTs registrations')
Expand Down Expand Up @@ -629,6 +634,7 @@ def build_workflow(opts, retval):
fs_no_resume=opts.fs_no_resume,
layout=layout,
longitudinal=opts.longitudinal,
standardize_with_T2w=opts.standardize_with_T2w,
low_mem=opts.low_mem,
msm_sulc=opts.msm_sulc,
omp_nthreads=omp_nthreads,
Expand Down
25 changes: 22 additions & 3 deletions smriprep/interfaces/templateflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,29 @@


class _TemplateFlowSelectInputSpec(BaseInterfaceInputSpec):
template = traits.Str('MNI152NLin2009cAsym', mandatory=True, desc='Template ID')
template = traits.Str(mandatory=True, desc='Template ID')
atlas = InputMultiObject(traits.Str, desc='Specify an atlas')
cohort = InputMultiObject(traits.Either(traits.Str, traits.Int), desc='Specify a cohort')
resolution = InputMultiObject(traits.Int, desc='Specify a template resolution index')
template_spec = traits.DictStrAny(
{'atlas': None, 'cohort': None}, usedefault=True, desc='Template specifications'
)
get_T2w = traits.Bool(False, usedefault=True, desc='Get the T2w if available')


class _TemplateFlowSelectOutputSpec(TraitedSpec):
t1w_file = File(exists=True, desc='T1w template')
t2w_file = File(exists=True, desc='T2w template')
brain_mask = File(exists=True, desc="Template's brain mask")


class TemplateFlowSelect(SimpleInterface):
"""
Select TemplateFlow elements.

>>> select = TemplateFlowSelect(resolution=1)
Examples
--------
>>> select = TemplateFlowSelect(resolution=1, get_T2w=True)
>>> select.inputs.template = 'MNI152NLin2009cAsym'
>>> result = select.run()
>>> result.outputs.t1w_file # doctest: +ELLIPSIS
Expand All @@ -66,6 +70,9 @@ class TemplateFlowSelect(SimpleInterface):
>>> result.outputs.brain_mask # doctest: +ELLIPSIS
'.../tpl-MNI152NLin2009cAsym_res-01_desc-brain_mask.nii.gz'

>>> result.outputs.t2w_file # doctest: +ELLIPSIS
'.../tpl-MNI152NLin2009cAsym_res-01_T2w.nii.gz'

>>> select = TemplateFlowSelect()
>>> select.inputs.template = 'MNIPediatricAsym'
>>> select.inputs.template_spec = {'cohort': 5, 'resolution': 1}
Expand Down Expand Up @@ -94,6 +101,9 @@ class TemplateFlowSelect(SimpleInterface):
>>> result.outputs.t1w_file # doctest: +ELLIPSIS
'.../tpl-MNI305_T1w.nii.gz'

>>> bool(result.outputs.t2w_file)
False

"""

input_spec = _TemplateFlowSelectInputSpec
Expand All @@ -108,8 +118,14 @@ def _run_interface(self, runtime):
if isdefined(self.inputs.cohort):
specs['cohort'] = self.inputs.cohort

files = fetch_template_files(self.inputs.template, specs)
files = fetch_template_files(
self.inputs.template,
specs,
get_T2w=self.inputs.get_T2w,
)
self._results['t1w_file'] = files['t1w']
if self.inputs.get_T2w and 't2w' in files:
self._results['t2w_file'] = files['t2w']
self._results['brain_mask'] = files['mask']
return runtime

Expand Down Expand Up @@ -167,6 +183,7 @@ def fetch_template_files(
template: str,
specs: dict | None = None,
sloppy: bool = False,
get_T2w: bool = False,
) -> dict:
if specs is None:
specs = {}
Expand Down Expand Up @@ -203,6 +220,8 @@ def fetch_template_files(

files = {}
files['t1w'] = tf.get(name[0], desc=None, suffix='T1w', **specs)
if get_T2w and (t2w := tf.get(name[0], desc=None, suffix='T2w', **specs)):
files['t2w'] = t2w
files['mask'] = tf.get(name[0], desc='brain', suffix='mask', **specs) or tf.get(
name[0], label='brain', suffix='mask', **specs
)
Expand Down
43 changes: 42 additions & 1 deletion smriprep/workflows/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def init_anat_preproc_wf(
name: str = 'anat_preproc_wf',
skull_strip_fixed_seed: bool = False,
fs_no_resume: bool = False,
norm_add_T2w: bool = False,
):
"""
Stage the anatomical preprocessing steps of *sMRIPrep*.
Expand Down Expand Up @@ -190,6 +191,9 @@ def init_anat_preproc_wf(
EXPERT: Import pre-computed FreeSurfer reconstruction without resuming.
The user is responsible for ensuring that all necessary files are present.
(default: ``False``).
norm_add_T2w : :obj:`bool`
Use T2w as a moving image channel in the spatial normalization to template
space(s).

Inputs
------
Expand Down Expand Up @@ -283,6 +287,7 @@ def init_anat_preproc_wf(
omp_nthreads=omp_nthreads,
skull_strip_fixed_seed=skull_strip_fixed_seed,
fs_no_resume=fs_no_resume,
norm_add_T2w=norm_add_T2w,
)
template_iterator_wf = init_template_iterator_wf(spaces=spaces, sloppy=sloppy)
ds_std_volumes_wf = init_ds_anat_volumes_wf(
Expand Down Expand Up @@ -461,6 +466,7 @@ def init_anat_fit_wf(
name='anat_fit_wf',
skull_strip_fixed_seed: bool = False,
fs_no_resume: bool = False,
norm_add_T2w: bool = False,
):
"""
Stage the anatomical preprocessing steps of *sMRIPrep*.
Expand Down Expand Up @@ -541,6 +547,9 @@ def init_anat_fit_wf(
Do not use a random seed for skull-stripping - will ensure
run-to-run replicability when used with --omp-nthreads 1
(default: ``False``).
norm_add_T2w : :obj:`bool`
Use T2w as a moving image channel in the spatial normalization to template
space(s).

Inputs
------
Expand Down Expand Up @@ -994,6 +1003,7 @@ def init_anat_fit_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
templates=templates,
use_T2w=norm_add_T2w and t2w,
)
ds_template_registration_wf = init_ds_template_registration_wf(
output_dir=output_dir, image_type='T1w'
Expand All @@ -1002,7 +1012,6 @@ def init_anat_fit_wf(
# fmt:off
workflow.connect([
(inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]),
(t1w_buffer, register_template_wf, [('t1w_preproc', 'inputnode.moving_image')]),
(refined_buffer, register_template_wf, [('t1w_mask', 'inputnode.moving_mask')]),
(sourcefile_buffer, ds_template_registration_wf, [
('source_files', 'inputnode.source_files')
Expand Down Expand Up @@ -1137,6 +1146,12 @@ def init_anat_fit_wf(
image_type='T2w',
name='t2w_template_wf',
)
register_template_wf = init_register_template_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
templates=templates,
use_T2w=norm_add_T2w and t2w,
)
bbreg = pe.Node(
fs.BBRegister(
contrast_type='t2',
Expand Down Expand Up @@ -1166,6 +1181,8 @@ def init_anat_fit_wf(
)
ds_t2w_preproc.inputs.SkullStripped = False

merge_t2w = pe.Node(niu.Merge(2), name='merge_t2w', run_without_submitting=True)

workflow.connect([
(inputnode, t2w_template_wf, [('t2w', 'inputnode.anat_files')]),
(t2w_template_wf, bbreg, [('outputnode.anat_ref', 'source_file')]),
Expand All @@ -1182,10 +1199,34 @@ def init_anat_fit_wf(
(inputnode, ds_t2w_preproc, [('t2w', 'source_file')]),
(t2w_resample, ds_t2w_preproc, [('output_image', 'in_file')]),
(ds_t2w_preproc, outputnode, [('out_file', 't2w_preproc')]),
(t1w_buffer, merge_t2w, [('t1w_preproc', 'in1')]),
(t2w_resample, merge_t2w, [('output_image', 'in2')]),
(merge_t2w, register_template_wf, [('out', 'inputnode.moving_image')]),
]) # fmt:skip
elif not t2w:
register_template_wf = init_register_template_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
templates=templates,
use_T2w=norm_add_T2w and t2w,
)
workflow.connect([
(t1w_buffer, register_template_wf, [('t1w_preproc', 'inputnode.moving_image')]),
])
LOGGER.info('ANAT No T2w images provided - skipping Stage 7')
else:
register_template_wf = init_register_template_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
templates=templates,
use_T2w=norm_add_T2w and t2w,
)
merge_t2w = pe.Node(niu.Merge(2), name='merge_t2w', run_without_submitting=True)
workflow.connect([
(t1w_buffer, merge_t2w, [('t1w_preproc', 'in1')]),
(inputnode, merge_t2w, [('t2w', 'in2')]),
(merge_t2w, register_template_wf, [('out', 'inputnode.moving_image')]),
]) # fmt:skip
LOGGER.info('ANAT Found preprocessed T2w - skipping Stage 7')

# Stages 8-10: Surface conversion and registration
Expand Down
10 changes: 10 additions & 0 deletions smriprep/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def init_smriprep_wf(
work_dir,
bids_filters,
cifti_output,
standardize_with_T2w
):
"""
Create the execution graph of *sMRIPrep*, with a sub-workflow for each subject.
Expand Down Expand Up @@ -156,6 +157,9 @@ def init_smriprep_wf(
bids_filters : dict
Provides finer specification of the pipeline input files through pybids entities filters.
A dict with the following structure {<suffix>:{<entity>:<filter>,...},...}
standardize_with_T2w : :obj:`bool`
Use T2w as a moving image channel in the spatial normalization to template
space(s).

"""
smriprep_wf = Workflow(name='smriprep_wf')
Expand Down Expand Up @@ -196,6 +200,7 @@ def init_smriprep_wf(
subject_id=subject_id,
bids_filters=bids_filters,
cifti_output=cifti_output,
standardize_with_T2w=standardize_with_T2w,
)

single_subject_wf.config['execution']['crashdump_dir'] = os.path.join(
Expand Down Expand Up @@ -233,6 +238,7 @@ def init_single_subject_wf(
subject_id,
bids_filters,
cifti_output,
standardize_with_T2w,
):
"""
Create a single subject workflow.
Expand Down Expand Up @@ -324,6 +330,9 @@ def init_single_subject_wf(
bids_filters : dict
Provides finer specification of the pipeline input files through pybids entities filters.
A dict with the following structure {<suffix>:{<entity>:<filter>,...},...}
standardize_with_T2w : :obj:`bool`
Use T2w as a moving image channel in the spatial normalization to template
space(s).

Inputs
------
Expand Down Expand Up @@ -441,6 +450,7 @@ def init_single_subject_wf(
skull_strip_template=skull_strip_template,
spaces=spaces,
cifti_output=cifti_output,
norm_add_T2w=standardize_with_T2w,
)

# fmt:off
Expand Down
26 changes: 23 additions & 3 deletions smriprep/workflows/fit/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def init_register_template_wf(
sloppy,
omp_nthreads,
templates,
use_T2w=False,
name='register_template_wf',
):
"""
Expand Down Expand Up @@ -171,8 +172,9 @@ def init_register_template_wf(
)

# With the improvements from nipreps/niworkflows#342 this truncation is now necessary
trunc_mov = pe.Node(
ants.ImageMath(operation='TruncateImageIntensity', op2='0.01 0.999 256'),
trunc_mov = pe.MapNode(
ants.ImageMath(operation='TruncateImageIntensity', op2='0.01 0.999 255'),
iterfield='op1',
name='trunc_mov',
)

Expand All @@ -192,10 +194,19 @@ def init_register_template_wf(
run_without_submitting=True,
)

include_t2w = pe.Node(
niu.Function(function=_include_t2w, output_names=['moving_image', 'get_T2w']),
name='include_t2w',
run_without_submitting=True,
)
include_t2w.inputs.use_T2w = use_T2w

# fmt:off
workflow.connect([
(inputnode, split_desc, [('template', 'template')]),
(inputnode, trunc_mov, [('moving_image', 'op1')]),
(inputnode, include_t2w, [('moving_image', 'moving_image')]),
(include_t2w, tf_select, [('get_T2w', 'get_T2w')]),
(include_t2w, trunc_mov, [('moving_image', 'op1')]),
(inputnode, registration, [
('moving_mask', 'moving_mask'),
('lesion_mask', 'lesion_mask')]),
Expand Down Expand Up @@ -243,3 +254,12 @@ def _fmt_cohort(template, spec):
if cohort is not None:
template = f'{template}:cohort-{cohort}'
return template, spec


def _include_t2w(moving_image, use_T2w=False):
islist = isinstance(moving_image, list)
if not use_T2w:
return moving_image[0] if islist else moving_image, False

return moving_image, islist and len(moving_image) > 1