diff --git a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py index 46ea6ab3947a..90846eb84824 100644 --- a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py +++ b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py @@ -69,7 +69,6 @@ def train_func(config): epochs = config.get("epochs", 3) model = resnet18() - model = train.torch.prepare_model(model) # Create optimizer. optimizer_config = { @@ -98,6 +97,8 @@ def train_func(config): checkpoint_epoch = checkpoint_dict["epoch"] starting_epoch = checkpoint_epoch + 1 + model = train.torch.prepare_model(model) + # Load in training and validation data. transform_train = transforms.Compose( [ diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index 8323d0ad5f8c..a21a6a93adb0 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -93,6 +93,40 @@ def train_loop_per_worker(): To save a model to use for the ``TorchPredictor``, you must save it under the "model" kwarg in ``Checkpoint`` passed to ``session.report()``. + .. note:: + When you wrap the ``model`` with ``prepare_model``, the keys of its + ``state_dict`` are prefixed by ``module.``. For example, + ``layer1.0.bn1.bias`` becomes ``module.layer1.0.bn1.bias``. + However, when saving ``model`` through ``session.report()`` + all ``module.`` prefixes are stripped. + As a result, when you load from a saved checkpoint, make sure that + you first load ``state_dict`` to the model + before calling ``prepare_model``. + Otherwise, you will run into errors like + ``Error(s) in loading state_dict for DistributedDataParallel: + Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below. + + .. testcode:: + + from torchvision.models import resnet18 + from ray.air import session + from ray.air.checkpoint import Checkpoint + import ray.train as train + + def train_func(): + ... + model = resnet18() + model = train.torch.prepare_model(model) + for epoch in range(3): + ... + ckpt = Checkpoint.from_dict({ + "epoch": epoch, + "model": model.state_dict(), + # "model": model.module.state_dict(), + # ** The above two are equivalent ** + }) + session.report({"foo": "bar"}, ckpt) + Example: .. testcode::