diff --git a/doc/source/train/api.rst b/doc/source/train/api.rst index 470e22919172..9dbf26385a45 100644 --- a/doc/source/train/api.rst +++ b/doc/source/train/api.rst @@ -173,4 +173,12 @@ train.torch.prepare_data_loader train.torch.get_device ~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: ray.train.torch.get_device \ No newline at end of file +.. autofunction:: ray.train.torch.get_device + +TensorFlow Training Function Utilities +-------------------------------------- + +train.tensorflow.prepare_dataset_shard +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: ray.train.tensorflow.prepare_dataset_shard \ No newline at end of file diff --git a/doc/source/train/user_guide.rst b/doc/source/train/user_guide.rst index cfe8bc1a24a6..a9c4cc291a9f 100644 --- a/doc/source/train/user_guide.rst +++ b/doc/source/train/user_guide.rst @@ -945,14 +945,14 @@ To get started, pass in a Ray Dataset (or multiple) into ``Trainer.run``. Undern already sharded. .. code-block:: python + :emphasize-lines: 1, 6 + + from ray.train.tensorflow import prepare_dataset_shard def train_func(): ... - tf_dataset = ray.train.get_dataset_shard().to_tf() - options = tf.data.Options() - options.experimental_distribute.auto_shard_policy = \ - tf.data.experimental.AutoShardPolicy.OFF - tf_dataset = tf_dataset.with_options(options) + tf_dataset = ray.train.get_dataset_shard().to_tf(...) + tf_dataset = prepare_dataset_shard(tf_dataset) **Simple Dataset Example** diff --git a/python/ray/train/examples/tensorflow_linear_dataset_example.py b/python/ray/train/examples/tensorflow_linear_dataset_example.py index 2ae1f5c24d66..8f1871cc09f0 100644 --- a/python/ray/train/examples/tensorflow_linear_dataset_example.py +++ b/python/ray/train/examples/tensorflow_linear_dataset_example.py @@ -8,6 +8,7 @@ from ray.data import Dataset from ray.data.dataset_pipeline import DatasetPipeline from ray.train import Trainer +from ray.train.tensorflow import prepare_dataset_shard class TrainReportCallback(Callback): @@ -31,16 +32,6 @@ def get_dataset(a, b, size) -> Dataset: return dataset_pipeline -def prepare_dataset_shard(dataset_shard: tf.data.Dataset): - # Disable Tensorflow autosharding since the dataset has already been - # sharded. - options = tf.data.Options() - options.experimental_distribute.auto_shard_policy = \ - tf.data.experimental.AutoShardPolicy.OFF - dataset = dataset_shard.with_options(options) - return dataset - - def build_and_compile_model(config): model = tf.keras.Sequential([ tf.keras.layers.InputLayer(input_shape=(1, )), diff --git a/python/ray/train/tensorflow.py b/python/ray/train/tensorflow.py index e2f05c739808..56c9be880a32 100644 --- a/python/ray/train/tensorflow.py +++ b/python/ray/train/tensorflow.py @@ -11,6 +11,8 @@ from ray.train.worker_group import WorkerGroup from ray.util import PublicAPI +import tensorflow as tf + logger = logging.getLogger(__name__) @@ -79,3 +81,23 @@ def handle_failure(self, worker_group: WorkerGroup, worker_group.execute(shutdown_session) worker_group.add_workers(len(failed_worker_indexes)) self.on_start(worker_group, backend_config) + + +@PublicAPI(stability="beta") +def prepare_dataset_shard(tf_dataset_shard: tf.data.Dataset): + """ A utility function that disables Tensorflow autosharding. + + This should be used on a TensorFlow ``Dataset`` created by calling ``to_tf()`` + on a ``ray.data.Dataset`` returned by ``ray.train.get_dataset_shard()`` since + the dataset has already been sharded across the workers. + + Args: + tf_dataset_shard (tf.data.Dataset): A TensorFlow Dataset. + + Returns: + A TensorFlow Dataset with autosharding turned off. + """ + options = tf.data.Options() + options.experimental_distribute.auto_shard_policy = \ + tf.data.experimental.AutoShardPolicy.OFF + return tf_dataset_shard.with_options(options)