Skip to content

Commit

Permalink
[train] add a utility function to turn off TF autosharding (#21887)
Browse files Browse the repository at this point in the history
This PR adds a utility function to turn off TF autosharding as a temporary solution.

Closes #19324.
  • Loading branch information
jwyyy authored Jan 29, 2022
1 parent fe1bf02 commit eb8adc6
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 16 deletions.
10 changes: 9 additions & 1 deletion doc/source/train/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,12 @@ train.torch.prepare_data_loader
train.torch.get_device
~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: ray.train.torch.get_device
.. autofunction:: ray.train.torch.get_device

TensorFlow Training Function Utilities
--------------------------------------

train.tensorflow.prepare_dataset_shard
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: ray.train.tensorflow.prepare_dataset_shard
10 changes: 5 additions & 5 deletions doc/source/train/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -945,14 +945,14 @@ To get started, pass in a Ray Dataset (or multiple) into ``Trainer.run``. Undern
already sharded.

.. code-block:: python
:emphasize-lines: 1, 6
from ray.train.tensorflow import prepare_dataset_shard
def train_func():
...
tf_dataset = ray.train.get_dataset_shard().to_tf()
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = \
tf.data.experimental.AutoShardPolicy.OFF
tf_dataset = tf_dataset.with_options(options)
tf_dataset = ray.train.get_dataset_shard().to_tf(...)
tf_dataset = prepare_dataset_shard(tf_dataset)
**Simple Dataset Example**
Expand Down
11 changes: 1 addition & 10 deletions python/ray/train/examples/tensorflow_linear_dataset_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.data import Dataset
from ray.data.dataset_pipeline import DatasetPipeline
from ray.train import Trainer
from ray.train.tensorflow import prepare_dataset_shard


class TrainReportCallback(Callback):
Expand All @@ -31,16 +32,6 @@ def get_dataset(a, b, size) -> Dataset:
return dataset_pipeline


def prepare_dataset_shard(dataset_shard: tf.data.Dataset):
# Disable Tensorflow autosharding since the dataset has already been
# sharded.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = \
tf.data.experimental.AutoShardPolicy.OFF
dataset = dataset_shard.with_options(options)
return dataset


def build_and_compile_model(config):
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(1, )),
Expand Down
22 changes: 22 additions & 0 deletions python/ray/train/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from ray.train.worker_group import WorkerGroup
from ray.util import PublicAPI

import tensorflow as tf

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -79,3 +81,23 @@ def handle_failure(self, worker_group: WorkerGroup,
worker_group.execute(shutdown_session)
worker_group.add_workers(len(failed_worker_indexes))
self.on_start(worker_group, backend_config)


@PublicAPI(stability="beta")
def prepare_dataset_shard(tf_dataset_shard: tf.data.Dataset):
""" A utility function that disables Tensorflow autosharding.
This should be used on a TensorFlow ``Dataset`` created by calling ``to_tf()``
on a ``ray.data.Dataset`` returned by ``ray.train.get_dataset_shard()`` since
the dataset has already been sharded across the workers.
Args:
tf_dataset_shard (tf.data.Dataset): A TensorFlow Dataset.
Returns:
A TensorFlow Dataset with autosharding turned off.
"""
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = \
tf.data.experimental.AutoShardPolicy.OFF
return tf_dataset_shard.with_options(options)

0 comments on commit eb8adc6

Please sign in to comment.