Skip to content

Commit

Permalink
[Datasets] Add .iter_torch_batches() and .iter_tf_batches() APIs. (
Browse files Browse the repository at this point in the history
…#26689)

This PR adds .iter_torch_batches() and .iter_tf_batches() convenience APIs, which takes care of ML framework tensor conversion, the narrow tensor waste for the .iter_batches() call ("numpy" format), and unifies batch formats around two options: a single tensor for simple/pure-tensor/single-column datasets, and a dictionary of tensors for multi-column datasets.
  • Loading branch information
clarkzinzow authored Jul 22, 2022
1 parent 1fd2913 commit a29baf9
Show file tree
Hide file tree
Showing 14 changed files with 436 additions and 110 deletions.
4 changes: 3 additions & 1 deletion doc/source/ray-air/doc_code/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.InputLayer(input_shape=()),
# Add feature dimension, expanding (batch_size,) to (batch_size, 1).
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1),
]
)
Expand Down
4 changes: 3 additions & 1 deletion doc/source/ray-air/doc_code/tf_starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.InputLayer(input_shape=()),
# Add feature dimension, expanding (batch_size,) to (batch_size, 1).
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10),
tf.keras.layers.Dense(1),
]
Expand Down
39 changes: 38 additions & 1 deletion python/ray/air/_internal/tensorflow_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union, Dict

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -81,3 +81,40 @@ def tensorize(series):
return tf.expand_dims(concatenated_tensor, axis=1)

return concatenated_tensor


def convert_ndarray_batch_to_tf_tensor_batch(
ndarrays: Union[np.ndarray, Dict[str, np.ndarray]],
dtypes: Optional[Union[tf.dtypes.DType, Dict[str, tf.dtypes.DType]]] = None,
) -> Union[tf.Tensor, Dict[str, tf.Tensor]]:
"""Convert a NumPy ndarray batch to a TensorFlow Tensor batch.
Args:
ndarray: A (dict of) NumPy ndarray(s) that we wish to convert to a TensorFlow
Tensor.
dtype: A (dict of) TensorFlow dtype(s) for the created tensor; if None, the
dtype will be inferred from the NumPy ndarray data.
Returns: A (dict of) TensorFlow Tensor(s).
"""
if isinstance(ndarrays, np.ndarray):
# Single-tensor case.
if isinstance(dtypes, dict):
if len(dtypes) != 1:
raise ValueError(
"When constructing a single-tensor batch, only a single dtype "
f"should be given, instead got: {dtypes}"
)
dtypes = next(iter(dtypes.values()))
batch = tf.convert_to_tensor(ndarrays, dtype=dtypes)
else:
# Multi-tensor case.
batch = {
col_name: tf.convert_to_tensor(
col_ndarray,
dtype=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
)
for col_name, col_ndarray in ndarrays.items()
}

return batch
41 changes: 41 additions & 0 deletions python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, List, Optional, Union

import numpy as np
import pandas as pd
import torch

Expand Down Expand Up @@ -101,6 +102,46 @@ def get_tensor_for_columns(columns, dtype):
return get_tensor_for_columns(columns=columns, dtype=column_dtypes)


def convert_ndarray_batch_to_torch_tensor_batch(
ndarrays: Union[np.ndarray, Dict[str, np.ndarray]],
dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
device: Optional[str] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Convert a NumPy ndarray batch to a Torch Tensor batch.
Args:
ndarray: A (dict of) NumPy ndarray(s) that we wish to convert to a Torch Tensor.
dtype: A (dict of) Torch dtype(s) for the created tensor; if None, the dtype
will be inferred from the NumPy ndarray data.
device: The device on which the tensor(s) should be placed; if None, the Torch
tensor(s) will be constructed on the CPU.
Returns: A (dict of) Torch Tensor(s).
"""
if isinstance(ndarrays, np.ndarray):
# Single-tensor case.
if isinstance(dtypes, dict):
if len(dtypes) != 1:
raise ValueError(
"When constructing a single-tensor batch, only a single dtype "
f"should be given, instead got: {dtypes}"
)
dtypes = next(iter(dtypes.values()))
batch = torch.as_tensor(ndarrays, dtype=dtypes, device=device)
else:
# Multi-tensor case.
batch = {
col_name: torch.as_tensor(
col_ndarray,
dtype=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
device=device,
)
for col_name, col_ndarray in ndarrays.items()
}

return batch


def load_torch_model(
saved_model: Union[torch.nn.Module, Dict],
model_definition: Optional[torch.nn.Module] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def get_dataset(a=5, b=10, size=1000) -> Dataset:
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.InputLayer(input_shape=()),
# Add feature dimension, expanding (batch_size,) to (batch_size, 1).
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10),
tf.keras.layers.Dense(1),
]
Expand Down Expand Up @@ -72,10 +74,11 @@ def train_func(config: dict):
def train_tensorflow_linear(num_workers: int = 2, use_gpu: bool = False) -> Result:
dataset_pipeline = get_dataset()
config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config=config,
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
scaling_config=scaling_config,
datasets={"train": dataset_pipeline},
)
results = trainer.fit()
Expand Down
185 changes: 154 additions & 31 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@
"tf.TypeSpec", List["tf.TypeSpec"], Dict[str, "tf.TypeSpec"]
]

TorchTensorBatchType = Union["torch.Tensor", Dict[str, "torch.Tensor"]]
TensorFlowTensorBatchType = Union["tf.Tensor", Dict[str, "tf.Tensor"]]


@PublicAPI
class Dataset(Generic[T]):
Expand Down Expand Up @@ -2377,19 +2380,11 @@ def iter_batches(
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
minimum number of rows that must be in the local in-memory shuffle
buffer in order to yield a batch. This is a light-weight alternative to
the global `.random_shuffle()` operation; this shuffle will be less
random but will be faster and less resource-intensive. This buffer size
must be greater than or equal to ``batch_size``, and therefore
``batch_size`` must also be specified when using local shuffling.
When there are no more rows to be added to the buffer, the number of
rows in the buffer *will* decrease below this value while yielding
the remaining batches, and the final batch may have less than
``batch_size`` rows. Increasing this will improve the randomness of
the shuffle but will increase CPU memory utilization and the latency
to the first batch. The CPU memory utilization ceiling is the max of
the prefetch buffer size (controlled by ``prefetch_blocks``) and
this shuffle buffer size.
buffer in order to yield a batch. When there are no more rows to add to
the buffer, the remaining rows in the buffer will be drained. This
buffer size must be greater than or equal to ``batch_size``, and
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
Returns:
Expand All @@ -2413,6 +2408,142 @@ def iter_batches(

stats.iter_total_s.add(time.perf_counter() - time_start)

def iter_torch_batches(
self,
*,
prefetch_blocks: int = 0,
batch_size: Optional[int] = None,
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: Optional[str] = None,
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
) -> Iterator[TorchTensorBatchType]:
"""Return a local batched iterator of Torch Tensors over the dataset.
This iterator will yield single-tensor batches if the underlying dataset
consists of a single column; otherwise, it will yield a dictionary of
column-tensors. If looking for more flexibility in the tensor conversion (e.g.
casting dtypes) or the batch format, try use `.iter_batches` directly, which is
a lower-level API.
Examples:
>>> import ray
>>> for batch in ray.data.range( # doctest: +SKIP
... 12,
... ).iter_torch_batches(batch_size=4):
... print(batch.shape) # doctest: +SKIP
torch.Size([4, 1])
torch.Size([4, 1])
torch.Size([4, 1])
Time complexity: O(1)
Args:
prefetch_blocks: The number of blocks to prefetch ahead of the
current block during the scan.
batch_size: Record batch size, or None to let the system pick.
dtypes: The Torch dtype(s) for the created tensor(s); if None, the dtype
will be inferred from the tensor data.
device: The device on which the tensor should be placed; if None, the Torch
tensor will be constructed on the CPU.
drop_last: Whether to drop the last batch if it's incomplete.
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
minimum number of rows that must be in the local in-memory shuffle
buffer in order to yield a batch. When there are no more rows to add to
the buffer, the remaining rows in the buffer will be drained. This
buffer size must be greater than or equal to ``batch_size``, and
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
Returns:
An iterator over Torch Tensor batches.
"""
from ray.air._internal.torch_utils import (
convert_ndarray_batch_to_torch_tensor_batch,
)

for batch in self.iter_batches(
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
batch_format="numpy",
drop_last=drop_last,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
):
yield convert_ndarray_batch_to_torch_tensor_batch(
batch,
dtypes=dtypes,
device=device,
)

def iter_tf_batches(
self,
*,
prefetch_blocks: int = 0,
batch_size: Optional[int] = None,
dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None,
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
) -> Iterator[TensorFlowTensorBatchType]:
"""Return a local batched iterator of TensorFlow Tensors over the dataset.
This iterator will yield single-tensor batches of the underlying dataset
consists of a single column; otherwise, it will yield a dictionary of
column-tensors. If looking for more flexibility in the tensor conversion (e.g.
casting dtypes) or the batch format, try using `.to_tf`, which has a
declarative API for tensor casting and batch formatting, or use `.iter_batches`
directly, which is a lower-level API.
Examples:
>>> import ray
>>> for batch in ray.data.range( # doctest: +SKIP
... 12,
... ).iter_torch_batches(batch_size=4):
... print(batch.shape) # doctest: +SKIP
(4, 1)
(4, 1)
(4, 1)
Time complexity: O(1)
Args:
prefetch_blocks: The number of blocks to prefetch ahead of the
current block during the scan.
batch_size: Record batch size, or None to let the system pick.
dtypes: The TensorFlow dtype(s) for the created tensor(s); if None, the
dtype will be inferred from the tensor data.
drop_last: Whether to drop the last batch if it's incomplete.
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
minimum number of rows that must be in the local in-memory shuffle
buffer in order to yield a batch. When there are no more rows to add to
the buffer, the remaining rows in the buffer will be drained. This
buffer size must be greater than or equal to ``batch_size``, and
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
Returns:
An iterator over TensorFlow Tensor batches.
"""
from ray.air._internal.tensorflow_utils import (
convert_ndarray_batch_to_tf_tensor_batch,
)

for batch in self.iter_batches(
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
batch_format="numpy",
drop_last=drop_last,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
):
yield convert_ndarray_batch_to_tf_tensor_batch(batch, dtypes=dtypes)

def to_torch(
self,
*,
Expand Down Expand Up @@ -2500,15 +2631,11 @@ def to_torch(
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
minimum number of rows that must be in the local in-memory shuffle
buffer in order to yield a batch. This is a light-weight alternative to
the global `.random_shuffle()` operation; this shuffle will be less
random but will be faster and less resource-intensive. This buffer size
must be greater than or equal to ``batch_size``, and therefore
``batch_size`` must also be specified when using local shuffling.
Increasing this will improve the randomness of the shuffle but will
increase CPU memory utilization and the latency to the first batch. The
CPU memory utilization ceiling is the max of the prefetch buffer size
(controlled by ``prefetch_blocks``) and this shuffle buffer size.
buffer in order to yield a batch. When there are no more rows to add to
the buffer, the remaining rows in the buffer will be drained. This
buffer size must be greater than or equal to ``batch_size``, and
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
unsqueeze_label_tensor: If set to True, the label tensor
will be unsqueezed (reshaped to (N, 1)). Otherwise, it will
Expand Down Expand Up @@ -2681,15 +2808,11 @@ def to_tf(
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
minimum number of rows that must be in the local in-memory shuffle
buffer in order to yield a batch. This is a light-weight alternative to
the global `.random_shuffle()` operation; this shuffle will be less
random but will be faster and less resource-intensive. This buffer size
must be greater than or equal to ``batch_size``, and therefore
``batch_size`` must also be specified when using local shuffling.
Increasing this will improve the randomness of the shuffle but will
increase CPU memory utilization and the latency to the first batch. The
CPU memory utilization ceiling is the max of the prefetch buffer size
(controlled by ``prefetch_blocks``) and this shuffle buffer size.
buffer in order to yield a batch. When there are no more rows to add to
the buffer, the remaining rows in the buffer will be drained. This
buffer size must be greater than or equal to ``batch_size``, and
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
Returns:
Expand Down
Loading

0 comments on commit a29baf9

Please sign in to comment.