Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RAY AIR][DOC][TorchTrainer] Rewrote the TorchTrainer code snippet as a working example #30492

Merged
merged 11 commits into from
Nov 28, 2022
101 changes: 72 additions & 29 deletions python/ray/train/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@ class TorchTrainer(DataParallelTrainer):
"""A Trainer for data parallel PyTorch training.

This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
Actors. These actors already have the necessary torch process group already
Actors. These actors already have the necessary torch process group
configured for distributed PyTorch training.

The ``train_loop_per_worker`` function is expected to take in either 0 or 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paragraph above this, it says "already" twice in the sentence -- it would be great to also fix this :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. Fixed

arguments:

.. code-block:: python
.. testcode::

def train_loop_per_worker():
...

.. code-block:: python
.. testcode::

def train_loop_per_worker(config: Dict):
from typing import Dict, Any
def train_loop_per_worker(config: Dict[str, Any]):
...

If ``train_loop_per_worker`` accepts an argument, then
Expand All @@ -43,34 +44,34 @@ def train_loop_per_worker(config: Dict):
``session.get_dataset_shard(...)`` will return the the entire Dataset.

Inside the ``train_loop_per_worker`` function, you can use any of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally there would also be an example for the above paragraph somewhere, we can feel free to do that in another PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(You can discard this, I saw the usage is already shown in the example below -- maybe add (see example below).

:ref:`Ray AIR session methods <air-session-ref>`.
:ref:`Ray AIR session methods <air-session-ref>`. See full example code below.

.. code-block:: python
.. testcode::

def train_loop_per_worker():
# Report intermediate results for callbacks or logging and
# checkpoint data.
session.report(...)

# Returns dict of last saved checkpoint.
# Get dict of last saved checkpoint.
session.get_checkpoint()

# Returns the Ray Dataset shard for the given key.
# Session returns the Ray Dataset shard for the given key.
session.get_dataset_shard("my_dataset")

# Returns the total number of workers executing training.
# Get the total number of workers executing training.
session.get_world_size()

# Returns the rank of this worker.
# Get the rank of this worker.
session.get_world_rank()

# Returns the rank of the worker on the current node.
# Get the rank of the worker on the current node.
session.get_local_rank()

You can also use any of the Torch specific function utils,
such as :func:`ray.train.torch.get_device` and :func:`ray.train.torch.prepare_model`

.. code-block:: python
.. testcode::

def train_loop_per_worker():
# Prepares model for distribted training by wrapping in
Expand All @@ -83,7 +84,7 @@ def train_loop_per_worker():
# `session.get_dataset_shard(...).iter_torch_batches(...)`
train.torch.prepare_data_loader(...)

# Returns the current torch device.
# Get the current torch device.
train.torch.get_device()

Any returns from the ``train_loop_per_worker`` will be discarded and not
Expand All @@ -93,7 +94,8 @@ def train_loop_per_worker():
"model" kwarg in ``Checkpoint`` passed to ``session.report()``.

Example:
.. code-block:: python

.. testcode::

import torch
import torch.nn as nn
Expand All @@ -103,12 +105,17 @@ def train_loop_per_worker():
from ray.air import session, Checkpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air.config import RunConfig
from ray.air.config import CheckpointConfig

# Define NN layers archicture, epochs, and number of workers
input_size = 1
layer_size = 15
layer_size = 32
output_size = 1
num_epochs = 3
num_epochs = 200
num_workers = 3

# Define your network structure
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
Expand All @@ -119,46 +126,82 @@ def __init__(self):
def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would either keep the ReLU layer here or have only one linear layer -- composing two linear layers doesn't do anything and it would likely be confusing to users :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping ReLU does not make sense. Why add non-linearity to a linear data relationship With ReLU the model's does not converge, it goes on like a seesaw. Having two linear layers is not uncommon. We can put in a comment, you can also use one layer if you relationship between your data and outcome (target) is linear.


# Define your train worker loop
def train_loop_per_worker():

# Fetch training set from the session
dataset_shard = session.get_dataset_shard("train")
model = NeuralNetwork()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Loss function, optimizer, prepare model for training.
# This moves the data and prepares model for distributed
# execution
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=0.01,
weight_decay=0.01)
model = train.torch.prepare_model(model)

# Iterate over epochs and batches
for epoch in range(num_epochs):
for batches in dataset_shard.iter_torch_batches(
batch_size=32, dtypes=torch.float
):
for batches in dataset_shard.iter_torch_batches(batch_size=32,
dtypes=torch.float):

# Add batch or unsqueeze as an additional dimension [32, x]
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
output = model(inputs)
loss = loss_fn(output, labels)

# Make output shape same as the as labels
loss = loss_fn(output.squeeze(), labels)

# Zero out grads, do backward, and update optimizer
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch: {epoch}, loss: {loss.item()}")

session.report(
{},
checkpoint=Checkpoint.from_dict(
dict(epoch=epoch, model=model.state_dict()
),
# Print what's happening with loss per 30 epochs
if epoch % 20 == 0:
print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}")

# Report and record metrics, checkpoint model at end of each
# epoch
session.report({"loss": loss.item(), "epoch": epoch},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing since epoch is both here and below, @amogkam can you recommend how to do this? Most users will follow the example, so we should make sure we do this well :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One is reporting the loss per epoch as metrics, the other is there for checkpoint per epoch. Nice to have that metrics per epoch. If @amogkam feels strongly that we should not include "epoch" in the metrics to report, then I can remove that entity.

checkpoint=Checkpoint.from_dict(
dict(epoch=epoch, model=model.state_dict()))
)

torch.manual_seed(42)
train_dataset = ray.data.from_items(
[{"x": x, "y": 2 * x + 1} for x in range(200)]
)
scaling_config = ScalingConfig(num_workers=3)

# Define scaling and run configs
# If using GPUs, use the below scaling config instead.
# scaling_config = ScalingConfig(num_workers=3, use_gpu=True)
scaling_config = ScalingConfig(num_workers=num_workers)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))

trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
scaling_config=scaling_config,
run_config=run_config,
datasets={"train": train_dataset})

result = trainer.fit()

best_checkpoint_loss = result.metrics['loss']

# Assert loss is less 0.09
assert best_checkpoint_loss <= 0.09

.. testoutput::
:hide:
:options: +ELLIPSIS

...

Args:

train_loop_per_worker: The training function to execute.
This can either take in no arguments or a ``config`` dict.
train_loop_config: Configurations to pass into
Expand Down