diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 80a89c99709e..f6160ad74c81 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -47,7 +47,7 @@ def from_state_dict( >>> import torch >>> >>> model = torch.nn.Linear(1, 1) - >>> checkpoint = TorchCheckpoint.from_model(model.state_dict()) + >>> checkpoint = TorchCheckpoint.from_state_dict(model.state_dict()) To load the state dictionary, call :meth:`~ray.train.torch.TorchCheckpoint.get_model`.