Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release] fix pytorch pbt failure test. #31791

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def train_func(config):
epochs = config.get("epochs", 3)

model = resnet18()
model = train.torch.prepare_model(model)

# Create optimizer.
optimizer_config = {
Expand Down Expand Up @@ -98,6 +97,8 @@ def train_func(config):
checkpoint_epoch = checkpoint_dict["epoch"]
starting_epoch = checkpoint_epoch + 1

model = train.torch.prepare_model(model)

# Load in training and validation data.
transform_train = transforms.Compose(
[
Expand Down
34 changes: 34 additions & 0 deletions python/ray/train/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,40 @@ def train_loop_per_worker():
To save a model to use for the ``TorchPredictor``, you must save it under the
"model" kwarg in ``Checkpoint`` passed to ``session.report()``.

.. note::
When you wrap the ``model`` with ``prepare_model``, the keys of its
``state_dict`` are prefixed by ``module.``. For example,
``layer1.0.bn1.bias`` becomes ``module.layer1.0.bn1.bias``.
However, when saving ``model`` through ``session.report()``
all ``module.`` prefixes are stripped.
As a result, when you load from a saved checkpoint, make sure that
you first load ``state_dict`` to the model
before calling ``prepare_model``.
Otherwise, you will run into errors like
``Error(s) in loading state_dict for DistributedDataParallel:
Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below.

.. testcode::

from torchvision.models import resnet18
from ray.air import session
from ray.air.checkpoint import Checkpoint
import ray.train as train

def train_func():
...
model = resnet18()
model = train.torch.prepare_model(model)
for epoch in range(3):
...
ckpt = Checkpoint.from_dict({
"epoch": epoch,
"model": model.state_dict(),
# "model": model.module.state_dict(),
# ** The above two are equivalent **
})
session.report({"foo": "bar"}, ckpt)

Example:

.. testcode::
Expand Down