From f646204dc0f38365850a7555e29deb6df3ee19f4 Mon Sep 17 00:00:00 2001 From: xwjiang2010 Date: Thu, 19 Jan 2023 15:42:08 -0800 Subject: [PATCH 1/4] [release] fix pytorch pbt failure test. The model is already wrapped in DDP. Need to access through `.module.model`. Signed-off-by: xwjiang2010 --- .../ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py index 46ea6ab3947a..d8e8655b42b4 100644 --- a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py +++ b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py @@ -146,7 +146,7 @@ def train_func(config): checkpoint = Checkpoint.from_dict( { "epoch": epoch, - "model": model.state_dict(), + "model": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), } ) From f0f4559998ce57d01491f7e2289e7a085f69bdfa Mon Sep 17 00:00:00 2001 From: xwjiang2010 Date: Mon, 23 Jan 2023 10:26:07 -0800 Subject: [PATCH 2/4] fix test. Signed-off-by: xwjiang2010 --- .../pytorch/tune_cifar_torch_pbt_example.py | 5 +-- python/ray/train/torch/torch_trainer.py | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py index d8e8655b42b4..90846eb84824 100644 --- a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py +++ b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py @@ -69,7 +69,6 @@ def train_func(config): epochs = config.get("epochs", 3) model = resnet18() - model = train.torch.prepare_model(model) # Create optimizer. optimizer_config = { @@ -98,6 +97,8 @@ def train_func(config): checkpoint_epoch = checkpoint_dict["epoch"] starting_epoch = checkpoint_epoch + 1 + model = train.torch.prepare_model(model) + # Load in training and validation data. transform_train = transforms.Compose( [ @@ -146,7 +147,7 @@ def train_func(config): checkpoint = Checkpoint.from_dict( { "epoch": epoch, - "model": model.module.state_dict(), + "model": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), } ) diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index 8323d0ad5f8c..63cfc5334d8a 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -93,6 +93,40 @@ def train_loop_per_worker(): To save a model to use for the ``TorchPredictor``, you must save it under the "model" kwarg in ``Checkpoint`` passed to ``session.report()``. + .. note:: + If you are wrapping your ``model`` with ``prepare_model``, save ``model`` to + session is equivalent to saving ``model.module`` to session. + When you load from a saved checkpoint, make sure that you first + load ``state_dict`` to model before calling ``prepare_model``. + Otherwise, you will run into errors like + ``Error(s) in loading state_dict for DistributedDataParallel: + Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below. + + .. testcode:: + + from torchvision.models import resnet18 + from ray.air import session + from ray.air.checkpoint import Checkpoint + import ray.train as train + + def train_func(): + ... + model = resnet18() + # The following wraps model with DDP. + # An effect of that is in `state_dict`, keys are prefixed by + # `module.`. For example: + # `layer1.0.bn1.bias` becomes `module.layer1.0.bn1.bias`. + model = train.torch.prepare_model(model) + for epoch in range(3): + ... + ckpt = Checkpoint.from_dict( + "epoch": epoch, + "model": model.state_dict(), + # "model": model.module.state_dict(), + # ** The above two are equivalent ** + ) + session.report({"foo": "bar"}, ckpt) + Example: .. testcode:: From c8bcfcda6684601bb0094fab57270a9ec63f2400 Mon Sep 17 00:00:00 2001 From: xwjiang2010 Date: Mon, 23 Jan 2023 13:35:29 -0800 Subject: [PATCH 3/4] fix doc string. Signed-off-by: xwjiang2010 --- python/ray/train/torch/torch_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index 63cfc5334d8a..e8fc2b25ccde 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -119,12 +119,12 @@ def train_func(): model = train.torch.prepare_model(model) for epoch in range(3): ... - ckpt = Checkpoint.from_dict( + ckpt = Checkpoint.from_dict({ "epoch": epoch, "model": model.state_dict(), # "model": model.module.state_dict(), # ** The above two are equivalent ** - ) + }) session.report({"foo": "bar"}, ckpt) Example: From 116b1673e244a615b4745497382ba97aacf4a314 Mon Sep 17 00:00:00 2001 From: xwjiang2010 Date: Mon, 23 Jan 2023 15:12:44 -0800 Subject: [PATCH 4/4] address comments Signed-off-by: xwjiang2010 --- python/ray/train/torch/torch_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index e8fc2b25ccde..a21a6a93adb0 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -94,10 +94,14 @@ def train_loop_per_worker(): "model" kwarg in ``Checkpoint`` passed to ``session.report()``. .. note:: - If you are wrapping your ``model`` with ``prepare_model``, save ``model`` to - session is equivalent to saving ``model.module`` to session. - When you load from a saved checkpoint, make sure that you first - load ``state_dict`` to model before calling ``prepare_model``. + When you wrap the ``model`` with ``prepare_model``, the keys of its + ``state_dict`` are prefixed by ``module.``. For example, + ``layer1.0.bn1.bias`` becomes ``module.layer1.0.bn1.bias``. + However, when saving ``model`` through ``session.report()`` + all ``module.`` prefixes are stripped. + As a result, when you load from a saved checkpoint, make sure that + you first load ``state_dict`` to the model + before calling ``prepare_model``. Otherwise, you will run into errors like ``Error(s) in loading state_dict for DistributedDataParallel: Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below. @@ -112,10 +116,6 @@ def train_loop_per_worker(): def train_func(): ... model = resnet18() - # The following wraps model with DDP. - # An effect of that is in `state_dict`, keys are prefixed by - # `module.`. For example: - # `layer1.0.bn1.bias` becomes `module.layer1.0.bn1.bias`. model = train.torch.prepare_model(model) for epoch in range(3): ...