From b5796eca5ab071a8e1f22fa58938871380370bab Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 17 Jan 2023 17:29:27 -0800 Subject: [PATCH] [Data] Don't drop first dataset when peeking `DatasetPipeline` (#31513) Signed-off-by: amogkam amogkamsetty@yahoo.com Closes #31505. When peeking a DatasetPipeline via .schema() for example, the first dataset in the base iterator is consumed. Then when chaining new operations on the pipeline, such as a map_batches, the dataset that was peeked is lost. In this PR, we change the implementation of peek to not consume the base iterable, but rather create a new iterable consisting of just the first dataset. Signed-off-by: Andrea Pisoni --- python/ray/data/_internal/stats.py | 38 ++++++++-- python/ray/data/dataset.py | 2 +- python/ray/data/dataset_pipeline.py | 70 ++++++++++++++----- .../ray/data/tests/test_dataset_pipeline.py | 24 ++++++- 4 files changed, 108 insertions(+), 26 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 193814aadf30..c17477fb0469 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -644,13 +644,23 @@ def __init__(self, *, max_history: int = 3): self.wait_time_s = [] # Iteration stats, filled out if the user iterates over the pipeline. - self.iter_ds_wait_s: Timer = Timer() - self.iter_wait_s: Timer = Timer() - self.iter_get_s: Timer = Timer() - self.iter_next_batch_s: Timer = Timer() - self.iter_format_batch_s: Timer = Timer() - self.iter_user_s: Timer = Timer() - self.iter_total_s: Timer = Timer() + self._iter_stats = { + "iter_ds_wait_s": Timer(), + "iter_wait_s": Timer(), + "iter_get_s": Timer(), + "iter_next_batch_s": Timer(), + "iter_format_batch_s": Timer(), + "iter_user_s": Timer(), + "iter_total_s": Timer(), + } + + # Make iteration stats also accessible via attributes. + def __getattr__(self, name): + if name == "_iter_stats": + raise AttributeError + if name in self._iter_stats: + return self._iter_stats[name] + raise AttributeError def add(self, stats: DatasetStats) -> None: """Called to add stats for a newly computed window.""" @@ -659,6 +669,20 @@ def add(self, stats: DatasetStats) -> None: self.history_buffer.pop(0) self.count += 1 + def add_pipeline_stats(self, other_stats: "DatasetPipelineStats"): + """Add the provided pipeline stats to the current stats. + + `other_stats` should cover a disjoint set of windows than + the current stats. + """ + for _, dataset_stats in other_stats.history_buffer: + self.add(dataset_stats) + + self.wait_time_s.extend(other_stats.wait_time_s) + + for stat_name, timer in self._iter_stats.items(): + timer.add(other_stats._iter_stats[stat_name].get()) + def _summarize_iter(self) -> str: out = "" if ( diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 55f8ab1bffa7..961d318f83fb 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3622,7 +3622,7 @@ def __init__(self, blocks): self._blocks = blocks self._i = 0 - def __next__(self) -> "Dataset[T]": + def __next__(self) -> Callable[[], "Dataset[T]"]: if times and self._i >= times: raise StopIteration epoch = self._i diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 2bc43497d7e0..f0d586931696 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -82,7 +82,7 @@ def __init__( self, base_iterable: Iterable[Callable[[], Dataset[T]]], stages: List[Callable[[Dataset[Any]], Dataset[Any]]] = None, - length: int = None, + length: Optional[int] = None, progress_bars: bool = progress_bar._enabled, _executed: List[bool] = None, ): @@ -101,8 +101,8 @@ def __init__( # Whether the pipeline execution has started. # This variable is shared across all pipelines descending from this. self._executed = _executed or [False] - self._dataset_iter = None - self._first_dataset = None + self._first_dataset: Optional[Dataset] = None + self._remaining_datasets_iter: Optional[Iterator[Callable[[], Dataset]]] = None self._schema = None self._stats = DatasetPipelineStats() @@ -504,7 +504,7 @@ def __init__(self, original_iter): # This is calculated later. self._max_i = None - def __next__(self) -> Dataset[T]: + def __next__(self) -> Callable[[], Dataset[T]]: # Still going through the original pipeline. if self._original_iter: try: @@ -540,11 +540,11 @@ def gen(): raise StopIteration class RepeatIterable: - def __init__(self, original_iter): - self._original_iter = original_iter + def __init__(self, original_iterable): + self._original_iterable = original_iterable def __iter__(self): - return RepeatIterator(self._original_iter) + return RepeatIterator(iter(self._original_iterable)) if not times: length = float("inf") @@ -554,7 +554,7 @@ def __iter__(self): length = None return DatasetPipeline( - RepeatIterable(iter(self._base_iterable)), + RepeatIterable(self._base_iterable), stages=self._stages.copy(), length=length, ) @@ -1180,12 +1180,35 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: if self._executed[0]: raise RuntimeError("Pipeline cannot be read multiple times.") self._executed[0] = True - if self._first_dataset is None: - self._peek() - iter = itertools.chain([self._first_dataset], self._dataset_iter) - self._first_dataset = None - self._dataset_iter = None - return iter + + self._optimize_stages() + + # If the first dataset has already been executed (via a peek operation), then + # we don't re-execute the first dataset when iterating through the pipeline. + # We re-use the saved _first_dataset and _remaining_dataset_iter. + if self._first_dataset is not None: + + class _IterableWrapper(Iterable): + """Wrapper that takes an iterator and converts it to an + iterable.""" + + def __init__(self, base_iterator): + self.base_iterator = base_iterator + + def __iter__(self): + return self.base_iterator + + # Update the base iterable to skip the first dataset. + # It is ok to update the base iterable here since + # the pipeline can never be executed again. + self._base_iterable = _IterableWrapper(self._remaining_datasets_iter) + + iter = itertools.chain([self._first_dataset], PipelineExecutor(self)) + self._first_dataset = None + self._remaining_datasets_iter = None + return iter + else: + return PipelineExecutor(self) @DeveloperAPI def foreach_window( @@ -1201,6 +1224,7 @@ def foreach_window( """ if self._executed[0]: raise RuntimeError("Pipeline cannot be read multiple times.") + return DatasetPipeline( self._base_iterable, self._stages + [fn], @@ -1289,9 +1313,21 @@ def add_stage(ds, stage): def _peek(self) -> Dataset[T]: if self._first_dataset is None: - self._optimize_stages() - self._dataset_iter = PipelineExecutor(self) - self._first_dataset = next(self._dataset_iter) + dataset_iter = iter(self._base_iterable) + first_dataset_gen = next(dataset_iter) + peek_pipe = DatasetPipeline( + base_iterable=[first_dataset_gen], + stages=self._stages.copy(), + length=1, + progress_bars=True, + ) + # Cache the executed _first_dataset. + self._first_dataset = next(peek_pipe.iter_datasets()) + self._remaining_datasets_iter = dataset_iter + + # Store the stats from the peek pipeline. + self._stats.add_pipeline_stats(peek_pipe._stats) + return self._first_dataset def _write_each_dataset(self, write_fn: Callable[[Dataset[T]], None]) -> None: diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index 7f3895ca3544..20264e65260d 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -471,6 +471,24 @@ def test_schema_peek(ray_start_regular_shared): assert pipe.schema() is None +def test_schema_after_repeat(ray_start_regular_shared): + pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2) + assert pipe.schema() == int + output = [] + for ds in pipe.iter_datasets(): + output.extend(ds.take()) + assert sorted(output) == sorted(list(range(6)) * 2) + + pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2) + assert pipe.schema() == int + # Test that operations still work after peek. + pipe = pipe.map_batches(lambda batch: batch) + output = [] + for ds in pipe.iter_datasets(): + output.extend(ds.take()) + assert sorted(output) == sorted(list(range(6)) * 2) + + def test_split(ray_start_regular_shared): pipe = ray.data.range(3).map(lambda x: x + 1).repeat(10) @@ -757,7 +775,11 @@ def __next__(self): return lambda: ds p1 = ray.data.range(10).repeat() - p2 = DatasetPipeline.from_iterable(Iterable(p1.iter_datasets())) + # Start the pipeline. + data_iter = p1.iter_datasets() + next(data_iter) + + p2 = DatasetPipeline.from_iterable(Iterable(data_iter)) with pytest.raises(RuntimeError) as error: p2.split(2) assert "PipelineExecutor is not serializable once it has started" in str(error)