Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] [Tensor Story - 1/2] Automatically provide tensor views to UDFs and infer tensor blocks for pure-tensor datasets. #24812

Merged
57 changes: 46 additions & 11 deletions doc/source/data/dataset-tensor-support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,57 @@ Automatic conversion between the Pandas and Arrow extension types/arrays keeps t
Single-column tensor datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The most basic case is when a dataset only has a single column, which is of tensor type. This kind of dataset can be created with ``.range_tensor()``, and can be read from and written to ``.npy`` files. Here are some examples:
The most basic case is when a dataset only has a single column, which is of tensor
type. This kind of dataset can be:

.. code-block:: python
* created with :func:`range_tensor() <ray.data.range_tensor>`
or :func:`from_numpy() <ray.data.from_numpy>`,
* transformed with NumPy UDFs via
:meth:`ds.map_batches() <ray.data.Dataset.map_batches>`,
* consumed with :meth:`ds.iter_rows() <ray.data.Dataset.iter_rows>` and
:meth:`ds.iter_batches() <ray.data.Dataset.iter_batches>`, and
* can be read from and written to ``.npy`` files.

# Create a Dataset of tensor-typed values.
ds = ray.data.range_tensor(10000, shape=(3, 5))
# -> Dataset(num_blocks=200, num_rows=10000,
# schema={value: <ArrowTensorType: shape=(3, 5), dtype=int64>})
Here is an end-to-end example:

# Save to storage.
ds.write_numpy("/tmp/tensor_out", column="value")
.. code-block:: python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use literalinclude?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I matched the existing inline code blocks style to keep the PR focused and keep the diff small, hoping to do a port all of the code examples in this feature guide in a Working with Tensors feature guide overhaul.


# Read from storage.
# Create a synthetic pure-tensor Dataset.
ds = ray.data.range_tensor(10, shape=(3, 5))
# -> Dataset(num_blocks=10, num_rows=10,
# schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=int64>})

# Create a pure-tensor Dataset from an existing NumPy ndarray.
arr = np.arange(10 * 3 * 5).reshape((10, 3, 5))
ds = ray.data.from_numpy(arr)
# -> Dataset(num_blocks=1, num_rows=10,
# schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=int64>})

# Transform the tensors. Datasets will automatically unpack the single-column Arrow
# table into a NumPy ndarray, provide that ndarray to your UDF, and then repack it
# into a single-column Arrow table; this will be a zero-copy conversion in both
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean this ndarray is a view on Arrow table, rather than a new memory instance? If not I don't think this is actually zero copy.

Copy link
Contributor Author

@clarkzinzow clarkzinzow May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's zero-copy on the underlying array buffers; the conversion only involves switching to a different view on top of those array buffers (our single-tensor-column Arrow table vs. a NumPy ndarray).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG!

# cases.
ds = ds.map_batches(lambda arr: arr / arr.max())
# -> Dataset(num_blocks=1, num_rows=10,
# schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=double>})

# Consume the tensor. This will yield the underlying (3, 5) ndarrays.
for arr in ds.iter_rows():
assert isinstance(arr, np.ndarray)
assert arr.shape == (3, 5)

# Consume the tensor in batches.
for arr in ds.iter_batches(batch_size=2):
assert isinstance(arr, np.ndarray)
assert arr.shape == (2, 3, 5)

# Save to storage. This will write out the blocks of the tensor column as NPY files.
ds.write_numpy("/tmp/tensor_out")

# Read back from storage.
ray.data.read_numpy("/tmp/tensor_out")
# -> Dataset(num_blocks=200, num_rows=?,
# schema={value: <ArrowTensorType: shape=(3, 5), dtype=int64>})
# -> Dataset(num_blocks=1, num_rows=?,
# schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=double>})

Reading existing serialized tensor columns
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
26 changes: 23 additions & 3 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
TypeVar,
List,
Dict,
Generic,
Iterator,
Tuple,
Expand Down Expand Up @@ -82,6 +83,10 @@ def _validate_key_fn(ds: "Dataset", key: KeyFn) -> None:
# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``.
Block = Union[List[T], "pyarrow.Table", "pandas.DataFrame", bytes]

# User-facing data batch type. This is the data type for data that is supplied to and
# returned from batch UDFs.
DataBatch = Union[Block, np.ndarray]

# A list of block references pending computation by a single task. For example,
# this may be the output of a task reading a file.
BlockPartition = List[Tuple[ObjectRef[Block], "BlockMetadata"]]
Expand Down Expand Up @@ -210,11 +215,13 @@ def to_pandas(self) -> "pandas.DataFrame":
"""Convert this block into a Pandas dataframe."""
raise NotImplementedError

def to_numpy(self, column: str = None) -> np.ndarray:
"""Convert this block (or column of block) into a NumPy ndarray.
def to_numpy(
self, columns: Optional[Union[str, List[str]]] = None
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""Convert this block (or columns of block) into a NumPy ndarray.

Args:
column: Name of column to convert, or None.
columns: Name of columns to convert, or None if converting all columns.
"""
raise NotImplementedError

Expand All @@ -226,6 +233,10 @@ def to_block(self) -> Block:
"""Return the base block that this accessor wraps."""
raise NotImplementedError

def to_native(self) -> Block:
"""Return the native data format for this accessor."""
return self.to_block()

def size_bytes(self) -> int:
"""Return the approximate size in bytes of this block."""
raise NotImplementedError
Expand Down Expand Up @@ -255,6 +266,15 @@ def builder() -> "BlockBuilder[T]":
"""Create a builder for this block type."""
raise NotImplementedError

@staticmethod
def batch_to_block(batch: DataBatch) -> Block:
"""Create a block from user-facing data formats."""
if isinstance(batch, np.ndarray):
from ray.data.impl.arrow_block import ArrowBlockAccessor

return ArrowBlockAccessor.numpy_to_block(batch)
return batch

@staticmethod
def for_block(block: Block) -> "BlockAccessor[T]":
"""Create a block accessor for the given block."""
Expand Down
40 changes: 23 additions & 17 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from ray.data.row import TableRow
from ray.data.aggregate import AggregateFn, Sum, Max, Min, Mean, Std
from ray.data.random_access_dataset import RandomAccessDataset
from ray.data.impl.table_block import VALUE_COL_NAME
from ray.data.impl.remote_fn import cached_remote_fn
from ray.data.impl.block_batching import batch_blocks, BatchType
from ray.data.impl.plan import ExecutionPlan, OneToOneStage, AllToAllStage
Expand Down Expand Up @@ -198,8 +199,8 @@ def map(

def transform(block: Block) -> Iterable[Block]:
DatasetContext._set_current(context)
block = BlockAccessor.for_block(block)
output_buffer = BlockOutputBuffer(None, context.target_max_block_size)
block = BlockAccessor.for_block(block)
for row in block.iter_rows():
output_buffer.add(fn(row))
if output_buffer.has_next():
Expand All @@ -224,6 +225,9 @@ def map_batches(
) -> "Dataset[Any]":
"""Apply the given function to batches of records of this dataset.

The format of the data batch provided to ``fn`` can be controlled via the
``batch_format`` argument, and the output of the UDF can be any batch type.

This is a blocking operation.

Examples:
Expand Down Expand Up @@ -270,10 +274,9 @@ def map_batches(
blocks as batches. Defaults to a system-chosen batch size.
compute: The compute strategy, either "tasks" (default) to use Ray
tasks, or ActorPoolStrategy(min, max) to use an autoscaling actor pool.
batch_format: Specify "native" to use the native block format
(promotes Arrow to pandas), "pandas" to select
``pandas.DataFrame`` as the batch format,
or "pyarrow" to select ``pyarrow.Table``.
batch_format: Specify "native" to use the native block format (promotes
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
tables to Pandas and tensors to NumPy), "pandas" to select
``pandas.DataFrame``, or "pyarrow" to select `pyarrow.Table``.
ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
"""
Expand Down Expand Up @@ -302,9 +305,7 @@ def transform(block: Block) -> Iterable[Block]:
# bug where we include the entire base view on serialization.
view = block.slice(start, end, copy=batch_size is not None)
if batch_format == "native":
# Always promote Arrow blocks to pandas for consistency.
if isinstance(view, pa.Table) or isinstance(view, bytes):
view = BlockAccessor.for_block(view).to_pandas()
view = BlockAccessor.for_block(view).to_native()
elif batch_format == "pandas":
view = BlockAccessor.for_block(view).to_pandas()
elif batch_format == "pyarrow":
Expand All @@ -319,6 +320,7 @@ def transform(block: Block) -> Iterable[Block]:
if not (
isinstance(applied, list)
or isinstance(applied, pa.Table)
or isinstance(applied, np.ndarray)
or isinstance(applied, pd.core.frame.DataFrame)
):
raise ValueError(
Expand All @@ -328,7 +330,7 @@ def transform(block: Block) -> Iterable[Block]:
"The return type must be either list, "
"pandas.DataFrame, or pyarrow.Table"
)
output_buffer.add_block(applied)
output_buffer.add_batch(applied)
if output_buffer.has_next():
yield output_buffer.next()

Expand Down Expand Up @@ -667,6 +669,8 @@ def process_batch(batch):
)
if isinstance(batch, pd.DataFrame):
return batch.sample(frac=fraction)
if isinstance(batch, np.ndarray):
return np.array([row for row in batch if random.random() <= fraction])
raise ValueError(f"Unsupported batch type: {type(batch)}")

return self.map_batches(process_batch)
Expand Down Expand Up @@ -2037,7 +2041,7 @@ def write_numpy(
self,
path: str,
*,
column: str = "value",
column: str = VALUE_COL_NAME,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
try_create_dir: bool = True,
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -2065,7 +2069,8 @@ def write_numpy(
path: The path to the destination root directory, where npy
files will be written to.
column: The name of the table column that contains the tensor to
be written. This defaults to "value".
be written. The default is ``"__value__"``, the column name that
Datasets uses for storing tensors in single-column tables.
filesystem: The filesystem implementation to write to.
try_create_dir: Try to create all directories in destination path
if True. Does nothing if all directories already exist.
Expand Down Expand Up @@ -2212,10 +2217,10 @@ def iter_batches(
current block during the scan.
batch_size: Record batch size, or None to let the system pick.
batch_format: The format in which to return each batch.
Specify "native" to use the current block format (promoting
Arrow to pandas automatically), "pandas" to
select ``pandas.DataFrame`` or "pyarrow" to select
``pyarrow.Table``. Default is "native".
Specify "native" to use the native block format (promoting
tables to Pandas and tensors to NumPy), "pandas" to select
``pandas.DataFrame``, or "pyarrow" to select ``pyarrow.Table``. Default
is "native".
drop_last: Whether to drop the last batch if it's incomplete.

Returns:
Expand Down Expand Up @@ -2737,8 +2742,9 @@ def to_numpy_refs(
Time complexity: O(dataset size / parallelism)

Args:
column: The name of the column to convert to numpy, or None to
specify the entire row. Required for Arrow tables.
column: The name of the column to convert to numpy, or None to specify the
entire row. If not specified for Arrow or Pandas blocks, each returned
future will represent a dict of column ndarrays.

Returns:
A list of remote NumPy ndarrays created from this dataset.
Expand Down
21 changes: 7 additions & 14 deletions python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,11 @@ def make_block(start: int, count: int) -> Block:
elif block_format == "tensor":
import pyarrow as pa

tensor = TensorArray(
np.ones(tensor_shape, dtype=np.int64)
* np.expand_dims(
np.arange(start, start + count),
tuple(range(1, 1 + len(tensor_shape))),
)
tensor = np.ones(tensor_shape, dtype=np.int64) * np.expand_dims(
np.arange(start, start + count),
tuple(range(1, 1 + len(tensor_shape))),
)
return pa.Table.from_pydict({"value": tensor})
return BlockAccessor.batch_to_block(tensor)
else:
return list(builtins.range(start, start + count))

Expand All @@ -213,16 +210,12 @@ def make_block(start: int, count: int) -> Block:
schema = pa.Table.from_pydict({"value": [0]}).schema
elif block_format == "tensor":
_check_pyarrow_version()
from ray.data.extensions import TensorArray
import pyarrow as pa

tensor = TensorArray(
np.ones(tensor_shape, dtype=np.int64)
* np.expand_dims(
np.arange(0, 10), tuple(range(1, 1 + len(tensor_shape)))
)
tensor = np.ones(tensor_shape, dtype=np.int64) * np.expand_dims(
np.arange(0, 10), tuple(range(1, 1 + len(tensor_shape)))
)
schema = pa.Table.from_pydict({"value": tensor}).schema
schema = BlockAccessor.batch_to_block(tensor).schema
elif block_format == "list":
schema = int
else:
Expand Down
7 changes: 1 addition & 6 deletions python/ray/data/datasource/numpy_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,13 @@ class NumpyDatasource(FileBasedDatasource):
"""

def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args):
from ray.data.extensions import TensorArray
import pyarrow as pa

# TODO(ekl) Ideally numpy can read directly from the file, but it
# seems like it requires the file to be seekable.
buf = BytesIO()
data = f.readall()
buf.write(data)
buf.seek(0)
return pa.Table.from_pydict(
{"value": TensorArray(np.load(buf, allow_pickle=True))}
)
return BlockAccessor.batch_to_block(np.load(buf, allow_pickle=True))

def _write_block(
self,
Expand Down
Loading