From 1996d12900ed2e023592f757dd9908f4b5cfedc8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 6 Mar 2020 22:11:01 +0000 Subject: [PATCH] spacing and orientation; revise transforms cropping and zooming update input validation --- monai/data/nifti_reader.py | 8 +- monai/data/utils.py | 62 +++++++++ monai/transforms/composables.py | 104 +++++++++++--- monai/transforms/transforms.py | 238 +++++++++++++++++++++++--------- tests/test_header_correct.py | 36 +++++ tests/test_orientation.py | 38 +++++ tests/test_orientationd.py | 55 ++++++++ tests/test_spacing.py | 47 +++++++ tests/test_spacingd.py | 63 +++++++++ tests/test_spatial_crop.py | 20 ++- tests/test_zoom.py | 14 +- 11 files changed, 595 insertions(+), 90 deletions(-) create mode 100644 tests/test_header_correct.py create mode 100644 tests/test_orientation.py create mode 100644 tests/test_orientationd.py create mode 100644 tests/test_spacing.py create mode 100644 tests/test_spacingd.py diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 753691de208..1e497c5ce1b 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import nibabel as nib - +import numpy as np from torch.utils.data import Dataset from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.utils.module import export + +from monai.data.utils import correct_nifti_header_if_necessary from monai.transforms.compose import Randomizable +from monai.utils.module import export def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dtype=None): @@ -38,6 +39,7 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty """ img = nib.load(filename_or_obj) + img = correct_nifti_header_if_necessary(img) header = dict(img.header) header['filename_or_obj'] = filename_or_obj diff --git a/monai/data/utils.py b/monai/data/utils.py index 1e7de42141f..81f9ac8c567 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import math from itertools import starmap, product from torch.utils.data._utils.collate import default_collate @@ -191,3 +192,64 @@ def list_data_collate(batch): elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch return default_collate(data) + + +def correct_nifti_header_if_necessary(img_nii): + """ + check nifti object header's format, update the header if needed. + in the updated image pixdim matches the affine. + + Args: + img (nifti image object) + """ + dim = img_nii.header['dim'][0] + if dim >= 5: + return img_nii # do nothing for high-dimensional array + # check that affine matches zooms + pixdim = np.asarray(img_nii.header.get_zooms())[:dim] + norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:dim, :dim]), 0)) + if np.allclose(pixdim, norm_affine): + return img_nii + if hasattr(img_nii, 'get_sform'): + return rectify_header_sform_qform(img_nii) + return img_nii + + +def rectify_header_sform_qform(img_nii): + """ + Look at the sform and qform of the nifti object and correct it if any + incompatibilities with pixel dimensions + + Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/io/misc_io.py + """ + d = img_nii.header['dim'][0] + pixdim = np.asarray(img_nii.header.get_zooms())[:d] + sform, qform = img_nii.get_sform(), img_nii.get_qform() + norm_sform = np.sqrt(np.sum(np.square(sform[:d, :d]), 0)) + norm_qform = np.sqrt(np.sum(np.square(qform[:d, :d]), 0)) + sform_mismatch = not np.allclose(norm_sform, pixdim) + qform_mismatch = not np.allclose(norm_qform, pixdim) + + if img_nii.header['sform_code'] != 0: + if not sform_mismatch: + return img_nii + if not qform_mismatch: + img_nii.set_sform(img_nii.get_qform()) + return img_nii + if img_nii.header['qform_code'] != 0: + if not qform_mismatch: + return img_nii + if not sform_mismatch: + img_nii.set_qform(img_nii.get_sform()) + return img_nii + + norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:, :3]), 0)) + to_divide = np.tile(np.expand_dims(np.append(norm_affine, 1), axis=1), [1, 4]) + pixdim = np.append(pixdim, [1.] * (4 - len(pixdim))) + to_multiply = np.tile(np.expand_dims(pixdim, axis=1), [1, 4]) + affine = img_nii.affine / to_divide.T * to_multiply.T + warnings.warn('Modifying image affine from {} to {}'.format(img_nii.affine, affine)) + + img_nii.set_sform(affine) + img_nii.set_qform(affine) + return img_nii diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index d177bc98e14..4a6b06c922d 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -18,9 +18,10 @@ import monai from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable, Transform -from monai.transforms.transforms import Rotate90, SpatialCrop, AddChannel -from monai.utils.misc import ensure_tuple +from monai.transforms.transforms import (AddChannel, Orientation, Rotate90, Spacing, SpatialCrop) from monai.transforms.utils import generate_pos_neg_label_crop_centers +from monai.utils.aliases import alias +from monai.utils.misc import ensure_tuple export = monai.utils.export("monai.transforms") @@ -54,6 +55,85 @@ def __init__(self, keys): @export +@alias('SpacingD', 'SpacingDict') +class Spacingd(MapTransform): + """ + dictionary-based wrapper of :class: `monai.transforms.transforms.Spacing`. + """ + + def __init__(self, keys, affine_key, pixdim, interp_order=2, keep_shape=False, output_key='spacing'): + """ + Args: + affine_key (hashable): the key to the original affine. + The affine will be used to compute input data's pixdim. + pixdim (sequence of floats): output voxel spacing. + interp_order (int or sequence of ints): int: the same interpolation order + for all data indexed by `self,keys`; sequence of ints, should + correspond to an interpolation order for each data item indexed + by `self.keys` respectively. + keep_shape (bool): whether to maintain the original spatial shape + after resampling. Defaults to False. + output_key (hashable): key to be added to the output dictionary to track + the pixdim status. + + """ + MapTransform.__init__(self, keys) + self.affine_key = affine_key + self.spacing_transform = Spacing(pixdim, keep_shape=keep_shape) + interp_order = ensure_tuple(interp_order) + self.interp_order = interp_order \ + if len(interp_order) == len(self.keys) else interp_order * len(self.keys) + print(self.interp_order) + self.output_key = output_key + + def __call__(self, data): + d = dict(data) + affine = d[self.affine_key] + original_pixdim, new_pixdim = None, None + for key, interp in zip(self.keys, self.interp_order): + d[key], original_pixdim, new_pixdim = self.spacing_transform(d[key], affine, interp_order=interp) + d[self.output_key] = {'original_pixdim': original_pixdim, 'current_pixdim': new_pixdim} + return d + + +@export +@alias('OrientationD', 'OrientationDict') +class Orientationd(MapTransform): + """ + dictionary-based wrapper of :class: `monai.transforms.transforms.Orientation`. + """ + + def __init__(self, keys, affine_key, axcodes, labels=None, output_key='orientation'): + """ + Args: + affine_key (hashable): the key to the original affine. + The affine will be used to compute input data's orientation. + axcodes (N elements sequence): for spatial ND input's orientation. + e.g. axcodes='RAS' represents 3D orientation: + (Left, Right), (Posterior, Anterior), (Inferior, Superior). + default orientation labels options are: 'L' and 'R' for the first dimension, + 'P' and 'A' for the second, 'I' and 'S' for the third. + labels : optional, None or sequence of (2,) sequences + (2,) sequences are labels for (beginning, end) of output axis. + see: ``nibabel.orientations.ornt2axcodes``. + """ + MapTransform.__init__(self, keys) + self.affine_key = affine_key + self.orientation_transform = Orientation(axcodes=axcodes, labels=labels) + self.output_key = output_key + + def __call__(self, data): + d = dict(data) + affine = d[self.affine_key] + original_ornt, new_ornt = None, None + for key in self.keys: + d[key], original_ornt, new_ornt = self.orientation_transform(d[key], affine) + d[self.output_key] = {'original_ornt': original_ornt, 'current_ornt': new_ornt} + return d + + +@export +@alias('Rotate90D', 'Rotate90Dict') class Rotate90d(MapTransform): """ dictionary-based wrapper of Rotate90. @@ -79,6 +159,7 @@ def __call__(self, data): @export +@alias('UniformRandomPatchD', 'UniformRandomPatchD') class UniformRandomPatchd(Randomizable, MapTransform): """ Selects a patch of the given size chosen at a uniformly random position in the image. @@ -106,6 +187,7 @@ def __call__(self, data): @export +@alias('RandRotate90D', 'RandRotate90Dict') class RandRotate90d(Randomizable, MapTransform): """ With probability `prob`, input arrays are rotated by 90 degrees @@ -150,6 +232,7 @@ def __call__(self, data): @export +@alias('AddChannelD', 'AddChannelDict') class AddChanneld(MapTransform): """ dictionary-based wrapper of AddChannel. @@ -172,6 +255,7 @@ def __call__(self, data): @export +@alias('RandCropByPosNegLabelD', 'RandCropByPosNegLabelDict') class RandCropByPosNegLabeld(Randomizable, MapTransform): """ Crop random fixed sized regions with the center being a foreground or background voxel @@ -224,19 +308,3 @@ def __call__(self, data): results[i][key] = data[key] return results - - -# if __name__ == "__main__": -# import numpy as np -# data = { -# 'img': np.array((1, 2, 3, 4)).reshape((1, 2, 2)), -# 'seg': np.array((1, 2, 3, 4)).reshape((1, 2, 2)), -# 'affine': 3, -# 'dtype': 4, -# 'unused': 5, -# } -# rotator = RandRotate90d(keys=['img', 'seg'], prob=0.8) -# # rotator.set_random_state(1234) -# data_result = rotator(data) -# print(data_result.keys()) -# print(data_result['img'], data_result['seg']) diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index bec727e8bb6..296440b000b 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -13,10 +13,11 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +import nibabel as nib import numpy as np +import scipy.ndimage import torch from skimage.transform import resize -import scipy.ndimage import monai from monai.data.utils import get_random_patch, get_valid_patch_size @@ -29,6 +30,110 @@ export = monai.utils.export("monai.transforms") +@export +class Spacing: + """ + Resample input image into the specified `pixdim`. + """ + + def __init__(self, pixdim, keep_shape=False): + """ + Args: + pixdim (sequence of floats): output voxel spacing. + keep_shape (bool): whether to maintain the original spatial shape + after resampling. Defaults to False. + """ + self.pixdim = pixdim + self.keep_shape = keep_shape + self.original_pixdim = pixdim + + def __call__(self, data_array, original_affine=None, original_pixdim=None, interp_order=1): + """ + Args: + data_array (ndarray): in shape (num_channels, H[, W, ...]). + original_affine (4x4 matrix): original affine. + original_pixdim (sequence of floats): original voxel spacing. + interp_order (int): The order of the spline interpolation, default is 3. + The order has to be in the range 0-5. + https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.zoom.html + Returns: + resampled array (in spacing: `self.pixdim`), original pixdim, current pixdim. + """ + if original_affine is None and original_pixdim is None: + raise ValueError('please provide either original_affine or original_pixdim.') + spatial_rank = data_array.ndim - 1 + if original_affine is not None: + affine = np.array(original_affine, dtype=np.float64, copy=True) + if not affine.shape == (4, 4): + raise ValueError('`original_affine` must be 4 x 4.') + original_pixdim = np.sqrt(np.sum(np.square(affine[:spatial_rank, :spatial_rank]), 1)) + + inp_d = np.asarray(original_pixdim)[:spatial_rank] + if inp_d.size < spatial_rank: + inp_d = np.append(inp_d, [1.] * (inp_d.size - spatial_rank)) + out_d = np.asarray(self.pixdim)[:spatial_rank] + if out_d.size < spatial_rank: + out_d = np.append(out_d, [1.] * (out_d.size - spatial_rank)) + + self.original_pixdim, self.pixdim = inp_d, out_d + scale = inp_d / out_d + if not np.isfinite(scale).all(): + raise ValueError('Unknown pixdims: source {}, target {}'.format(inp_d, out_d)) + zoom_ = monai.transforms.Zoom(scale, order=interp_order, mode='nearest', keep_size=self.keep_shape) + return zoom_(data_array), self.original_pixdim, self.pixdim + + +@export +class Orientation: + """ + Change the input image's orientation into the specified based on `axcodes`. + """ + + def __init__(self, axcodes, labels=None): + """ + Args: + axcodes (N elements sequence): for spatial ND input's orientation. + e.g. axcodes='RAS' represents 3D orientation: + (Left, Right), (Posterior, Anterior), (Inferior, Superior). + default orientation labels options are: 'L' and 'R' for the first dimension, + 'P' and 'A' for the second, 'I' and 'S' for the third. + labels : optional, None or sequence of (2,) sequences + (2,) sequences are labels for (beginning, end) of output axis. + see: ``nibabel.orientations.ornt2axcodes``. + """ + self.axcodes = axcodes + self.labels = labels + + def __call__(self, data_array, original_affine=None, original_axcodes=None): + """ + if `original_affine` is provided, the orientation is computed from the affine. + + Args: + data_array (ndarray): in shape (num_channels, H[, W, ...]). + original_affine (4x4 matrix): original affine. + original_axcodes (N elements sequence): for spatial ND input's orientation. + Returns: + data_array (reoriented in `self.axcodes`), original axcodes, current axcodes. + """ + if original_affine is None and original_axcodes is None: + raise ValueError('please provide either original_affine or original_axcodes.') + spatial_rank = len(data_array.shape) - 1 + if original_affine is not None: + affine = np.array(original_affine, dtype=np.float64, copy=True) + if not affine.shape == (4, 4): + raise ValueError('`original_affine` must be 4 x 4.') + original_axcodes = nib.aff2axcodes(original_affine, labels=self.labels) + original_axcodes = original_axcodes[:spatial_rank] + self.axcodes = self.axcodes[:spatial_rank] + src = nib.orientations.axcodes2ornt(original_axcodes, labels=self.labels) + dst = nib.orientations.axcodes2ornt(self.axcodes) + spatial_ornt = nib.orientations.ornt_transform(src, dst) + spatial_ornt[:, 1] += 1 # skip channel dim + ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) + data_array = nib.orientations.apply_orientation(data_array, ornt) + return data_array, original_axcodes, self.axcodes + + @export class AddChannel: """ @@ -124,8 +229,7 @@ class Resize: """ def __init__(self, output_shape, order=1, mode='reflect', cval=0, - clip=True, preserve_range=True, - anti_aliasing=True, anti_aliasing_sigma=None): + clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): assert isinstance(order, int), "order must be integer." self.output_shape = output_shape self.order = order @@ -148,7 +252,7 @@ def __call__(self, img): class Rotate: """ Rotates an input image by given angle. Uses scipy.ndimage.rotate. For more details, see - http://lagrange.univ-lyon1.fr/docs/scipy/0.17.1/generated/scipy.ndimage.rotate.html. + https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.rotate.html Args: angle (float): Rotation angle in degrees. @@ -163,8 +267,7 @@ class Rotate: prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, angle, axes=(1, 2), reshape=True, order=1, - mode='constant', cval=0, prefilter=True): + def __init__(self, angle, axes=(1, 2), reshape=True, order=1, mode='constant', cval=0, prefilter=True): self.angle = angle self.reshape = reshape self.order = order @@ -175,8 +278,7 @@ def __init__(self, angle, axes=(1, 2), reshape=True, order=1, def __call__(self, img): return scipy.ndimage.rotate(img, self.angle, self.axes, - reshape=self.reshape, order=self.order, - mode=self.mode, cval=self.cval, + reshape=self.reshape, order=self.order, mode=self.mode, cval=self.cval, prefilter=self.prefilter) @@ -186,8 +288,9 @@ class Zoom: For details, please see https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.zoom.html. Args: - zoom (float or sequence): The zoom factor along the axes. If a float, zoom is the same for each axis. - If a sequence, zoom should contain one value for each axis. + zoom (float or sequence): The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. order (int): order of interpolation. Default=3. mode (str): Determines how input is extended beyond boundaries. Default is 'constant'. cval (scalar, optional): Value to fill past edges. Default is 0. @@ -195,6 +298,7 @@ class Zoom: 'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found. keep_size (bool): Should keep original size (pad if needed). """ + def __init__(self, zoom, order=3, mode='constant', cval=0, prefilter=True, use_gpu=False, keep_size=False): assert isinstance(order, int), "Order must be integer." self.zoom = zoom @@ -205,41 +309,59 @@ def __init__(self, zoom, order=3, mode='constant', cval=0, prefilter=True, use_g self.use_gpu = use_gpu self.keep_size = keep_size - def __call__(self, img): - zoomed = None if self.use_gpu: try: - import cupy from cupyx.scipy.ndimage import zoom as zoom_gpu - zoomed_gpu = zoom_gpu(cupy.array(img), zoom=self.zoom, order=self.order, - mode=self.mode, cval=self.cval, prefilter=self.prefilter) - zoomed = cupy.asnumpy(zoomed_gpu) - except ModuleNotFoundError: + self._zoom = zoom_gpu + except ImportError: print('For GPU zoom, please install cupy. Defaulting to cpu.') - except NotImplementedError: - print("Defaulting to CPU. cupyx doesn't support order > 1 and modes 'wrap' or 'reflect'.") - - if zoomed is None: - zoomed = scipy.ndimage.zoom(img, zoom=self.zoom, order=self.order, - mode=self.mode, cval=self.cval, prefilter=self.prefilter) - - # Crops to original size or pads. - if self.keep_size: - shape = img.shape - pad_vec = [[0, 0]] * len(shape) - crop_vec = list(zoomed.shape) - for d in range(len(shape)): - if zoomed.shape[d] > shape[d]: - crop_vec[d] = shape[d] - elif zoomed.shape[d] < shape[d]: - # pad_vec[d] = [0, shape[d] - zoomed.shape[d]] - pad_h = (float(shape[d]) - float(zoomed.shape[d])) / 2 - pad_vec[d] = [int(np.floor(pad_h)), int(np.ceil(pad_h))] - zoomed = zoomed[0:crop_vec[0], 0:crop_vec[1], 0:crop_vec[2]] - zoomed = np.pad(zoomed, pad_vec, mode='constant', constant_values=self.cval) - - return zoomed + self._zoom = scipy.ndimage.zoom + self.use_gpu = False + else: + self._zoom = scipy.ndimage.zoom + + def __call__(self, img): + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + zoomed = [] + if self.use_gpu: + import cupy + for channel in cupy.array(img): + zoom_channel = self._zoom(channel, + zoom=self.zoom, + order=self.order, + mode=self.mode, + cval=self.cval, + prefilter=self.prefilter) + zoomed.append(cupy.asnumpy(zoom_channel)) + else: + for channel in img: + zoomed.append( + self._zoom(channel, + zoom=self.zoom, + order=self.order, + mode=self.mode, + cval=self.cval, + prefilter=self.prefilter)) + zoomed = np.stack(zoomed) + + if not self.keep_size or np.allclose(img.shape, zoomed.shape): + return zoomed + + pad_vec = [[0, 0]] * len(img.shape) + slice_vec = [slice(None)] * len(img.shape) + for idx, (od, zd) in enumerate(zip(img.shape, zoomed.shape)): + diff = od - zd + half = abs(diff) // 2 + if diff > 0: # need padding + pad_vec[idx] = [half, diff - half] + elif diff < 0: # need slicing + slice_vec[idx] = slice(half, half + od) + zoomed = np.pad(zoomed, pad_vec) + return zoomed[tuple(slice_vec)] @export @@ -393,7 +515,7 @@ def __call__(self, img): @export class SpatialCrop: """General purpose cropper to produce sub-volume region of interest (ROI). - It can support to crop 1, 2 or 3 dimensions spatial data. + It can support to crop ND spatial (channel-first) data. Either a center and size must be provided, or alternatively if center and size are not provided, the start and end coordinates of the ROI must be provided. The sub-volume must sit the within original image. @@ -410,37 +532,27 @@ def __init__(self, roi_center=None, roi_size=None, roi_start=None, roi_end=None) roi_end (list or tuple): voxel coordinates for end of the crop ROI. """ if roi_center is not None and roi_size is not None: - assert isinstance(roi_center, (list, tuple)), 'roi_center must be list or tuple.' - assert isinstance(roi_size, (list, tuple)), 'roi_size must be list or tuple.' - assert all(x > 0 for x in roi_center), 'all elements of roi_center must be positive.' - assert all(x > 0 for x in roi_size), 'all elements of roi_size must be positive.' roi_center = np.asarray(roi_center, dtype=np.uint16) roi_size = np.asarray(roi_size, dtype=np.uint16) self.roi_start = np.subtract(roi_center, np.floor_divide(roi_size, 2)) self.roi_end = np.add(self.roi_start, roi_size) else: assert roi_start is not None and roi_end is not None, 'roi_start and roi_end must be provided.' - assert isinstance(roi_start, (list, tuple)), 'roi_start must be list or tuple.' - assert isinstance(roi_end, (list, tuple)), 'roi_end must be list or tuple.' - assert all(x >= 0 for x in roi_start), 'all elements of roi_start must be greater than or equal to 0.' - assert all(x > 0 for x in roi_end), 'all elements of roi_end must be positive.' - self.roi_start = roi_start - self.roi_end = roi_end + self.roi_start = np.asarray(roi_start, dtype=np.uint16) + self.roi_end = np.asarray(roi_end, dtype=np.uint16) + + assert np.all(self.roi_start >= 0), 'all elements of roi_start must be greater than or equal to 0.' + assert np.all(self.roi_end > 0), 'all elements of roi_end must be positive.' + assert np.all(self.roi_end >= self.roi_start), 'invalid roi range.' def __call__(self, img): max_end = img.shape[1:] - assert (np.subtract(max_end, self.roi_start) >= 0).all(), 'roi start out of image space.' - assert (np.subtract(max_end, self.roi_end) >= 0).all(), 'roi end out of image space.' - assert (np.subtract(self.roi_end, self.roi_start) >= 0).all(), 'invalid roi range.' - if len(self.roi_start) == 1: - data = img[:, self.roi_start[0]:self.roi_end[0]].copy() - elif len(self.roi_start) == 2: - data = img[:, self.roi_start[0]:self.roi_end[0], self.roi_start[1]:self.roi_end[1]].copy() - elif len(self.roi_start) == 3: - data = img[:, self.roi_start[0]:self.roi_end[0], self.roi_start[1]:self.roi_end[1], - self.roi_start[2]:self.roi_end[2]].copy() - else: - raise ValueError('unsupported image shape.') + sd = min(len(self.roi_start), len(max_end)) + assert np.all(max_end[:sd] >= self.roi_start[:sd]), 'roi start out of image space.' + assert np.all(max_end[:sd] >= self.roi_end[:sd]), 'roi end out of image space.' + + slices = [slice(None)] + [slice(s, e) for s, e in zip(self.roi_start[:sd], self.roi_end[:sd])] + data = img[tuple(slices)].copy() return data diff --git a/tests/test_header_correct.py b/tests/test_header_correct.py new file mode 100644 index 00000000000..681dc2cb1bc --- /dev/null +++ b/tests/test_header_correct.py @@ -0,0 +1,36 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import nibabel as nib +import numpy as np + +from monai.data.utils import correct_nifti_header_if_necessary + + +class TestCorrection(unittest.TestCase): + + def test_correct(self): + test_img = nib.Nifti1Image(np.zeros((1, 2, 3)), np.eye(4)) + test_img.header.set_zooms((100, 100, 100)) + test_img = correct_nifti_header_if_necessary(test_img) + np.testing.assert_allclose( + test_img.affine, np.array([[100., 0., 0., 0.], [0., 100., 0., 0.], [0., 0., 100., 0.], [0., 0., 0., 1.]])) + + def test_correcting(self): + test_img = nib.Nifti1Image(np.zeros((1, 2, 3)), np.eye(4) * 20.) + test_img = correct_nifti_header_if_necessary(test_img) + np.testing.assert_allclose( + test_img.affine, np.array([[20., 0., 0., 0.], [0., 20., 0., 0.], [0., 0., 20., 0.], [0., 0., 0., 20.]])) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_orientation.py b/tests/test_orientation.py new file mode 100644 index 00000000000..eded01834d5 --- /dev/null +++ b/tests/test_orientation.py @@ -0,0 +1,38 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.transforms import Orientation + +TEST_CASES = [ + [{'axcodes': 'RAS'}, + np.ones((2, 10, 15, 20)), {'original_axcodes': 'ALS'}, (2, 15, 10, 20)], + [{'axcodes': 'AL'}, + np.ones((2, 10, 15)), {'original_axcodes': 'AR'}, (2, 10, 15)], + [{'axcodes': 'L'}, + np.ones((2, 10)), {'original_axcodes': 'R'}, (2, 10)], +] + + +class OrientationTestCase(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_ornt(self, init_param, img, data_param, expected_shape): + res = Orientation(**init_param)(img, **data_param) + np.testing.assert_allclose(res[0].shape, expected_shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py new file mode 100644 index 00000000000..5b48a0edd65 --- /dev/null +++ b/tests/test_orientationd.py @@ -0,0 +1,55 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms.composables import Orientationd + + +class OrientationdTestCase(unittest.TestCase): + + def test_orntd(self): + data = {'seg': np.ones((2, 1, 2, 3)), 'affine': np.eye(4)} + ornt = Orientationd(keys='seg', affine_key='affine', axcodes='RAS') + res = ornt(data) + np.testing.assert_allclose(res['seg'].shape, (2, 1, 2, 3)) + self.assertEqual(res['orientation']['original_ornt'], ('R', 'A', 'S')) + self.assertEqual(res['orientation']['current_ornt'], 'RAS') + + def test_orntd_3d(self): + data = {'seg': np.ones((2, 1, 2, 3)), 'img': np.ones((2, 1, 2, 3)), 'affine': np.eye(4)} + ornt = Orientationd(keys=('img', 'seg'), affine_key='affine', axcodes='PLI') + res = ornt(data) + np.testing.assert_allclose(res['img'].shape, (2, 2, 1, 3)) + self.assertEqual(res['orientation']['original_ornt'], ('R', 'A', 'S')) + self.assertEqual(res['orientation']['current_ornt'], 'PLI') + + def test_orntd_2d(self): + data = {'seg': np.ones((2, 1, 3)), 'img': np.ones((2, 1, 3)), 'affine': np.eye(4)} + ornt = Orientationd(keys=('img', 'seg'), affine_key='affine', axcodes='PLI') + res = ornt(data) + np.testing.assert_allclose(res['img'].shape, (2, 3, 1)) + self.assertEqual(res['orientation']['original_ornt'], ('R', 'A')) + self.assertEqual(res['orientation']['current_ornt'], 'PL') + + def test_orntd_1d(self): + data = {'seg': np.ones((2, 3)), 'img': np.ones((2, 3)), 'affine': np.eye(4)} + ornt = Orientationd(keys=('img', 'seg'), affine_key='affine', axcodes='L') + res = ornt(data) + np.testing.assert_allclose(res['img'].shape, (2, 3)) + self.assertEqual(res['orientation']['original_ornt'], ('R',)) + self.assertEqual(res['orientation']['current_ornt'], 'L') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_spacing.py b/tests/test_spacing.py new file mode 100644 index 00000000000..f02d827bdbd --- /dev/null +++ b/tests/test_spacing.py @@ -0,0 +1,47 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.transforms import Spacing + +TEST_CASES = [ + [{'pixdim': (1.0, 2.0, 1.5)}, + np.ones((2, 10, 15, 20)), {'original_pixdim': (0.5, 0.5, 1.0)}, (2, 5, 4, 13)], + [{'pixdim': (1.0, 2.0, 1.5), 'keep_shape': True}, + np.ones((1, 2, 1, 2)), {'original_pixdim': (0.5, 0.5, 1.0)}, (1, 2, 1, 2)], + [{'pixdim': (1.0, 0.2, 1.5), 'keep_shape': False}, + np.ones((1, 2, 1, 2)), {'original_affine': np.eye(4)}, (1, 2, 5, 1)], + [{'pixdim': (1.0, 2.0), 'keep_shape': True}, + np.ones((3, 2, 2)), {'original_pixdim': (1.5, 0.5)}, (3, 2, 2)], + [{'pixdim': (1.0, 0.2), 'keep_shape': False}, + np.ones((5, 2, 1)), {'original_pixdim': (1.5, 0.5)}, (5, 3, 2)], + [{'pixdim': (1.0,), 'keep_shape': False}, + np.ones((1, 2)), {'original_pixdim': (1.5,), 'interp_order': 0}, (1, 3)], +] + + +class SpacingTestCase(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_spacing(self, init_param, img, data_param, expected_shape): + res = Spacing(**init_param)(img, **data_param) + np.testing.assert_allclose(res[0].shape, expected_shape) + if 'original_pixdim' in data_param: + np.testing.assert_allclose(res[1], data_param['original_pixdim']) + np.testing.assert_allclose(res[2], init_param['pixdim']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py new file mode 100644 index 00000000000..c71d6eb4463 --- /dev/null +++ b/tests/test_spacingd.py @@ -0,0 +1,63 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms.composables import Spacingd + + +class SpacingDTestCase(unittest.TestCase): + + def test_spacingd_3d(self): + data = {'image': np.ones((2, 10, 15, 20)), 'affine': np.eye(4)} + spacing = Spacingd(keys='image', affine_key='affine', pixdim=(1, 2, 1.4)) + res = spacing(data) + np.testing.assert_allclose(res['image'].shape, (2, 10, 8, 14)) + np.testing.assert_allclose(res['spacing']['current_pixdim'], (1, 2, 1.4)) + np.testing.assert_allclose(res['spacing']['original_pixdim'], (1, 1, 1)) + + def test_spacingd_2d(self): + data = {'image': np.ones((2, 10, 20)), 'affine': np.eye(4)} + spacing = Spacingd(keys='image', affine_key='affine', pixdim=(1, 2, 1.4)) + res = spacing(data) + np.testing.assert_allclose(res['image'].shape, (2, 10, 10)) + np.testing.assert_allclose(res['spacing']['current_pixdim'], (1, 2)) + np.testing.assert_allclose(res['spacing']['original_pixdim'], (1, 1)) + + def test_spacingd_1d(self): + data = {'image': np.ones((2, 10)), 'affine': np.eye(4)} + spacing = Spacingd(keys='image', affine_key='affine', pixdim=(0.2,)) + res = spacing(data) + np.testing.assert_allclose(res['image'].shape, (2, 50)) + np.testing.assert_allclose(res['spacing']['current_pixdim'], (0.2,)) + np.testing.assert_allclose(res['spacing']['original_pixdim'], (1,)) + + def test_interp_all(self): + data = {'image': np.ones((2, 10)), 'seg': np.ones((2, 10)), 'affine': np.eye(4)} + spacing = Spacingd(keys=('image', 'seg'), affine_key='affine', interp_order=0, pixdim=(0.2,)) + res = spacing(data) + np.testing.assert_allclose(res['image'].shape, (2, 50)) + np.testing.assert_allclose(res['spacing']['current_pixdim'], (0.2,)) + np.testing.assert_allclose(res['spacing']['original_pixdim'], (1,)) + + def test_interp_sep(self): + data = {'image': np.ones((2, 10)), 'seg': np.ones((2, 10)), 'affine': np.eye(4)} + spacing = Spacingd(keys=('image', 'seg'), affine_key='affine', interp_order=(2, 0), pixdim=(0.2,)) + res = spacing(data) + np.testing.assert_allclose(res['image'].shape, (2, 50)) + np.testing.assert_allclose(res['spacing']['current_pixdim'], (0.2,)) + np.testing.assert_allclose(res['spacing']['original_pixdim'], (1,)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index 2a3c2e7f9cb..a9ae910f3b5 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -32,9 +32,27 @@ (3, 2, 2, 2), ] +TEST_CASE_3 = [ + { + 'roi_start': [0, 0], + 'roi_end': [2, 2] + }, + np.random.randint(0, 2, size=[3, 3, 3, 3]), + (3, 2, 2, 3), +] + +TEST_CASE_4 = [ + { + 'roi_start': [0, 0, 0, 0, 0], + 'roi_end': [2, 2, 2, 2, 2] + }, + np.random.randint(0, 2, size=[3, 3, 3, 3]), + (3, 2, 2, 2), +] + class TestSpatialCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_data, expected_shape): result = SpatialCrop(**input_param)(input_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_zoom.py b/tests/test_zoom.py index f445895bfcf..874e587a98f 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -29,9 +29,9 @@ class ZoomTest(NumpyImageTestCase2D): (0.8, 1, 'reflect', 0, False, False, False) ]) def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): - zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, + zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) - zoomed = zoom_fn(self.imt) + zoomed = zoom_fn(self.imt[0]) expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, cval=cval, prefilter=prefilter) self.assertTrue(np.allclose(expected, zoomed)) @@ -43,15 +43,19 @@ def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter): if importlib.util.find_spec('cupy'): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=True, keep_size=False) - zoomed = zoom_fn(self.imt) + zoomed = zoom_fn(self.imt[0]) expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, cval=cval, prefilter=prefilter) self.assertTrue(np.allclose(expected, zoomed)) def test_keep_size(self): zoom_fn = Zoom(zoom=0.6, keep_size=True) - zoomed = zoom_fn(self.imt) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape)) + zoomed = zoom_fn(self.imt[0]) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + + zoom_fn = Zoom(zoom=1.3, keep_size=True) + zoomed = zoom_fn(self.imt[0]) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @parameterized.expand([ ("no_zoom", None, 1, TypeError),