-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RAY AIR][DOC][TorchTrainer] Rewrote the TorchTrainer code snippet as a working example #30492
Changes from all commits
ede73e5
30b5d68
1164186
d4bb5a3
46338ab
64db02b
d7dcebf
2f1b306
e4187c2
7a4b2b9
35de73c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,20 +16,21 @@ class TorchTrainer(DataParallelTrainer): | |
"""A Trainer for data parallel PyTorch training. | ||
|
||
This Trainer runs the function ``train_loop_per_worker`` on multiple Ray | ||
Actors. These actors already have the necessary torch process group already | ||
Actors. These actors already have the necessary torch process group | ||
configured for distributed PyTorch training. | ||
|
||
The ``train_loop_per_worker`` function is expected to take in either 0 or 1 | ||
arguments: | ||
|
||
.. code-block:: python | ||
.. testcode:: | ||
|
||
def train_loop_per_worker(): | ||
... | ||
|
||
.. code-block:: python | ||
.. testcode:: | ||
|
||
def train_loop_per_worker(config: Dict): | ||
from typing import Dict, Any | ||
def train_loop_per_worker(config: Dict[str, Any]): | ||
... | ||
|
||
If ``train_loop_per_worker`` accepts an argument, then | ||
|
@@ -43,34 +44,34 @@ def train_loop_per_worker(config: Dict): | |
``session.get_dataset_shard(...)`` will return the the entire Dataset. | ||
|
||
Inside the ``train_loop_per_worker`` function, you can use any of the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally there would also be an example for the above paragraph somewhere, we can feel free to do that in another PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (You can discard this, I saw the usage is already shown in the example below -- maybe add |
||
:ref:`Ray AIR session methods <air-session-ref>`. | ||
:ref:`Ray AIR session methods <air-session-ref>`. See full example code below. | ||
|
||
.. code-block:: python | ||
.. testcode:: | ||
|
||
def train_loop_per_worker(): | ||
# Report intermediate results for callbacks or logging and | ||
# checkpoint data. | ||
session.report(...) | ||
|
||
# Returns dict of last saved checkpoint. | ||
# Get dict of last saved checkpoint. | ||
session.get_checkpoint() | ||
|
||
# Returns the Ray Dataset shard for the given key. | ||
# Session returns the Ray Dataset shard for the given key. | ||
session.get_dataset_shard("my_dataset") | ||
|
||
# Returns the total number of workers executing training. | ||
# Get the total number of workers executing training. | ||
session.get_world_size() | ||
|
||
# Returns the rank of this worker. | ||
# Get the rank of this worker. | ||
session.get_world_rank() | ||
|
||
# Returns the rank of the worker on the current node. | ||
# Get the rank of the worker on the current node. | ||
session.get_local_rank() | ||
|
||
You can also use any of the Torch specific function utils, | ||
such as :func:`ray.train.torch.get_device` and :func:`ray.train.torch.prepare_model` | ||
|
||
.. code-block:: python | ||
.. testcode:: | ||
|
||
def train_loop_per_worker(): | ||
# Prepares model for distribted training by wrapping in | ||
|
@@ -83,7 +84,7 @@ def train_loop_per_worker(): | |
# `session.get_dataset_shard(...).iter_torch_batches(...)` | ||
train.torch.prepare_data_loader(...) | ||
|
||
# Returns the current torch device. | ||
# Get the current torch device. | ||
train.torch.get_device() | ||
|
||
Any returns from the ``train_loop_per_worker`` will be discarded and not | ||
|
@@ -93,7 +94,8 @@ def train_loop_per_worker(): | |
"model" kwarg in ``Checkpoint`` passed to ``session.report()``. | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
.. testcode:: | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
@@ -103,12 +105,17 @@ def train_loop_per_worker(): | |
from ray.air import session, Checkpoint | ||
from ray.train.torch import TorchTrainer | ||
from ray.air.config import ScalingConfig | ||
from ray.air.config import RunConfig | ||
from ray.air.config import CheckpointConfig | ||
|
||
# Define NN layers archicture, epochs, and number of workers | ||
input_size = 1 | ||
layer_size = 15 | ||
layer_size = 32 | ||
output_size = 1 | ||
num_epochs = 3 | ||
num_epochs = 200 | ||
num_workers = 3 | ||
|
||
# Define your network structure | ||
class NeuralNetwork(nn.Module): | ||
def __init__(self): | ||
super(NeuralNetwork, self).__init__() | ||
|
@@ -119,46 +126,82 @@ def __init__(self): | |
def forward(self, input): | ||
return self.layer2(self.relu(self.layer1(input))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would either keep the ReLU layer here or have only one linear layer -- composing two linear layers doesn't do anything and it would likely be confusing to users :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keeping ReLU does not make sense. Why add non-linearity to a linear data relationship With ReLU the model's does not converge, it goes on like a seesaw. Having two linear layers is not uncommon. We can put in a comment, you can also use one layer if you relationship between your data and outcome (target) is linear. |
||
|
||
# Define your train worker loop | ||
def train_loop_per_worker(): | ||
|
||
# Fetch training set from the session | ||
dataset_shard = session.get_dataset_shard("train") | ||
model = NeuralNetwork() | ||
loss_fn = nn.MSELoss() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
|
||
# Loss function, optimizer, prepare model for training. | ||
# This moves the data and prepares model for distributed | ||
# execution | ||
loss_fn = nn.MSELoss() | ||
optimizer = torch.optim.Adam(model.parameters(), | ||
lr=0.01, | ||
weight_decay=0.01) | ||
model = train.torch.prepare_model(model) | ||
|
||
# Iterate over epochs and batches | ||
for epoch in range(num_epochs): | ||
for batches in dataset_shard.iter_torch_batches( | ||
batch_size=32, dtypes=torch.float | ||
): | ||
for batches in dataset_shard.iter_torch_batches(batch_size=32, | ||
dtypes=torch.float): | ||
|
||
# Add batch or unsqueeze as an additional dimension [32, x] | ||
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"] | ||
output = model(inputs) | ||
loss = loss_fn(output, labels) | ||
|
||
# Make output shape same as the as labels | ||
loss = loss_fn(output.squeeze(), labels) | ||
|
||
# Zero out grads, do backward, and update optimizer | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
print(f"epoch: {epoch}, loss: {loss.item()}") | ||
|
||
session.report( | ||
{}, | ||
checkpoint=Checkpoint.from_dict( | ||
dict(epoch=epoch, model=model.state_dict() | ||
), | ||
# Print what's happening with loss per 30 epochs | ||
if epoch % 20 == 0: | ||
print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}") | ||
|
||
# Report and record metrics, checkpoint model at end of each | ||
# epoch | ||
session.report({"loss": loss.item(), "epoch": epoch}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is confusing since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One is reporting the loss per epoch as metrics, the other is there for checkpoint per epoch. Nice to have that metrics per epoch. If @amogkam feels strongly that we should not include "epoch" in the metrics to report, then I can remove that entity. |
||
checkpoint=Checkpoint.from_dict( | ||
dict(epoch=epoch, model=model.state_dict())) | ||
) | ||
|
||
torch.manual_seed(42) | ||
train_dataset = ray.data.from_items( | ||
[{"x": x, "y": 2 * x + 1} for x in range(200)] | ||
) | ||
scaling_config = ScalingConfig(num_workers=3) | ||
|
||
# Define scaling and run configs | ||
# If using GPUs, use the below scaling config instead. | ||
# scaling_config = ScalingConfig(num_workers=3, use_gpu=True) | ||
scaling_config = ScalingConfig(num_workers=num_workers) | ||
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1)) | ||
|
||
trainer = TorchTrainer( | ||
train_loop_per_worker=train_loop_per_worker, | ||
scaling_config=scaling_config, | ||
run_config=run_config, | ||
datasets={"train": train_dataset}) | ||
|
||
result = trainer.fit() | ||
|
||
best_checkpoint_loss = result.metrics['loss'] | ||
|
||
# Assert loss is less 0.09 | ||
assert best_checkpoint_loss <= 0.09 | ||
|
||
.. testoutput:: | ||
:hide: | ||
:options: +ELLIPSIS | ||
|
||
... | ||
|
||
Args: | ||
|
||
train_loop_per_worker: The training function to execute. | ||
This can either take in no arguments or a ``config`` dict. | ||
train_loop_config: Configurations to pass into | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the paragraph above this, it says "already" twice in the sentence -- it would be great to also fix this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. Fixed