Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Replace train. with session. #26303

Merged
merged 99 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
b39a864
Use new Train API for examples
Yard1 Jun 13, 2022
b31399e
Fix FailureConfig not being a dataclass
Yard1 Jun 14, 2022
5cc9229
Fix errors
Yard1 Jun 14, 2022
baaf8a5
Merge branch 'master' into use_new_train_api
Yard1 Jun 14, 2022
5230218
Fix
Yard1 Jun 14, 2022
ef4a3fc
Fix link
Yard1 Jun 14, 2022
f5cfe62
Fix simple example
Yard1 Jun 14, 2022
468f7e8
train loop utils
Yard1 Jun 14, 2022
4ef6302
Remove tensorboard example
Yard1 Jun 14, 2022
5db3c14
PBT test update
Yard1 Jun 14, 2022
cb805f2
WIP
Yard1 Jun 14, 2022
2f69e37
Do not use pipeline
Yard1 Jun 15, 2022
0d8eeb4
Remove callback test
Yard1 Jun 15, 2022
4a3103e
Examples tests
Yard1 Jun 15, 2022
f7f3ea8
Move tests
Yard1 Jun 15, 2022
50ca40b
Fixture fix
Yard1 Jun 15, 2022
1872f73
Merge branch 'master' into use_new_train_api
Yard1 Jun 16, 2022
10d88d3
Merge branch 'master' into use_new_train_api
Yard1 Jun 16, 2022
20b7075
CI fixes
Yard1 Jun 16, 2022
c3b7d42
Fix
Yard1 Jun 16, 2022
33f8fd1
Merge branch 'master' into use_new_train_api
Yard1 Jun 16, 2022
37b8182
Apply suggestions from code review
Yard1 Jun 16, 2022
6f8d7e0
Fix tracked checkpoint error
Yard1 Jun 16, 2022
85cb1a7
CI fixes
Yard1 Jun 16, 2022
86a71d6
Add checkpoint configuration to `RunConfig`
Yard1 Jun 20, 2022
41eb780
Add `best_checkpoint` and `dataframe` to `Result`
Yard1 Jun 20, 2022
eb2eb67
Tests, fixes
Yard1 Jun 20, 2022
024932e
Result grid tweaks
Yard1 Jun 20, 2022
abf2cdc
Extend
Yard1 Jun 20, 2022
1f1d28b
Merge branch 'ray-project:master' into more_checkpoint_configurability
Yard1 Jun 20, 2022
563bc33
Update result_grid.py
Yard1 Jun 21, 2022
d0261be
Fix
Yard1 Jun 21, 2022
56df493
Lint
Yard1 Jun 21, 2022
ef0c75a
Lint
Yard1 Jun 21, 2022
3464c93
WIP
Yard1 Jun 21, 2022
ee87c12
Renaming
Yard1 Jun 21, 2022
fe9d68e
Merge branch 'more_checkpoint_configurability' into use_new_train_api
Yard1 Jun 21, 2022
b10fe1e
Improve test coverage
Yard1 Jun 21, 2022
4dbccca
Simplify
Yard1 Jun 21, 2022
27e531c
Docstring tweak
Yard1 Jun 21, 2022
7d1abfe
Remove docstring
Yard1 Jun 21, 2022
b0dd3ba
Fix
Yard1 Jun 21, 2022
1c2e4b1
Merge branch 'more_checkpoint_configurability' into use_new_train_api
Yard1 Jun 21, 2022
5b226ab
Tweak docstring
Yard1 Jun 21, 2022
65ce1d3
Fix
Yard1 Jun 21, 2022
555f705
Merge branch 'more_checkpoint_configurability' into use_new_train_api
Yard1 Jun 21, 2022
1e1fbea
Use CheckpointStrategy
Yard1 Jun 22, 2022
3aa277d
Merge branch 'master' into more_checkpoint_configurability
Yard1 Jun 22, 2022
e19d40f
Fix
Yard1 Jun 22, 2022
5cbb15f
Merge branch 'master' into more_checkpoint_configurability
Yard1 Jun 24, 2022
fd96174
dataframe -> metrics_dataframe
Yard1 Jun 24, 2022
8d5f1b3
CheckpointStrategy -> CheckpointConfig
Yard1 Jun 24, 2022
0482bce
Missed this
Yard1 Jun 24, 2022
207d8d1
Merge branch 'more_checkpoint_configurability' into use_new_train_api
Yard1 Jun 24, 2022
0cb579d
Update test_result_grid.py
Yard1 Jun 24, 2022
7ade7e4
Fix
Yard1 Jun 24, 2022
0937dc8
Apply feeedback from code review
Yard1 Jun 24, 2022
49ffb18
Merge branch 'more_checkpoint_configurability' into use_new_train_api
Yard1 Jun 24, 2022
b993627
Fix lint
Yard1 Jun 24, 2022
9244b8e
Merge branch 'more_checkpoint_configurability' into use_new_train_api
Yard1 Jun 24, 2022
ed870bd
Update python/ray/train/__init__.py
Yard1 Jun 24, 2022
ad90782
Merge branch 'master' into more_checkpoint_configurability
Yard1 Jun 27, 2022
c777bb5
Merge branch 'master' into use_new_train_api
Yard1 Jun 27, 2022
77305b2
Merge branch 'more_checkpoint_configurability' into use_new_train_api
Yard1 Jun 27, 2022
a4fd532
Fix CI
Yard1 Jun 27, 2022
d0ae2ba
Use warnings.warn
Yard1 Jun 28, 2022
d44f750
Make method privat
Yard1 Jun 28, 2022
c9d3380
Update python/ray/util/ml_utils/checkpoint_manager.py
Yard1 Jun 28, 2022
5c0a753
Update checkpoint_manager.py
Yard1 Jun 28, 2022
19108f4
Merge branch 'more_checkpoint_configurability' into use_new_train_api
Yard1 Jun 29, 2022
44f62e0
Merge branch 'master' into use_new_train_api
Yard1 Jun 29, 2022
c7b783b
Fix test
Yard1 Jun 29, 2022
2e9ec66
Rename files
Yard1 Jun 30, 2022
2bf89d2
Use keras callback
Yard1 Jun 30, 2022
375790e
Revert docstring changes
Yard1 Jun 30, 2022
de5103e
Merge branch 'master' into use_new_train_api
Yard1 Jun 30, 2022
baaaf47
Rename example files in docs
Yard1 Jun 30, 2022
d931a50
Merge branch 'master' into use_new_train_api
Yard1 Jun 30, 2022
691ce99
Add legacy tests
Yard1 Jun 30, 2022
b407873
Merge branch 'master' into use_new_train_api
Yard1 Jul 5, 2022
d9122c3
Switch to session in train code
Yard1 Jul 5, 2022
17366ce
Update docs
Yard1 Jul 5, 2022
cc7d066
Fix horovod test
Yard1 Jul 5, 2022
6cf961e
Fix CI
Yard1 Jul 5, 2022
330c36b
Fix CI
Yard1 Jul 5, 2022
30c9ab8
Fix CI
Yard1 Jul 5, 2022
2c7611c
Merge branch 'ray-project:master' into use_new_train_api
Yard1 Jul 6, 2022
618f7b7
Merge branch 'master' into replace_train_with_session
Yard1 Jul 6, 2022
d0affbc
Fix tests
Yard1 Jul 6, 2022
587ad56
Add todo
Yard1 Jul 6, 2022
0b05727
Merge branch 'master' into use_new_train_api
Yard1 Jul 6, 2022
139f44d
Use `trial_logdir` instead
Yard1 Jul 6, 2022
3a4d3f3
Fix
Yard1 Jul 6, 2022
a064f96
Merge branch 'ray-project:master' into use_new_train_api
Yard1 Jul 7, 2022
302d336
Merge branch 'ray-project:master' into use_new_train_api
Yard1 Jul 7, 2022
2ea93d7
Only print metrics
Yard1 Jul 7, 2022
f0d3beb
Merge branch 'master' into use_new_train_api
Yard1 Jul 7, 2022
5c2a078
Merge branch 'use_new_train_api' into replace_train_with_session
Yard1 Jul 7, 2022
25a066f
Merge branch 'master' into replace_train_with_session
Yard1 Jul 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/source/ray-air/doc_code/air_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@

# __config_4__
import ray
from ray import train
from ray.air import session
from ray.data import Dataset
from ray.train.torch import TorchTrainer
from ray.air.config import DatasetConfig


def train_loop_per_worker():
# By default, bulk loading is used and returns a Dataset object.
data_shard: Dataset = train.get_dataset_shard("train")
data_shard: Dataset = session.get_dataset_shard("train")

# Manually iterate over the data 10 times (10 epochs).
for _ in range(10):
Expand All @@ -117,15 +117,15 @@ def train_loop_per_worker():

# __config_5__
import ray
from ray import train
from ray.air import session
from ray.data import DatasetPipeline
from ray.train.torch import TorchTrainer
from ray.air.config import DatasetConfig


def train_loop_per_worker():
# A DatasetPipeline object is returned when `use_stream_api` is set.
data_shard: DatasetPipeline = train.get_dataset_shard("train")
data_shard: DatasetPipeline = session.get_dataset_shard("train")

# Use iter_epochs(10) to iterate over 10 epochs of data.
for epoch in data_shard.iter_epochs(10):
Expand Down
9 changes: 5 additions & 4 deletions doc/source/ray-air/doc_code/pytorch_starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torch import nn
from torch.utils.data import DataLoader
import ray.train as train
from ray.air import session
from ray.train.torch import TorchTrainer

# Define model
Expand All @@ -52,7 +53,7 @@ def forward(self, x):


def train_epoch(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset) // train.world_size()
size = len(dataloader.dataset) // session.get_world_size()
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction error
Expand All @@ -70,7 +71,7 @@ def train_epoch(dataloader, model, loss_fn, optimizer):


def validate_epoch(dataloader, model, loss_fn):
size = len(dataloader.dataset) // train.world_size()
size = len(dataloader.dataset) // session.get_world_size()
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
Expand All @@ -94,7 +95,7 @@ def train_func(config):
lr = config["lr"]
epochs = config["epochs"]

worker_batch_size = batch_size // train.world_size()
worker_batch_size = batch_size // session.get_world_size()

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
Expand All @@ -113,7 +114,7 @@ def train_func(config):
for _ in range(epochs):
train_epoch(train_dataloader, model, loss_fn, optimizer)
loss = validate_epoch(test_dataloader, model, loss_fn)
train.report(loss=loss)
session.report(dict(loss=loss))


num_workers = 2
Expand Down
16 changes: 4 additions & 12 deletions doc/source/ray-air/doc_code/tf_starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

# __air_tf_train_start__
import tensorflow as tf
from tensorflow.keras.callbacks import Callback

import ray.train as train
from ray.air import session
from ray.air.callbacks.keras import Callback
from ray.train.tensorflow import prepare_dataset_shard
from ray.train.tensorflow import TensorflowTrainer

Expand All @@ -33,12 +33,6 @@ def build_model() -> tf.keras.Model:
return model


class TrainCheckpointReportCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
train.save_checkpoint(**{"model": self.model.get_weights()})
train.report(**logs)


def train_func(config: dict):
batch_size = config.get("batch_size", 64)
epochs = config.get("epochs", 3)
Expand All @@ -53,7 +47,7 @@ def train_func(config: dict):
metrics=[tf.keras.metrics.mean_squared_error],
)

dataset = train.get_dataset_shard("train")
dataset = session.get_dataset_shard("train")

results = []
for _ in range(epochs):
Expand All @@ -67,9 +61,7 @@ def train_func(config: dict):
batch_size=batch_size,
)
)
history = multi_worker_model.fit(
tf_dataset, callbacks=[TrainCheckpointReportCallback()]
)
history = multi_worker_model.fit(tf_dataset, callbacks=[Callback()])
results.append(history.history)
return results

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,11 @@
"\n",
"To facilitate this, we only need a few changes to the code:\n",
"\n",
"1. We import Ray Train:\n",
"1. We import Ray Train and Ray AIR Session:\n",
"\n",
"```python\n",
"import ray.train as train\n",
"from ray.air import session\n",
"```\n",
"\n",
"\n",
Expand All @@ -693,7 +694,7 @@
"3. We dynamically adjust the worker batch size according to the number of workers:\n",
"\n",
"```python\n",
" batch_size_per_worker = batch_size // train.world_size()\n",
" batch_size_per_worker = batch_size // session.get_world_size()\n",
"```\n",
"\n",
"4. We prepare the data loader for distributed data sharding:\n",
Expand All @@ -716,13 +717,13 @@
"\n",
"```python\n",
" test_loss = test(test_dataloader, model, loss_fn)\n",
" train.report(loss=test_loss)\n",
" session.report(dict(loss=test_loss))\n",
"```\n",
"\n",
"7. In the `train_epoch()` and `test_epoch()` functions we divide the `size` by the world size:\n",
"\n",
"```python\n",
" size = len(dataloader.dataset) // train.world_size() # Divide by word size\n",
" size = len(dataloader.dataset) // session.get_world_size() # Divide by word size\n",
"```\n",
"\n",
"8. In the `train_epoch()` function we can get rid of the device mapping. Ray Train does this for us:\n",
Expand All @@ -745,7 +746,7 @@
"outputs": [],
"source": [
"def train_epoch(dataloader, model, loss_fn, optimizer):\n",
" size = len(dataloader.dataset) // train.world_size() # Divide by word size\n",
" size = len(dataloader.dataset) // session.get_world_size() # Divide by word size\n",
" model.train()\n",
" for batch, (X, y) in enumerate(dataloader):\n",
" # We don't need this anymore! Ray Train does this automatically:\n",
Expand Down Expand Up @@ -781,7 +782,7 @@
"outputs": [],
"source": [
"def test_epoch(dataloader, model, loss_fn):\n",
" size = len(dataloader.dataset) // train.world_size() # Divide by word size\n",
" size = len(dataloader.dataset) // session.get_world_size() # Divide by word size\n",
" num_batches = len(dataloader)\n",
" model.eval()\n",
" test_loss, correct = 0, 0\n",
Expand Down Expand Up @@ -821,14 +822,14 @@
],
"source": [
"import ray.train as train\n",
"\n",
"from ray.air import session\n",
"\n",
"def train_func(config: dict):\n",
" batch_size = config[\"batch_size\"]\n",
" lr = config[\"lr\"]\n",
" epochs = config[\"epochs\"]\n",
" \n",
" batch_size_per_worker = batch_size // train.world_size()\n",
" batch_size_per_worker = batch_size // session.get_world_size()\n",
" \n",
" # Create data loaders.\n",
" train_dataloader = DataLoader(training_data, batch_size=batch_size_per_worker)\n",
Expand All @@ -846,7 +847,7 @@
" for t in range(epochs):\n",
" train_epoch(train_dataloader, model, loss_fn, optimizer)\n",
" test_loss = test_epoch(test_dataloader, model, loss_fn)\n",
" train.report(loss=test_loss)\n",
" session.report(dict(loss=test_loss))\n",
"\n",
" print(\"Done!\")"
]
Expand Down Expand Up @@ -1062,10 +1063,15 @@
"metadata": {},
"source": [
"### Enabling checkpointing to retrieve the model\n",
"Enabling checkpointing is pretty easy - we just need to call the `train.save_checkpoint()` API and pass the model state to it:\n",
"Enabling checkpointing is pretty easy - we just need to pass a `Checkpoint` object with the model state to the `session.report()` API.\n",
"\n",
"```python\n",
" train.save_checkpoint(epoch=t, model=model.module.state_dict())\n",
" from ray.air import Checkpoint\n",
"\n",
" checkpoint = Checkpoint.from_dict(\n",
" dict(epoch=t, model=model.module.state_dict())\n",
" )\n",
" session.report(dict(loss=test_loss), checkpoint=checkpoint)\n",
"```\n",
"\n",
"Note that the `model.module` part is needed because the model gets wrapped in `torch.nn.DistributedDataParallel` by `train.torch.prepare_model`.\n",
Expand All @@ -1086,6 +1092,8 @@
"metadata": {},
"outputs": [],
"source": [
"from ray.air import Checkpoint\n",
"\n",
"def load_data():\n",
" # Download training data from open datasets.\n",
" training_data = datasets.FashionMNIST(\n",
Expand All @@ -1110,7 +1118,7 @@
" lr = config[\"lr\"]\n",
" epochs = config[\"epochs\"]\n",
" \n",
" batch_size_per_worker = batch_size // train.world_size()\n",
" batch_size_per_worker = batch_size // session.get_world_size()\n",
" \n",
" training_data, test_data = load_data() # <- this is new!\n",
" \n",
Expand All @@ -1130,8 +1138,10 @@
" for t in range(epochs):\n",
" train_epoch(train_dataloader, model, loss_fn, optimizer)\n",
" test_loss = test_epoch(test_dataloader, model, loss_fn)\n",
" train.save_checkpoint(epoch=t, model=model.module.state_dict()) # <- this is new!\n",
" train.report(loss=test_loss)\n",
" checkpoint = Checkpoint.from_dict(\n",
" dict(epoch=t, model=model.module.state_dict())\n",
" )\n",
" session.report(dict(loss=test_loss), checkpoint=checkpoint)\n",
"\n",
" print(\"Done!\")"
]
Expand Down
12 changes: 8 additions & 4 deletions doc/source/ray-air/examples/tfx_tabular_train_to_serve.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,11 @@
},
"outputs": [],
"source": [
"from ray import train\n",
"from ray.air import session, Checkpoint\n",
"from ray.train.tensorflow import prepare_dataset_shard\n",
"from ray.tune.integration.keras import TuneReportCallback\n",
"\n",
"def train_loop_per_worker():\n",
" dataset_shard = train.get_dataset_shard(\"train\")\n",
" dataset_shard = session.get_dataset_shard(\"train\")\n",
"\n",
" strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()\n",
" with strategy.scope():\n",
Expand Down Expand Up @@ -653,7 +652,12 @@
"\n",
" model.fit(tf_dataset, verbose=0)\n",
" # This saves checkpoint in a way that can be used by Ray Serve coherently.\n",
" train.save_checkpoint(epoch=epoch, model=model.get_weights())"
" session.report(\n",
" {},\n",
" checkpoint=Checkpoint.from_dict(\n",
" dict(epoch=epoch, model=model.get_weights())\n",
" ),\n",
" )"
]
},
{
Expand Down
12 changes: 8 additions & 4 deletions doc/source/ray-air/examples/torch_image_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@
"\n",
"`train_loop_per_worker` contains regular PyTorch code with a few notable exceptions:\n",
"* We wrap our model with {py:func}`train.torch.prepare_model <ray.train.torch.prepare_model>`.\n",
"* We call {py:func}`train.get_dataset_shard <ray.train.get_dataset_shard>` and {py:meth}`Dataset.to_torch <ray.data.Dataset.to_torch>` to convert a subset of our training data to a Torch dataset.\n",
"* We save model state using {py:func}`train.save_checkpoint <ray.train.save_checkpoint>`."
"* We call {py:func}`session.get_dataset_shard <ray.air.session.get_dataset_shard>` and {py:meth}`Dataset.to_torch <ray.data.Dataset.to_torch>` to convert a subset of our training data to a Torch dataset.\n",
"* We save model state using {py:func}`session.report <ray.air.session.report>`."
]
},
{
Expand All @@ -265,6 +265,7 @@
"outputs": [],
"source": [
"from ray import train\n",
"from ray.air import session, Checkpoint\n",
"import torch.optim as optim\n",
"\n",
"\n",
Expand All @@ -274,7 +275,7 @@
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
"\n",
" train_dataset_shard: torch.utils.data.Dataset = train.get_dataset_shard(\"train\").to_torch(\n",
" train_dataset_shard: torch.utils.data.Dataset = session.get_dataset_shard(\"train\").to_torch(\n",
" feature_columns=[\"image\"],\n",
" label_column=\"label\",\n",
" batch_size=config[\"batch_size\"],\n",
Expand Down Expand Up @@ -303,7 +304,10 @@
" print(f\"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}\")\n",
" running_loss = 0.0\n",
"\n",
" train.save_checkpoint(model=model.module.state_dict())"
" session.report(\n",
" dict(running_loss=running_loss),\n",
" checkpoint=Checkpoint.from_dict(dict(model=model.module.state_dict())),\n",
" )"
]
},
{
Expand Down
Loading