Skip to content

Commit

Permalink
[AIR] Add TorchCheckpoint.from_state_dict (#27970)
Browse files Browse the repository at this point in the history
PyTorch recommends saving state dictionaries instead of modules, but we don't support any way to do this.

Signed-off-by: Balaji Veeramani [email protected]
  • Loading branch information
bveeramani authored Aug 30, 2022
1 parent 8a30606 commit dad98dc
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 12 deletions.
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_torch_checkpoint",
size = "small",
srcs = ["tests/test_torch_checkpoint.py"],
tags = ["team:ml", "exclusive", "ray_air"],
deps = [":train_lib"]
)

py_test(
name = "test_torch_predictor",
size = "small",
Expand Down
19 changes: 19 additions & 0 deletions python/ray/train/tests/test_torch_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

from ray.train.torch import TorchCheckpoint


def test_from_state_dict():
model = torch.nn.Linear(1, 1)
expected_state_dict = model.state_dict()
checkpoint = TorchCheckpoint.from_state_dict(expected_state_dict)
actual_state_dict = checkpoint.get_model(torch.nn.Linear(1, 1)).state_dict()
assert actual_state_dict == expected_state_dict


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", "-x", __file__]))
66 changes: 54 additions & 12 deletions python/ray/train/torch/torch_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

import torch

Expand All @@ -14,29 +14,72 @@

@PublicAPI(stability="beta")
class TorchCheckpoint(Checkpoint):
"""A :py:class:`~ray.air.checkpoint.Checkpoint` with Torch-specific
functionality.
"""A :class:`~ray.air.checkpoint.Checkpoint` with Torch-specific functionality.
Create this from a generic :py:class:`~ray.air.checkpoint.Checkpoint` by calling
Create this from a generic :class:`~ray.air.checkpoint.Checkpoint` by calling
``TorchCheckpoint.from_checkpoint(ckpt)``.
"""

@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, Any],
*,
preprocessor: Optional["Preprocessor"] = None,
) -> "TorchCheckpoint":
"""Create a :class:`~ray.air.checkpoint.Checkpoint` that stores a model state
dictionary.
.. tip::
This is the recommended method for creating
:class:`TorchCheckpoints<TorchCheckpoint>`.
Args:
state_dict: The model state dictionary to store in the checkpoint.
preprocessor: A fitted preprocessor to be applied before inference.
Returns:
A :class:`TorchCheckpoint` containing the specified state dictionary.
Examples:
>>> from ray.train.torch import TorchCheckpoint
>>> import torch
>>>
>>> model = torch.nn.Linear(1, 1)
>>> checkpoint = TorchCheckpoint.from_model(model.state_dict())
To load the state dictionary, call
:meth:`~ray.train.torch.TorchCheckpoint.get_model`.
>>> checkpoint.get_model(torch.nn.Linear(1, 1))
Linear(in_features=1, out_features=1, bias=True)
"""
return cls.from_dict({PREPROCESSOR_KEY: preprocessor, MODEL_KEY: state_dict})

@classmethod
def from_model(
cls,
model: torch.nn.Module,
*,
preprocessor: Optional["Preprocessor"] = None,
) -> "TorchCheckpoint":
"""Create a :py:class:`~ray.air.checkpoint.Checkpoint` that stores a Torch
model.
"""Create a :class:`~ray.air.checkpoint.Checkpoint` that stores a Torch model.
.. note::
PyTorch recommends storing state dictionaries. To create a
:class:`TorchCheckpoint` from a state dictionary, call
:meth:`~ray.train.torch.TorchCheckpoint.from_state_dict`. To learn more
about state dictionaries, read
`Saving and Loading Models <https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict>`_.
Args:
model: The Torch model to store in the checkpoint.
preprocessor: A fitted preprocessor to be applied before inference.
Returns:
An :py:class:`TorchCheckpoint` containing the specified model.
A :class:`TorchCheckpoint` containing the specified model.
Examples:
>>> from ray.train.torch import TorchCheckpoint
Expand All @@ -45,15 +88,14 @@ def from_model(
>>> model = torch.nn.Identity()
>>> checkpoint = TorchCheckpoint.from_model(model)
You can use a :py:class:`TorchCheckpoint` to create an
:py:class:`~ray.train.torch.TorchPredictor` and preform inference.
You can use a :class:`TorchCheckpoint` to create an
:class:`~ray.train.torch.TorchPredictor` and perform inference.
>>> from ray.train.torch import TorchPredictor
>>>
>>> predictor = TorchPredictor.from_checkpoint(checkpoint)
"""
checkpoint = cls.from_dict({PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model})
return checkpoint
""" # noqa: E501
return cls.from_dict({PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model})

def get_model(self, model: Optional[torch.nn.Module] = None) -> torch.nn.Module:
"""Retrieve the model stored in this checkpoint.
Expand Down

0 comments on commit dad98dc

Please sign in to comment.