Skip to content

Commit

Permalink
[Datasets] Add Dataset.default_batch_format (#28434)
Browse files Browse the repository at this point in the history
Participants in the PyTorch UX study couldn't understand how the "native" batch format works. This PR introduces a method Dataset.native_batch_format that tells users exactly what the native batch format is, so users don't have to guess.
  • Loading branch information
bveeramani authored Sep 19, 2022
1 parent c8bfd1a commit 206e847
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 8 deletions.
19 changes: 11 additions & 8 deletions doc/source/data/api/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Dataset API

ray.data.Dataset.count
ray.data.Dataset.schema
ray.data.Dataset.default_batch_format
ray.data.Dataset.num_blocks
ray.data.Dataset.size_bytes
ray.data.Dataset.input_files
Expand Down Expand Up @@ -256,30 +257,32 @@ Inspecting Metadata

.. automethod:: ray.data.Dataset.schema

.. automethod:: ray.data.Dataset.default_batch_format

.. automethod:: ray.data.Dataset.num_blocks

.. automethod:: ray.data.Dataset.size_bytes

.. automethod:: ray.data.Dataset.input_files

.. automethod:: ray.data.Dataset.stats

.. automethod:: ray.data.Dataset.get_internal_block_refs

Execution
---------

.. automethod:: ray.data.Dataset.fully_executed

.. automethod:: ray.data.Dataset.is_fully_executed

.. automethod:: ray.data.Dataset.lazy

Serialization
-------------

.. automethod:: ray.data.Dataset.has_serializable_lineage

.. automethod:: ray.data.Dataset.serialize_lineage
.. automethod:: ray.data.Dataset.deserialize_lineage

.. automethod:: ray.data.Dataset.deserialize_lineage
78 changes: 78 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Iterable,
Iterator,
List,
Type,
Optional,
Tuple,
Union,
Expand All @@ -39,6 +40,7 @@
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.util import _estimate_available_parallelism
from ray.data._internal.pandas_block import PandasBlockSchema
from ray.data._internal.plan import (
ExecutionPlan,
OneToOneStage,
Expand Down Expand Up @@ -3572,6 +3574,82 @@ def _divide(self, block_idx: int) -> ("Dataset[T]", "Dataset[T]"):
)
return l_ds, r_ds

def default_batch_format(self) -> Type:
"""Return this dataset's default batch format.
The default batch format describes what batches of data look like. To learn more
about batch formats, read
:ref:`writing user-defined functions <transform_datasets_writing_udfs>`.
Example:
If your dataset represents a list of Python objects, then the default batch
format is ``list``.
>>> ds = ray.data.range(100)
>>> ds # doctest: +SKIP
Dataset(num_blocks=20, num_rows=100, schema=<class 'int'>)
>>> ds.default_batch_format()
<class 'list'>
>>> next(ds.iter_batches(batch_size=4))
[0, 1, 2, 3]
If your dataset contains a single ``TensorDtype`` or ``ArrowTensorType``
column named ``__value__`` (as created by :func:`ray.data.from_numpy`), then
the default batch format is ``np.ndarray``. For more information on tensor
datasets, read the :ref:`tensor support guide <datasets_tensor_support>`.
>>> ds = ray.data.range_tensor(100)
>>> ds # doctest: +SKIP
Dataset(num_blocks=20, num_rows=100, schema={__value__: ArrowTensorType(shape=(1,), dtype=int64)})
>>> ds.default_batch_format()
<class 'numpy.ndarray'>
>>> next(ds.iter_batches(batch_size=4))
array([[0],
[1],
[2],
[3]])
If your dataset represents tabular data and doesn't only consist of a
``__value__`` tensor column (such as is created by
:meth:`ray.data.from_numpy`), then the default batch format is
``pd.DataFrame``.
>>> import pandas as pd
>>> df = pd.DataFrame({"foo": ["a", "b"], "bar": [0, 1]})
>>> ds = ray.data.from_pandas(df)
>>> ds # doctest: +SKIP
Dataset(num_blocks=1, num_rows=2, schema={foo: object, bar: int64})
>>> ds.default_batch_format()
<class 'pandas.core.frame.DataFrame'>
>>> next(ds.iter_batches(batch_size=4))
foo bar
0 a 0
1 b 1
.. seealso::
:meth:`~Dataset.map_batches`
Call this function to transform batches of data.
:meth:`~Dataset.iter_batches`
Call this function to iterate over batches of data.
""" # noqa: E501
import pandas as pd
import pyarrow as pa

schema = self.schema()
assert isinstance(schema, (type, PandasBlockSchema, pa.Schema))

if isinstance(schema, type):
return list

if isinstance(schema, (PandasBlockSchema, pa.Schema)):
if schema.names == [VALUE_COL_NAME]:
return np.ndarray
return pd.DataFrame

def _dataset_format(self) -> str:
"""Determine the format of the dataset. Possible values are: "arrow",
"pandas", "simple".
Expand Down
12 changes: 12 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4912,6 +4912,18 @@ def f(x):
), "Number of actors is out of the expected bound"


def test_default_batch_format(shutdown_only):
ds = ray.data.range(100)
assert ds.default_batch_format() == list

ds = ray.data.range_tensor(100)
assert ds.default_batch_format() == np.ndarray

df = pd.DataFrame({"foo": ["a", "b"], "bar": [0, 1]})
ds = ray.data.from_pandas(df)
assert ds.default_batch_format() == pd.DataFrame


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 206e847

Please sign in to comment.