From d020facccbd3afe979fce68c24703dcda47234f6 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:10:36 +0100 Subject: [PATCH] Merge genaidev Into Dev (#7886) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #6676 . ### Description This merges the Generative Models code into dev. Everything has been checked by the generative team, tests all pass, and the changes that have been done recently are integrated. This is ready to merge. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham Signed-off-by: vgrau98 Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> Signed-off-by: Wenqi Li Signed-off-by: dongy Signed-off-by: myron Signed-off-by: kaibo Signed-off-by: monai-bot Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: Ishan Dutta Signed-off-by: dependabot[bot] Signed-off-by: heyufan1995 Signed-off-by: binliu Signed-off-by: axel.vlaminck Signed-off-by: Ibrahim Hadzic Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: Timothy Baker Signed-off-by: Mathijs de Boer Signed-off-by: Fabian Klopfer Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Signed-off-by: Suraj Pai Signed-off-by: Juan Pablo de la Cruz Gutiérrez Signed-off-by: John Zielke Signed-off-by: Mingxin Zheng Signed-off-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Signed-off-by: Yiheng Wang Signed-off-by: Szabolcs Botond Lorincz Molnar Signed-off-by: Lucas Robinet Signed-off-by: Mingxin Signed-off-by: Han Wang Signed-off-by: Konstantin Sukharev Signed-off-by: Ben Murray Signed-off-by: Matthew Vine <32849887+MattTheCuber@users.noreply.github.com> Signed-off-by: Peter Kaplinsky Signed-off-by: Simon Jensen <61684806+simojens@users.noreply.github.com> Signed-off-by: NabJa Signed-off-by: virginiafdez Signed-off-by: Eric Kerfoot Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: KumoLiu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> Co-authored-by: Dong Yang Co-authored-by: myron Co-authored-by: Kaibo Tang <99367900+kvttt@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Ishan Dutta Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Kaibo Tang Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: axel.vlaminck Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: Ibrahim Hadzic Co-authored-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Timothy J. Baker <62781117+tim-the-baker@users.noreply.github.com> Co-authored-by: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com> Co-authored-by: Mathijs de Boer Co-authored-by: Fabian Klopfer Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: cxlcl Co-authored-by: Suraj Pai Co-authored-by: Juampa <1523654+juampatronics@users.noreply.github.com> Co-authored-by: johnzielke Co-authored-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Co-authored-by: Lőrincz-Molnár Szabolcs-Botond Co-authored-by: Nic Ma Co-authored-by: Lucas Robinet Co-authored-by: Han Wang Co-authored-by: Konstantin Sukharev <50718389+k-sukharev@users.noreply.github.com> Co-authored-by: Matthew Vine <32849887+MattTheCuber@users.noreply.github.com> Co-authored-by: Pkaps25 <43655728+Pkaps25@users.noreply.github.com> Co-authored-by: Peter Kaplinsky Co-authored-by: Simon Jensen <61684806+simojens@users.noreply.github.com> Co-authored-by: NabJa <32510324+NabJa@users.noreply.github.com> Co-authored-by: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Co-authored-by: virginiafdez Co-authored-by: Yu <146002968+Yu0610@users.noreply.github.com> --- docs/source/engines.rst | 5 + docs/source/inferers.rst | 23 + docs/source/utils.rst | 5 + monai/apps/detection/utils/anchor_utils.py | 4 +- monai/apps/pathology/transforms/post/array.py | 1 + monai/bundle/utils.py | 1 + monai/data/dataset_summary.py | 1 + monai/data/utils.py | 15 +- monai/engines/__init__.py | 4 +- monai/engines/trainer.py | 283 ++- monai/engines/utils.py | 77 +- monai/inferers/__init__.py | 5 + monai/inferers/inferer.py | 1280 ++++++++++- monai/networks/blocks/__init__.py | 3 + monai/networks/blocks/attention_utils.py | 128 ++ monai/networks/blocks/crossattention.py | 166 ++ monai/networks/blocks/rel_pos_embedding.py | 56 + monai/networks/blocks/selfattention.py | 77 +- monai/networks/blocks/spade_norm.py | 95 + monai/networks/blocks/spatialattention.py | 82 + monai/networks/blocks/transformerblock.py | 28 +- monai/networks/blocks/upsample.py | 32 +- monai/networks/layers/__init__.py | 3 +- monai/networks/layers/factories.py | 13 +- monai/networks/layers/utils.py | 15 +- monai/networks/layers/vector_quantizer.py | 233 ++ monai/networks/nets/__init__.py | 9 + monai/networks/nets/autoencoderkl.py | 702 ++++++ monai/networks/nets/controlnet.py | 465 ++++ monai/networks/nets/diffusion_model_unet.py | 1913 +++++++++++++++++ monai/networks/nets/patchgan_discriminator.py | 230 ++ monai/networks/nets/quicknat.py | 2 + monai/networks/nets/spade_autoencoderkl.py | 480 +++++ .../nets/spade_diffusion_model_unet.py | 934 ++++++++ monai/networks/nets/spade_network.py | 435 ++++ monai/networks/nets/swin_unetr.py | 7 +- monai/networks/nets/transformer.py | 157 ++ monai/networks/nets/vqvae.py | 472 ++++ monai/networks/schedulers/__init__.py | 17 + monai/networks/schedulers/ddim.py | 294 +++ monai/networks/schedulers/ddpm.py | 250 +++ monai/networks/schedulers/pndm.py | 316 +++ monai/networks/schedulers/scheduler.py | 205 ++ monai/networks/utils.py | 22 + monai/transforms/regularization/array.py | 4 + .../transforms/utils_create_transform_ims.py | 6 +- monai/utils/__init__.py | 1 + monai/utils/misc.py | 5 +- monai/utils/ordering.py | 207 ++ tests/hvd_evenly_divisible_all_gather.py | 8 +- tests/min_tests.py | 1 + tests/test_autoencoderkl.py | 337 +++ tests/test_controlnet.py | 215 ++ tests/test_controlnet_inferers.py | 1310 +++++++++++ tests/test_crossattention.py | 131 ++ tests/test_diffusion_inferer.py | 236 ++ tests/test_diffusion_model_unet.py | 585 +++++ tests/test_ensure_channel_first.py | 7 +- tests/test_ensure_channel_firstd.py | 7 +- .../test_evenly_divisible_all_gather_dist.py | 8 +- tests/test_handler_metrics_saver_dist.py | 4 +- tests/test_hilbert_transform.py | 123 +- tests/test_integration_unet_2d.py | 1 + .../test_integration_workflows_adversarial.py | 173 ++ tests/test_latent_diffusion_inferer.py | 824 +++++++ tests/test_ordering.py | 289 +++ tests/test_patch_gan_dicriminator.py | 179 ++ tests/test_prepare_batch_diffusion.py | 104 + tests/test_reg_loss_integration.py | 3 + tests/test_scheduler_ddim.py | 83 + tests/test_scheduler_ddpm.py | 104 + tests/test_scheduler_pndm.py | 108 + tests/test_selfattention.py | 42 +- tests/test_spade_autoencoderkl.py | 295 +++ tests/test_spade_diffusion_model_unet.py | 574 +++++ tests/test_spade_vaegan.py | 140 ++ tests/test_spatialattention.py | 55 + tests/test_synthetic.py | 2 +- tests/test_transformer.py | 109 + tests/test_transformerblock.py | 29 +- tests/test_vector_quantizer.py | 89 + tests/test_vis_cam.py | 3 + tests/test_vis_gradcam.py | 2 + tests/test_vqvae.py | 274 +++ tests/test_vqvaetransformer_inferer.py | 295 +++ tests/testing_data/data_config.json | 20 + 86 files changed, 16359 insertions(+), 178 deletions(-) create mode 100644 monai/networks/blocks/attention_utils.py create mode 100644 monai/networks/blocks/crossattention.py create mode 100644 monai/networks/blocks/rel_pos_embedding.py create mode 100644 monai/networks/blocks/spade_norm.py create mode 100644 monai/networks/blocks/spatialattention.py create mode 100644 monai/networks/layers/vector_quantizer.py create mode 100644 monai/networks/nets/autoencoderkl.py create mode 100644 monai/networks/nets/controlnet.py create mode 100644 monai/networks/nets/diffusion_model_unet.py create mode 100644 monai/networks/nets/patchgan_discriminator.py create mode 100644 monai/networks/nets/spade_autoencoderkl.py create mode 100644 monai/networks/nets/spade_diffusion_model_unet.py create mode 100644 monai/networks/nets/spade_network.py create mode 100644 monai/networks/nets/transformer.py create mode 100644 monai/networks/nets/vqvae.py create mode 100644 monai/networks/schedulers/__init__.py create mode 100644 monai/networks/schedulers/ddim.py create mode 100644 monai/networks/schedulers/ddpm.py create mode 100644 monai/networks/schedulers/pndm.py create mode 100644 monai/networks/schedulers/scheduler.py create mode 100644 monai/utils/ordering.py create mode 100644 tests/test_autoencoderkl.py create mode 100644 tests/test_controlnet.py create mode 100644 tests/test_controlnet_inferers.py create mode 100644 tests/test_crossattention.py create mode 100644 tests/test_diffusion_inferer.py create mode 100644 tests/test_diffusion_model_unet.py create mode 100644 tests/test_integration_workflows_adversarial.py create mode 100644 tests/test_latent_diffusion_inferer.py create mode 100644 tests/test_ordering.py create mode 100644 tests/test_patch_gan_dicriminator.py create mode 100644 tests/test_prepare_batch_diffusion.py create mode 100644 tests/test_scheduler_ddim.py create mode 100644 tests/test_scheduler_ddpm.py create mode 100644 tests/test_scheduler_pndm.py create mode 100644 tests/test_spade_autoencoderkl.py create mode 100644 tests/test_spade_diffusion_model_unet.py create mode 100644 tests/test_spade_vaegan.py create mode 100644 tests/test_spatialattention.py create mode 100644 tests/test_transformer.py create mode 100644 tests/test_vector_quantizer.py create mode 100644 tests/test_vqvae.py create mode 100644 tests/test_vqvaetransformer_inferer.py diff --git a/docs/source/engines.rst b/docs/source/engines.rst index afb2682822..a015c7b2a3 100644 --- a/docs/source/engines.rst +++ b/docs/source/engines.rst @@ -30,6 +30,11 @@ Workflows .. autoclass:: GanTrainer :members: +`AdversarialTrainer` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: AdversarialTrainer + :members: + `Evaluator` ~~~~~~~~~~~ .. autoclass:: Evaluator diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 33f9e14d83..326f56e96c 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -49,6 +49,29 @@ Inferers :members: :special-members: __call__ +`DiffusionInferer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionInferer + :members: + :special-members: __call__ + +`LatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LatentDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetLatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetLatentDiffusionInferer + :members: + :special-members: __call__ Splitters --------- diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 527247799f..fef671e1f8 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -81,3 +81,8 @@ Component store --------------- .. autoclass:: monai.utils.component_store.ComponentStore :members: + +Ordering +-------- +.. automodule:: monai.utils.ordering + :members: diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index 283169b653..cbde3ebae9 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -189,7 +189,7 @@ def generate_anchors( w_ratios = 1 / area_scale h_ratios = area_scale # if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1] - elif self.spatial_dims == 3: + else: area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0) w_ratios = 1 / area_scale h_ratios = aspect_ratios_t[:, 0] / area_scale @@ -199,7 +199,7 @@ def generate_anchors( hs = (h_ratios[:, None] * scales_t[None, :]).view(-1) if self.spatial_dims == 2: base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0 - elif self.spatial_dims == 3: + else: # elif self.spatial_dims == 3: ds = (d_ratios[:, None] * scales_t[None, :]).view(-1) base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0 diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 42ca385fa0..0aa8e14655 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -379,6 +379,7 @@ def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) -> """ p_delta = (current[0] - previous[0], current[1] - previous[1]) + row, col = -1, -1 if p_delta in ((0.0, 1.0), (0.5, 0.5), (1.0, 0.0)): row = int(current[0] + 0.5) diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index a0f39d236f..0f17422ba5 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -221,6 +221,7 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any raise ValueError(f"Cannot find config file '{full_cname}'") ardata = archive.read(full_cname) + cdata = {} if full_cname.lower().endswith("json"): cdata = json.loads(ardata, **load_kw_args) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 769ae33b46..5b9e32afca 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -84,6 +84,7 @@ def collect_meta_data(self): """ for data in self.data_loader: + meta_dict = {} if isinstance(data[self.image_key], MetaTensor): meta_dict = data[self.image_key].meta elif self.meta_key in data: diff --git a/monai/data/utils.py b/monai/data/utils.py index 585f02ec9e..7a08300abb 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -53,10 +53,6 @@ pytorch_after, ) -if pytorch_after(1, 13): - # import private code for reuse purposes, comment in case things break in the future - from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map - pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") nib, _ = optional_import("nibabel") @@ -454,8 +450,13 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None): Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` and so should not be used as a collate function directly in dataloaders. """ - collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate - collated = collate_fn(batch) # type: ignore + if pytorch_after(1, 13): + from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues + + collated = collate_tensor_fn(batch) + else: + collated = default_collate(batch) + meta_dicts = [i.meta or TraceKeys.NONE for i in batch] common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) if common_: @@ -496,6 +497,8 @@ def list_data_collate(batch: Sequence): if pytorch_after(1, 13): # needs to go here to avoid circular import + from torch.utils.data._utils.collate import default_collate_fn_map + from monai.data.meta_tensor import MetaTensor default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn}) diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index d8dc51f620..93cc40e292 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,12 +12,14 @@ from __future__ import annotations from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator -from .trainer import GanTrainer, SupervisedTrainer, Trainer +from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer from .utils import ( + DiffusionPrepareBatch, IterationEvents, PrepareBatch, PrepareBatchDefault, PrepareBatchExtraInput, + VPredictionPrepareBatch, default_make_latent, default_metric_cmp_fn, default_prepare_batch, diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index f1513ea73b..c1364fe015 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -24,7 +24,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import GanKeys, min_version, optional_import +from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys from monai.utils.module import pytorch_after @@ -37,7 +37,7 @@ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") -__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] +__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"] class Trainer(Workflow): @@ -471,3 +471,282 @@ def _iteration( GanKeys.GLOSS: g_loss.item(), GanKeys.DLOSS: d_total_loss.item(), } + + +class AdversarialTrainer(Trainer): + """ + Standard supervised training workflow for adversarial loss enabled neural networks. + + Args: + device: an object representing the device on which to run. + max_epochs: the total epoch number for engine to run. + train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata. + g_network: ''generator'' (G) network architecture. + g_optimizer: G optimizer function. + g_loss_function: G loss function for adversarial training. + recon_loss_function: G loss function for reconstructions. + d_network: discriminator (D) network architecture. + d_optimizer: D optimizer function. + d_loss_function: D loss function for adversarial training.. + epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. + non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to + the host. For other cases, this argument has no effect. + prepare_batch: function to parse image and label for current iteration. + iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input + parameters. if not provided, use `self._iteration()` instead. + g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``. + d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based + transforms composed by `Compose`. Defaults to None + key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics + when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files. + additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args + (current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and + `best_metric_epoch` with current metric and epoch, default to `greater than`. + train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: + CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision training, default is False. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html + #ignite.engine.engine.Engine.register_events. + decollate: whether to decollate the batch-first data to a list of data after model computation, recommend + `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + """ + + def __init__( + self, + device: torch.device | str, + max_epochs: int, + train_data_loader: Iterable | DataLoader, + g_network: torch.nn.Module, + g_optimizer: Optimizer, + g_loss_function: Callable, + recon_loss_function: Callable, + d_network: torch.nn.Module, + d_optimizer: Optimizer, + d_loss_function: Callable, + epoch_length: int | None = None, + non_blocking: bool = False, + prepare_batch: Callable = default_prepare_batch, + iteration_update: Callable | None = None, + g_inferer: Inferer | None = None, + d_inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, + train_handlers: Sequence | None = None, + amp: bool = False, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, + event_to_attr: dict | None = None, + decollate: bool = True, + optim_set_to_none: bool = False, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, + ): + super().__init__( + device=device, + max_epochs=max_epochs, + data_loader=train_data_loader, + epoch_length=epoch_length, + non_blocking=non_blocking, + prepare_batch=prepare_batch, + iteration_update=iteration_update, + postprocessing=postprocessing, + key_metric=key_train_metric, + additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, + handlers=train_handlers, + amp=amp, + event_names=event_names, + event_to_attr=event_to_attr, + decollate=decollate, + to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, + ) + + self.register_events(*AdversarialIterationEvents) + + self.state.g_network = g_network + self.state.g_optimizer = g_optimizer + self.state.g_loss_function = g_loss_function + self.state.recon_loss_function = recon_loss_function + + self.state.d_network = d_network + self.state.d_optimizer = d_optimizer + self.state.d_loss_function = d_loss_function + + self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer + self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer + + self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None + self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None + + self.optim_set_to_none = optim_set_to_none + self._complete_state_dict_user_keys() + + def _complete_state_dict_user_keys(self) -> None: + """ + This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for + checkpoint saving. + + Follows the example found at: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict + """ + self._state_dict_user_keys.extend( + ["g_network", "g_optimizer", "d_network", "d_optimizer", "g_scaler", "d_scaler"] + ) + + g_loss_state_dict = getattr(self.state.g_loss_function, "state_dict", None) + if callable(g_loss_state_dict): + self._state_dict_user_keys.append("g_loss_function") + + d_loss_state_dict = getattr(self.state.d_loss_function, "state_dict", None) + if callable(d_loss_state_dict): + self._state_dict_user_keys.append("d_loss_function") + + recon_loss_state_dict = getattr(self.state.recon_loss_function, "state_dict", None) + if callable(recon_loss_state_dict): + self._state_dict_user_keys.append("recon_loss_function") + + def _iteration( + self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor | int | float | bool]: + """ + Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine. + Return below items in a dictionary: + - IMAGE: image Tensor data for model input, already moved to device. + - LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised + Learning this is equal to IMAGE. + - PRED: prediction result of model. + - LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up). + - AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE. + - AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED. + - AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images. + - AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images. + - AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function. + - AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the + discriminator loss for the fake images. That is backpropagated through the generator only. + - AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the + discriminator loss for the real images and the fake images. That is backpropagated through the + discriminator only. + + Args: + engine: `AdversarialTrainer` to execute operation for an iteration. + batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. + + Raises: + ValueError: must provide batch data for current iteration. + + """ + + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) + + if len(batch) == 2: + inputs, targets = batch + args: tuple = () + kwargs: dict = {} + else: + inputs, targets, args, kwargs = batch + + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs} + + def _compute_generator_loss() -> None: + engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer( + inputs, engine.state.g_network, *args, **kwargs + ) + engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES] + engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs + ) + engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function( + engine.state.output[AdversarialKeys.FAKES], targets + ).mean() + engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED) + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function( + engine.state.output[AdversarialKeys.FAKE_LOGITS] + ).mean() + engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED) + + # Train Generator + engine.state.g_network.train() + engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.state.g_scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_generator_loss() + + engine.state.output[Keys.LOSS] = ( + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] + ) + engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward() + engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED) + engine.state.g_scaler.step(engine.state.g_optimizer) + engine.state.g_scaler.update() + else: + _compute_generator_loss() + ( + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] + ).backward() + engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED) + engine.state.g_optimizer.step() + engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED) + + def _compute_discriminator_loss() -> None: + engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.REALS].contiguous().detach(), + engine.state.d_network, + *args, + **kwargs, + ) + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.FAKES].contiguous().detach(), + engine.state.d_network, + *args, + **kwargs, + ) + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function( + engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS] + ).mean() + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED) + + # Train Discriminator + engine.state.d_network.train() + engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.state.d_scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_discriminator_loss() + + engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward() + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED) + engine.state.d_scaler.step(engine.state.d_optimizer) + engine.state.d_scaler.update() + else: + _compute_discriminator_loss() + engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward() + engine.state.d_optimizer.step() + + return engine.state.output diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 02c718cd14..5339d6965a 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -13,9 +13,10 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Mapping, cast import torch +import torch.nn as nn from monai.config import IgniteInfo from monai.transforms import apply_transform @@ -36,6 +37,8 @@ "PrepareBatch", "PrepareBatchDefault", "PrepareBatchExtraInput", + "DiffusionPrepareBatch", + "VPredictionPrepareBatch", "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", @@ -238,6 +241,78 @@ def _get_data(key: str) -> torch.Tensor: return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_ +class DiffusionPrepareBatch(PrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". + This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None: + self.condition_name = condition_name + self.num_train_timesteps = num_train_timesteps + + def get_noise(self, images: torch.Tensor) -> torch.Tensor: + """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" + return torch.randn_like(images) + + def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: + """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" + return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() + + def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """Return the target for the loss function, this is the `noise` value by default.""" + return noise + + def __call__( + self, + batchdata: dict[str, torch.Tensor], + device: str | torch.device | None = None, + non_blocking: bool = False, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]: + images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) + noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) + timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) + + target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) + infer_kwargs = {"noise": noise, "timesteps": timesteps} + + if self.condition_name is not None and isinstance(batchdata, Mapping): + infer_kwargs["condition"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) + + # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value + return images, target, (), infer_kwargs + + +class VPredictionPrepareBatch(DiffusionPrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + from this compute the velocity using the provided scheduler. This value is used as the target in place of the + noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer + being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None: + super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) + self.scheduler = scheduler + + def get_target(self, images, noise, timesteps): + return self.scheduler.get_velocity(images, noise, timesteps) + + def default_make_latent( num_latents: int, latent_size: int, diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 960380bfb8..fc78b9f7c4 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -12,13 +12,18 @@ from __future__ import annotations from .inferer import ( + ControlNetDiffusionInferer, + ControlNetLatentDiffusionInferer, + DiffusionInferer, Inferer, + LatentDiffusionInferer, PatchInferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer, SlidingWindowInfererAdapt, + VQVAETransformerInferer, ) from .merger import AvgMerger, Merger, ZarrAvgMerger from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 0b4199938d..769b6cc0e7 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,24 +11,41 @@ from __future__ import annotations +import math import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from functools import partial from pydoc import locate from typing import Any import torch import torch.nn as nn +import torch.nn.functional as F from monai.apps.utils import get_logger +from monai.data import decollate_batch from monai.data.meta_tensor import MetaTensor from monai.data.thread_buffer import ThreadBuffer from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference -from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DecoderOnlyTransformer, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import Scheduler +from monai.transforms import CenterSpatialCrop, SpatialPad +from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") + logger = get_logger(__name__) __all__ = [ @@ -752,3 +769,1264 @@ def network_wrapper( return out return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) + + +class DiffusionInferer(Inferer): + """ + DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass + for a training iteration, and sample from the model. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] + super().__init__() + + self.scheduler = scheduler + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + if mode == "concat": + if condition is None: + raise ValueError("Conditioning is required for concat condition") + else: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. predict noise model_output + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat" and conditioning is not None: + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) + else: + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) + + # 2. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat" and conditioning is not None: + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) + else: + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + if inputs.shape != means.shape: + raise ValueError(f"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}") + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + return log_probs + + +class LatentDiffusionInferer(DiffusionInferer): + """ + LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can + be used to perform a signal forward pass for a training iteration, and sample from the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: Scheduler, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + prediction: torch.Tensor = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode, + seg=seg, + ) + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + f"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and" + f"{diffusion_model.label_nc}" + ) + + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class ControlNetDiffusionInferer(DiffusionInferer): + """ + ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal + forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: Scheduler) -> None: + Inferer.__init__(self) + self.scheduler = scheduler + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + controlnet: controlnet sub-network. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + cn_cond: conditioning image for the ControlNet. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond + ) + if mode == "concat" and condition is not None: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + prediction: torch.Tensor = diffuse( + x=noisy_image, + timesteps=timesteps, + context=condition, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond + ) + # 2. predict noise model_output + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat" and conditioning is not None: + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffuse( + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + # 3. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond + ) + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat" and conditioning is not None: + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffuse( + noisy_image, + timesteps=timesteps, + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + x=noisy_image, + timesteps=timesteps, + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -super()._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + +class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): + """ + ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, + and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from + the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: Scheduler, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + controlnet: instance of ControlNet model + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + cn_cond: conditioning tensor for the ControlNet network + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + if cn_cond.shape[2:] != latent.shape[2:]: + cn_cond = F.interpolate(cn_cond, latent.shape[2:]) + + prediction = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=cn_cond, + condition=condition, + mode=mode, + seg=seg, + ) + + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + "labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}" + ) + + if cn_cond.shape[2:] != input_noise.shape[2:]: + cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) + + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if cn_cond.shape[2:] != latents.shape[2:]: + cn_cond = F.interpolate(cn_cond, latents.shape[2:]) + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class VQVAETransformerInferer(nn.Module): + """ + Class to perform inference with a VQVAE + Transformer model. + """ + + def __init__(self) -> None: + Inferer.__init__(self) + + def __call__( + self, + inputs: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + condition: torch.Tensor | None = None, + return_latent: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted. + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + return_latent: also return latent sequence and spatial dim of the latent. + condition: conditioning for network input. + """ + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + # get the targets for the loss + target = latent.clone() + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + # crop the last token as we do not need the probability of the token that follows it + latent = latent[:, :-1] + latent = latent.long() + + # train on a part of the sequence if it is longer than max_seq_length + seq_len = latent.shape[1] + max_seq_len = transformer_model.max_seq_len + if max_seq_len < seq_len: + start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item()) + else: + start = 0 + prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) + if return_latent: + return prediction, target[:, start : start + max_seq_len], latent_spatial_dim + else: + return prediction + + @torch.no_grad() + def sample( + self, + latent_spatial_dim: tuple[int, int, int] | tuple[int, int], + starting_tokens: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool = True, + ) -> torch.Tensor: + """ + Sampling function for the VQVAE + Transformer model. + + Args: + latent_spatial_dim: shape of the sampled image. + starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. + vqvae_model: first stage model. + transformer_model: model to sample from. + conditioning: Conditioning for network input. + temperature: temperature for sampling. + top_k: top k sampling. + verbose: if true, prints the progression bar of the sampling process. + """ + seq_len = math.prod(latent_spatial_dim) + + if verbose and has_tqdm: + progress_bar = tqdm(range(seq_len)) + else: + progress_bar = iter(range(seq_len)) + + latent_seq = starting_tokens.long() + for _ in progress_bar: + # if the sequence context is growing too long we must crop it at block_size + if latent_seq.size(1) <= transformer_model.max_seq_len: + idx_cond = latent_seq + else: + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] + + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=conditioning) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # remove the chance to be sampled the BOS token + probs[:, vqvae_model.num_embeddings] = 0 + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + latent_seq = torch.cat((latent_seq, idx_next), dim=1) + + latent_seq = latent_seq[:, 1:] + latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] + latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) + + return vqvae_model.decode_samples(latent) + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + verbose: bool = False, + ) -> torch.Tensor: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + verbose: if true, prints the progression bar of the sampling process. + + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + seq_len = math.prod(latent_spatial_dim) + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + # get the first batch, up to max_seq_length, efficiently + logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition) + probs = F.softmax(logits, dim=-1) + # target token for each set of logits is the next token along + target = latent[:, 1:] + probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) + + # if we have not covered the full sequence we continue with inefficient looping + if probs.shape[1] < target.shape[1]: + if verbose and has_tqdm: + progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len)) + else: + progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) + + for i in progress_bar: + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=condition) + # pluck the logits at the final step + logits = logits[:, -1, :] + # apply softmax to convert logits to (normalized) probabilities + p = F.softmax(logits, dim=-1) + # select correct values and append + p = torch.gather(p, 1, target[:, i].unsqueeze(1)) + + probs = torch.cat((probs, p), dim=1) + + # convert to log-likelihood + probs = torch.log(probs) + + # reshape + probs = probs[:, ordering.get_revert_sequence_ordering()] + probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) + if resample_latent_likelihoods: + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + probs_reshaped = resizer(probs_reshaped[:, None, ...]) + + return probs_reshaped diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index e67cb3376f..47abc4a1c4 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -17,6 +17,7 @@ from .backbone_fpn_utils import BackboneWithFPN from .convolutions import Convolution, ResidualUnit from .crf import CRF +from .crossattention import CrossAttentionBlock from .denseblock import ConvDenseBlock, DenseBlock from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock from .downsample import MaxAvgPool @@ -30,6 +31,8 @@ from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock +from .spade_norm import SPADE +from .spatialattention import SpatialAttentionBlock from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py new file mode 100644 index 0000000000..8c9002a16e --- /dev/null +++ b/monai/networks/blocks/attention_utils.py @@ -0,0 +1,128 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + rel_pos_resized: torch.Tensor = torch.Tensor() + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear" + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple +) -> torch.Tensor: + r""" + Calculate decomposed Relative Positional Embeddings from mvitv2 implementation: + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Only 2D and 3D are supported. + + Encoding the relative position of tokens in the attention matrix: tokens spaced a distance + `d` apart will have the same embedding value (unlike absolute positional embedding). + + .. math:: + Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale + + where + + .. math:: + E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)} + + with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`, + respectively spatial positions of element :math:`i` and :math:`j` + + When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow: + + .. math:: + R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)} + + with :math:`n = 1...dim` + + Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to + :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding. + + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). + rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis. + q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n). + k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). + + Returns: + attn (Tensor): attention logits with added relative positional embeddings. + """ + rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) + rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1]) + + batch, _, dim = q.shape + + if len(rel_pos_lst) == 2: + q_h, q_w = q_size[:2] + k_h, k_w = k_size[:2] + r_q = q.reshape(batch, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) + + attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + batch, q_h * q_w, k_h * k_w + ) + elif len(rel_pos_lst) == 3: + q_h, q_w, q_d = q_size[:3] + k_h, k_w, k_d = k_size[:3] + + rd = get_rel_pos(q_d, k_d, rel_pos_lst[2]) + + r_q = q.reshape(batch, q_h, q_w, q_d, dim) + rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) + rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) + rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) + + attn = ( + attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) + + rel_h[:, :, :, :, None, None] + + rel_w[:, :, :, None, :, None] + + rel_d[:, :, :, None, None, :] + ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) + + return attn diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py new file mode 100644 index 0000000000..dc1d5d388e --- /dev/null +++ b/monai/networks/blocks/crossattention.py @@ -0,0 +1,166 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_rel_pos_embedding_layer +from monai.utils import optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + + +class CrossAttentionBlock(nn.Module): + """ + A cross-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + One can setup relative positional embedding as described in + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + hidden_input_size: int | None = None, + context_input_size: int | None = None, + dim_head: int | None = None, + qkv_bias: bool = False, + save_attn: bool = False, + causal: bool = False, + sequence_length: int | None = None, + rel_pos_embedding: Optional[str] = None, + input_size: Optional[Tuple] = None, + attention_dtype: Optional[torch.dtype] = None, + ) -> None: + """ + Args: + hidden_size (int): dimension of hidden layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. + hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. + context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. + qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative + positional parameter size. + attention_dtype: cast attention operations to this dtype. + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if dim_head: + inner_size = num_heads * dim_head + self.head_dim = dim_head + else: + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + inner_size = hidden_size + self.head_dim = hidden_size // num_heads + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + self.num_heads = num_heads + self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size + self.context_input_size = context_input_size if context_input_size else hidden_size + self.out_proj = nn.Linear(inner_size, self.hidden_input_size) + # key, query, value projections + self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias) + self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) + self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) + + self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.drop_output = nn.Dropout(dropout_rate) + self.drop_weights = nn.Dropout(dropout_rate) + + self.scale = self.head_dim**-0.5 + self.save_attn = save_attn + self.attention_dtype = attention_dtype + + self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + self.att_mat = torch.Tensor() + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + if rel_pos_embedding is not None + else None + ) + self.input_size = input_size + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C + + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + q = self.to_q(x) + kv = context if context is not None else x + _, kv_t, _ = kv.size() + k = self.to_k(kv) + v = self.to_v(kv) + + if self.attention_dtype is not None: + q = q.to(self.attention_dtype) + k = k.to(self.attention_dtype) + + q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + + att_mat = att_mat.softmax(dim=-1) + + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() + + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) + x = self.out_proj(x) + x = self.drop_output(x) + return x diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py new file mode 100644 index 0000000000..e53e5841b0 --- /dev/null +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -0,0 +1,56 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Iterable, Tuple + +import torch +from torch import nn + +from monai.networks.blocks.attention_utils import add_decomposed_rel_pos +from monai.utils.misc import ensure_tuple_size + + +class DecomposedRelativePosEmbedding(nn.Module): + def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None: + """ + Args: + s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D) + c_dim (int): channel dimension + num_heads(int): number of attention heads + """ + super().__init__() + + # validate inputs + if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]: + raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)") + + self.s_input_dims = s_input_dims + self.c_dim = c_dim + self.num_heads = num_heads + self.rel_pos_arr = nn.ParameterList( + [nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims] + ) + + def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + """""" + batch = x.shape[0] + h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1) + + att_mat = add_decomposed_rel_pos( + att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d), + q.contiguous().view(batch * self.num_heads, h * w * d, -1), + self.rel_pos_arr, + (h, w) if d == 1 else (h, w, d), + (h, w) if d == 1 else (h, w, d), + ) + + att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d) + return att_mat diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 7b410b1a7c..9905e7d036 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,9 +11,12 @@ from __future__ import annotations +from typing import Optional, Tuple + import torch import torch.nn as nn +from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -33,6 +36,12 @@ def __init__( qkv_bias: bool = False, save_attn: bool = False, dim_head: int | None = None, + hidden_input_size: int | None = None, + causal: bool = False, + sequence_length: int | None = None, + rel_pos_embedding: Optional[str] = None, + input_size: Optional[Tuple] = None, + attention_dtype: Optional[torch.dtype] = None, ) -> None: """ Args: @@ -42,6 +51,14 @@ def __init__( qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. + hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. + causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762). + sequence_length: if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative + positional parameter size. + attention_dtype: cast attention operations to this dtype. """ @@ -53,12 +70,23 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") + if dim_head: + self.inner_dim = num_heads * dim_head + self.dim_head = dim_head + else: + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + self.inner_dim = hidden_size + self.dim_head = hidden_size // num_heads + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + self.num_heads = num_heads - self.dim_head = hidden_size // num_heads if dim_head is None else dim_head - self.inner_dim = self.dim_head * num_heads + self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size + self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) - self.out_proj = nn.Linear(self.inner_dim, hidden_size) - self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) @@ -66,11 +94,50 @@ def __init__( self.scale = self.dim_head**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() + self.attention_dtype = attention_dtype + self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads) + if rel_pos_embedding is not None + else None + ) + self.input_size = input_size def forward(self, x): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] - att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + + if self.attention_dtype is not None: + q = q.to(self.attention_dtype) + k = k.to(self.attention_dtype) + + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + + att_mat = att_mat.softmax(dim=-1) + if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py new file mode 100644 index 0000000000..343dfa9ec0 --- /dev/null +++ b/monai/networks/blocks/spade_norm.py @@ -0,0 +1,95 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.networks.layers.utils import get_norm_layer + + +class SPADE(nn.Module): + """ + Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a + semantic map. This block is used in SPADE-based image-to-image translation models, as described in + Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291). + + Args: + label_nc: number of semantic labels + norm_nc: number of output channels + kernel_size: kernel size + spatial_dims: number of spatial dimensions + hidden_channels: number of channels in the intermediate gamma and beta layers + norm: type of base normalisation used before applying the SPADE normalisation + norm_params: parameters for the base normalisation + """ + + def __init__( + self, + label_nc: int, + norm_nc: int, + kernel_size: int = 3, + spatial_dims: int = 2, + hidden_channels: int = 64, + norm: str | tuple = "INSTANCE", + norm_params: dict | None = None, + ) -> None: + super().__init__() + + if norm_params is None: + norm_params = {} + if len(norm_params) != 0: + norm = (norm, norm_params) + self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc) + self.mlp_shared = Convolution( + spatial_dims=spatial_dims, + in_channels=label_nc, + out_channels=hidden_channels, + kernel_size=kernel_size, + norm=None, + act="LEAKYRELU", + ) + self.mlp_gamma = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + act=None, + ) + self.mlp_beta = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + act=None, + ) + + def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels. + segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels. + The map will be interpolated to the dimension of x internally. + """ + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x.contiguous()) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out: torch.Tensor = normalized * (1 + gamma) + beta + return out diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py new file mode 100644 index 0000000000..75319853d9 --- /dev/null +++ b/monai/networks/blocks/spatialattention.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn + +from monai.networks.blocks import SABlock +from monai.utils import optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + + +class SpatialAttentionBlock(nn.Module): + """Perform spatial self-attention on the input tensor. + + The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels, and then + self-attention is performed on the reshaped tensor. The output tensor is reshaped back to the original shape. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + num_channels: number of input channels. Must be divisible by num_head_channels. + num_head_channels: number of channels per head. + attention_dtype: cast attention operations to this dtype. + + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + attention_dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + # check num_head_channels is divisible by num_channels + if num_head_channels is not None and num_channels % num_head_channels != 0: + raise ValueError("num_channels must be divisible by num_head_channels") + num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.attn = SABlock( + hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype + ) + + def forward(self, x: torch.Tensor): + residual = x + + if self.spatial_dims == 1: + h = x.shape[2] + rearrange_input = Rearrange("b c h -> b h c") + rearrange_output = Rearrange("b h c -> b c h", h=h) + if self.spatial_dims == 2: + h, w = x.shape[2], x.shape[3] + rearrange_input = Rearrange("b c h w -> b (h w) c") + rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) + else: + h, w, d = x.shape[2], x.shape[3], x.shape[4] + rearrange_input = Rearrange("b c h w d -> b (h w d) c") + rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) + + x = self.norm(x) + x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C + + x = self.attn(x) + x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] + x = x + residual + return x diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index ddf959dad2..2458902cba 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -11,10 +11,10 @@ from __future__ import annotations +import torch import torch.nn as nn -from monai.networks.blocks.mlp import MLPBlock -from monai.networks.blocks.selfattention import SABlock +from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock class TransformerBlock(nn.Module): @@ -31,6 +31,9 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, ) -> None: """ Args: @@ -53,10 +56,27 @@ def __init__( self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) - self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) + self.attn = SABlock( + hidden_size, + num_heads, + dropout_rate, + qkv_bias=qkv_bias, + save_attn=save_attn, + causal=causal, + sequence_length=sequence_length, + ) self.norm2 = nn.LayerNorm(hidden_size) + self.with_cross_attention = with_cross_attention - def forward(self, x): + if self.with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: x = x + self.attn(self.norm1(x)) + if self.with_cross_attention: + x = x + self.cross_attn(self.norm_cross_attn(x), context=context) x = x + self.mlp(self.norm2(x)) return x diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index dee9966919..50fd39a70b 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -17,8 +17,8 @@ import torch.nn as nn from monai.networks.layers.factories import Conv, Pad, Pool -from monai.networks.utils import icnr_init, pixelshuffle -from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option +from monai.networks.utils import CastTempType, icnr_init, pixelshuffle +from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"] @@ -50,6 +50,7 @@ def __init__( size: tuple[int] | int | None = None, mode: UpsampleMode | str = UpsampleMode.DECONV, pre_conv: nn.Module | str | None = "default", + post_conv: nn.Module | None = None, interp_mode: str = InterpolateMode.LINEAR, align_corners: bool | None = True, bias: bool = True, @@ -71,6 +72,7 @@ def __init__( pre_conv: a conv block applied before upsampling. Defaults to "default". When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when Only used in the "nontrainable" or "pixelshuffle" mode. + post_conv: a conv block applied after upsampling. Defaults to None. Only used in the "nontrainable" mode. interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} Only used in the "nontrainable" mode. If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation. @@ -154,15 +156,25 @@ def __init__( linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] if interp_mode in linear_mode: # choose mode based on dimensions interp_mode = linear_mode[spatial_dims - 1] - self.add_module( - "upsample_non_trainable", - nn.Upsample( - size=size, - scale_factor=None if size else scale_factor_, - mode=interp_mode.value, - align_corners=align_corners, - ), + + upsample = nn.Upsample( + size=size, + scale_factor=None if size else scale_factor_, + mode=interp_mode.value, + align_corners=align_corners, ) + + # Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1 + if pytorch_after(major=2, minor=1): + self.add_module("upsample_non_trainable", upsample) + else: + self.add_module( + "upsample_non_trainable", + CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample), + ) + if post_conv: + self.add_module("postconv", post_conv) elif up_mode == UpsampleMode.PIXELSHUFFLE: self.add_module( "pixelshuffle", diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 3a6e4aa554..48c10270b1 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -14,7 +14,7 @@ from .conjugate_gradient import ConjugateGradient from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath -from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args +from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, RelPosEmbedding, split_args from .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter, TrainableJointBilateralFilter from .gmm import GaussianMixtureModel from .simplelayers import ( @@ -38,4 +38,5 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .vector_quantizer import EMAQuantizer, VectorQuantizer from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 4fc2c16f73..29b72a4f37 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -70,7 +70,7 @@ def use_factory(fact_args): from monai.networks.utils import has_nvfuser_instance_norm from monai.utils import ComponentStore, look_up_option, optional_import -__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] +__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"] class LayerFactory(ComponentStore): @@ -201,6 +201,10 @@ def split_args(args): Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.") Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.") Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.") +RelPosEmbedding = LayerFactory( + name="Relative positional embedding layers", + description="Factory for creating relative positional embedding factory", +) @Dropout.factory_function("dropout") @@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | """ types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] + + +@RelPosEmbedding.factory_function("decomposed") +def decomposed_rel_pos_embedding() -> type[nn.Module]: + from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding + + return DecomposedRelativePosEmbedding diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index ace1af27b6..8676f74638 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -11,9 +11,11 @@ from __future__ import annotations +from typing import Optional + import torch.nn -from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args +from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args from monai.utils import has_option __all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"] @@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1): pool_name, pool_args = split_args(name) pool_type = Pool[pool_name, spatial_dims] return pool_type(**pool_args) + + +def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int): + embedding_name, embedding_args = split_args(name) + embedding_type = RelPosEmbedding[embedding_name] + # create a dictionary with the default values which can be overridden by embedding_args + kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args} + # filter out unused argument names + kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)} + + return embedding_type(**kw_args) diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py new file mode 100644 index 0000000000..9c354e1009 --- /dev/null +++ b/monai/networks/layers/vector_quantizer.py @@ -0,0 +1,233 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence, Tuple + +import torch +from torch import nn + +__all__ = ["VectorQuantizer", "EMAQuantizer"] + + +class EMAQuantizer(nn.Module): + """ + Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural + Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation + that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit + 58d9a2746493717a7c9252938da7efa6006f3739. + + This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due + to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 + on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. + + Args: + spatial_dims: number of spatial dimensions of the input. + num_embeddings: number of atomic elements in the codebook. + embedding_dim: number of channels of the input and atomic elements. + commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25. + decay: EMA decay. Defaults to 0.99. + epsilon: epsilon value. Defaults to 1e-5. + embedding_init: initialization method for the codebook. Defaults to "normal". + ddp_sync: whether to synchronize the codebook across processes. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + num_embeddings: int, + embedding_dim: int, + commitment_cost: float = 0.25, + decay: float = 0.99, + epsilon: float = 1e-5, + embedding_init: str = "normal", + ddp_sync: bool = True, + ): + super().__init__() + self.spatial_dims: int = spatial_dims + self.embedding_dim: int = embedding_dim + self.num_embeddings: int = num_embeddings + + assert self.spatial_dims in [2, 3], ValueError( + f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}." + ) + + self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim) + if embedding_init == "normal": + # Initialization is passed since the default one is normal inside the nn.Embedding + pass + elif embedding_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear") + self.embedding.weight.requires_grad = False + + self.commitment_cost: float = commitment_cost + + self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings)) + self.register_buffer("ema_w", self.embedding.weight.data.clone()) + # declare types for mypy + self.ema_cluster_size: torch.Tensor + self.ema_w: torch.Tensor + self.decay: float = decay + self.epsilon: float = epsilon + + self.ddp_sync: bool = ddp_sync + + # Precalculating required permutation shapes + self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1] + self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list( + range(1, self.spatial_dims + 1) + ) + + def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. + + Args: + inputs: Encoding space tensors of shape [B, C, H, W, D]. + + Returns: + torch.Tensor: Flatten version of the input of shape [B*H*W*D, C]. + torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings]. + torch.Tensor: Quantization indices of shape [B,H,W,D,1] + + """ + with torch.cuda.amp.autocast(enabled=False): + encoding_indices_view = list(inputs.shape) + del encoding_indices_view[1] + + inputs = inputs.float() + + # Converting to channel last format + flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) + + # Calculate Euclidean distances + distances = ( + (flat_input**2).sum(dim=1, keepdim=True) + + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) + - 2 * torch.mm(flat_input, self.embedding.weight.t()) + ) + + # Mapping distances to indexes + encoding_indices = torch.max(-distances, dim=1)[1] + encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() + + # Quantize and reshape + encoding_indices = encoding_indices.view(encoding_indices_view) + + return flat_input, encodings, encoding_indices + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + """ + Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space + [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the + decoder. + + Args: + embedding_indices: Tensor in channel last format which holds indices referencing atomic + elements from self.embedding + + Returns: + torch.Tensor: Quantize space representation of encoding_indices in channel first format. + """ + with torch.cuda.amp.autocast(enabled=False): + embedding: torch.Tensor = ( + self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() + ) + return embedding + + def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: + """ + TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the + example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused + + Args: + encodings_sum: The summation of one hot representation of what encoding was used for each + position. + dw: The multiplication of the one hot representation of what encoding was used for each + position with the flattened input. + + Returns: + None + """ + if self.ddp_sync and torch.distributed.is_initialized(): + torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM) + else: + pass + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat_input, encodings, encoding_indices = self.quantize(inputs) + quantized = self.embed(encoding_indices) + + # Use EMA to update the embedding vectors + if self.training: + with torch.no_grad(): + encodings_sum = encodings.sum(0) + dw = torch.mm(encodings.t(), flat_input) + + if self.ddp_sync: + self.distributed_synchronization(encodings_sum, dw) + + self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay)) + + # Laplace smoothing of the cluster size + n = self.ema_cluster_size.sum() + weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n + self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay)) + self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1)) + + # Encoding Loss + loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs) + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + + return quantized, loss, encoding_indices + + +class VectorQuantizer(torch.nn.Module): + """ + Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of + the quantization in their own class. + + Args: + quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index + based quantized representation. + """ + + def __init__(self, quantizer: EMAQuantizer): + super().__init__() + + self.quantizer: EMAQuantizer = quantizer + + self.perplexity: torch.Tensor = torch.rand(1) + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + quantized, loss, encoding_indices = self.quantizer(inputs) + # Perplexity calculations + avg_probs = ( + torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings) + .float() + .div(encoding_indices.numel()) + ) + + self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return loss, quantized + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.quantizer.embed(embedding_indices=embedding_indices) + + def quantize(self, encodings: torch.Tensor) -> torch.Tensor: + output = self.quantizer(encodings) + encoding_indices: torch.Tensor = output[2] + return encoding_indices diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index de5d1adc7e..c777fe6442 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -14,9 +14,11 @@ from .ahnet import AHnet, Ahnet, AHNet from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder +from .autoencoderkl import AutoencoderKL from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator +from .controlnet import ControlNet from .daf3d import DAF3D from .densenet import ( DenseNet, @@ -34,6 +36,7 @@ densenet201, densenet264, ) +from .diffusion_model_unet import DiffusionModelUNet from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( @@ -52,6 +55,7 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter +from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet @@ -104,9 +108,13 @@ seresnext50, seresnext101, ) +from .spade_autoencoderkl import SPADEAutoencoderKL +from .spade_diffusion_model_unet import SPADEDiffusionModelUNet +from .spade_network import SPADENet from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .transformer import DecoderOnlyTransformer from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder @@ -114,3 +122,4 @@ from .vitautoenc import ViTAutoEnc from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet +from .vqvae import VQVAE diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py new file mode 100644 index 0000000000..35d80e0565 --- /dev/null +++ b/monai/networks/nets/autoencoderkl.py @@ -0,0 +1,702 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample +from monai.utils import ensure_tuple_rep, optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + +__all__ = ["AutoencoderKL"] + + +class AsymmetricPad(nn.Module): + """ + Pad the input tensor asymmetrically along every spatial dimension. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + """ + + def __init__(self, spatial_dims: int) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + return x + + +class AEKLDownsample(nn.Module): + """ + Convolution-based downsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + """ + + def __init__(self, spatial_dims: int, in_channels: int) -> None: + super().__init__() + self.pad = AsymmetricPad(spatial_dims=spatial_dims) + + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pad(x) + x = self.conv(x) + return x + + +class AEKLResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.nin_shortcut: nn.Module + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = F.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class Encoder(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + in_channels: number of input channels. + channels: sequence of block output channels. + out_channels: number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + blocks: List[nn.Module] = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Residual and downsampling blocks + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(self.num_res_blocks[i]): + blocks.append( + AEKLResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + if not is_final_block: + blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel)) + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + AEKLResBlock( + spatial_dims=spatial_dims, + in_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=channels[-1], + ) + ) + + blocks.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + blocks.append( + AEKLResBlock( + spatial_dims=spatial_dims, + in_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=channels[-1], + ) + ) + # Normalise and convert to latent size + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + reversed_block_out_channels = list(reversed(channels)) + + blocks: List[nn.Module] = [] + + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + AEKLResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + blocks.append( + AEKLResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + AEKLResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + if not is_final_block: + if use_convtranspose: + blocks.append( + Upsample( + spatial_dims=spatial_dims, mode="deconv", in_channels=block_in_ch, out_channels=block_in_ch + ) + ) + else: + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=block_in_ch, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + blocks.append( + Upsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=block_in_ch, + out_channels=block_in_ch, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + ) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class AutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + channels: number of output channels for each block. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_checkpoint: if True, use activation checkpoint to save memory. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_checkpoint: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + ) + self.decoder = Decoder( + spatial_dims=spatial_dims, + channels=channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_convtranspose=use_convtranspose, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + self.use_checkpoint = use_checkpoint + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + if self.use_checkpoint: + h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) + else: + h = self.encoder(x) + + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu) + return reconstruction + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec: torch.Tensor + if self.use_checkpoint: + dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) + else: + dec = self.decoder(z) + return dec + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + image = self.decode(z) + return image + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from an AutoencoderKL trained with [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old AutoencoderKL model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( + [ + old_state_dict[f"{block}.to_q.weight"], + old_state_dict[f"{block}.to_k.weight"], + old_state_dict[f"{block}.to_v.weight"], + ], + dim=0, + ) + new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat( + [ + old_state_dict[f"{block}.to_q.bias"], + old_state_dict[f"{block}.to_k.bias"], + old_state_dict[f"{block}.to_v.bias"], + ], + dim=0, + ) + # old version did not have a projection so set these to the identity + new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye( + new_state_dict[f"{block}.attn.out_proj.weight"].shape[0] + ) + new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros( + new_state_dict[f"{block}.attn.out_proj.bias"].shape + ) + + # fix the upsample conv blocks which were renamed postconv + for k in new_state_dict: + if "postconv" in k: + old_name = k.replace("postconv", "conv") + new_state_dict[k] = old_state_dict[old_name] + self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py new file mode 100644 index 0000000000..ed3654733d --- /dev/null +++ b/monai/networks/nets/controlnet.py @@ -0,0 +1,465 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding +from monai.utils import ensure_tuple_rep + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Network to encode the conditioning into a latent space. + """ + + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]): + super().__init__() + + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + adn_ordering="A", + act="SWISH", + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(channels) - 1): + channel_in = channels[i] + channel_out = channels[i + 1] + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_in, + strides=1, + kernel_size=3, + padding=1, + adn_ordering="A", + act="SWISH", + ) + ) + + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_out, + strides=2, + kernel_size=3, + padding=1, + adn_ordering="A", + act="SWISH", + ) + ) + + self.conv_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + + for block in self.blocks: + embedding = block(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNet(nn.Module): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError( + f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={channels} and norm_num_groups={norm_num_groups}" + ) + + if len(channels) != len(attention_levels): + raise ValueError( + f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"channels={channels} and attention_levels={attention_levels}" + ) + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + f"num_head_channels should have the same length as attention_levels, but got channels={channels} and " + f"attention_levels={attention_levels} . For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + spatial_dims=spatial_dims, + in_channels=conditioning_embedding_in_channels, + channels=conditioning_embedding_num_channels, + out_channels=channels[0], + ) + + # down + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + output_channel = channels[0] + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block.conv) + self.controlnet_down_blocks.append(controlnet_block) + + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + self.down_blocks.append(down_block) + + for _ in range(num_res_blocks[i]): + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + # + if not is_final_block: + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = channels[-1] + + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + """ + Args: + x: input tensor (N, C, H, W, [D]). + timesteps: timestep tensor (N,). + controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D]) + conditioning_scale: conditioning scale. + context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init. + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + h += controlnet_cond + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # 6. Control net blocks + controlnet_down_block_res_samples = [] + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples.append(down_block_res_sample) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h) + + # 6. scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a ControlNet trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old ControlNet model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( + [ + old_state_dict[f"{block}.attn1.to_q.weight"], + old_state_dict[f"{block}.attn1.to_k.weight"], + old_state_dict[f"{block}.attn1.to_v.weight"], + ], + dim=0, + ) + + # projection + new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] + new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + + new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] + new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + + self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py new file mode 100644 index 0000000000..8a9ac859a3 --- /dev/null +++ b/monai/networks/nets/diffusion_model_unet.py @@ -0,0 +1,1913 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import math +from collections.abc import Sequence + +import torch +from torch import nn + +from monai.networks.blocks import Convolution, CrossAttentionBlock, MLPBlock, SABlock, SpatialAttentionBlock, Upsample +from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep, optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + +__all__ = ["DiffusionModelUNet"] + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class DiffusionUNetTransformerBlock(nn.Module): + """ + A Transformer block that allows for the input dimension to differ from the hidden dimension. + + Args: + num_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + dropout: dropout probability to use. + cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. + + """ + + def __init__( + self, + num_channels: int, + num_attention_heads: int, + num_head_channels: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + ) -> None: + super().__init__() + self.attn1 = SABlock( + hidden_size=num_attention_heads * num_head_channels, + hidden_input_size=num_channels, + num_heads=num_attention_heads, + dim_head=num_head_channels, + dropout_rate=dropout, + attention_dtype=torch.float if upcast_attention else None, + ) + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) + self.attn2 = CrossAttentionBlock( + hidden_size=num_attention_heads * num_head_channels, + num_heads=num_attention_heads, + hidden_input_size=num_channels, + context_input_size=cross_attention_dim, + dim_head=num_head_channels, + dropout_rate=dropout, + attention_dtype=torch.float if upcast_attention else None, + ) + self.norm1 = nn.LayerNorm(num_channels) + self.norm2 = nn.LayerNorm(num_channels) + self.norm3 = nn.LayerNorm(num_channels) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # 1. Self-Attention + x = self.attn1(self.norm1(x)) + x + + # 2. Cross-Attention + x = self.attn2(self.norm2(x), context=context) + x + + # 3. Feed-forward + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + num_layers: number of layers of Transformer blocks to use. + dropout: dropout probability to use. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_attention_heads: int, + num_head_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + inner_dim = num_attention_heads * num_head_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + + self.proj_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=inner_dim, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + self.transformer_blocks = nn.ModuleList( + [ + DiffusionUNetTransformerBlock( + num_channels=inner_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=inner_dim, + out_channels=in_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # note: if no context is given, cross-attention defaults to self-attention + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + residual = x + x = self.norm(x) + x = self.proj_in(x) + + inner_dim = x.shape[1] + + if self.spatial_dims == 2: + x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + if self.spatial_dims == 3: + x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) + + for block in self.transformer_blocks: + x = block(x, context=context) + + if self.spatial_dims == 2: + x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + if self.spatial_dims == 3: + x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() + + x = self.proj_out(x) + return x + residual + + +def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic + Models" https://arxiv.org/abs/2006.11239. + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + embedding_dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + """ + if timesteps.ndim != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + freqs = torch.exp(exponent / half_dim) + + args = timesteps[:, None].float() * freqs[None, :] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) + + return embedding + + +class DiffusionUnetDownsample(nn.Module): + """ + Downsampling layer. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is + False, the number of output channels must be the same as the number of input channels. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points + for each dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.op = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=2, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + if self.num_channels != self.out_channels: + raise ValueError("num_channels and out_channels must be equal when use_conv=False") + self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError( + f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " + f"({self.num_channels})" + ) + output: torch.Tensor = self.op(x) + return output + + +class WrappedUpsample(Upsample): + """ + Wraps MONAI upsample block to allow for calling with timestep embeddings. + """ + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + upsampled: torch.Tensor = super().forward(x) + return upsampled + + +class DiffusionUNetResnetBlock(nn.Module): + """ + Residual block with timestep conditioning. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=in_channels, + out_channels=in_channels, + interp_mode="nearest", + scale_factor=2.0, + align_corners=None, + ) + elif down: + self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + self.skip_connection: nn.Module + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + + if self.upsample is not None: + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + output: torch.Tensor = self.skip_connection(x) + h + return output + + +class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsampler: nn.Module | None + if resblock_updown: + self.downsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = DiffusionUnetDownsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.downsampler: nn.Module | None + if add_downsample: + if resblock_updown: + self.downsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = DiffusionUnetDownsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states).contiguous() + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.downsampler: nn.Module | None + if add_downsample: + if resblock_updown: + self.downsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = DiffusionUnetDownsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context).contiguous() + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + ) -> None: + super().__init__() + + self.resnet_1 = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=in_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + self.resnet_2 = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + del context + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states).contiguous() + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + + self.resnet_1 = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_attention_heads=in_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout=dropout_cattn, + ) + self.resnet_2 = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states, context=context) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states).contiguous() + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_down_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_downsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + ) + elif with_cross_attn: + return CrossAttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout_cattn=dropout_cattn, + ) + else: + return DownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + ) + + +def get_mid_block( + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int, + norm_eps: float, + with_conditioning: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_conditioning: + return CrossAttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout_cattn=dropout_cattn, + ) + else: + return AttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + ) + + +def get_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + ) + elif with_cross_attn: + return CrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout_cattn=dropout_cattn, + ) + else: + return UpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + ) + + +class DiffusionModelUNet(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError("Dropout cannot be negative or >1.0!") + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout_cattn=dropout_cattn, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout_cattn=dropout_cattn, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] + + is_final_block = i == len(channels) - 1 + + up_block = get_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + dropout_cattn=dropout_cattn, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples: list[torch.Tensor] = [] + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += [down_block_res_sample] + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) + + # 7. output block + output: torch.Tensor = self.out(h) + + return output + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a DiffusionModelUNet trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old DecoderOnlyTransformer model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( + [ + old_state_dict[f"{block}.attn1.to_q.weight"], + old_state_dict[f"{block}.attn1.to_k.weight"], + old_state_dict[f"{block}.attn1.to_v.weight"], + ], + dim=0, + ) + + # projection + new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] + new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + + new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] + new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + # fix the upsample conv blocks which were renamed postconv + for k in new_state_dict: + if "postconv" in k: + old_name = k.replace("postconv", "conv") + new_state_dict[k] = old_state_dict[old_name] + self.load_state_dict(new_state_dict) + + +class DiffusionModelEncoder(nn.Module): + """ + Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on + Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: if True, upcast attention operations to full precision. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") + if len(channels) != len(attention_levels): + raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) # - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + self.down_blocks.append(down_block) + + self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + for downsample_block in self.down_blocks: + h, _ = downsample_block(hidden_states=h, temb=emb, context=context) + + h = h.reshape(h.shape[0], -1) + output: torch.Tensor = self.out(h) + + return output diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py new file mode 100644 index 0000000000..74da917694 --- /dev/null +++ b/monai/networks/nets/patchgan_discriminator.py @@ -0,0 +1,230 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.networks.utils import normal_init + + +class MultiScalePatchDiscriminator(nn.Sequential): + """ + Multi-scale Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + + The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images + at different spatial scales. + + Args: + num_d: number of discriminators + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first + discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved. + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels in each discriminator + kernel_size: kernel size of the convolution layers + activation: activation layer type + norm: normalisation type + bias: introduction of layer bias + dropout: probability of dropout applied, defaults to 0. + minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture + requested isn't going to downsample the input image beyond value of 1. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + num_d: int, + num_layers_d: int, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + dropout: float | tuple = 0.0, + minimum_size_im: int = 256, + last_conv_kernel_size: int = 1, + ) -> None: + super().__init__() + self.num_d = num_d + self.num_layers_d = num_layers_d + self.num_channels = channels + self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims) + for i_ in range(self.num_d): + num_layers_d_i = self.num_layers_d * (i_ + 1) + output_size = float(minimum_size_im) / (2**num_layers_d_i) + if output_size < 1: + raise AssertionError( + f"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}." + "Please reduce num_layers, reduce num_D or enter bigger images." + ) + subnet_d = PatchDiscriminator( + spatial_dims=spatial_dims, + channels=self.num_channels, + in_channels=in_channels, + out_channels=out_channels, + num_layers_d=num_layers_d_i, + kernel_size=kernel_size, + activation=activation, + norm=norm, + bias=bias, + padding=self.padding, + dropout=dropout, + last_conv_kernel_size=last_conv_kernel_size, + ) + + self.add_module("discriminator_%d" % i_, subnet_d) + + def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: + """ + Args: + i: Input tensor + + Returns: + list of outputs and another list of lists with the intermediate features + of each discriminator. + """ + + out: list[torch.Tensor] = [] + intermediate_features: list[list[torch.Tensor]] = [] + for disc in self.children(): + out_d: list[torch.Tensor] = disc(i) + out.append(out_d[-1]) + intermediate_features.append(out_d[:-1]) + + return out, intermediate_features + + +class PatchDiscriminator(nn.Sequential): + """ + Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + + + Args: + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator. + kernel_size: kernel size of the convolution layers + act: activation type and arguments. Defaults to LeakyReLU. + norm: feature normalization type and arguments. Defaults to batch norm. + bias: whether to have a bias term in convolution blocks. Defaults to False. + padding: padding to be applied to the convolutional layers + dropout: proportion of dropout applied, defaults to 0. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + num_layers_d: int = 3, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + padding: int | Sequence[int] = 1, + dropout: float | tuple = 0.0, + last_conv_kernel_size: int | None = None, + ) -> None: + super().__init__() + self.num_layers_d = num_layers_d + self.num_channels = channels + if last_conv_kernel_size is None: + last_conv_kernel_size = kernel_size + + self.add_module( + "initial_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=in_channels, + out_channels=channels, + act=activation, + bias=True, + norm=None, + dropout=dropout, + padding=padding, + strides=2, + ), + ) + + input_channels = channels + output_channels = channels * 2 + + # Initial Layer + for l_ in range(self.num_layers_d): + if l_ == self.num_layers_d - 1: + stride = 1 + else: + stride = 2 + layer = Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + act=activation, + bias=bias, + norm=norm, + dropout=dropout, + padding=padding, + strides=stride, + ) + self.add_module("%d" % l_, layer) + input_channels = output_channels + output_channels = output_channels * 2 + + # Final layer + self.add_module( + "final_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=last_conv_kernel_size, + in_channels=input_channels, + out_channels=out_channels, + bias=True, + conv_only=True, + padding=int((last_conv_kernel_size - 1) / 2), + dropout=0.0, + strides=1, + ), + ) + + self.apply(normal_init) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Args: + x: input tensor + + Returns: + list of intermediate features, with the last element being the output. + """ + out = [x] + for submodel in self.children(): + intermediate_output = submodel(out[-1]) + out.append(intermediate_output) + + return out[1:] diff --git a/monai/networks/nets/quicknat.py b/monai/networks/nets/quicknat.py index cbcccf24d7..bbc4e7e490 100644 --- a/monai/networks/nets/quicknat.py +++ b/monai/networks/nets/quicknat.py @@ -168,6 +168,8 @@ def _get_layer(self, in_channels, out_channels, dilation): def forward(self, input, _): i = 0 result = input + result1 = input # this will not stay this value, needed here for pylint/mypy + for l in self.children(): # ignoring the max (un-)pool and droupout already added in the initial initialization step if isinstance(l, (nn.MaxPool2d, nn.MaxUnpool2d, nn.Dropout2d)): diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py new file mode 100644 index 0000000000..294b121c94 --- /dev/null +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -0,0 +1,480 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.autoencoderkl import Encoder +from monai.utils import ensure_tuple_rep + +__all__ = ["SPADEAutoencoderKL"] + + +class SPADEResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm_num_groups: int, + norm_eps: float, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.nin_shortcut: nn.Module + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h, seg) + h = F.silu(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class SPADEDecoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from channels contain an attention block. + label_nc: number of semantic channels for SPADE normalisation. + with_nonlocal_attn: if True use non-local attention block. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + label_nc: int, + with_nonlocal_attn: bool = True, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + self.label_nc = label_nc + + reversed_block_out_channels = list(reversed(channels)) + + blocks: list[nn.Module] = [] + + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + blocks.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + if not is_final_block: + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=block_in_ch, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + blocks.append( + Upsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=block_in_ch, + out_channels=block_in_ch, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + ) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + if isinstance(block, SPADEResBlock): + x = block(x, seg) + else: + x = block(x) + return x + + +class SPADEAutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + label_nc: number of semantic channels for SPADE normalisation. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + label_nc: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("SPADEAutoencoderKL expects all channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("SPADEAutoencoderKL expects channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`channels`." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + ) + self.decoder = SPADEDecoder( + spatial_dims=spatial_dims, + channels=channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + label_nc=label_nc, + with_nonlocal_attn=with_decoder_nonlocal_attn, + spade_intermediate_channels=spade_intermediate_channels, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + h = self.encoder(x) + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu, seg) + return reconstruction + + def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec: torch.Tensor = self.decoder(z, seg) + return dec + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z, seg) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + image = self.decode(z, seg) + return image diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py new file mode 100644 index 0000000000..75d1687df3 --- /dev/null +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -0,0 +1,934 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from torch import nn + +from monai.networks.blocks import Convolution, SpatialAttentionBlock +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.diffusion_model_unet import ( + DiffusionUnetDownsample, + DiffusionUNetResnetBlock, + SpatialTransformer, + WrappedUpsample, + get_down_block, + get_mid_block, + get_timestep_embedding, + zero_module, +) +from monai.utils import ensure_tuple_rep + +__all__ = ["SPADEDiffusionModelUNet"] + + +class SPADEDiffResBlock(nn.Module): + """ + Residual block with timestep conditioning and SPADE norm. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + label_nc: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=in_channels, + out_channels=in_channels, + interp_mode="nearest", + scale_factor=2.0, + align_corners=None, + ) + elif down: + self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=self.out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + self.skip_connection: nn.Module + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = self.nonlinearity(h) + + if self.upsample is not None: + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h, seg) + h = self.nonlinearity(h) + h = self.conv2(h) + output: torch.Tensor = self.skip_connection(x) + h + return output + + +class SPADEUpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEDiffResBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADEAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEDiffResBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + SpatialAttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states).contiguous() + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADECrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEDiffResBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = DiffusionUNetResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor | None = None, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states, context=context).contiguous() + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_spade_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + label_nc: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + spade_intermediate_channels: int = 128, +) -> nn.Module: + if with_attn: + return SPADEAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + spade_intermediate_channels=spade_intermediate_channels, + ) + elif with_cross_attn: + return SPADECrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + else: + return SPADEUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + spade_intermediate_channels=spade_intermediate_channels, + ) + + +class SPADEDiffusionModelUNet(nn.Module): + """ + UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for + semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at + https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + self.label_nc = label_nc + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] + + is_final_block = i == len(channels) - 1 + + up_block = get_spade_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples: list[torch.Tensor] = [h] + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples.append(down_block_res_sample) + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context) + + # 7. output block + output: torch.Tensor = self.out(h) + + return output diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py new file mode 100644 index 0000000000..9164541f27 --- /dev/null +++ b/monai/networks/nets/spade_network.py @@ -0,0 +1,435 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.layers import Act +from monai.networks.layers.utils import get_act_layer +from monai.utils.enums import StrEnum + +__all__ = ["SPADENet"] + + +class UpsamplingModes(StrEnum): + bicubic = "bicubic" + nearest = "nearest" + bilinear = "bilinear" + + +class SPADENetResBlock(nn.Module): + """ + Creates a Residual Block with SPADE normalisation. + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks + spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks + norm: base normalisation type used on top of SPADE + kernel_size: convolutional kernel size + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.int_channels = min(self.in_channels, self.out_channels) + self.learned_shortcut = self.in_channels != self.out_channels + self.conv_0 = Convolution( + spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None + ) + self.conv_1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.int_channels, + out_channels=self.out_channels, + act=None, + norm=None, + ) + self.activation = get_act_layer(act) + self.norm_0 = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + self.norm_1 = SPADE( + label_nc=label_nc, + norm_nc=self.int_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + + if self.learned_shortcut: + self.conv_s = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + act=None, + norm=None, + kernel_size=1, + ) + self.norm_s = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + dx = self.conv_0(self.activation(self.norm_0(x, seg))) + dx = self.conv_1(self.activation(self.norm_1(dx, seg))) + out = x_s + dx + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + +class SPADEEncoder(nn.Module): + """ + Encoding branch of a VAE compatible with a SPADE-like generator + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + z_dim: latent space dimension of the VAE containing the image sytle information + channels: number of output after each downsampling block + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + of the autoencoder (HxWx[D]) + kernel_size: convolutional kernel size + norm: normalisation layer type + act: activation type + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + z_dim: int, + channels: Sequence[int], + input_shape: Sequence[int], + kernel_size: int = 3, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + ): + super().__init__() + self.in_channels = in_channels + self.z_dim = z_dim + self.channels = channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)): + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels)) + ) + self.input_shape = input_shape + self.latent_spatial_shape = [s_ // (2 ** len(self.channels)) for s_ in self.input_shape] + blocks = [] + ch_init = self.in_channels + for _, ch_value in enumerate(channels): + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=ch_init, + out_channels=ch_value, + strides=2, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + ) + ch_init = ch_value + + self.blocks = nn.ModuleList(blocks) + self.fc_mu = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim + ) + self.fc_var = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return mu, logvar + + def encode(self, x): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return self.reparameterize(mu, logvar) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul(std) + mu + + +class SPADEDecoder(nn.Module): + """ + Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch, + behaving like a GAN, or coupled to a SPADE encoder. + + Args: + label_nc: number of semantic labels + spatial_dims: number of spatial dimensions + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_vae: whether the decoder is going to be coupled to an autoencoder or not (true: yes, false: no) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + + def __init__( + self, + spatial_dims: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + channels: list[int], + z_dim: int | None = None, + is_vae: bool = True, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + super().__init__() + self.is_vae = is_vae + self.out_channels = out_channels + self.label_nc = label_nc + self.num_channels = channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)): + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels)) + ) + self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape] + + if not self.is_vae: + self.conv_init = Convolution( + spatial_dims=spatial_dims, in_channels=label_nc, out_channels=channels[0], kernel_size=kernel_size + ) + elif self.is_vae and z_dim is None: + raise ValueError( + "If the network is used in VAE-GAN mode, parameter z_dim " + "(number of latent channels in the VAE) must be populated." + ) + else: + self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0]) + + self.z_dim = z_dim + blocks = [] + channels.append(self.out_channels) + self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode) + for ch_ind, ch_value in enumerate(channels[:-1]): + blocks.append( + SPADENetResBlock( + spatial_dims=spatial_dims, + in_channels=ch_value, + out_channels=channels[ch_ind + 1], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + kernel_size=kernel_size, + act=act, + ) + ) + + self.blocks = torch.nn.ModuleList(blocks) + self.last_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + padding=(kernel_size - 1) // 2, + kernel_size=kernel_size, + norm=None, + act=last_act, + ) + + def forward(self, seg, z: torch.Tensor | None = None): + """ + Args: + seg: input BxCxHxW[xD] semantic map on which the output is conditioned on + z: latent vector output by the encoder if self.is_vae is True. When is_vae is + False, z is a random noise vector. + + Returns: + + """ + if not self.is_vae: + x = F.interpolate(seg, size=tuple(self.latent_spatial_shape)) + x = self.conv_init(x) + else: + if ( + z is None and self.z_dim is not None + ): # Even though this network is a VAE (self.is_vae), you should be able to sample from noise as well. + z = torch.randn(seg.size(0), self.z_dim, dtype=torch.float32, device=seg.get_device()) + x = self.fc(z) + x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape) + + for res_block in self.blocks: + x = res_block(x, seg) + x = self.upsampling(x) + + x = self.last_conv(x) + return x + + +class SPADENet(nn.Module): + """ + SPADE Network, implemented based on the code by Park, T et al. in + "Semantic Image Synthesis with Spatially-Adaptive Normalization" + (https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + channels: list[int], + z_dim: int | None = None, + is_vae: bool = True, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + super().__init__() + self.is_vae = is_vae + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.label_nc = label_nc + self.input_shape = input_shape + + if self.is_vae: + if z_dim is None: + ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.") + else: + self.encoder = SPADEEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + z_dim=z_dim, + channels=channels, + input_shape=input_shape, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + + decoder_channels = channels + decoder_channels.reverse() + + self.decoder = SPADEDecoder( + spatial_dims=spatial_dims, + out_channels=out_channels, + label_nc=label_nc, + input_shape=input_shape, + channels=decoder_channels, + z_dim=z_dim, + is_vae=is_vae, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + act=act, + last_act=last_act, + kernel_size=kernel_size, + upsampling_mode=upsampling_mode, + ) + + def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None): + z = None + if self.is_vae: + z_mu, z_logvar = self.encoder(x) + z = self.encoder.reparameterize(z_mu, z_logvar) + return self.decoder(seg, z), z_mu, z_logvar + else: + return (self.decoder(seg, z),) + + def encode(self, x: torch.Tensor): + if self.is_vae: + return self.encoder.encode(x) + else: + return None + + def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None): + return self.decoder(seg, z) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 6f96dfd291..3900c866b3 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -347,7 +347,7 @@ def window_partition(x, window_size): x: input tensor. window_size: local window size. """ - x_shape = x.size() + x_shape = x.size() # length 4 or 5 only if len(x_shape) == 5: b, d, h, w, c = x_shape x = x.view( @@ -363,10 +363,11 @@ def window_partition(x, window_size): windows = ( x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c) ) - elif len(x_shape) == 4: + else: # if len(x_shape) == 4: b, h, w, c = x.shape x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c) + return windows @@ -613,7 +614,7 @@ def forward_part1(self, x, mask_matrix): _, dp, hp, wp, _ = x.shape dims = [b, dp, hp, wp] - elif len(x_shape) == 4: + else: # elif len(x_shape) == 4 b, h, w, c = x.shape window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) pad_l = pad_t = 0 diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py new file mode 100644 index 0000000000..1af725abda --- /dev/null +++ b/monai/networks/nets/transformer.py @@ -0,0 +1,157 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.networks.blocks import TransformerBlock + +__all__ = ["DecoderOnlyTransformer"] + + +class AbsolutePositionalEmbedding(nn.Module): + """Absolute positional embedding. + + Args: + max_seq_len: Maximum sequence length. + embedding_dim: Dimensionality of the embedding. + """ + + def __init__(self, max_seq_len: int, embedding_dim: int) -> None: + super().__init__() + self.max_seq_len = max_seq_len + self.embedding_dim = embedding_dim + self.embedding = nn.Embedding(max_seq_len, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len = x.size() + positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) + embedding: torch.Tensor = self.embedding(positions) + return embedding + + +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. + embedding_dropout_rate: Dropout rate for the embedding. + """ + + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, + embedding_dropout_rate: float = 0.0, + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + self.with_cross_attention = with_cross_attention + + self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) + self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) + self.embedding_dropout = nn.Dropout(embedding_dropout_rate) + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + hidden_size=attn_layers_dim, + mlp_dim=attn_layers_dim * 4, + num_heads=attn_layers_heads, + dropout_rate=0.0, + qkv_bias=False, + causal=True, + sequence_length=max_seq_len, + with_cross_attention=with_cross_attention, + ) + for _ in range(attn_layers_depth) + ] + ) + + self.to_logits = nn.Linear(attn_layers_dim, num_tokens) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + tok_emb = self.token_embeddings(x) + pos_emb = self.position_embeddings(x) + x = self.embedding_dropout(tok_emb + pos_emb) + + for block in self.blocks: + x = block(x, context=context) + logits: torch.Tensor = self.to_logits(x) + return logits + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a DecoderOnlyTransformer trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old DecoderOnlyTransformer model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( + [ + old_state_dict[f"{block}.attn.to_q.weight"], + old_state_dict[f"{block}.attn.to_k.weight"], + old_state_dict[f"{block}.attn.to_v.weight"], + ], + dim=0, + ) + + # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 + for k in old_state_dict: + if "norm2" in k: + new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k] + if "norm3" in k: + new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k] + + self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py new file mode 100644 index 0000000000..f198bfbb2b --- /dev/null +++ b/monai/networks/nets/vqvae.py @@ -0,0 +1,472 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Tuple + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer +from monai.utils import ensure_tuple_rep + +__all__ = ["VQVAE"] + + +class VQVAEResidualUnit(nn.Module): + """ + Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving + Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf). + + The original implementation that can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150. + + Args: + spatial_dims: number of spatial spatial_dims of the input data. + in_channels: number of input channels. + num_res_channels: number of channels in the residual layers. + act: activation type and arguments. Defaults to RELU. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_channels: int, + act: tuple | str | None = Act.RELU, + dropout: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_res_channels = num_res_channels + self.act = act + self.dropout = dropout + self.bias = bias + + self.conv1 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=self.num_res_channels, + adn_ordering="DA", + act=self.act, + dropout=self.dropout, + bias=self.bias, + ) + + self.conv2 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_res_channels, + out_channels=self.in_channels, + bias=self.bias, + conv_only=True, + ) + + def forward(self, x): + return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) + + +class Encoder(nn.Module): + """ + Encoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of channels in the latent space (embedding_dim). + channels: sequence containing the number of channels at each level of the encoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + dropout: dropout ratio. + act: activation type and arguments. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + downsample_parameters: Sequence[Tuple[int, int, int, int]], + dropout: float, + act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.downsample_parameters = downsample_parameters + self.dropout = dropout + self.act = act + + blocks: list[nn.Module] = [] + + for i in range(len(self.channels)): + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels if i == 0 else self.channels[i - 1], + out_channels=self.channels[i], + strides=self.downsample_parameters[i][0], + kernel_size=self.downsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=None if i == 0 else self.dropout, + dropout_dim=1, + dilation=self.downsample_parameters[i][2], + padding=self.downsample_parameters[i][3], + ) + ) + + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + num_res_channels=self.num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.channels[len(self.channels) - 1], + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Decoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of channels in the latent space (embedding_dim). + out_channels: number of output channels. + channels: sequence containing the number of channels at each level of the decoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + dropout: dropout ratio. + act: activation type and arguments. + output_act: activation type and arguments for the output. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + upsample_parameters: Sequence[Tuple[int, int, int, int, int]], + dropout: float, + act: tuple | str | None, + output_act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.upsample_parameters = upsample_parameters + self.dropout = dropout + self.act = act + self.output_act = output_act + + reversed_num_channels = list(reversed(self.channels)) + + blocks: list[nn.Module] = [] + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=reversed_num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + reversed_num_res_channels = list(reversed(self.num_res_channels)) + for i in range(len(self.channels)): + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + num_res_channels=reversed_num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1], + strides=self.upsample_parameters[i][0], + kernel_size=self.upsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=self.dropout if i != len(self.channels) - 1 else None, + norm=None, + dilation=self.upsample_parameters[i][2], + conv_only=i == len(self.channels) - 1, + is_transposed=True, + padding=self.upsample_parameters[i][3], + output_padding=self.upsample_parameters[i][4], + ) + ) + + if self.output_act: + blocks.append(Act[self.output_act]()) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class VQVAE(nn.Module): + """ + Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative + Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) + + The original implementation can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of output channels. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + num_res_layers: number of sequential residual layers at each level. + channels: number of channels at each level. + num_res_channels: number of channels in the residual layers at each level. + num_embeddings: VectorQuantization number of atomic elements in the codebook. + embedding_dim: VectorQuantization number of channels of the input and atomic elements. + commitment_cost: VectorQuantization commitment_cost. + decay: VectorQuantization decay. + epsilon: VectorQuantization epsilon. + act: activation type and arguments. + dropout: dropout ratio. + output_act: activation type and arguments for the output. + ddp_sync: whether to synchronize the codebook across processes. + use_checkpointing if True, use activation checkpointing to save memory. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int] = (96, 96, 192), + num_res_layers: int = 3, + num_res_channels: Sequence[int] | int = (96, 96, 192), + downsample_parameters: Sequence[Tuple[int, int, int, int]] | Tuple[int, int, int, int] = ( + (2, 4, 1, 1), + (2, 4, 1, 1), + (2, 4, 1, 1), + ), + upsample_parameters: Sequence[Tuple[int, int, int, int, int]] | Tuple[int, int, int, int, int] = ( + (2, 4, 1, 1, 0), + (2, 4, 1, 1, 0), + (2, 4, 1, 1, 0), + ), + num_embeddings: int = 32, + embedding_dim: int = 64, + embedding_init: str = "normal", + commitment_cost: float = 0.25, + decay: float = 0.5, + epsilon: float = 1e-5, + dropout: float = 0.0, + act: tuple | str | None = Act.RELU, + output_act: tuple | str | None = None, + ddp_sync: bool = True, + use_checkpointing: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_dims = spatial_dims + self.channels = channels + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.use_checkpointing = use_checkpointing + + if isinstance(num_res_channels, int): + num_res_channels = ensure_tuple_rep(num_res_channels, len(channels)) + + if len(num_res_channels) != len(channels): + raise ValueError( + "`num_res_channels` should be a single integer or a tuple of integers with the same length as " + "`num_channls`." + ) + if all(isinstance(values, int) for values in upsample_parameters): + upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels) + else: + upsample_parameters_tuple = upsample_parameters + + if all(isinstance(values, int) for values in downsample_parameters): + downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels) + else: + downsample_parameters_tuple = downsample_parameters + + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple): + raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple): + raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + for parameter in downsample_parameters_tuple: + if len(parameter) != 4: + raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") + + for parameter in upsample_parameters_tuple: + if len(parameter) != 5: + raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") + + if len(downsample_parameters_tuple) != len(channels): + raise ValueError( + "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + if len(upsample_parameters_tuple) != len(channels): + raise ValueError( + "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embedding_dim, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + downsample_parameters=downsample_parameters_tuple, + dropout=dropout, + act=act, + ) + + self.decoder = Decoder( + spatial_dims=spatial_dims, + in_channels=embedding_dim, + out_channels=out_channels, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + upsample_parameters=upsample_parameters_tuple, + dropout=dropout, + act=act, + output_act=output_act, + ) + + self.quantizer = VectorQuantizer( + quantizer=EMAQuantizer( + spatial_dims=spatial_dims, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_cost=commitment_cost, + decay=decay, + epsilon=epsilon, + embedding_init=embedding_init, + ddp_sync=ddp_sync, + ) + ) + + def encode(self, images: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) + else: + output = self.encoder(images) + return output + + def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x_loss, x = self.quantizer(encodings) + return x, x_loss + + def decode(self, quantizations: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) + else: + output = self.decoder(quantizations) + return output + + def index_quantize(self, images: torch.Tensor) -> torch.Tensor: + return self.quantizer.quantize(self.encode(images=images)) + + def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.decode(self.quantizer.embed(embedding_indices)) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + quantizations, quantization_losses = self.quantize(self.encode(images)) + reconstruction = self.decode(quantizations) + + return reconstruction, quantization_losses + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z = self.encode(x) + e, _ = self.quantize(z) + return e + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + e, _ = self.quantize(z) + image = self.decode(e) + return image diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py new file mode 100644 index 0000000000..29e9020d65 --- /dev/null +++ b/monai/networks/schedulers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .ddim import DDIMScheduler +from .ddpm import DDPMScheduler +from .pndm import PNDMScheduler +from .scheduler import NoiseSchedules, Scheduler diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py new file mode 100644 index 0000000000..2a0121d063 --- /dev/null +++ b/monai/networks/schedulers/ddim.py @@ -0,0 +1,294 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from .ddpm import DDPMPredictionType +from .scheduler import Scheduler + +DDIMPredictionType = DDPMPredictionType + + +class DDIMScheduler(Scheduler): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion + Implicit Models" https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. + For the final step there is no previous alpha. When this option is `True` the previous alpha product is + fixed to `1`, otherwise it uses the value of alpha at step 0. + steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type: member of DDPMPredictionType + clip_sample_min: minimum clipping value when clip_sample equals True + clip_sample_max: maximum clipping value when clip_sample equals True + schedule_args: arguments to pass to the schedule function + + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = DDIMPredictionType.EPSILON, + clip_sample_min: float = -1.0, + clip_sample_max: float = 1.0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in DDIMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") + + self.prediction_type = prediction_type + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) + + self.clip_sample = clip_sample + self.clip_sample_values = [clip_sample_min, clip_sample_max] + self.steps_offset = steps_offset + + # default the number of inference timesteps to the number of train steps + self.num_inference_steps: int + self.set_timesteps(self.num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + if self.steps_offset >= step_ratio: + raise ValueError( + f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " + f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" + f" the max train timestep." + ) + + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + eta: weight of noise for added noise in diffusion step. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # predefinitions satisfy pylint/mypy, these values won't be ultimately used + pred_original_sample = sample + pred_epsilon = model_output + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance**0.5 + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon + + # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample + + def reversed_step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_post_sample -> "x_t+1" + + # 1. get previous step value (=t+1) + prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas at timestep t+1 + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # predefinitions satisfy pylint/mypy, these values won't be ultimately used + pred_original_sample = sample + pred_epsilon = model_output + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return pred_post_sample, pred_original_sample diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py new file mode 100644 index 0000000000..93ad833031 --- /dev/null +++ b/monai/networks/schedulers/ddpm.py @@ -0,0 +1,250 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDPMVarianceType(StrEnum): + """ + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + to the denoised sample. + """ + + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED = "learned" + LEARNED_RANGE = "learned_range" + + +class DDPMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDPMScheduler(Scheduler): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" + https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + variance_type: member of DDPMVarianceType + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + prediction_type: member of DDPMPredictionType + clip_sample_min: minimum clipping value when clip_sample equals True + clip_sample_max: maximum clipping value when clip_sample equals True + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + variance_type: str = DDPMVarianceType.FIXED_SMALL, + clip_sample: bool = True, + prediction_type: str = DDPMPredictionType.EPSILON, + clip_sample_min: float = -1.0, + clip_sample_max: float = 1.0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if variance_type not in DDPMVarianceType.__members__.values(): + raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") + + if prediction_type not in DDPMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") + + self.clip_sample = clip_sample + self.clip_sample_values = [clip_sample_min, clip_sample_max] + self.variance_type = variance_type + self.prediction_type = prediction_type + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + + def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the variance of the posterior at timestep t. + + Args: + timestep: current timestep. + predicted_variance: variance predicted by the model. + + Returns: + Returns the variance + """ + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] + # hacks - were probably added for training stability + if self.variance_type == DDPMVarianceType.FIXED_SMALL: + variance = torch.clamp(variance, min=1e-20) + elif self.variance_type == DDPMVarianceType.FIXED_LARGE: + variance = self.betas[timestep] + elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None: + return predicted_variance + elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None: + min_log = variance + max_log = self.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.prediction_type == DDPMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == DDPMPredictionType.SAMPLE: + pred_original_sample = model_output + elif self.prediction_type == DDPMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + + # 3. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t + current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if timestep > 0: + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample diff --git a/monai/networks/schedulers/pndm.py b/monai/networks/schedulers/pndm.py new file mode 100644 index 0000000000..c0728bbdff --- /dev/null +++ b/monai/networks/schedulers/pndm.py @@ -0,0 +1,316 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class PNDMPredictionType(StrEnum): + """ + Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + V_PREDICTION = "v_prediction" + + +class PNDMScheduler(Scheduler): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., + "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + skip_prk_steps: + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms step. + set_alpha_to_one: + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + prediction_type: member of DDPMPredictionType + steps_offset: + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + prediction_type: str = PNDMPredictionType.EPSILON, + steps_offset: int = 0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in PNDMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") + + self.prediction_type = prediction_type + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + self.skip_prk_steps = skip_prk_steps + self.steps_offset = steps_offset + + # running values + self.cur_model_output = torch.Tensor() + self.counter = 0 + self.cur_sample = torch.Tensor() + self.ets: list = [] + + # default the number of inference timesteps to the number of train steps + self.set_timesteps(num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) + self._timesteps += self.steps_offset + + if self.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = self._timesteps[::-1] + + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + # update num_inference_steps - necessary if we use prk steps + self.num_inference_steps = len(self.timesteps) + + self.ets = [] + self.counter = 0 + + def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + Returns: + pred_prev_sample: Predicted previous sample + """ + # return a tuple for consistency with samplers that return (previous pred, original sample pred) + + if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None + + def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = timestep - diff_to_prev + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output = 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = torch.Tensor() + + # cur_sample should not be an empty torch.Tensor() + cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample + + prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + ) + + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + if self.counter != 1: + self.ets = self.ets[-3:] + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = torch.Tensor() + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.prediction_type == PNDMPredictionType.V_PREDICTION: + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py new file mode 100644 index 0000000000..acdccc60de --- /dev/null +++ b/monai/networks/schedulers/scheduler.py @@ -0,0 +1,205 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.utils import ComponentStore, unsqueeze_right + +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") + + +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): + """ + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 (see referenced paper) + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod /= alphas_cumprod[0].item() + alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) + betas = 1.0 - alphas + return betas, alphas, alphas_cumprod[:-1] + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + + from monai.networks.schedulers import NoiseSchedules, DDPMScheduler + + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") + + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function + """ + + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: + super().__init__() + schedule_args["num_train_timesteps"] = num_train_timesteps + noise_sched = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched + else: + self.betas = noise_sched + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.num_train_timesteps = num_train_timesteps + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps: int | None = None + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim + ) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim + ) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 152911f443..6a97434215 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -42,6 +42,7 @@ "predict_segmentation", "normalize_transform", "to_norm_affine", + "CastTempType", "normal_init", "icnr_init", "pixelshuffle", @@ -1167,3 +1168,24 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.") logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") + + +class CastTempType(nn.Module): + """ + Cast the input tensor to a temporary type before applying the submodule, and then cast it back to the initial type. + """ + + def __init__(self, initial_type, temporary_type, submodule): + super().__init__() + self.initial_type = initial_type + self.temporary_type = temporary_type + self.submodule = submodule + + def forward(self, x): + dtype = x.dtype + if dtype == self.initial_type: + x = x.to(self.temporary_type) + x = self.submodule(x) + if dtype == self.initial_type: + x = x.to(self.initial_type) + return x diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index a7436bda84..4bf6cff649 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -87,12 +87,14 @@ def apply(self, data: torch.Tensor): def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): data_t = convert_to_tensor(data, track_meta=get_track_meta()) + labels_t = data_t # will not stay this value, needed to satisfy pylint/mypy if labels is not None: labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) if randomize: self.randomize() if labels is None: return convert_to_dst_type(self.apply(data_t), dst=data)[0] + return ( convert_to_dst_type(self.apply(data_t), dst=data)[0], convert_to_dst_type(self.apply(labels_t), dst=labels)[0], @@ -149,11 +151,13 @@ def apply_on_labels(self, labels: torch.Tensor): def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): data_t = convert_to_tensor(data, track_meta=get_track_meta()) + augmented_label = None if labels is not None: labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) if randomize: self.randomize(data) augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0] + if labels is not None: augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0] return (augmented, augmented_label) if labels is not None else augmented diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index 4b5990abd3..a29fd4dbf9 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -269,11 +269,9 @@ def update_docstring(code_path, transform_name): def pre_process_data(data, ndim, is_map, is_post): - """If transform requires 2D data, then convert to 2D""" + """If transform requires 2D data, then convert to 2D by selecting the middle of the last dimension.""" if ndim == 2: - for k in keys: - data[k] = data[k][..., data[k].shape[-1] // 2] - + data = {k: v[..., v.shape[-1] // 2] for k, v in data.items()} if is_map: return data return data[CommonKeys.LABEL] if is_post else data[CommonKeys.IMAGE] diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 2c32eb2cf4..03fa1ceed1 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -126,6 +126,7 @@ version_leq, ) from .nvtx import Range +from .ordering import Ordering from .profiling import ( PerfContext, ProfileHandler, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index e8a46ecc61..40370ca2c6 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -118,6 +118,7 @@ def star_zip_with(op, *vals): T = TypeVar("T") +NT = TypeVar("NT", np.ndarray, torch.Tensor) @overload @@ -907,11 +908,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_right(arr: NT, ndim: int) -> NT: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_left(arr: NT, ndim: int) -> NT: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] diff --git a/monai/utils/ordering.py b/monai/utils/ordering.py new file mode 100644 index 0000000000..1be61f98ab --- /dev/null +++ b/monai/utils/ordering.py @@ -0,0 +1,207 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np + +from monai.utils.enums import OrderingTransformations, OrderingType + + +class Ordering: + """ + Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with + one of the following transformations: + Reflection (see np.flip for more details). + Transposition (see np.transpose for more details). + 90-degree rotation (see np.rot90 for more details). + + The transformations are applied in the order specified by the transformation_order parameter. + + Args: + ordering_type: The ordering type. One of the following: + - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from + top to bottom. Also called a row major ordering. + - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like + pattern from top left towards right gowing in a spiral towards the center. + - random': The image is projected into a 1D sequence by randomly shuffling the image. + spatial_dims: The number of spatial dimensions of the image. + dimensions: The dimensions of the image. + reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension. + transpositions_axes: A tuple of tuples indicating the axes to transpose the image along. + rot90_axes: A tuple of tuples indicating the axes to rotate the image along. + transformation_order: The order in which to apply the transformations. + """ + + def __init__( + self, + ordering_type: str, + spatial_dims: int, + dimensions: tuple[int, int, int] | tuple[int, int, int, int], + reflected_spatial_dims: tuple[bool, bool] | None = None, + transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] | None = None, + rot90_axes: tuple[tuple[int, int], ...] | None = None, + transformation_order: tuple[str, ...] = ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + ) -> None: + super().__init__() + self.ordering_type = ordering_type + + if self.ordering_type not in list(OrderingType): + raise ValueError( + f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}." + ) + + self.spatial_dims = spatial_dims + self.dimensions = dimensions + + if len(dimensions) != self.spatial_dims + 1: + raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.") + + self.reflected_spatial_dims = reflected_spatial_dims + self.transpositions_axes = transpositions_axes + self.rot90_axes = rot90_axes + if len(set(transformation_order)) != len(transformation_order): + raise ValueError(f"No duplicates are allowed. Received {transformation_order}.") + + for transformation in transformation_order: + if transformation not in list(OrderingTransformations): + raise ValueError( + f"Valid transformations are {list(OrderingTransformations)} but received {transformation}." + ) + self.transformation_order = transformation_order + + self.template = self._create_template() + self._sequence_ordering = self._create_ordering() + self._revert_sequence_ordering = np.argsort(self._sequence_ordering) + + def __call__(self, x: np.ndarray) -> np.ndarray: + x = x[self._sequence_ordering] + + return x + + def get_sequence_ordering(self) -> np.ndarray: + return self._sequence_ordering + + def get_revert_sequence_ordering(self) -> np.ndarray: + return self._revert_sequence_ordering + + def _create_ordering(self) -> np.ndarray: + self.template = self._transform_template() + order = self._order_template(template=self.template) + + return order + + def _create_template(self) -> np.ndarray: + spatial_dimensions = self.dimensions[1:] + template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions) + + return template + + def _transform_template(self) -> np.ndarray: + for transformation in self.transformation_order: + if transformation == OrderingTransformations.TRANSPOSE.value: + self.template = self._transpose_template(template=self.template) + elif transformation == OrderingTransformations.ROTATE_90.value: + self.template = self._rot90_template(template=self.template) + elif transformation == OrderingTransformations.REFLECT.value: + self.template = self._flip_template(template=self.template) + + return self.template + + def _transpose_template(self, template: np.ndarray) -> np.ndarray: + if self.transpositions_axes is not None: + for axes in self.transpositions_axes: + template = np.transpose(template, axes=axes) + + return template + + def _flip_template(self, template: np.ndarray) -> np.ndarray: + if self.reflected_spatial_dims is not None: + for axis, to_reflect in enumerate(self.reflected_spatial_dims): + template = np.flip(template, axis=axis) if to_reflect else template + + return template + + def _rot90_template(self, template: np.ndarray) -> np.ndarray: + if self.rot90_axes is not None: + for axes in self.rot90_axes: + template = np.rot90(template, axes=axes) + + return template + + def _order_template(self, template: np.ndarray) -> np.ndarray: + depths = None + if self.spatial_dims == 2: + rows, columns = template.shape[0], template.shape[1] + else: + rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2]) + + sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths) + + ordering = np.array([template[tuple(e)] for e in sequence]) + + return ordering + + @staticmethod + def raster_scan_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + for c in range(cols): + if depths is not None: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + + return idx_np + + @staticmethod + def s_curve_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1) + for c in col_idx: + if depths: + depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1) + + for d in depth_idx: + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + + return idx_np + + @staticmethod + def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + for c in range(cols): + if depths: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + np.random.shuffle(idx_np) + + return idx_np diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py index 78c6ca06bc..732ad13b83 100644 --- a/tests/hvd_evenly_divisible_all_gather.py +++ b/tests/hvd_evenly_divisible_all_gather.py @@ -30,10 +30,10 @@ def test_data(self): self._run() def _run(self): - if hvd.rank() == 0: - data1 = torch.tensor([[1, 2], [3, 4]]) - data2 = torch.tensor([[1.0, 2.0]]) - data3 = torch.tensor(7) + # if hvd.rank() == 0: + data1 = torch.tensor([[1, 2], [3, 4]]) + data2 = torch.tensor([[1.0, 2.0]]) + data3 = torch.tensor(7) if hvd.rank() == 1: data1 = torch.tensor([[5, 6]]) diff --git a/tests/min_tests.py b/tests/min_tests.py index 8128bb7b84..3a143df84b 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -154,6 +154,7 @@ def run_testsuit(): "test_plot_2d_or_3d_image", "test_png_rw", "test_prepare_batch_default", + "test_prepare_batch_diffusion", "test_prepare_batch_extra_input", "test_prepare_batch_hovernet", "test_rand_grid_patch", diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py new file mode 100644 index 0000000000..d15cb79084 --- /dev/null +++ b/tests/test_autoencoderkl.py @@ -0,0 +1,337 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.apps import download_url +from monai.networks import eval_mode +from monai.networks.nets import AutoencoderKL +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, testing_data_config + +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") +_, has_einops = optional_import("einops") + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +CASES_NO_ATTENTION = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + +CASES_ATTENTION = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + +if has_einops: + CASES = CASES_NO_ATTENTION + CASES_ATTENTION +else: + CASES = CASES_NO_ATTENTION + + +class TestAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + @parameterized.expand(CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_reconstruction(self): + input_param, input_shape, expected_shape, _ = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_encode(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, True), + num_res_blocks=1, + norm_num_groups=4, + ).to(device) + + tmpdir = tempfile.mkdtemp() + key = "autoencoderkl_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "autoencoderkl_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py new file mode 100644 index 0000000000..4746c7ce22 --- /dev/null +++ b/tests/test_controlnet.py @@ -0,0 +1,215 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.apps import download_url +from monai.networks import eval_mode +from monai.networks.nets.controlnet import ControlNet +from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config + +_, has_einops = optional_import("einops") +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 4, + }, + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "num_head_channels": 4, + "attention_levels": (False, False, False), + "norm_num_groups": 4, + "resblock_updown": True, + }, + (1, 4, 4, 4, 4), + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + }, + (1, 8, 4, 4), + ], +] + + +class TestControlNet(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") + def test_shape_unconditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond, context=torch.rand((1, 1, 3))) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = ControlNet( + spatial_dims=2, + in_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + resblock_updown=True, + ) + + tmpdir = tempfile.mkdtemp() + key = "controlnet_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "controlnet_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py new file mode 100644 index 0000000000..e3b0aeb5a2 --- /dev/null +++ b/tests/test_controlnet_inferers.py @@ -0,0 +1,1310 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") +_, has_einops = optional_import("einops") + + +CNDM_TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "norm_num_groups": 8, + "num_res_blocks": 1, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8, 8), + ], +] +LATENT_CNDM_TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], +] +LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_call(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_intermediates(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_ddim_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_get_likelihood(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + @unittest.skipUnless(has_scipy, "Requires scipy library.") + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + +class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_prediction_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_intermediates( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + controlnet=controlnet, + cn_cond=mask, + ) + + # TODO: this isn't correct, should the above produce intermediates as well? + # This test has always passed so is this branch not being used? + intermediates = None + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + controlnet=controlnet, + cn_cond=mask, + ) + + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_get_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_resample_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_prediction_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + ) + self.assertEqual(prediction.shape, latent_shape) + + @skipUnless(has_einops, "Requires einops") + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + controlnet = ControlNet( + spatial_dims=2, + in_channels=1, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + conditioning_embedding_num_channels=[16], + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + mask = torch.randn((1, 1, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py new file mode 100644 index 0000000000..4ab0ab1823 --- /dev/null +++ b/tests/test_crossattention.py @@ -0,0 +1,131 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.crossattention import CrossAttentionBlock +from monai.networks.layers.factories import RelPosEmbedding +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASE_CABLOCK = [] +for dropout_rate in np.linspace(0, 1, 4): + for hidden_size in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_CABLOCK.append(test_case) + + +class TestResBlock(unittest.TestCase): + + @parameterized.expand(TEST_CASE_CABLOCK) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = CrossAttentionBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0) + + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + + @skipUnless(has_einops, "Requires einops") + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + @skipUnless(has_einops, "Requires einops") + def test_inner_dim_different(self): + CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + def test_causal_no_sequence_length(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + @skipUnless(has_einops, "Requires einops") + def test_causal(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True + ) + input_shape = (1, 16, 128) + block(torch.randn(input_shape)) + # check upper triangular part of the attention matrix is zero + assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + @skipUnless(has_einops, "Requires einops") + def test_context_input(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + ) + input_shape = (1, 16, 128) + block(torch.randn(input_shape), context=torch.randn(1, 3, 12)) + + @skipUnless(has_einops, "Requires einops") + def test_context_wrong_input_size(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + ) + input_shape = (1, 16, 128) + with self.assertRaises(RuntimeError): + block(torch.randn(input_shape), context=torch.randn(1, 3, 24)) + + @skipUnless(has_einops, "Requires einops") + def test_access_attn_matrix(self): + # input format + hidden_size = 128 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + + # be not able to access the matrix + no_matrix_acess_blk = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate + ) + no_matrix_acess_blk(torch.randn(input_shape)) + assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) + # no of elements is zero + assert no_matrix_acess_blk.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py new file mode 100644 index 0000000000..7f37025d3c --- /dev/null +++ b/tests/test_diffusion_inferer.py @@ -0,0 +1,236 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.inferers import DiffusionInferer +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") +_, has_einops = optional_import("einops") + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_call(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_intermediates(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_ddim_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + @unittest.skipUnless(has_scipy, "Requires scipy library.") + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_call_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat" + ) + self.assertEqual(sample.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py new file mode 100644 index 0000000000..7f764d85de --- /dev/null +++ b/tests/test_diffusion_model_unet.py @@ -0,0 +1,585 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.apps import download_url +from monai.networks import eval_mode +from monai.networks.nets import DiffusionModelUNet +from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config + +_, has_einops = optional_import("einops") + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + } + ], +] + +DROPOUT_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], +] + +DROPOUT_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0, + } + ] +] + + +class TestDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + @skipUnless(has_einops, "Requires einops") + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_timestep_with_wrong_shape(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) + + @skipUnless(has_einops, "Requires einops") + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + @skipUnless(has_einops, "Requires einops") + def test_context_with_conditioning_none(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models_class_conditioning(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + @skipUnless(has_einops, "Requires einops") + def test_conditioned_models_no_class_labels(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + + @skipUnless(has_einops, "Requires einops") + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") + def test_conditioned_2d_models_shape(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + # Test dropout specification for cross-attention blocks + @parameterized.expand(DROPOUT_WRONG) + def test_wrong_dropout(self, input_param): + with self.assertRaises(ValueError): + _ = DiffusionModelUNet(**input_param) + + @parameterized.expand(DROPOUT_OK) + @skipUnless(has_einops, "Requires einops") + def test_right_dropout(self, input_param): + _ = DiffusionModelUNet(**input_param) + + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + cross_attention_dim=3, + transformer_num_layers=1, + norm_num_groups=8, + ) + + tmpdir = tempfile.mkdtemp() + key = "diffusion_model_unet_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "diffusion_model_unet_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index 0c9ad5869e..fe046a4cdf 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -50,9 +50,10 @@ class TestEnsureChannelFirst(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) @unittest.skipUnless(has_itk, "itk not installed") def test_load_nifti(self, input_param, filenames, original_channel_dim): - if original_channel_dim is None: - test_image = np.random.rand(8, 8, 8) - elif original_channel_dim == -1: + # if original_channel_dim is None + test_image = np.random.rand(8, 8, 8) + + if original_channel_dim == -1: test_image = np.random.rand(8, 8, 8, 1) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 63a437894b..e9effad951 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -35,9 +35,10 @@ class TestEnsureChannelFirstd(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_load_nifti(self, input_param, filenames, original_channel_dim): - if original_channel_dim is None: - test_image = np.random.rand(8, 8, 8) - elif original_channel_dim == -1: + # if original_channel_dim is None: + test_image = np.random.rand(8, 8, 8) + + if original_channel_dim == -1: test_image = np.random.rand(8, 8, 8, 1) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index d6d26c7e23..f1d45ba48f 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -27,10 +27,10 @@ def test_data(self): self._run() def _run(self): - if dist.get_rank() == 0: - data1 = torch.tensor([[1, 2], [3, 4]]) - data2 = torch.tensor([[1.0, 2.0]]) - data3 = torch.tensor(7) + # if dist.get_rank() == 0 + data1 = torch.tensor([[1, 2], [3, 4]]) + data2 = torch.tensor([[1.0, 2.0]]) + data3 = torch.tensor(7) if dist.get_rank() == 1: data1 = torch.tensor([[5, 6]]) diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 46c9ad27d7..2e12b08aa9 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -51,8 +51,10 @@ def _val_func(engine, batch): engine = Engine(_val_func) + # define here to ensure symbol always exists regardless of the following if conditions + data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}] + if my_rank == 0: - data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}] @engine.on(Events.EPOCH_COMPLETED) def _save_metrics0(engine): diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 879a74969d..b91ba3f6b7 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -19,11 +19,11 @@ from monai.networks.layers import HilbertTransform from monai.utils import OptionalImportError -from tests.utils import SkipIfModule, SkipIfNoModule, skip_if_no_cuda +from tests.utils import SkipIfModule, SkipIfNoModule def create_expected_numpy_output(input_datum, **kwargs): - x = np.fft.fft(input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), **kwargs) + x = np.fft.fft(input_datum.cpu().numpy(), **kwargs) f = np.fft.fftfreq(x.shape[kwargs["axis"]]) u = np.heaviside(f, 0.5) new_dims_before = kwargs["axis"] @@ -44,19 +44,15 @@ def create_expected_numpy_output(input_datum, **kwargs): # CPU TEST DATA cpu_input_data = {} -cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu).unsqueeze(0).unsqueeze(0) -cpu_input_data["2D"] = ( - torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0).unsqueeze(0) -) -cpu_input_data["3D"] = ( - torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu) - .unsqueeze(0) - .unsqueeze(0) -) -cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0) +cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu)[None, None] +cpu_input_data["2D"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None, None] +cpu_input_data["3D"] = torch.as_tensor( + np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu +)[None, None] +cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None] cpu_input_data["2D 2CH"] = torch.as_tensor( np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu -).unsqueeze(0) +)[None] # SINGLE-CHANNEL CPU VALUE TESTS @@ -97,64 +93,21 @@ def create_expected_numpy_output(input_datum, **kwargs): 1e-5, # absolute tolerance ] +TEST_CASES_CPU = [ + TEST_CASE_1D_SINE_CPU, + TEST_CASE_2D_SINE_CPU, + TEST_CASE_3D_SINE_CPU, + TEST_CASE_1D_2CH_SINE_CPU, + TEST_CASE_2D_2CH_SINE_CPU, +] + # GPU TEST DATA if torch.cuda.is_available(): gpu = torch.device("cuda") - - gpu_input_data = {} - gpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=gpu).unsqueeze(0).unsqueeze(0) - gpu_input_data["2D"] = ( - torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0).unsqueeze(0) - ) - gpu_input_data["3D"] = ( - torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu) - .unsqueeze(0) - .unsqueeze(0) - ) - gpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0) - gpu_input_data["2D 2CH"] = torch.as_tensor( - np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu - ).unsqueeze(0) - - # SINGLE CHANNEL GPU VALUE TESTS - - TEST_CASE_1D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["1D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["1D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - TEST_CASE_2D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["2D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["2D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - TEST_CASE_3D_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["3D"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["3D"], axis=2), # Expected output: FFT of signal - 1e-5, # absolute tolerance - ] - - # MULTICHANNEL GPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS - - TEST_CASE_1D_2CH_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["1D 2CH"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["1D 2CH"], axis=2), - 1e-5, # absolute tolerance - ] - - TEST_CASE_2D_2CH_SINE_GPU = [ - {}, # args (empty, so use default) - gpu_input_data["2D 2CH"], # Input data: Random 1D signal - create_expected_numpy_output(gpu_input_data["2D 2CH"], axis=2), - 1e-5, # absolute tolerance - ] + TEST_CASES_GPU = [[args, image.to(gpu), exp_data, atol] for args, image, exp_data, atol in TEST_CASES_CPU] +else: + TEST_CASES_GPU = [] # TESTS CHECKING PADDING, AXIS SELECTION ETC ARE COVERED BY test_detect_envelope.py @@ -162,42 +115,10 @@ def create_expected_numpy_output(input_datum, **kwargs): @SkipIfNoModule("torch.fft") class TestHilbertTransformCPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_1D_SINE_CPU, - TEST_CASE_2D_SINE_CPU, - TEST_CASE_3D_SINE_CPU, - TEST_CASE_1D_2CH_SINE_CPU, - TEST_CASE_2D_2CH_SINE_CPU, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = HilbertTransform(**arguments)(image) - result = result.squeeze(0).squeeze(0).numpy() - np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) - - -@skip_if_no_cuda -@SkipIfNoModule("torch.fft") -class TestHilbertTransformGPU(unittest.TestCase): - - @parameterized.expand( - ( - [] - if not torch.cuda.is_available() - else [ - TEST_CASE_1D_SINE_GPU, - TEST_CASE_2D_SINE_GPU, - TEST_CASE_3D_SINE_GPU, - TEST_CASE_1D_2CH_SINE_GPU, - TEST_CASE_2D_2CH_SINE_GPU, - ] - ), - skip_on_empty=True, - ) + @parameterized.expand(TEST_CASES_CPU + TEST_CASES_GPU) def test_value(self, arguments, image, expected_data, atol): result = HilbertTransform(**arguments)(image) - result = result.squeeze(0).squeeze(0).cpu().numpy() + result = np.squeeze(result.cpu().numpy()) np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py index 918190775c..3b40682de0 100644 --- a/tests/test_integration_unet_2d.py +++ b/tests/test_integration_unet_2d.py @@ -35,6 +35,7 @@ def __getitem__(self, _unused_id): def __len__(self): return train_steps + net = None if net_name == "basicunet": net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32)) elif net_name == "unet": diff --git a/tests/test_integration_workflows_adversarial.py b/tests/test_integration_workflows_adversarial.py new file mode 100644 index 0000000000..f323fc9917 --- /dev/null +++ b/tests/test_integration_workflows_adversarial.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import shutil +import tempfile +import unittest +from glob import glob + +import numpy as np +import torch + +import monai +from monai.data import create_test_image_2d +from monai.engines import AdversarialTrainer +from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler +from monai.networks.nets import AutoEncoder, Discriminator +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd +from monai.utils import AdversarialKeys as Keys +from monai.utils import CommonKeys, optional_import, set_determinism +from tests.utils import DistTestCase, TimedCall, skip_if_quick + +nib, has_nibabel = optional_import("nibabel") + + +def run_training_test(root_dir, device="cuda:0"): + learning_rate = 2e-4 + real_label = 1 + fake_label = 0 + + real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) + train_files = [{CommonKeys.IMAGE: img, CommonKeys.LABEL: img} for img in zip(real_images)] + + # prepare real data + train_transforms = Compose( + [ + LoadImaged(keys=[CommonKeys.IMAGE, CommonKeys.LABEL]), + EnsureChannelFirstd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], channel_dim=2), + ScaleIntensityd(keys=[CommonKeys.IMAGE]), + RandFlipd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], prob=0.5), + ] + ) + train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) + train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) + + # Create Discriminator + discriminator_net = Discriminator( + in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5 + ).to(device) + discriminator_opt = torch.optim.Adam(discriminator_net.parameters(), learning_rate) + discriminator_loss_criterion = torch.nn.BCELoss() + + def discriminator_loss(real_logits, fake_logits): + real_target = real_logits.new_full((real_logits.shape[0], 1), real_label) + fake_target = fake_logits.new_full((fake_logits.shape[0], 1), fake_label) + real_loss = discriminator_loss_criterion(real_logits, real_target) + fake_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) + return torch.div(torch.add(real_loss, fake_loss), 2) + + # Create Generator + generator_network = AutoEncoder( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 16, 32, 64), + strides=(2, 2, 2, 2), + num_res_units=1, + num_inter_units=1, + ) + generator_network = generator_network.to(device) + generator_optimiser = torch.optim.Adam(generator_network.parameters(), learning_rate) + generator_loss_criterion = torch.nn.MSELoss() + + def reconstruction_loss(recon_images, real_images): + return generator_loss_criterion(recon_images, real_images) + + def generator_loss(fake_logits): + fake_target = fake_logits.new_full((fake_logits.shape[0], 1), real_label) + recon_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) + return recon_loss + + key_train_metric = None + + train_handlers = [ + StatsHandler( + name="training_loss", + output_transform=lambda x: { + Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], + Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], + Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], + }, + ), + TensorBoardStatsHandler( + log_dir=root_dir, + tag_name="training_loss", + output_transform=lambda x: { + Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], + Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], + Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], + }, + ), + CheckpointSaver( + save_dir=root_dir, + save_dict={"g_net": generator_network, "d_net": discriminator_net}, + save_interval=2, + epoch_level=True, + ), + ] + + num_epochs = 5 + + trainer = AdversarialTrainer( + device=device, + max_epochs=num_epochs, + train_data_loader=train_loader, + g_network=generator_network, + g_optimizer=generator_optimiser, + g_loss_function=generator_loss, + recon_loss_function=reconstruction_loss, + d_network=discriminator_net, + d_optimizer=discriminator_opt, + d_loss_function=discriminator_loss, + non_blocking=True, + key_train_metric=key_train_metric, + train_handlers=train_handlers, + ) + trainer.run() + + return trainer.state + + +@skip_if_quick +@unittest.skipUnless(has_nibabel, "Requires nibabel library.") +class IntegrationWorkflowsAdversarialTrainer(DistTestCase): + def setUp(self): + set_determinism(seed=0) + + self.data_dir = tempfile.mkdtemp() + for i in range(40): + im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz")) + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") + monai.config.print_config() + + def tearDown(self): + set_determinism(seed=None) + shutil.rmtree(self.data_dir) + + @TimedCall(seconds=200, daemon=False) + def test_training(self): + torch.manual_seed(0) + + finish_state = run_training_test(self.data_dir, device=self.device) + + # Assert AdversarialTrainer training finished + self.assertEqual(finish_state.iteration, 100) + self.assertEqual(finish_state.epoch, 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py new file mode 100644 index 0000000000..2e04ad6c5c --- /dev/null +++ b/tests/test_latent_diffusion_inferer.py @@ -0,0 +1,824 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.inferers import LatentDiffusionInferer +from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler +from monai.utils import optional_import + +_, has_einops = optional_import("einops") +TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] +TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_prediction_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_intermediates( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + save_intermediates=True, + intermediate_steps=1, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_get_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_resample_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_prediction_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + @skipUnless(has_einops, "Requires einops") + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ordering.py b/tests/test_ordering.py new file mode 100644 index 0000000000..e6b235e179 --- /dev/null +++ b/tests/test_ordering.py @@ -0,0 +1,289 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.utils.enums import OrderingTransformations, OrderingType +from monai.utils.ordering import Ordering + +TEST_2D_NON_RANDOM = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 0, 1], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 1, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 1, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 3, 1], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 0, 2], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 2, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], +] + + +TEST_3D = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 3, + "dimensions": (1, 2, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3, 4, 5, 6, 7], + ] +] + +TEST_ORDERING_TYPE_FAILURE = [ + [ + { + "ordering_type": "hilbert", + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + +TEST_ORDERING_TRANSFORMATION_FAILURE = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + "flip", + ), + } + ] +] + +TEST_REVERT = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + + +class TestOrdering(unittest.TestCase): + @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D) + def test_ordering(self, input_param, expected_sequence_ordering): + ordering = Ordering(**input_param) + self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True)) + + @parameterized.expand(TEST_ORDERING_TYPE_FAILURE) + def test_ordering_type_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE) + def test_ordering_transformation_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_REVERT) + def test_revert(self, input_param): + sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() + + ordering = Ordering(**input_param) + + reverted_sequence = sequence[ordering.get_sequence_ordering()] + reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()] + + self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_patch_gan_dicriminator.py b/tests/test_patch_gan_dicriminator.py new file mode 100644 index 0000000000..c19898e70d --- /dev/null +++ b/tests/test_patch_gan_dicriminator.py @@ -0,0 +1,179 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator +from tests.utils import test_script_save + +TEST_PATCHGAN = [ + [ + { + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512]), + (1, 8, 128, 256), + (1, 1, 32, 64), + ], + [ + { + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512, 256]), + (1, 8, 128, 256, 128), + (1, 1, 32, 64, 32), + ], +] + +TEST_MULTISCALE_PATCHGAN = [ + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512]), + [(1, 1, 32, 64), (1, 1, 4, 8)], + [4, 7], + ], + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512, 256]), + [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)], + [4, 7], + ], +] +TEST_TOO_SMALL_SIZE = [ + { + "num_d": 2, + "num_layers_d": 6, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + } +] + + +class TestPatchGAN(unittest.TestCase): + @parameterized.expand(TEST_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape_feature, expected_shape_output): + net = PatchDiscriminator(**input_param) + with eval_mode(net): + result = net.forward(input_data) + self.assertEqual(tuple(result[0].shape), expected_shape_feature) + self.assertEqual(tuple(result[-1].shape), expected_shape_output) + + def test_script(self): + net = PatchDiscriminator( + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +class TestMultiscalePatchGAN(unittest.TestCase): + @parameterized.expand(TEST_MULTISCALE_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape, features_lengths=None): + net = MultiScalePatchDiscriminator(**input_param) + with eval_mode(net): + result, features = net.forward(input_data) + for r_ind, r in enumerate(result): + self.assertEqual(tuple(r.shape), expected_shape[r_ind]) + for o_d_ind, o_d in enumerate(features): + self.assertEqual(len(o_d), features_lengths[o_d_ind]) + + def test_too_small_shape(self): + with self.assertRaises(AssertionError): + MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) + + def test_script(self): + net = MultiScalePatchDiscriminator( + num_d=2, + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + minimum_size_im=256, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prepare_batch_diffusion.py b/tests/test_prepare_batch_diffusion.py new file mode 100644 index 0000000000..d969c06368 --- /dev/null +++ b/tests/test_prepare_batch_diffusion.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.engines import SupervisedEvaluator +from monai.engines.utils import DiffusionPrepareBatch +from monai.inferers import DiffusionInferer +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestPrepareBatchDiffusion(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_output_sizes(self, input_args, image_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [{"image": torch.randn(image_size).to(device)}] + scheduler = DDPMScheduler(num_train_timesteps=20) + inferer = DiffusionInferer(scheduler=scheduler) + network = DiffusionModelUNet(**input_args).to(device) + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=network, + inferer=inferer, + non_blocking=True, + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + # check shapes are the same + self.assertEqual(output["pred"].shape, image_size) + self.assertEqual(output["label"].shape, output["image"].shape) + + @parameterized.expand(TEST_CASES) + def test_conditioning(self, input_args, image_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [{"image": torch.randn(image_size).to(device), "context": torch.randn((2, 4, 3)).to(device)}] + scheduler = DDPMScheduler(num_train_timesteps=20) + inferer = DiffusionInferer(scheduler=scheduler) + network = DiffusionModelUNet(**input_args, with_conditioning=True, cross_attention_dim=3).to(device) + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=network, + inferer=inferer, + non_blocking=True, + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20, condition_name="context"), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + # check shapes are the same + self.assertEqual(output["pred"].shape, image_size) + self.assertEqual(output["label"].shape, output["image"].shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 1fb81689e6..8afc2da6ad 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -83,6 +83,9 @@ def forward(self, x): # initialize a SGD optimizer optimizer = optim.Adam(net.parameters(), lr=learning_rate) + # declare first for pylint + init_loss = None + # train the network for it in range(max_iter): # set the gradient to zero diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py new file mode 100644 index 0000000000..1a8f8cab67 --- /dev/null +++ b/tests/test_scheduler_ddim.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDIMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-0.9579, -0.6457], [0.4684, -0.9694]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py new file mode 100644 index 0000000000..f0447aded2 --- /dev/null +++ b/tests/test_scheduler_ddpm.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDPMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_2D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] + ) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_3D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] + ) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDPMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_CASES) + def test_get_velocity_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + sample = torch.randn(input_shape) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() + velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) + self.assertEqual(velocity.shape, expected_shape) + + def test_step_learned(self): + for variance_type in ["learned", "learned_range"]: + scheduler = DDPMScheduler(variance_type=variance_type) + model_output = torch.randn(2, 6, 16, 16) + sample = torch.randn(2, 3, 16, 16) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, sample.shape) + self.assertEqual(output_step[1].shape, sample.shape) + + def test_set_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py new file mode 100644 index 0000000000..69e5e403f5 --- /dev/null +++ b/tests/test_scheduler_pndm.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import PNDMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [ + {"schedule": "linear_beta"}, + (1, 1, 2, 2), + torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]), + ] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(600) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1], None) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_FULl_LOOP) + def test_timestep_two_loops(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + torch.manual_seed(42) + model_output2 = torch.randn(input_shape) + sample2 = torch.randn(input_shape) + scheduler.set_timesteps(50) + for t in range(50): + sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2) + assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_prk(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 109) + self.assertEqual(len(scheduler.timesteps), 109) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 0ebed84159..d069d6aa30 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -20,6 +20,7 @@ from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock +from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import einops, has_einops = optional_import("einops") @@ -28,12 +29,20 @@ for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 6, 8, 12]: - test_case = [ - {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @@ -53,6 +62,27 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + @skipUnless(has_einops, "Requires einops") + def test_inner_dim_different(self): + SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + def test_causal_no_sequence_length(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + @skipUnless(has_einops, "Requires einops") + def test_causal(self): + block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True) + input_shape = (1, 16, 128) + block(torch.randn(input_shape)) + # check upper triangular part of the attention matrix is zero + assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format hidden_size = 128 diff --git a/tests/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py new file mode 100644 index 0000000000..9353ceedc2 --- /dev/null +++ b/tests/test_spade_autoencoderkl.py @@ -0,0 +1,295 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEAutoencoderKL +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +CASES_NO_ATTENTION = [ + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16, 16), + (1, 3, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + +CASES_ATTENTION = [ + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 3, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + "spade_intermediate_channels": 32, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], +] + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +if has_einops: + CASES = CASES_ATTENTION + CASES_NO_ATTENTION +else: + CASES = CASES_NO_ATTENTION + + +class TestSPADEAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, input_seg, expected_shape, expected_latent_shape): + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device), torch.randn(input_seg).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + @skipUnless(has_einops, "Requires einops") + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + @skipUnless(has_einops, "Requires einops") + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + @skipUnless(has_einops, "Requires einops") + def test_model_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_encode(self): + input_param, input_shape, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, _, input_seg_shape, expected_input_shape, latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + @skipUnless(has_einops, "Requires einops") + def test_wrong_shape_decode(self): + net = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, False), + num_res_blocks=1, + norm_num_groups=4, + ) + with self.assertRaises(RuntimeError): + _ = net.decode(torch.randn((1, 1, 16, 16)).to(device), torch.randn((1, 6, 16, 16)).to(device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py new file mode 100644 index 0000000000..481705f56f --- /dev/null +++ b/tests/test_spade_diffusion_model_unet.py @@ -0,0 +1,574 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEDiffusionModelUNet +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + "spade_intermediate_channels": 256, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + "label_nc": 3, + } + ], +] + + +class TestSPADEDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + @skipUnless(has_einops, "Requires einops") + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_timestep_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16)) + ) + + @skipUnless(has_einops, "Requires einops") + def test_label_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(RuntimeError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16))) + + @skipUnless(has_einops, "Requires einops") + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 3, 16, 16)) + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + @skipUnless(has_einops, "Requires einops") + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + @skipUnless(has_einops, "Requires einops") + def test_context_with_conditioning_none(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models_class_conditioning(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + @skipUnless(has_einops, "Requires einops") + def test_conditioned_models_no_class_labels(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + ) + + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") + def test_conditioned_2d_models_shape(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, 3, 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + @skipUnless(has_einops, "Requires einops") + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 16, 16)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py new file mode 100644 index 0000000000..3fdb9b74cb --- /dev/null +++ b/tests/test_spade_vaegan.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADENet + +CASE_2D = [ + [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]], + [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], None, False]], +] +CASE_3D = [ + [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]], + [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], None, False]], +] + + +def create_semantic_data(shape: list, semantic_regions: int): + """ + To create semantic and image mock inputs for the network. + Args: + shape: input shape + semantic_regions: number of semantic region + Returns: + """ + out_label = torch.zeros(shape) + out_image = torch.zeros(shape) + torch.randn(shape) * 0.01 + for i in range(1, semantic_regions): + shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape] + start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)] + if len(shape) == 2: + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = (base_intensity + torch.randn(shape_square) * 0.1) + elif len(shape) == 3: + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = (base_intensity + torch.randn(shape_square) * 0.1) + else: + ValueError("Supports only 2D and 3D tensors") + + # One hot encode label + out_label_ = torch.zeros([semantic_regions] + list(out_label.shape)) + for ch in range(semantic_regions): + out_label_[ch, ...] = out_label == ch + + return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0) + + +class TestSpadeNet(unittest.TestCase): + @parameterized.expand(CASE_2D) + def test_forward_2d(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + if not net.is_vae: + out = net(in_label, in_image) + out = out[0] + else: + out, z_mu, z_logvar = net(in_label, in_image) + self.assertTrue(torch.all(torch.isfinite(z_mu))) + self.assertTrue(torch.all(torch.isfinite(z_logvar))) + + self.assertTrue(torch.all(torch.isfinite(out))) + self.assertEqual(list(out.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_2D) + def test_encoder_decoder(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + out_z = net.encode(in_image) + if net.is_vae: + self.assertEqual(list(out_z.shape), [1, 16]) + else: + self.assertEqual(out_z, None) + out_i = net.decode(in_label, out_z) + self.assertEqual(list(out_i.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_3D) + def test_forward_3d(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + if net.is_vae: + out, z_mu, z_logvar = net(in_label, in_image) + self.assertTrue(torch.all(torch.isfinite(z_mu))) + self.assertTrue(torch.all(torch.isfinite(z_logvar))) + else: + out = net(in_label, in_image) + out = out[0] + self.assertTrue(torch.all(torch.isfinite(out))) + self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) + + def test_shape_wrong(self): + """ + We input an input shape that isn't divisible by 2**(n downstream steps) + """ + with self.assertRaises(ValueError): + _ = SPADENet(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spatialattention.py b/tests/test_spatialattention.py new file mode 100644 index 0000000000..70b78263c5 --- /dev/null +++ b/tests/test_spatialattention.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.spatialattention import SpatialAttentionBlock +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASES = [ + [ + {"spatial_dims": 2, "num_channels": 128, "num_head_channels": 32, "norm_num_groups": 32, "norm_eps": 1e-6}, + (1, 128, 32, 32), + (1, 128, 32, 32), + ], + [ + {"spatial_dims": 3, "num_channels": 16, "num_head_channels": 8, "norm_num_groups": 8, "norm_eps": 1e-6}, + (1, 16, 8, 8, 8), + (1, 16, 8, 8, 8), + ], +] + + +class TestBlock(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = SpatialAttentionBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + SpatialAttentionBlock(spatial_dims=2, num_channels=128, num_head_channels=33) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py index 7db3c3e77a..4ab2144568 100644 --- a/tests/test_synthetic.py +++ b/tests/test_synthetic.py @@ -47,7 +47,7 @@ def test_create_test_image(self, dim, input_param, expected_img, expected_seg, e set_determinism(seed=0) if dim == 2: img, seg = create_test_image_2d(**input_param) - elif dim == 3: + else: # dim == 3 img, seg = create_test_image_3d(**input_param) self.assertEqual(img.shape, expected_shape) self.assertEqual(seg.max(), expected_max_cls) diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 0000000000..b371809d47 --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,109 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps import download_url +from monai.networks import eval_mode +from monai.networks.nets import DecoderOnlyTransformer +from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config + +_, has_einops = optional_import("einops") +TEST_CASES = [] +for dropout_rate in np.linspace(0, 1, 2): + for attention_layer_dim in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + TEST_CASES.append( + [ + { + "num_tokens": 10, + "max_seq_len": 16, + "attn_layers_dim": attention_layer_dim, + "attn_layers_depth": 2, + "attn_layers_heads": num_heads, + "embedding_dropout_rate": dropout_rate, + } + ] + ) + + +class TestDecoderOnlyTransformer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_unconditioned_models(self, input_param): + net = DecoderOnlyTransformer(**input_param) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_conditioned_models(self, input_param): + net = DecoderOnlyTransformer(**input_param, with_cross_attention=True) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param["attn_layers_dim"])) + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3 + ) + + @skipUnless(has_einops, "Requires einops") + def test_dropout_rate_negative(self): + + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + embedding_dropout_rate=-1, + ) + + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + with_cross_attention=True, + embedding_dropout_rate=0, + ) + + tmpdir = tempfile.mkdtemp() + key = "decoder_only_transformer_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "decoder_only_transformer_monai_generative_weights.pt" + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 5a8dbba83c..a850cc6f74 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import numpy as np import torch @@ -19,28 +20,33 @@ from monai.networks import eval_mode from monai.networks.blocks.transformerblock import TransformerBlock +from monai.utils import optional_import +einops, has_einops = optional_import("einops") TEST_CASE_TRANSFORMERBLOCK = [] for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 8, 12]: for mlp_dim in [1024, 3072]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "mlp_dim": mlp_dim, - "dropout_rate": dropout_rate, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_TRANSFORMERBLOCK.append(test_case) + for cross_attention in [False, True]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "mlp_dim": mlp_dim, + "dropout_rate": dropout_rate, + "with_cross_attention": cross_attention, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_TRANSFORMERBLOCK.append(test_case) class TestTransformerBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK) + @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): net = TransformerBlock(**input_param) with eval_mode(net): @@ -54,6 +60,7 @@ def test_ill_arg(self): with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format hidden_size = 128 diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py new file mode 100644 index 0000000000..43533d0377 --- /dev/null +++ b/tests/test_vector_quantizer.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from math import prod + +import torch +from parameterized import parameterized + +from monai.networks.layers import EMAQuantizer, VectorQuantizer + +TEST_CASES = [ + [{"spatial_dims": 2, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4), (1, 4, 4)], + [{"spatial_dims": 3, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)], +] + + +class TestEMA(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_ema_shape(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + layer = layer.train() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + layer = layer.eval() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + @parameterized.expand(TEST_CASES) + def test_ema_quantize(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1])) # (HxW[xD], C) + self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param["num_embeddings"])) # (HxW[xD], E) + self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:]) # (1, H, W, [D]) + + def test_ema(self): + layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0) + original_weight_0 = layer.embedding.weight[0].clone() + original_weight_1 = layer.embedding.weight[1].clone() + x_0 = original_weight_0 + x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_0 = x_0.repeat(1, 1, 1, 2) + 0.001 + + x_1 = original_weight_1 + x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_1 = x_1.repeat(1, 1, 1, 2) + + x = torch.cat([x_0, x_1], dim=0) + layer = layer.train() + _ = layer(x) + + self.assertTrue(all(layer.embedding.weight[0] != original_weight_0)) + self.assertTrue(all(layer.embedding.weight[1] == original_weight_1)) + + +class TestVectorQuantizer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_shape(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer(x) + self.assertEqual(outputs[1].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs.shape, output_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index b641599af2..68b12de2f8 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -70,6 +70,8 @@ class TestClassActivationMap(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_data, expected_shape): + model = None + if input_data["model"] == "densenet2d": model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) if input_data["model"] == "densenet3d": @@ -80,6 +82,7 @@ def test_shape(self, input_data, expected_shape): model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) if input_data["model"] == "senet3d": model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) + device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index 325b74b3ce..f77d916a5b 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -153,6 +153,8 @@ class TestGradientClassActivationMap(unittest.TestCase): @parameterized.expand(TESTS) def test_shape(self, cam_class, input_data, expected_shape): + model = None + if input_data["model"] == "densenet2d": model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) elif input_data["model"] == "densenet2d_bin": diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py new file mode 100644 index 0000000000..4916dc2faa --- /dev/null +++ b/tests/test_vqvae.py @@ -0,0 +1,274 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vqvae import VQVAE +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": 4, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": (2, 4, 1, 1), + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": (2, 4, 1, 1, 0), + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], +] + +TEST_LATENT_SHAPE = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "num_embeddings": 16, + "embedding_dim": 8, +} + + +class TestVQVAE(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + input_param = input_param.copy() + input_param.update({"use_checkpointing": True}) + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + # Removed this test case since TorchScript currently does not support activation checkpoint. + # def test_script(self): + # net = VQVAE( + # spatial_dims=2, + # in_channels=1, + # out_channels=1, + # downsample_parameters=((2, 4, 1, 1),) * 2, + # upsample_parameters=((2, 4, 1, 1, 0),) * 2, + # num_res_layers=1, + # channels=(8, 8), + # num_res_channels=(8, 8), + # num_embeddings=16, + # embedding_dim=8, + # ddp_sync=False, + # ) + # test_data = torch.randn(1, 1, 16, 16) + # test_script_save(net, test_data) + + def test_channels_not_same_size_of_num_res_channels(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_downsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_upsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_downsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=(("test", 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=(("test", 4, 1, 1, 0),) * 2, + ) + + def test_downsample_parameter_length_different_4(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameter_length_different_5(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, + ) + + def test_encode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.encode(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8, 8)) + + def test_index_quantize_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8)) + + def test_decode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode(torch.randn(1, 8, 8, 8).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + def test_decode_samples_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py new file mode 100644 index 0000000000..36b715f588 --- /dev/null +++ b/tests/test_vqvaetransformer_inferer.py @@ -0,0 +1,295 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.inferers import VQVAETransformerInferer +from monai.networks.nets import VQVAE, DecoderOnlyTransformer +from monai.utils import optional_import +from monai.utils.ordering import Ordering, OrderingType + +einops, has_einops = optional_import("einops") +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 4, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, + (2, 1, 8, 8), + (2, 4, 17), + (2, 2, 2), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 8, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, + (2, 1, 8, 8, 8), + (2, 8, 17), + (2, 2, 2, 2), + ], +] + + +class TestVQVAETransformerInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_prediction_shape( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + self.assertEqual(prediction.shape, logits_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_prediction_shape_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) + self.assertEqual(prediction.shape, cropped_logits_shape) + + @skipUnless(has_einops, "Requires einops") + def test_sample(self): + + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=4, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + @skipUnless(has_einops, "Requires einops") + def test_sample_shorter_sequence(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=2, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_get_likelihood( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_get_likelihood_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_get_likelihood_resampling( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + resample_latent_likelihoods=True, + resample_interpolation_mode="nearest", + ) + self.assertEqual(likelihood.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index a570c787ba..8b1d2868b7 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -138,6 +138,26 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth", "hash_type": "sha256", "hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8" + }, + "decoder_only_transformer_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/decoder_only_transformer.pth", + "hash_type": "sha256", + "hash_val": "f93de37d64d77cf91f3bde95cdf93d161aee800074c89a92aff9d5699120ec0d" + }, + "diffusion_model_unet_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/diffusion_model_unet.pth", + "hash_type": "sha256", + "hash_val": "0d2171b386902f5b4fd3e967b4024f63e353694ca45091b114970019d045beee" + }, + "autoencoderkl_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth", + "hash_type": "sha256", + "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184" + }, + "controlnet_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth", + "hash_type": "sha256", + "hash_val": "cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e" } }, "configs": {