Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Strip "module." from state dict #30705

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,11 @@
" from ray.air import Checkpoint\n",
"\n",
" checkpoint = Checkpoint.from_dict(\n",
" dict(epoch=t, model=model.module.state_dict())\n",
" dict(epoch=t, model=model.state_dict())\n",
" )\n",
" session.report(dict(loss=test_loss), checkpoint=checkpoint)\n",
"```\n",
"\n",
"Note that the `model.module` part is needed because the model gets wrapped in `torch.nn.DistributedDataParallel` by `train.torch.prepare_model`.\n",
"\n",
"### Move the data loader to the training function\n",
"\n",
"You may have noticed a warning: `Warning: The actor TrainTrainable is very large (52 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.`.\n",
Expand Down Expand Up @@ -861,7 +859,7 @@
" train_epoch(train_dataloader, model, loss_fn, optimizer)\n",
" test_loss = test_epoch(test_dataloader, model, loss_fn)\n",
" checkpoint = Checkpoint.from_dict(\n",
" dict(epoch=t, model=model.module.state_dict())\n",
" dict(epoch=t, model=model.state_dict())\n",
" )\n",
" session.report(dict(loss=test_loss), checkpoint=checkpoint)\n",
"\n",
Expand Down Expand Up @@ -1302,7 +1300,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('.venv': venv)",
"display_name": "Python 3.8.10 ('venv': venv)",
"language": "python",
"name": "python3"
},
Expand All @@ -1316,11 +1314,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "a658351b4133f922c5967ed6133cfc05c9f16c53a5161e5843ace3f528fccaf5"
"hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f"
}
}
},
Expand Down
10 changes: 5 additions & 5 deletions doc/source/ray-air/examples/torch_image_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
"\n",
"train_dataset = train_dataset.map_batches(convert_batch_to_numpy)\n",
"test_dataset = test_dataset.map_batches(convert_batch_to_numpy)"
]
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -328,7 +328,7 @@
" running_loss = 0.0\n",
"\n",
" metrics = dict(running_loss=running_loss)\n",
" checkpoint = TorchCheckpoint.from_state_dict(model.module.state_dict())\n",
" checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())\n",
" session.report(metrics, checkpoint=checkpoint)"
]
},
Expand Down Expand Up @@ -810,7 +810,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.8 ('.venv': venv)",
"display_name": "Python 3.8.10 ('venv': venv)",
"language": "python",
"name": "python3"
},
Expand All @@ -824,11 +824,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "c704e19737f24b51bc631dadcac7a7e356bb35d1c5cd7766248d8a6946059909"
"hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f"
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,6 @@
"from torch.optim import SGD\n",
"from torch.nn import CrossEntropyLoss\n",
"\n",
"from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present\n",
"\n",
"def train_loop_per_worker(config: dict):\n",
" num_epochs = config[\"num_epochs\"]\n",
" learning_rate = config[\"learning_rate\"]\n",
Expand Down Expand Up @@ -488,7 +486,6 @@
"\n",
" # Checkpoint model after every epoch.\n",
" state_dict = model.state_dict()\n",
" consume_prefix_in_state_dict_if_present(state_dict, \"module.\")\n",
" checkpoint = Checkpoint.from_dict(dict(model=state_dict))\n",
" session.report({\"loss\": running_loss}, checkpoint=checkpoint)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel

import ray
from ray import train
Expand Down Expand Up @@ -455,8 +454,7 @@ def train_func(config):
)

# Checkpoint model.
module = net.module if isinstance(net, DistributedDataParallel) else net
checkpoint = Checkpoint.from_dict(dict(model=module.state_dict()))
checkpoint = Checkpoint.from_dict(dict(model=net.state_dict()))

# Record and log stats.
print(f"session report on {session.get_world_rank()}")
Expand Down
8 changes: 0 additions & 8 deletions doc/source/train/dl_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ appropriately in distributed training.

import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.optim import Adam
import numpy as np

Expand All @@ -552,12 +551,7 @@ appropriately in distributed training.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# To fetch non-DDP state_dict
# w/o DDP: model.state_dict()
# w/ DDP: model.module.state_dict()
# See: https://github.com/ray-project/ray/issues/20915
state_dict = model.state_dict()
consume_prefix_in_state_dict_if_present(state_dict, "module.")
checkpoint = Checkpoint.from_dict(
dict(epoch=epoch, model_weights=state_dict)
)
Expand Down Expand Up @@ -714,7 +708,6 @@ Checkpoints can be loaded into the training function in 2 steps:

import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.optim import Adam
import numpy as np

Expand Down Expand Up @@ -751,7 +744,6 @@ Checkpoints can be loaded into the training function in 2 steps:
loss.backward()
optimizer.step()
state_dict = model.state_dict()
consume_prefix_in_state_dict_if_present(state_dict, "module.")
checkpoint = Checkpoint.from_dict(
dict(epoch=epoch, model_weights=state_dict)
)
Expand Down
44 changes: 43 additions & 1 deletion python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -227,3 +227,45 @@ def contains_tensor(obj):
if contains_tensor(v):
return True
return False


# Not present in torch<=1.7.0
# Adapted from https://github.com/pytorch/pytorch/blob/\
# c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py
def consume_prefix_in_state_dict_if_present_not_in_place(
state_dict: Dict[str, Any], prefix: str
) -> Dict[str, Any]:
"""Strip the prefix in state_dict, if any and return a new dict.

Adapted from https://github.com/pytorch/pytorch/blob/\
c18da597e0bb1c1aecc97c77a73fed1849057fa4/torch/nn/modules/utils.py
The original method modified the dict in-place.

Args:
state_dict: a state-dict to be loaded to the model.
prefix: prefix.

"""
copied = False

for key in state_dict:
if key.startswith(prefix):
newkey = key[len(prefix) :]
if not copied:
# We are doing shallow copies here, so the performance
# impact should be negligible anyway, but this is
# a simple optimization.
state_dict = state_dict.copy()
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
copied = True
state_dict[newkey] = state_dict.pop(key)

if "_metadata" in state_dict:
state_dict["_metadata"] = state_dict["_metadata"].copy()
metadata = state_dict["_metadata"]
for key in metadata:
if len(key) == 0:
continue
newkey = key[len(prefix) :]
metadata[newkey] = metadata.pop(key)

return state_dict
18 changes: 14 additions & 4 deletions python/ray/train/tests/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def train_func(config):
trainer.fit()


def test_torch_e2e(ray_start_4_cpus):
@pytest.mark.parametrize("prepare_model", (True, False))
def test_torch_e2e(ray_start_4_cpus, prepare_model):
def train_func():
model = torch.nn.Linear(3, 1)
if prepare_model:
model = train.torch.prepare_model(model)
session.report({}, checkpoint=TorchCheckpoint.from_model(model))

scaling_config = ScalingConfig(num_workers=2)
Expand All @@ -83,10 +86,15 @@ def train_func():
assert predictions.count() == 3


def test_torch_e2e_state_dict(ray_start_4_cpus):
@pytest.mark.parametrize("prepare_model", (True, False))
def test_torch_e2e_state_dict(ray_start_4_cpus, prepare_model):
def train_func():
model = torch.nn.Linear(3, 1).state_dict()
session.report({}, checkpoint=TorchCheckpoint.from_state_dict(model))
model = torch.nn.Linear(3, 1)
if prepare_model:
model = train.torch.prepare_model(model)
session.report(
{}, checkpoint=TorchCheckpoint.from_state_dict(model.state_dict())
)

scaling_config = ScalingConfig(num_workers=2)
trainer = TorchTrainer(
Expand All @@ -111,6 +119,8 @@ def train_func():
assert predictions.count() == 3


# We can't really test for prepare_model here as we can't detect what the user
# has saved without loading (and thus triggering the exception anyway)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding, can you elaborate on why prepare_model causes this test to fail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_model will wrap the model in DDP. If the user doesn't manually unwrap it before saving, DDP will throw an exception after being loaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I guess I mean more- why is it not going through the _encode_dict path?

Copy link
Member Author

@Yard1 Yard1 Dec 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a checkpoint is created from directory, we aren't really able to detect what's actually in the files without deserializing them in the first place (which would not only add overhead but also cause the error anyway), and we can't apply _encode_dict on already serialized data

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then for this dir checkpoint, why does it get deserialized in the first place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we don't have a native way of supporting torch models from files (as mentioned by the TODO in this test). Therefore, the test implements its own predictor. Using dir checkpoints with torch is not what we want users to do right now, but the purpose of this test is to make sure that it works regardless.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add prepare_model here but we'd have to unwrap the model before saving anyway, meaning we wouldn't really test anything extra here.

def test_torch_e2e_dir(ray_start_4_cpus, tmpdir):
def train_func():
model = torch.nn.Linear(3, 1)
Expand Down
19 changes: 16 additions & 3 deletions python/ray/train/torch/torch_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import torch
import warnings

from torch.nn import Module

import ray.cloudpickle
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint_dict
from ray.air._internal.torch_utils import load_torch_model
from ray.air._internal.torch_utils import (
load_torch_model,
consume_prefix_in_state_dict_if_present_not_in_place,
)
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
Expand All @@ -28,11 +33,19 @@ class TorchCheckpoint(Checkpoint):
# Special encoding logic to avoid serialization errors with torch.
def _encode_data_dict(self, data_dict: dict) -> dict:
"""Encode data_dict using torch.save."""
from torch.nn.parallel import DistributedDataParallel

for k, v in data_dict.items():
if isinstance(v, DistributedDataParallel) and hasattr(v, "module"):
# Only check for attribute as we want to support
# DDP, FSDP and any future approaches
if isinstance(v, Module) and hasattr(v, "module"):
data_dict[k] = v.module
elif isinstance(v, dict):
# We could limit this only to the MODEL_KEY, but we'd
# miss any extra user-specified keys. This should be a
# noop with anything but DDP/FSDP module state dicts.
data_dict[k] = consume_prefix_in_state_dict_if_present_not_in_place(
v, "module."
)

# Convert the checkpoint dict to bytes, so that any GPU tensors that
# are in the checkpoint dict can be properly deserialized on the
Expand Down