diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index 270d130badb9..8323d0ad5f8c 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -16,20 +16,21 @@ class TorchTrainer(DataParallelTrainer): """A Trainer for data parallel PyTorch training. This Trainer runs the function ``train_loop_per_worker`` on multiple Ray - Actors. These actors already have the necessary torch process group already + Actors. These actors already have the necessary torch process group configured for distributed PyTorch training. The ``train_loop_per_worker`` function is expected to take in either 0 or 1 arguments: - .. code-block:: python + .. testcode:: def train_loop_per_worker(): ... - .. code-block:: python + .. testcode:: - def train_loop_per_worker(config: Dict): + from typing import Dict, Any + def train_loop_per_worker(config: Dict[str, Any]): ... If ``train_loop_per_worker`` accepts an argument, then @@ -43,34 +44,34 @@ def train_loop_per_worker(config: Dict): ``session.get_dataset_shard(...)`` will return the the entire Dataset. Inside the ``train_loop_per_worker`` function, you can use any of the - :ref:`Ray AIR session methods `. + :ref:`Ray AIR session methods `. See full example code below. - .. code-block:: python + .. testcode:: def train_loop_per_worker(): # Report intermediate results for callbacks or logging and # checkpoint data. session.report(...) - # Returns dict of last saved checkpoint. + # Get dict of last saved checkpoint. session.get_checkpoint() - # Returns the Ray Dataset shard for the given key. + # Session returns the Ray Dataset shard for the given key. session.get_dataset_shard("my_dataset") - # Returns the total number of workers executing training. + # Get the total number of workers executing training. session.get_world_size() - # Returns the rank of this worker. + # Get the rank of this worker. session.get_world_rank() - # Returns the rank of the worker on the current node. + # Get the rank of the worker on the current node. session.get_local_rank() You can also use any of the Torch specific function utils, such as :func:`ray.train.torch.get_device` and :func:`ray.train.torch.prepare_model` - .. code-block:: python + .. testcode:: def train_loop_per_worker(): # Prepares model for distribted training by wrapping in @@ -83,7 +84,7 @@ def train_loop_per_worker(): # `session.get_dataset_shard(...).iter_torch_batches(...)` train.torch.prepare_data_loader(...) - # Returns the current torch device. + # Get the current torch device. train.torch.get_device() Any returns from the ``train_loop_per_worker`` will be discarded and not @@ -93,7 +94,8 @@ def train_loop_per_worker(): "model" kwarg in ``Checkpoint`` passed to ``session.report()``. Example: - .. code-block:: python + + .. testcode:: import torch import torch.nn as nn @@ -103,12 +105,17 @@ def train_loop_per_worker(): from ray.air import session, Checkpoint from ray.train.torch import TorchTrainer from ray.air.config import ScalingConfig + from ray.air.config import RunConfig + from ray.air.config import CheckpointConfig + # Define NN layers archicture, epochs, and number of workers input_size = 1 - layer_size = 15 + layer_size = 32 output_size = 1 - num_epochs = 3 + num_epochs = 200 + num_workers = 3 + # Define your network structure class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() @@ -119,46 +126,82 @@ def __init__(self): def forward(self, input): return self.layer2(self.relu(self.layer1(input))) + # Define your train worker loop def train_loop_per_worker(): + + # Fetch training set from the session dataset_shard = session.get_dataset_shard("train") model = NeuralNetwork() - loss_fn = nn.MSELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + # Loss function, optimizer, prepare model for training. + # This moves the data and prepares model for distributed + # execution + loss_fn = nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), + lr=0.01, + weight_decay=0.01) model = train.torch.prepare_model(model) + # Iterate over epochs and batches for epoch in range(num_epochs): - for batches in dataset_shard.iter_torch_batches( - batch_size=32, dtypes=torch.float - ): + for batches in dataset_shard.iter_torch_batches(batch_size=32, + dtypes=torch.float): + + # Add batch or unsqueeze as an additional dimension [32, x] inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"] output = model(inputs) - loss = loss_fn(output, labels) + + # Make output shape same as the as labels + loss = loss_fn(output.squeeze(), labels) + + # Zero out grads, do backward, and update optimizer optimizer.zero_grad() loss.backward() optimizer.step() - print(f"epoch: {epoch}, loss: {loss.item()}") - session.report( - {}, - checkpoint=Checkpoint.from_dict( - dict(epoch=epoch, model=model.state_dict() - ), + # Print what's happening with loss per 30 epochs + if epoch % 20 == 0: + print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}") + + # Report and record metrics, checkpoint model at end of each + # epoch + session.report({"loss": loss.item(), "epoch": epoch}, + checkpoint=Checkpoint.from_dict( + dict(epoch=epoch, model=model.state_dict())) ) + torch.manual_seed(42) train_dataset = ray.data.from_items( [{"x": x, "y": 2 * x + 1} for x in range(200)] ) - scaling_config = ScalingConfig(num_workers=3) + + # Define scaling and run configs # If using GPUs, use the below scaling config instead. # scaling_config = ScalingConfig(num_workers=3, use_gpu=True) + scaling_config = ScalingConfig(num_workers=num_workers) + run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1)) + trainer = TorchTrainer( train_loop_per_worker=train_loop_per_worker, scaling_config=scaling_config, + run_config=run_config, datasets={"train": train_dataset}) + result = trainer.fit() + best_checkpoint_loss = result.metrics['loss'] + + # Assert loss is less 0.09 + assert best_checkpoint_loss <= 0.09 + + .. testoutput:: + :hide: + :options: +ELLIPSIS + + ... + Args: + train_loop_per_worker: The training function to execute. This can either take in no arguments or a ``config`` dict. train_loop_config: Configurations to pass into