Skip to content

Commit

Permalink
[Train] Strip "module." from state dict (ray-project#30705)
Browse files Browse the repository at this point in the history
This PR adds logic to automatically strip the "module." prefix from a user-saved state dict in TorchCheckpoint, which is present if a user obtains the state dict from a DistributedDataParallel module directly. We already obtain the underlying module if a user saves the model object, so this merely makes the logic consistent.

This PR also edits our examples to remove instances where this operation was conducted in the example itself. This led to issues if train.torch.prepare_model was used with num_workers=1 (eg. on Google Colab), as the module was not wrapped around, thus leading to the .module attribute being missing.

Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: tmynn <[email protected]>
  • Loading branch information
Yard1 authored and tamohannes committed Jan 25, 2023
1 parent 0bf3e93 commit c8dd801
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,11 @@
" from ray.air import Checkpoint\n",
"\n",
" checkpoint = Checkpoint.from_dict(\n",
" dict(epoch=t, model=model.module.state_dict())\n",
" dict(epoch=t, model=model.state_dict())\n",
" )\n",
" session.report(dict(loss=test_loss), checkpoint=checkpoint)\n",
"```\n",
"\n",
"Note that the `model.module` part is needed because the model gets wrapped in `torch.nn.DistributedDataParallel` by `train.torch.prepare_model`.\n",
"\n",
"### Move the data loader to the training function\n",
"\n",
"You may have noticed a warning: `Warning: The actor TrainTrainable is very large (52 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.`.\n",
Expand Down Expand Up @@ -861,7 +859,7 @@
" train_epoch(train_dataloader, model, loss_fn, optimizer)\n",
" test_loss = test_epoch(test_dataloader, model, loss_fn)\n",
" checkpoint = Checkpoint.from_dict(\n",
" dict(epoch=t, model=model.module.state_dict())\n",
" dict(epoch=t, model=model.state_dict())\n",
" )\n",
" session.report(dict(loss=test_loss), checkpoint=checkpoint)\n",
"\n",
Expand Down Expand Up @@ -1302,7 +1300,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('.venv': venv)",
"display_name": "Python 3.8.10 ('venv': venv)",
"language": "python",
"name": "python3"
},
Expand All @@ -1316,11 +1314,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "a658351b4133f922c5967ed6133cfc05c9f16c53a5161e5843ace3f528fccaf5"
"hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f"
}
}
},
Expand Down
10 changes: 5 additions & 5 deletions doc/source/ray-air/examples/torch_image_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
"\n",
"train_dataset = train_dataset.map_batches(convert_batch_to_numpy)\n",
"test_dataset = test_dataset.map_batches(convert_batch_to_numpy)"
]
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -328,7 +328,7 @@
" running_loss = 0.0\n",
"\n",
" metrics = dict(running_loss=running_loss)\n",
" checkpoint = TorchCheckpoint.from_state_dict(model.module.state_dict())\n",
" checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())\n",
" session.report(metrics, checkpoint=checkpoint)"
]
},
Expand Down Expand Up @@ -810,7 +810,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.8 ('.venv': venv)",
"display_name": "Python 3.8.10 ('venv': venv)",
"language": "python",
"name": "python3"
},
Expand All @@ -824,11 +824,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "c704e19737f24b51bc631dadcac7a7e356bb35d1c5cd7766248d8a6946059909"
"hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f"
}
}
},
Expand Down
3 changes: 0 additions & 3 deletions doc/source/ray-air/examples/torch_incremental_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,6 @@
"from torch.optim import SGD\n",
"from torch.nn import CrossEntropyLoss\n",
"\n",
"from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present\n",
"\n",
"def train_loop_per_worker(config: dict):\n",
" num_epochs = config[\"num_epochs\"]\n",
" learning_rate = config[\"learning_rate\"]\n",
Expand Down Expand Up @@ -488,7 +486,6 @@
"\n",
" # Checkpoint model after every epoch.\n",
" state_dict = model.state_dict()\n",
" consume_prefix_in_state_dict_if_present(state_dict, \"module.\")\n",
" checkpoint = Checkpoint.from_dict(dict(model=state_dict))\n",
" session.report({\"loss\": running_loss}, checkpoint=checkpoint)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel

import ray
from ray import train
Expand Down Expand Up @@ -455,8 +454,7 @@ def train_func(config):
)

# Checkpoint model.
module = net.module if isinstance(net, DistributedDataParallel) else net
checkpoint = Checkpoint.from_dict(dict(model=module.state_dict()))
checkpoint = Checkpoint.from_dict(dict(model=net.state_dict()))

# Record and log stats.
print(f"session report on {session.get_world_rank()}")
Expand Down
8 changes: 0 additions & 8 deletions doc/source/train/dl_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ appropriately in distributed training.
import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.optim import Adam
import numpy as np
Expand All @@ -552,12 +551,7 @@ appropriately in distributed training.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# To fetch non-DDP state_dict
# w/o DDP: model.state_dict()
# w/ DDP: model.module.state_dict()
# See: https://github.com/ray-project/ray/issues/20915
state_dict = model.state_dict()
consume_prefix_in_state_dict_if_present(state_dict, "module.")
checkpoint = Checkpoint.from_dict(
dict(epoch=epoch, model_weights=state_dict)
)
Expand Down Expand Up @@ -714,7 +708,6 @@ Checkpoints can be loaded into the training function in 2 steps:
import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.optim import Adam
import numpy as np
Expand Down Expand Up @@ -751,7 +744,6 @@ Checkpoints can be loaded into the training function in 2 steps:
loss.backward()
optimizer.step()
state_dict = model.state_dict()
consume_prefix_in_state_dict_if_present(state_dict, "module.")
checkpoint = Checkpoint.from_dict(
dict(epoch=epoch, model_weights=state_dict)
)
Expand Down
44 changes: 43 additions & 1 deletion python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -227,3 +227,45 @@ def contains_tensor(obj):
if contains_tensor(v):
return True
return False


# Not present in torch<=1.7.0
# Adapted from https://github.com/pytorch/pytorch/blob/\
# c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py
def consume_prefix_in_state_dict_if_present_not_in_place(
state_dict: Dict[str, Any], prefix: str
) -> Dict[str, Any]:
"""Strip the prefix in state_dict, if any and return a new dict.
Adapted from https://github.com/pytorch/pytorch/blob/\
c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py
The original method modified the dict in-place.
Args:
state_dict: a state-dict to be loaded to the model.
prefix: prefix.
"""
copied = False

for key in state_dict:
if key.startswith(prefix):
newkey = key[len(prefix) :]
if not copied:
# We are doing shallow copies here, so the performance
# impact should be negligible anyway, but this is
# a simple optimization.
state_dict = state_dict.copy()
copied = True
state_dict[newkey] = state_dict.pop(key)

if "_metadata" in state_dict:
state_dict["_metadata"] = state_dict["_metadata"].copy()
metadata = state_dict["_metadata"]
for key in metadata:
if len(key) == 0:
continue
newkey = key[len(prefix) :]
metadata[newkey] = metadata.pop(key)

return state_dict
18 changes: 14 additions & 4 deletions python/ray/train/tests/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def train_func(config):
trainer.fit()


def test_torch_e2e(ray_start_4_cpus):
@pytest.mark.parametrize("prepare_model", (True, False))
def test_torch_e2e(ray_start_4_cpus, prepare_model):
def train_func():
model = torch.nn.Linear(3, 1)
if prepare_model:
model = train.torch.prepare_model(model)
session.report({}, checkpoint=TorchCheckpoint.from_model(model))

scaling_config = ScalingConfig(num_workers=2)
Expand All @@ -83,10 +86,15 @@ def train_func():
assert predictions.count() == 3


def test_torch_e2e_state_dict(ray_start_4_cpus):
@pytest.mark.parametrize("prepare_model", (True, False))
def test_torch_e2e_state_dict(ray_start_4_cpus, prepare_model):
def train_func():
model = torch.nn.Linear(3, 1).state_dict()
session.report({}, checkpoint=TorchCheckpoint.from_state_dict(model))
model = torch.nn.Linear(3, 1)
if prepare_model:
model = train.torch.prepare_model(model)
session.report(
{}, checkpoint=TorchCheckpoint.from_state_dict(model.state_dict())
)

scaling_config = ScalingConfig(num_workers=2)
trainer = TorchTrainer(
Expand All @@ -111,6 +119,8 @@ def train_func():
assert predictions.count() == 3


# We can't really test for prepare_model here as we can't detect what the user
# has saved without loading (and thus triggering the exception anyway)
def test_torch_e2e_dir(ray_start_4_cpus, tmpdir):
def train_func():
model = torch.nn.Linear(3, 1)
Expand Down
19 changes: 16 additions & 3 deletions python/ray/train/torch/torch_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import torch
import warnings

from torch.nn import Module

import ray.cloudpickle
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint_dict
from ray.air._internal.torch_utils import load_torch_model
from ray.air._internal.torch_utils import (
load_torch_model,
consume_prefix_in_state_dict_if_present_not_in_place,
)
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
Expand All @@ -28,11 +33,19 @@ class TorchCheckpoint(Checkpoint):
# Special encoding logic to avoid serialization errors with torch.
def _encode_data_dict(self, data_dict: dict) -> dict:
"""Encode data_dict using torch.save."""
from torch.nn.parallel import DistributedDataParallel

for k, v in data_dict.items():
if isinstance(v, DistributedDataParallel) and hasattr(v, "module"):
# Only check for attribute as we want to support
# DDP, FSDP and any future approaches
if isinstance(v, Module) and hasattr(v, "module"):
data_dict[k] = v.module
elif isinstance(v, dict):
# We could limit this only to the MODEL_KEY, but we'd
# miss any extra user-specified keys. This should be a
# noop with anything but DDP/FSDP module state dicts.
data_dict[k] = consume_prefix_in_state_dict_if_present_not_in_place(
v, "module."
)

# Convert the checkpoint dict to bytes, so that any GPU tensors that
# are in the checkpoint dict can be properly deserialized on the
Expand Down

0 comments on commit c8dd801

Please sign in to comment.