Skip to content

Commit

Permalink
[data] Make sure the tf and tensor iteration work in dataset pipeline (
Browse files Browse the repository at this point in the history
…ray-project#34248)

* Revert "[Datasets] Revert "Enable streaming executor by default (ray-project#32493)" (ray-project#33485)"

This reverts commit 5c79954.

* make sure tf and tensor iteration in datapipeline work

* Fix

* fix

* fix

* fix

* feedback

* feedback

* fix

Signed-off-by: elliottower <[email protected]>
  • Loading branch information
jianoaix authored and elliottower committed Apr 22, 2023
1 parent 5b50270 commit 46012cc
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 15 deletions.
1 change: 1 addition & 0 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,7 @@ def __init__(self, *, max_history: int = 3):
"iter_get_s": Timer(),
"iter_next_batch_s": Timer(),
"iter_format_batch_s": Timer(),
"iter_collate_batch_s": Timer(),
"iter_user_s": Timer(),
"iter_total_s": Timer(),
}
Expand Down
19 changes: 10 additions & 9 deletions python/ray/data/dataset_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
from ray.data.dataset import TensorFlowTensorBatchType


def _is_tensor_dataset(schema) -> bool:
"""Return ``True`` if this is an iterator over a tensor dataset."""
if schema is None or isinstance(schema, type):
return False
return _is_tensor_schema(schema.names)


@PublicAPI(stability="beta")
class DataIterator(abc.ABC):
"""An iterator for reading items from a :class:`~Dataset` or
Expand Down Expand Up @@ -728,13 +735,14 @@ def to_tf(
except ImportError:
raise ValueError("tensorflow must be installed!")

if self._is_tensor_dataset():
schema = self.schema()

if _is_tensor_dataset(schema):
raise NotImplementedError(
"`to_tf` doesn't support single-column tensor datasets. Call the "
"more-flexible `iter_batches` instead."
)

schema = self.schema()
if isinstance(schema, type):
raise NotImplementedError(
"`to_tf` doesn't support simple datasets. Call `map_batches` and "
Expand Down Expand Up @@ -819,13 +827,6 @@ def iter_epochs(self, max_epoch: int = -1) -> None:
"iter_torch_batches(), or to_tf()."
)

def _is_tensor_dataset(self) -> bool:
"""Return ``True`` if this is an iterator over a tensor dataset."""
schema = self.schema()
if schema is None or isinstance(schema, type):
return False
return _is_tensor_schema(schema.names)


# Backwards compatibility alias.
DatasetIterator = DataIterator
10 changes: 6 additions & 4 deletions python/ray/data/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def gen_rows() -> Iterator[Union[T, TableRow]]:
def iter_batches(
self,
*,
prefetch_batches: int = 1,
# Deprecated.
prefetch_blocks: int = 0,
batch_size: Optional[int] = 256,
batch_format: Optional[str] = "default",
Expand Down Expand Up @@ -1071,7 +1073,7 @@ def iter_tf_batches(
"""Call
:py:meth:`Dataset.iter_tf_batches <ray.data.Dataset.iter_tf_batches>`
over the stream of output batches from the pipeline."""
return Dataset.iter_tf_batches(
return DataIterator.iter_tf_batches(
self,
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
Expand All @@ -1097,7 +1099,7 @@ def iter_torch_batches(
"""Call
:py:meth:`Dataset.iter_torch_batches <ray.data.Dataset.iter_torch_batches>`
over the stream of output batches from the pipeline."""
return Dataset.iter_torch_batches(
return DataIterator.iter_torch_batches(
self,
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
Expand All @@ -1122,7 +1124,7 @@ def to_tf(
) -> "tf.data.Dataset":
"""Call :py:meth:`Dataset.to_tf <ray.data.Dataset.to_tf>` over the stream of
output batches from the pipeline"""
return Dataset.to_tf(
return DataIterator.to_tf(
self,
feature_columns=feature_columns,
label_columns=label_columns,
Expand Down Expand Up @@ -1152,7 +1154,7 @@ def to_torch(
) -> "torch.utils.data.IterableDataset":
"""Call :py:meth:`Dataset.to_torch <ray.data.Dataset.to_torch>` over the stream
of output batches from the pipeline"""
return Dataset.to_torch(
return DataIterator.to_torch(
self,
label_column=label_column,
feature_columns=feature_columns,
Expand Down
26 changes: 24 additions & 2 deletions python/ray/data/tests/test_dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,31 @@ def test_iter_batches_basic(ray_start_regular_shared):


def test_to_torch(ray_start_regular_shared):
pipe = ray.data.range(10, parallelism=10).window(blocks_per_window=2)
pipe = ray.data.range(10, parallelism=10).window(blocks_per_window=2).repeat(2)
batches = list(pipe.to_torch(batch_size=None))
assert len(batches) == 10
assert len(batches) == 20


def test_to_tf(ray_start_regular_shared):
ds = ray.data.range_tensor(10, shape=(1, 1, 1), parallelism=10)
ds = ds.add_column("label", lambda x: 1)
pipe = ds.window(blocks_per_window=2).repeat(2)
batches = list(
pipe.to_tf(feature_columns="__value__", label_columns="label", batch_size=None)
)
assert len(batches) == 20


def test_iter_torch_batches(ray_start_regular_shared):
pipe = ray.data.range(10).repeat(2)
batches = list(pipe.iter_torch_batches(batch_size=1))
assert len(batches) == 20


def test_iter_tf_batches(ray_start_regular_shared):
pipe = ray.data.range(10).repeat(2)
batches = list(pipe.iter_tf_batches(batch_size=1))
assert len(batches) == 20


def test_iter_batches_batch_across_windows(ray_start_regular_shared):
Expand Down

0 comments on commit 46012cc

Please sign in to comment.