Skip to content

Commit

Permalink
spacing and orientation; revise transforms cropping and zooming
Browse files Browse the repository at this point in the history
update input validation
  • Loading branch information
wyli committed Mar 9, 2020
1 parent 7304182 commit 1996d12
Show file tree
Hide file tree
Showing 11 changed files with 595 additions and 90 deletions.
8 changes: 5 additions & 3 deletions monai/data/nifti_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import nibabel as nib

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import np_str_obj_array_pattern
from monai.utils.module import export

from monai.data.utils import correct_nifti_header_if_necessary
from monai.transforms.compose import Randomizable
from monai.utils.module import export


def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dtype=None):
Expand All @@ -38,6 +39,7 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty
"""

img = nib.load(filename_or_obj)
img = correct_nifti_header_if_necessary(img)

header = dict(img.header)
header['filename_or_obj'] = filename_or_obj
Expand Down
62 changes: 62 additions & 0 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import math
from itertools import starmap, product
from torch.utils.data._utils.collate import default_collate
Expand Down Expand Up @@ -191,3 +192,64 @@ def list_data_collate(batch):
elem = batch[0]
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
return default_collate(data)


def correct_nifti_header_if_necessary(img_nii):
"""
check nifti object header's format, update the header if needed.
in the updated image pixdim matches the affine.
Args:
img (nifti image object)
"""
dim = img_nii.header['dim'][0]
if dim >= 5:
return img_nii # do nothing for high-dimensional array
# check that affine matches zooms
pixdim = np.asarray(img_nii.header.get_zooms())[:dim]
norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:dim, :dim]), 0))
if np.allclose(pixdim, norm_affine):
return img_nii
if hasattr(img_nii, 'get_sform'):
return rectify_header_sform_qform(img_nii)
return img_nii


def rectify_header_sform_qform(img_nii):
"""
Look at the sform and qform of the nifti object and correct it if any
incompatibilities with pixel dimensions
Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/io/misc_io.py
"""
d = img_nii.header['dim'][0]
pixdim = np.asarray(img_nii.header.get_zooms())[:d]
sform, qform = img_nii.get_sform(), img_nii.get_qform()
norm_sform = np.sqrt(np.sum(np.square(sform[:d, :d]), 0))
norm_qform = np.sqrt(np.sum(np.square(qform[:d, :d]), 0))
sform_mismatch = not np.allclose(norm_sform, pixdim)
qform_mismatch = not np.allclose(norm_qform, pixdim)

if img_nii.header['sform_code'] != 0:
if not sform_mismatch:
return img_nii
if not qform_mismatch:
img_nii.set_sform(img_nii.get_qform())
return img_nii
if img_nii.header['qform_code'] != 0:
if not qform_mismatch:
return img_nii
if not sform_mismatch:
img_nii.set_qform(img_nii.get_sform())
return img_nii

norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:, :3]), 0))
to_divide = np.tile(np.expand_dims(np.append(norm_affine, 1), axis=1), [1, 4])
pixdim = np.append(pixdim, [1.] * (4 - len(pixdim)))
to_multiply = np.tile(np.expand_dims(pixdim, axis=1), [1, 4])
affine = img_nii.affine / to_divide.T * to_multiply.T
warnings.warn('Modifying image affine from {} to {}'.format(img_nii.affine, affine))

img_nii.set_sform(affine)
img_nii.set_qform(affine)
return img_nii
104 changes: 86 additions & 18 deletions monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import monai
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transforms import Rotate90, SpatialCrop, AddChannel
from monai.utils.misc import ensure_tuple
from monai.transforms.transforms import (AddChannel, Orientation, Rotate90, Spacing, SpatialCrop)
from monai.transforms.utils import generate_pos_neg_label_crop_centers
from monai.utils.aliases import alias
from monai.utils.misc import ensure_tuple

export = monai.utils.export("monai.transforms")

Expand Down Expand Up @@ -54,6 +55,85 @@ def __init__(self, keys):


@export
@alias('SpacingD', 'SpacingDict')
class Spacingd(MapTransform):
"""
dictionary-based wrapper of :class: `monai.transforms.transforms.Spacing`.
"""

def __init__(self, keys, affine_key, pixdim, interp_order=2, keep_shape=False, output_key='spacing'):
"""
Args:
affine_key (hashable): the key to the original affine.
The affine will be used to compute input data's pixdim.
pixdim (sequence of floats): output voxel spacing.
interp_order (int or sequence of ints): int: the same interpolation order
for all data indexed by `self,keys`; sequence of ints, should
correspond to an interpolation order for each data item indexed
by `self.keys` respectively.
keep_shape (bool): whether to maintain the original spatial shape
after resampling. Defaults to False.
output_key (hashable): key to be added to the output dictionary to track
the pixdim status.
"""
MapTransform.__init__(self, keys)
self.affine_key = affine_key
self.spacing_transform = Spacing(pixdim, keep_shape=keep_shape)
interp_order = ensure_tuple(interp_order)
self.interp_order = interp_order \
if len(interp_order) == len(self.keys) else interp_order * len(self.keys)
print(self.interp_order)
self.output_key = output_key

def __call__(self, data):
d = dict(data)
affine = d[self.affine_key]
original_pixdim, new_pixdim = None, None
for key, interp in zip(self.keys, self.interp_order):
d[key], original_pixdim, new_pixdim = self.spacing_transform(d[key], affine, interp_order=interp)
d[self.output_key] = {'original_pixdim': original_pixdim, 'current_pixdim': new_pixdim}
return d


@export
@alias('OrientationD', 'OrientationDict')
class Orientationd(MapTransform):
"""
dictionary-based wrapper of :class: `monai.transforms.transforms.Orientation`.
"""

def __init__(self, keys, affine_key, axcodes, labels=None, output_key='orientation'):
"""
Args:
affine_key (hashable): the key to the original affine.
The affine will be used to compute input data's orientation.
axcodes (N elements sequence): for spatial ND input's orientation.
e.g. axcodes='RAS' represents 3D orientation:
(Left, Right), (Posterior, Anterior), (Inferior, Superior).
default orientation labels options are: 'L' and 'R' for the first dimension,
'P' and 'A' for the second, 'I' and 'S' for the third.
labels : optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
see: ``nibabel.orientations.ornt2axcodes``.
"""
MapTransform.__init__(self, keys)
self.affine_key = affine_key
self.orientation_transform = Orientation(axcodes=axcodes, labels=labels)
self.output_key = output_key

def __call__(self, data):
d = dict(data)
affine = d[self.affine_key]
original_ornt, new_ornt = None, None
for key in self.keys:
d[key], original_ornt, new_ornt = self.orientation_transform(d[key], affine)
d[self.output_key] = {'original_ornt': original_ornt, 'current_ornt': new_ornt}
return d


@export
@alias('Rotate90D', 'Rotate90Dict')
class Rotate90d(MapTransform):
"""
dictionary-based wrapper of Rotate90.
Expand All @@ -79,6 +159,7 @@ def __call__(self, data):


@export
@alias('UniformRandomPatchD', 'UniformRandomPatchD')
class UniformRandomPatchd(Randomizable, MapTransform):
"""
Selects a patch of the given size chosen at a uniformly random position in the image.
Expand Down Expand Up @@ -106,6 +187,7 @@ def __call__(self, data):


@export
@alias('RandRotate90D', 'RandRotate90Dict')
class RandRotate90d(Randomizable, MapTransform):
"""
With probability `prob`, input arrays are rotated by 90 degrees
Expand Down Expand Up @@ -150,6 +232,7 @@ def __call__(self, data):


@export
@alias('AddChannelD', 'AddChannelDict')
class AddChanneld(MapTransform):
"""
dictionary-based wrapper of AddChannel.
Expand All @@ -172,6 +255,7 @@ def __call__(self, data):


@export
@alias('RandCropByPosNegLabelD', 'RandCropByPosNegLabelDict')
class RandCropByPosNegLabeld(Randomizable, MapTransform):
"""
Crop random fixed sized regions with the center being a foreground or background voxel
Expand Down Expand Up @@ -224,19 +308,3 @@ def __call__(self, data):
results[i][key] = data[key]

return results


# if __name__ == "__main__":
# import numpy as np
# data = {
# 'img': np.array((1, 2, 3, 4)).reshape((1, 2, 2)),
# 'seg': np.array((1, 2, 3, 4)).reshape((1, 2, 2)),
# 'affine': 3,
# 'dtype': 4,
# 'unused': 5,
# }
# rotator = RandRotate90d(keys=['img', 'seg'], prob=0.8)
# # rotator.set_random_state(1234)
# data_result = rotator(data)
# print(data_result.keys())
# print(data_result['img'], data_result['seg'])
Loading

0 comments on commit 1996d12

Please sign in to comment.