diff --git a/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb b/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb index effd3302b67b..e458fd3f53c0 100644 --- a/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb +++ b/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb @@ -791,13 +791,11 @@ " from ray.air import Checkpoint\n", "\n", " checkpoint = Checkpoint.from_dict(\n", - " dict(epoch=t, model=model.module.state_dict())\n", + " dict(epoch=t, model=model.state_dict())\n", " )\n", " session.report(dict(loss=test_loss), checkpoint=checkpoint)\n", "```\n", "\n", - "Note that the `model.module` part is needed because the model gets wrapped in `torch.nn.DistributedDataParallel` by `train.torch.prepare_model`.\n", - "\n", "### Move the data loader to the training function\n", "\n", "You may have noticed a warning: `Warning: The actor TrainTrainable is very large (52 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.`.\n", @@ -861,7 +859,7 @@ " train_epoch(train_dataloader, model, loss_fn, optimizer)\n", " test_loss = test_epoch(test_dataloader, model, loss_fn)\n", " checkpoint = Checkpoint.from_dict(\n", - " dict(epoch=t, model=model.module.state_dict())\n", + " dict(epoch=t, model=model.state_dict())\n", " )\n", " session.report(dict(loss=test_loss), checkpoint=checkpoint)\n", "\n", @@ -1302,7 +1300,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.12 ('.venv': venv)", + "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, @@ -1316,11 +1314,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.8.10" }, "vscode": { "interpreter": { - "hash": "a658351b4133f922c5967ed6133cfc05c9f16c53a5161e5843ace3f528fccaf5" + "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f" } } }, diff --git a/doc/source/ray-air/examples/torch_image_example.ipynb b/doc/source/ray-air/examples/torch_image_example.ipynb index 7c8039b539cb..64f1ba04ca59 100644 --- a/doc/source/ray-air/examples/torch_image_example.ipynb +++ b/doc/source/ray-air/examples/torch_image_example.ipynb @@ -200,7 +200,7 @@ "\n", "train_dataset = train_dataset.map_batches(convert_batch_to_numpy)\n", "test_dataset = test_dataset.map_batches(convert_batch_to_numpy)" - ] + ] }, { "cell_type": "code", @@ -328,7 +328,7 @@ " running_loss = 0.0\n", "\n", " metrics = dict(running_loss=running_loss)\n", - " checkpoint = TorchCheckpoint.from_state_dict(model.module.state_dict())\n", + " checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())\n", " session.report(metrics, checkpoint=checkpoint)" ] }, @@ -810,7 +810,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.8 ('.venv': venv)", + "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, @@ -824,11 +824,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.8.10" }, "vscode": { "interpreter": { - "hash": "c704e19737f24b51bc631dadcac7a7e356bb35d1c5cd7766248d8a6946059909" + "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f" } } }, diff --git a/doc/source/ray-air/examples/torch_incremental_learning.ipynb b/doc/source/ray-air/examples/torch_incremental_learning.ipynb index 96f83f02e3fe..e959eced3463 100644 --- a/doc/source/ray-air/examples/torch_incremental_learning.ipynb +++ b/doc/source/ray-air/examples/torch_incremental_learning.ipynb @@ -439,8 +439,6 @@ "from torch.optim import SGD\n", "from torch.nn import CrossEntropyLoss\n", "\n", - "from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present\n", - "\n", "def train_loop_per_worker(config: dict):\n", " num_epochs = config[\"num_epochs\"]\n", " learning_rate = config[\"learning_rate\"]\n", @@ -488,7 +486,6 @@ "\n", " # Checkpoint model after every epoch.\n", " state_dict = model.state_dict()\n", - " consume_prefix_in_state_dict_if_present(state_dict, \"module.\")\n", " checkpoint = Checkpoint.from_dict(dict(model=state_dict))\n", " session.report({\"loss\": running_loss}, checkpoint=checkpoint)" ] diff --git a/doc/source/ray-core/_examples/datasets_train/datasets_train.py b/doc/source/ray-core/_examples/datasets_train/datasets_train.py index 0f195612b65d..111e4d81a8dd 100644 --- a/doc/source/ray-core/_examples/datasets_train/datasets_train.py +++ b/doc/source/ray-core/_examples/datasets_train/datasets_train.py @@ -23,7 +23,6 @@ import torch import torch.nn as nn import torch.optim as optim -from torch.nn.parallel import DistributedDataParallel import ray from ray import train @@ -455,8 +454,7 @@ def train_func(config): ) # Checkpoint model. - module = net.module if isinstance(net, DistributedDataParallel) else net - checkpoint = Checkpoint.from_dict(dict(model=module.state_dict())) + checkpoint = Checkpoint.from_dict(dict(model=net.state_dict())) # Record and log stats. print(f"session report on {session.get_world_rank()}") diff --git a/doc/source/train/dl_guide.rst b/doc/source/train/dl_guide.rst index 00ca8a8df6a1..ef31e82acc1b 100644 --- a/doc/source/train/dl_guide.rst +++ b/doc/source/train/dl_guide.rst @@ -527,7 +527,6 @@ appropriately in distributed training. import torch import torch.nn as nn - from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present from torch.optim import Adam import numpy as np @@ -552,12 +551,7 @@ appropriately in distributed training. optimizer.zero_grad() loss.backward() optimizer.step() - # To fetch non-DDP state_dict - # w/o DDP: model.state_dict() - # w/ DDP: model.module.state_dict() - # See: https://github.com/ray-project/ray/issues/20915 state_dict = model.state_dict() - consume_prefix_in_state_dict_if_present(state_dict, "module.") checkpoint = Checkpoint.from_dict( dict(epoch=epoch, model_weights=state_dict) ) @@ -714,7 +708,6 @@ Checkpoints can be loaded into the training function in 2 steps: import torch import torch.nn as nn - from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present from torch.optim import Adam import numpy as np @@ -751,7 +744,6 @@ Checkpoints can be loaded into the training function in 2 steps: loss.backward() optimizer.step() state_dict = model.state_dict() - consume_prefix_in_state_dict_if_present(state_dict, "module.") checkpoint = Checkpoint.from_dict( dict(epoch=epoch, model_weights=state_dict) ) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 5f4a0925575b..15be5b1d5d34 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -1,5 +1,5 @@ import warnings -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Any import numpy as np import pandas as pd @@ -227,3 +227,45 @@ def contains_tensor(obj): if contains_tensor(v): return True return False + + +# Not present in torch<=1.7.0 +# Adapted from https://github.com/pytorch/pytorch/blob/\ +# c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py +def consume_prefix_in_state_dict_if_present_not_in_place( + state_dict: Dict[str, Any], prefix: str +) -> Dict[str, Any]: + """Strip the prefix in state_dict, if any and return a new dict. + + Adapted from https://github.com/pytorch/pytorch/blob/\ +c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py + The original method modified the dict in-place. + + Args: + state_dict: a state-dict to be loaded to the model. + prefix: prefix. + + """ + copied = False + + for key in state_dict: + if key.startswith(prefix): + newkey = key[len(prefix) :] + if not copied: + # We are doing shallow copies here, so the performance + # impact should be negligible anyway, but this is + # a simple optimization. + state_dict = state_dict.copy() + copied = True + state_dict[newkey] = state_dict.pop(key) + + if "_metadata" in state_dict: + state_dict["_metadata"] = state_dict["_metadata"].copy() + metadata = state_dict["_metadata"] + for key in metadata: + if len(key) == 0: + continue + newkey = key[len(prefix) :] + metadata[newkey] = metadata.pop(key) + + return state_dict diff --git a/python/ray/train/tests/test_torch_trainer.py b/python/ray/train/tests/test_torch_trainer.py index 7297d64e8dc3..b7692e7ef446 100644 --- a/python/ray/train/tests/test_torch_trainer.py +++ b/python/ray/train/tests/test_torch_trainer.py @@ -61,9 +61,12 @@ def train_func(config): trainer.fit() -def test_torch_e2e(ray_start_4_cpus): +@pytest.mark.parametrize("prepare_model", (True, False)) +def test_torch_e2e(ray_start_4_cpus, prepare_model): def train_func(): model = torch.nn.Linear(3, 1) + if prepare_model: + model = train.torch.prepare_model(model) session.report({}, checkpoint=TorchCheckpoint.from_model(model)) scaling_config = ScalingConfig(num_workers=2) @@ -83,10 +86,15 @@ def train_func(): assert predictions.count() == 3 -def test_torch_e2e_state_dict(ray_start_4_cpus): +@pytest.mark.parametrize("prepare_model", (True, False)) +def test_torch_e2e_state_dict(ray_start_4_cpus, prepare_model): def train_func(): - model = torch.nn.Linear(3, 1).state_dict() - session.report({}, checkpoint=TorchCheckpoint.from_state_dict(model)) + model = torch.nn.Linear(3, 1) + if prepare_model: + model = train.torch.prepare_model(model) + session.report( + {}, checkpoint=TorchCheckpoint.from_state_dict(model.state_dict()) + ) scaling_config = ScalingConfig(num_workers=2) trainer = TorchTrainer( @@ -111,6 +119,8 @@ def train_func(): assert predictions.count() == 3 +# We can't really test for prepare_model here as we can't detect what the user +# has saved without loading (and thus triggering the exception anyway) def test_torch_e2e_dir(ray_start_4_cpus, tmpdir): def train_func(): model = torch.nn.Linear(3, 1) diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 5b87ab1edb8c..68091a1c5196 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -4,11 +4,16 @@ import torch import warnings +from torch.nn import Module + import ray.cloudpickle from ray.air.checkpoint import Checkpoint from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.train.data_parallel_trainer import _load_checkpoint_dict -from ray.air._internal.torch_utils import load_torch_model +from ray.air._internal.torch_utils import ( + load_torch_model, + consume_prefix_in_state_dict_if_present_not_in_place, +) from ray.util.annotations import PublicAPI if TYPE_CHECKING: @@ -28,11 +33,19 @@ class TorchCheckpoint(Checkpoint): # Special encoding logic to avoid serialization errors with torch. def _encode_data_dict(self, data_dict: dict) -> dict: """Encode data_dict using torch.save.""" - from torch.nn.parallel import DistributedDataParallel for k, v in data_dict.items(): - if isinstance(v, DistributedDataParallel) and hasattr(v, "module"): + # Only check for attribute as we want to support + # DDP, FSDP and any future approaches + if isinstance(v, Module) and hasattr(v, "module"): data_dict[k] = v.module + elif isinstance(v, dict): + # We could limit this only to the MODEL_KEY, but we'd + # miss any extra user-specified keys. This should be a + # noop with anything but DDP/FSDP module state dicts. + data_dict[k] = consume_prefix_in_state_dict_if_present_not_in_place( + v, "module." + ) # Convert the checkpoint dict to bytes, so that any GPU tensors that # are in the checkpoint dict can be properly deserialized on the