From 206e847694cba414dc4664e4ae02b20e10e3f25d Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 19 Sep 2022 18:19:26 -0500 Subject: [PATCH] [Datasets] Add `Dataset.default_batch_format` (#28434) Participants in the PyTorch UX study couldn't understand how the "native" batch format works. This PR introduces a method Dataset.native_batch_format that tells users exactly what the native batch format is, so users don't have to guess. --- doc/source/data/api/dataset.rst | 19 ++++--- python/ray/data/dataset.py | 78 +++++++++++++++++++++++++++ python/ray/data/tests/test_dataset.py | 12 +++++ 3 files changed, 101 insertions(+), 8 deletions(-) diff --git a/doc/source/data/api/dataset.rst b/doc/source/data/api/dataset.rst index e5622063b528..69777c953e90 100644 --- a/doc/source/data/api/dataset.rst +++ b/doc/source/data/api/dataset.rst @@ -104,6 +104,7 @@ Dataset API ray.data.Dataset.count ray.data.Dataset.schema + ray.data.Dataset.default_batch_format ray.data.Dataset.num_blocks ray.data.Dataset.size_bytes ray.data.Dataset.input_files @@ -256,6 +257,8 @@ Inspecting Metadata .. automethod:: ray.data.Dataset.schema +.. automethod:: ray.data.Dataset.default_batch_format + .. automethod:: ray.data.Dataset.num_blocks .. automethod:: ray.data.Dataset.size_bytes @@ -263,23 +266,23 @@ Inspecting Metadata .. automethod:: ray.data.Dataset.input_files .. automethod:: ray.data.Dataset.stats - + .. automethod:: ray.data.Dataset.get_internal_block_refs Execution --------- - + .. automethod:: ray.data.Dataset.fully_executed - + .. automethod:: ray.data.Dataset.is_fully_executed - + .. automethod:: ray.data.Dataset.lazy Serialization ------------- - + .. automethod:: ray.data.Dataset.has_serializable_lineage - + .. automethod:: ray.data.Dataset.serialize_lineage - -.. automethod:: ray.data.Dataset.deserialize_lineage \ No newline at end of file + +.. automethod:: ray.data.Dataset.deserialize_lineage diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index e214bca8585d..14b3e89078d6 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -14,6 +14,7 @@ Iterable, Iterator, List, + Type, Optional, Tuple, Union, @@ -39,6 +40,7 @@ from ray.data._internal.lazy_block_list import LazyBlockList from ray.data._internal.output_buffer import BlockOutputBuffer from ray.data._internal.util import _estimate_available_parallelism +from ray.data._internal.pandas_block import PandasBlockSchema from ray.data._internal.plan import ( ExecutionPlan, OneToOneStage, @@ -3572,6 +3574,82 @@ def _divide(self, block_idx: int) -> ("Dataset[T]", "Dataset[T]"): ) return l_ds, r_ds + def default_batch_format(self) -> Type: + """Return this dataset's default batch format. + + The default batch format describes what batches of data look like. To learn more + about batch formats, read + :ref:`writing user-defined functions `. + + Example: + + If your dataset represents a list of Python objects, then the default batch + format is ``list``. + + >>> ds = ray.data.range(100) + >>> ds # doctest: +SKIP + Dataset(num_blocks=20, num_rows=100, schema=) + >>> ds.default_batch_format() + + >>> next(ds.iter_batches(batch_size=4)) + [0, 1, 2, 3] + + If your dataset contains a single ``TensorDtype`` or ``ArrowTensorType`` + column named ``__value__`` (as created by :func:`ray.data.from_numpy`), then + the default batch format is ``np.ndarray``. For more information on tensor + datasets, read the :ref:`tensor support guide `. + + >>> ds = ray.data.range_tensor(100) + >>> ds # doctest: +SKIP + Dataset(num_blocks=20, num_rows=100, schema={__value__: ArrowTensorType(shape=(1,), dtype=int64)}) + >>> ds.default_batch_format() + + >>> next(ds.iter_batches(batch_size=4)) + array([[0], + [1], + [2], + [3]]) + + If your dataset represents tabular data and doesn't only consist of a + ``__value__`` tensor column (such as is created by + :meth:`ray.data.from_numpy`), then the default batch format is + ``pd.DataFrame``. + + >>> import pandas as pd + >>> df = pd.DataFrame({"foo": ["a", "b"], "bar": [0, 1]}) + >>> ds = ray.data.from_pandas(df) + >>> ds # doctest: +SKIP + Dataset(num_blocks=1, num_rows=2, schema={foo: object, bar: int64}) + >>> ds.default_batch_format() + + >>> next(ds.iter_batches(batch_size=4)) + foo bar + 0 a 0 + 1 b 1 + + .. seealso:: + + :meth:`~Dataset.map_batches` + Call this function to transform batches of data. + + :meth:`~Dataset.iter_batches` + Call this function to iterate over batches of data. + + """ # noqa: E501 + import pandas as pd + import pyarrow as pa + + schema = self.schema() + assert isinstance(schema, (type, PandasBlockSchema, pa.Schema)) + + if isinstance(schema, type): + return list + + if isinstance(schema, (PandasBlockSchema, pa.Schema)): + if schema.names == [VALUE_COL_NAME]: + return np.ndarray + return pd.DataFrame + def _dataset_format(self) -> str: """Determine the format of the dataset. Possible values are: "arrow", "pandas", "simple". diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 8bbbfd407b3c..e143ded95611 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -4912,6 +4912,18 @@ def f(x): ), "Number of actors is out of the expected bound" +def test_default_batch_format(shutdown_only): + ds = ray.data.range(100) + assert ds.default_batch_format() == list + + ds = ray.data.range_tensor(100) + assert ds.default_batch_format() == np.ndarray + + df = pd.DataFrame({"foo": ["a", "b"], "bar": [0, 1]}) + ds = ray.data.from_pandas(df) + assert ds.default_batch_format() == pd.DataFrame + + if __name__ == "__main__": import sys