diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 79bedc02ad6e..b2f401e49d3a 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -786,6 +786,7 @@ def __init__(self, *, max_history: int = 3): "iter_get_s": Timer(), "iter_next_batch_s": Timer(), "iter_format_batch_s": Timer(), + "iter_collate_batch_s": Timer(), "iter_user_s": Timer(), "iter_total_s": Timer(), } diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index b43c498c1466..cce950249659 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -31,6 +31,13 @@ from ray.data.dataset import TensorFlowTensorBatchType +def _is_tensor_dataset(schema) -> bool: + """Return ``True`` if this is an iterator over a tensor dataset.""" + if schema is None or isinstance(schema, type): + return False + return _is_tensor_schema(schema.names) + + @PublicAPI(stability="beta") class DatasetIterator(abc.ABC): """An iterator for reading items from a :class:`~Dataset` or @@ -728,13 +735,14 @@ def to_tf( except ImportError: raise ValueError("tensorflow must be installed!") - if self._is_tensor_dataset(): + schema = self.schema() + + if _is_tensor_dataset(schema): raise NotImplementedError( "`to_tf` doesn't support single-column tensor datasets. Call the " "more-flexible `iter_batches` instead." ) - schema = self.schema() if isinstance(schema, type): raise NotImplementedError( "`to_tf` doesn't support simple datasets. Call `map_batches` and " @@ -818,10 +826,3 @@ def iter_epochs(self, max_epoch: int = -1) -> None: "To iterate over one epoch of data, use iter_batches(), " "iter_torch_batches(), or to_tf()." ) - - def _is_tensor_dataset(self) -> bool: - """Return ``True`` if this is an iterator over a tensor dataset.""" - schema = self.schema() - if schema is None or isinstance(schema, type): - return False - return _is_tensor_schema(schema.names) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index c02062b3d2c9..2a3a6c405a3b 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -169,6 +169,8 @@ def gen_rows() -> Iterator[Union[T, TableRow]]: def iter_batches( self, *, + prefetch_batches: int = 1, + # Deprecated. prefetch_blocks: int = 0, batch_size: Optional[int] = 256, batch_format: Optional[str] = "default", @@ -1071,7 +1073,7 @@ def iter_tf_batches( """Call :py:meth:`Dataset.iter_tf_batches ` over the stream of output batches from the pipeline.""" - return Dataset.iter_tf_batches( + return DatasetIterator.iter_tf_batches( self, prefetch_blocks=prefetch_blocks, batch_size=batch_size, @@ -1097,7 +1099,7 @@ def iter_torch_batches( """Call :py:meth:`Dataset.iter_torch_batches ` over the stream of output batches from the pipeline.""" - return Dataset.iter_torch_batches( + return DatasetIterator.iter_torch_batches( self, prefetch_blocks=prefetch_blocks, batch_size=batch_size, @@ -1122,7 +1124,7 @@ def to_tf( ) -> "tf.data.Dataset": """Call :py:meth:`Dataset.to_tf ` over the stream of output batches from the pipeline""" - return Dataset.to_tf( + return DatasetIterator.to_tf( self, feature_columns=feature_columns, label_columns=label_columns, @@ -1152,7 +1154,7 @@ def to_torch( ) -> "torch.utils.data.IterableDataset": """Call :py:meth:`Dataset.to_torch ` over the stream of output batches from the pipeline""" - return Dataset.to_torch( + return DatasetIterator.to_torch( self, label_column=label_column, feature_columns=feature_columns, diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index c36a2262e2d0..914253bf3279 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -394,9 +394,31 @@ def test_iter_batches_basic(ray_start_regular_shared): def test_to_torch(ray_start_regular_shared): - pipe = ray.data.range(10, parallelism=10).window(blocks_per_window=2) + pipe = ray.data.range(10, parallelism=10).window(blocks_per_window=2).repeat(2) batches = list(pipe.to_torch(batch_size=None)) - assert len(batches) == 10 + assert len(batches) == 20 + + +def test_to_tf(ray_start_regular_shared): + ds = ray.data.range_tensor(10, shape=(1, 1, 1), parallelism=10) + ds = ds.add_column("label", lambda x: 1) + pipe = ds.window(blocks_per_window=2).repeat(2) + batches = list( + pipe.to_tf(feature_columns="__value__", label_columns="label", batch_size=None) + ) + assert len(batches) == 20 + + +def test_iter_torch_batches(ray_start_regular_shared): + pipe = ray.data.range(10).repeat(2) + batches = list(pipe.iter_torch_batches(batch_size=1)) + assert len(batches) == 20 + + +def test_iter_tf_batches(ray_start_regular_shared): + pipe = ray.data.range(10).repeat(2) + batches = list(pipe.iter_tf_batches(batch_size=1)) + assert len(batches) == 20 def test_iter_batches_batch_across_windows(ray_start_regular_shared):