From fd3616fc097f3e9cb894e950a2add95f30dd144e Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 6 Jan 2023 19:37:18 -0800 Subject: [PATCH 01/14] fix Signed-off-by: amogkam --- python/ray/data/dataset.py | 2 +- python/ray/data/dataset_pipeline.py | 28 ++++++++++--------- .../ray/data/tests/test_dataset_pipeline.py | 18 ++++++++++++ 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index e39fea4fe204..b3e5795ad716 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3589,7 +3589,7 @@ def __init__(self, blocks): self._blocks = blocks self._i = 0 - def __next__(self) -> "Dataset[T]": + def __next__(self) -> Callable[[], "Dataset[T]"]: if times and self._i >= times: raise StopIteration epoch = self._i diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 5dcc9c8dd368..54ac02be8590 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1,4 +1,3 @@ -import itertools import logging import sys import time @@ -98,7 +97,6 @@ def __init__( # Whether the pipeline execution has started. # This variable is shared across all pipelines descending from this. self._executed = _executed or [False] - self._dataset_iter = None self._first_dataset = None self._schema = None self._stats = DatasetPipelineStats() @@ -481,7 +479,7 @@ def __init__(self, original_iter): # This is calculated later. self._max_i = None - def __next__(self) -> Dataset[T]: + def __next__(self) -> Callable[[], Dataset[T]]: # Still going through the original pipeline. if self._original_iter: try: @@ -1150,12 +1148,8 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: if self._executed[0]: raise RuntimeError("Pipeline cannot be read multiple times.") self._executed[0] = True - if self._first_dataset is None: - self._peek() - iter = itertools.chain([self._first_dataset], self._dataset_iter) - self._first_dataset = None - self._dataset_iter = None - return iter + self._optimize_stages() + return PipelineExecutor(self) @DeveloperAPI def foreach_window( @@ -1171,13 +1165,15 @@ def foreach_window( """ if self._executed[0]: raise RuntimeError("Pipeline cannot be read multiple times.") - return DatasetPipeline( + + pipe = DatasetPipeline( self._base_iterable, self._stages + [fn], self._length, self._progress_bars, _executed=self._executed, ) + return pipe def stats(self, exclude_first_window: bool = True) -> str: """Returns a string containing execution timing information. @@ -1259,9 +1255,15 @@ def add_stage(ds, stage): def _peek(self) -> Dataset[T]: if self._first_dataset is None: - self._optimize_stages() - self._dataset_iter = PipelineExecutor(self) - self._first_dataset = next(self._dataset_iter) + first_dataset_gen = next(iter(self._base_iterable)) + peek_pipe = DatasetPipeline( + base_iterable=[first_dataset_gen], + stages=self._stages.copy(), + length=1, + progress_bars=False, + ) + self._first_dataset = next(peek_pipe.iter_datasets()) + return self._first_dataset def _write_each_dataset(self, write_fn: Callable[[Dataset[T]], None]) -> None: diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index 7f3895ca3544..a32fce2a6576 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -471,6 +471,24 @@ def test_schema_peek(ray_start_regular_shared): assert pipe.schema() is None +def test_schema_after_repeat(ray_start_regular_shared): + pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2) + assert pipe.schema() == int + output = [] + for ds in pipe.iter_datasets(): + output.extend(dataset.take()) + assert output.sort() == (list(range(6)) * 2).sort() + + pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2) + assert pipe.schema() == int + # Test that operations still work after peek. + pipe = pipe.map_batches(lambda batch: batch) + output = [] + for ds in pipe.iter_datasets(): + output.extend(dataset.take()) + assert output.sort() == (list(range(6)) * 2).sort() + + def test_split(ray_start_regular_shared): pipe = ray.data.range(3).map(lambda x: x + 1).repeat(10) From 94a7681757e4a4a4e9e4d96001f0c0cb57385dcb Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 9 Jan 2023 13:49:48 -0800 Subject: [PATCH 02/14] update Signed-off-by: amogkam --- python/ray/data/dataset_pipeline.py | 53 ++++++++++++++----- .../ray/data/tests/test_dataset_pipeline.py | 8 +-- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 54ac02be8590..f2358defeb50 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1,3 +1,4 @@ +import itertools import logging import sys import time @@ -97,7 +98,8 @@ def __init__( # Whether the pipeline execution has started. # This variable is shared across all pipelines descending from this. self._executed = _executed or [False] - self._first_dataset = None + self._remaining_dataset_iter: Iterator[Callable[[], Dataset]] = None + self._first_dataset: Dataset = None self._schema = None self._stats = DatasetPipelineStats() @@ -515,11 +517,11 @@ def gen(): raise StopIteration class RepeatIterable: - def __init__(self, original_iter): - self._original_iter = original_iter + def __init__(self, original_iterable): + self._original_iterable = original_iterable def __iter__(self): - return RepeatIterator(self._original_iter) + return RepeatIterator(iter(self._original_iterable)) if not times: length = float("inf") @@ -529,7 +531,7 @@ def __iter__(self): length = None return DatasetPipeline( - RepeatIterable(iter(self._base_iterable)), + RepeatIterable(self._base_iterable), stages=self._stages.copy(), length=length, ) @@ -1148,8 +1150,35 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: if self._executed[0]: raise RuntimeError("Pipeline cannot be read multiple times.") self._executed[0] = True - self._optimize_stages() - return PipelineExecutor(self) + + # Peek has already been executed so we use the cached first dataset and + # remaining_dataset_iter. + if self._first_dataset is not None: + + class _IterableWrapper(Iterable): + """Wrapper that takes an iterator and converts it to an iterable.""" + + def __init__(self, base_iterator): + self.base_iterator = base_iterator + + def __iter__(self): + return self.base_iterator + + remaining_pipeline = DatasetPipeline( + base_iterable=_IterableWrapper(self._remaining_dataset_iter), + stages=self._stages.copy(), + length=self._length - 1, + progress_bars=True, + ) + iter = itertools.chain( + [self._first_dataset], remaining_pipeline.iter_datasets() + ) + self._first_dataset = None + self._remaining_dataset_iter = None + return iter + else: + self._optimize_stages() + return PipelineExecutor(self) @DeveloperAPI def foreach_window( @@ -1166,14 +1195,13 @@ def foreach_window( if self._executed[0]: raise RuntimeError("Pipeline cannot be read multiple times.") - pipe = DatasetPipeline( + return DatasetPipeline( self._base_iterable, self._stages + [fn], self._length, self._progress_bars, _executed=self._executed, ) - return pipe def stats(self, exclude_first_window: bool = True) -> str: """Returns a string containing execution timing information. @@ -1255,15 +1283,16 @@ def add_stage(ds, stage): def _peek(self) -> Dataset[T]: if self._first_dataset is None: - first_dataset_gen = next(iter(self._base_iterable)) + dataset_iter = iter(self._base_iterable) + first_dataset_gen = next(dataset_iter) peek_pipe = DatasetPipeline( base_iterable=[first_dataset_gen], stages=self._stages.copy(), length=1, - progress_bars=False, + progress_bars=True, ) self._first_dataset = next(peek_pipe.iter_datasets()) - + self._remaining_dataset_iter = dataset_iter return self._first_dataset def _write_each_dataset(self, write_fn: Callable[[Dataset[T]], None]) -> None: diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index a32fce2a6576..869c5e3632a4 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -476,8 +476,8 @@ def test_schema_after_repeat(ray_start_regular_shared): assert pipe.schema() == int output = [] for ds in pipe.iter_datasets(): - output.extend(dataset.take()) - assert output.sort() == (list(range(6)) * 2).sort() + output.extend(ds.take()) + assert sorted(output) == sorted(list(range(6)) * 2) pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2) assert pipe.schema() == int @@ -485,8 +485,8 @@ def test_schema_after_repeat(ray_start_regular_shared): pipe = pipe.map_batches(lambda batch: batch) output = [] for ds in pipe.iter_datasets(): - output.extend(dataset.take()) - assert output.sort() == (list(range(6)) * 2).sort() + output.extend(ds.take()) + assert sorted(output) == sorted(list(range(6)) * 2) def test_split(ray_start_regular_shared): From 1b42df2b78db1dd4c09be5e5ce51ba8c196cad02 Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 9 Jan 2023 13:53:23 -0800 Subject: [PATCH 03/14] add comments Signed-off-by: amogkam --- python/ray/data/dataset_pipeline.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index f2358defeb50..a5df0969dae2 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1151,8 +1151,9 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: raise RuntimeError("Pipeline cannot be read multiple times.") self._executed[0] = True - # Peek has already been executed so we use the cached first dataset and - # remaining_dataset_iter. + # If the first dataset has already been executed (via a peek operation), then + # we don't re-execute the first dataset when iterating through the pipeline. + # We re-use the saved _first_dataset and _remaining_dataset_iter if self._first_dataset is not None: class _IterableWrapper(Iterable): @@ -1291,6 +1292,8 @@ def _peek(self) -> Dataset[T]: length=1, progress_bars=True, ) + # Cache the executed _first_dataset and store an iterator + # for the remaining dataset generators. self._first_dataset = next(peek_pipe.iter_datasets()) self._remaining_dataset_iter = dataset_iter return self._first_dataset From b6ed7f5185f5579c225d971ea19082e947ba3b84 Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 9 Jan 2023 14:09:55 -0800 Subject: [PATCH 04/14] fix Signed-off-by: amogkam --- python/ray/data/dataset_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index a5df0969dae2..9e95fa43f6dd 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -79,7 +79,7 @@ def __init__( self, base_iterable: Iterable[Callable[[], Dataset[T]]], stages: List[Callable[[Dataset[Any]], Dataset[Any]]] = None, - length: int = None, + length: Optional[int] = None, progress_bars: bool = progress_bar._enabled, _executed: List[bool] = None, ): @@ -1168,7 +1168,7 @@ def __iter__(self): remaining_pipeline = DatasetPipeline( base_iterable=_IterableWrapper(self._remaining_dataset_iter), stages=self._stages.copy(), - length=self._length - 1, + length=(self._length - 1) if self._length else None, progress_bars=True, ) iter = itertools.chain( From 3f2ec868f903dbc863b146309ca3a9a78308c88e Mon Sep 17 00:00:00 2001 From: amogkam Date: Tue, 10 Jan 2023 17:08:12 -0800 Subject: [PATCH 05/14] handle stats Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 33 +++++++++++++++++++----- python/ray/data/dataset_pipeline.py | 40 +++++++++++++++-------------- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 95ca2a994b12..ddd9b98fa2e0 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -424,13 +424,22 @@ def __init__(self, *, max_history: int = 3): self.wait_time_s = [] # Iteration stats, filled out if the user iterates over the pipeline. - self.iter_ds_wait_s: Timer = Timer() - self.iter_wait_s: Timer = Timer() - self.iter_get_s: Timer = Timer() - self.iter_next_batch_s: Timer = Timer() - self.iter_format_batch_s: Timer = Timer() - self.iter_user_s: Timer = Timer() - self.iter_total_s: Timer = Timer() + self._iter_stats = { + "iter_ds_wait_s": Timer(), + "iter_wait_s": Timer(), + "iter_get_s": Timer(), + "iter_next_batch_s": Timer(), + "iter_format_batch_s": Timer(), + "iter_user_s": Timer(), + "iter_total_s": Timer(), + } + + # Make iteration stats also accessible via attributes. + def __getattr__(self, name): + if name in self._iter_stats: + return self._iter_stats[name] + + raise AttributeError def add(self, stats: DatasetStats) -> None: """Called to add stats for a newly computed window.""" @@ -439,6 +448,16 @@ def add(self, stats: DatasetStats) -> None: self.history_buffer.pop(0) self.count += 1 + def add_pipeline(self, other_stats: "DatasetPipelineStats"): + """Add the provided pipeline stats to the current stats.""" + for _, dataset_stats in other_stats.history_buffer: + self.add(dataset_stats) + + self.wait_time_s.extend(other_stats.wait_time_s) + + for stat_name, timer in self._iter_stats: + timer.add(other_stats._iter_stats[stat_name].get()) + def _summarize_iter(self) -> str: out = "" if ( diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 9e95fa43f6dd..73a2138c94f9 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -98,7 +98,6 @@ def __init__( # Whether the pipeline execution has started. # This variable is shared across all pipelines descending from this. self._executed = _executed or [False] - self._remaining_dataset_iter: Iterator[Callable[[], Dataset]] = None self._first_dataset: Dataset = None self._schema = None self._stats = DatasetPipelineStats() @@ -1151,34 +1150,35 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: raise RuntimeError("Pipeline cannot be read multiple times.") self._executed[0] = True + self._optimize_stages() + # If the first dataset has already been executed (via a peek operation), then # we don't re-execute the first dataset when iterating through the pipeline. - # We re-use the saved _first_dataset and _remaining_dataset_iter + # We re-use the saved _first_dataset. if self._first_dataset is not None: class _IterableWrapper(Iterable): - """Wrapper that takes an iterator and converts it to an iterable.""" + """Wrapper that takes an iterator and converts it to an + iterable with the first dataset skipped.""" - def __init__(self, base_iterator): - self.base_iterator = base_iterator + def __init__(self, base_iterable): + # Skip the first dataset since it's already been peeked. + self.base_iterator: Iterator = itertools.islice( + self._base_iterator, start=1 + ) def __iter__(self): return self.base_iterator - remaining_pipeline = DatasetPipeline( - base_iterable=_IterableWrapper(self._remaining_dataset_iter), - stages=self._stages.copy(), - length=(self._length - 1) if self._length else None, - progress_bars=True, - ) - iter = itertools.chain( - [self._first_dataset], remaining_pipeline.iter_datasets() - ) + # Update the base iterable to skip the first dataset. + # It is ok to update the base iterable here since + # the dataset can never be # executed again. + self._base_iterable = _IterableWrapper(self._base_iterable) + + iter = itertools.chain([self._first_dataset], PipelineExecutor(self)) self._first_dataset = None - self._remaining_dataset_iter = None return iter else: - self._optimize_stages() return PipelineExecutor(self) @DeveloperAPI @@ -1292,10 +1292,12 @@ def _peek(self) -> Dataset[T]: length=1, progress_bars=True, ) - # Cache the executed _first_dataset and store an iterator - # for the remaining dataset generators. + # Cache the executed _first_dataset. self._first_dataset = next(peek_pipe.iter_datasets()) - self._remaining_dataset_iter = dataset_iter + + # Store the stats from the peek pipeline. + self._stats.add_pipeline(peek_pipe._stats) + return self._first_dataset def _write_each_dataset(self, write_fn: Callable[[Dataset[T]], None]) -> None: From 6ebbcc6d9b04a31337d941088d6f1d9a368563a3 Mon Sep 17 00:00:00 2001 From: amogkam Date: Tue, 10 Jan 2023 17:10:13 -0800 Subject: [PATCH 06/14] update name Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 2 +- python/ray/data/dataset_pipeline.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index ddd9b98fa2e0..6640900ac8b5 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -448,7 +448,7 @@ def add(self, stats: DatasetStats) -> None: self.history_buffer.pop(0) self.count += 1 - def add_pipeline(self, other_stats: "DatasetPipelineStats"): + def add_pipeline_stats(self, other_stats: "DatasetPipelineStats"): """Add the provided pipeline stats to the current stats.""" for _, dataset_stats in other_stats.history_buffer: self.add(dataset_stats) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 73a2138c94f9..d972674c08d5 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1296,7 +1296,7 @@ def _peek(self) -> Dataset[T]: self._first_dataset = next(peek_pipe.iter_datasets()) # Store the stats from the peek pipeline. - self._stats.add_pipeline(peek_pipe._stats) + self._stats.add_pipeline_stats(peek_pipe._stats) return self._first_dataset From ab9200d65361e68490dc9bae27f97c0f6d934285 Mon Sep 17 00:00:00 2001 From: amogkam Date: Tue, 10 Jan 2023 18:33:35 -0800 Subject: [PATCH 07/14] infinite recursion Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 6640900ac8b5..8837f153d840 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -436,9 +436,10 @@ def __init__(self, *, max_history: int = 3): # Make iteration stats also accessible via attributes. def __getattr__(self, name): + if name == "_iter_stats": + raise AttributeError if name in self._iter_stats: return self._iter_stats[name] - raise AttributeError def add(self, stats: DatasetStats) -> None: From f341b24c56dcc40d33059d68645ec323d88044b7 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 11 Jan 2023 12:02:19 -0800 Subject: [PATCH 08/14] fix Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 8837f153d840..daa037db07af 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -456,7 +456,7 @@ def add_pipeline_stats(self, other_stats: "DatasetPipelineStats"): self.wait_time_s.extend(other_stats.wait_time_s) - for stat_name, timer in self._iter_stats: + for stat_name, timer in self._iter_stats.items(): timer.add(other_stats._iter_stats[stat_name].get()) def _summarize_iter(self) -> str: From baf10085c828b4dd3927c8d3c7d799c9fa99c0ed Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 11 Jan 2023 12:33:39 -0800 Subject: [PATCH 09/14] fix Signed-off-by: amogkam --- python/ray/data/dataset_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index d972674c08d5..9e20f0c028c4 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1164,7 +1164,7 @@ class _IterableWrapper(Iterable): def __init__(self, base_iterable): # Skip the first dataset since it's already been peeked. self.base_iterator: Iterator = itertools.islice( - self._base_iterator, start=1 + base_iterable, start=1 ) def __iter__(self): @@ -1172,7 +1172,7 @@ def __iter__(self): # Update the base iterable to skip the first dataset. # It is ok to update the base iterable here since - # the dataset can never be # executed again. + # the pipeline can never be executed again. self._base_iterable = _IterableWrapper(self._base_iterable) iter = itertools.chain([self._first_dataset], PipelineExecutor(self)) From 26a405a0bbd99dac617708f0261a63ded01c4ebd Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 11 Jan 2023 14:39:43 -0800 Subject: [PATCH 10/14] fix Signed-off-by: amogkam --- python/ray/data/dataset_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 9e20f0c028c4..51fc82733d97 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1164,7 +1164,7 @@ class _IterableWrapper(Iterable): def __init__(self, base_iterable): # Skip the first dataset since it's already been peeked. self.base_iterator: Iterator = itertools.islice( - base_iterable, start=1 + base_iterable, 1, None ) def __iter__(self): From ce8cd1f644d1ce553e06800120216ddc3e309591 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 11 Jan 2023 17:29:16 -0800 Subject: [PATCH 11/14] update test Signed-off-by: amogkam --- python/ray/data/tests/test_dataset_pipeline.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index 869c5e3632a4..20264e65260d 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -775,7 +775,11 @@ def __next__(self): return lambda: ds p1 = ray.data.range(10).repeat() - p2 = DatasetPipeline.from_iterable(Iterable(p1.iter_datasets())) + # Start the pipeline. + data_iter = p1.iter_datasets() + next(data_iter) + + p2 = DatasetPipeline.from_iterable(Iterable(data_iter)) with pytest.raises(RuntimeError) as error: p2.split(2) assert "PipelineExecutor is not serializable once it has started" in str(error) From 2a9111f3a4f761a19d808c40da0bb51de4b82ab8 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 11 Jan 2023 18:31:37 -0800 Subject: [PATCH 12/14] update Signed-off-by: amogkam --- python/ray/data/dataset_pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 51fc82733d97..cb10107b9489 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -98,7 +98,8 @@ def __init__( # Whether the pipeline execution has started. # This variable is shared across all pipelines descending from this. self._executed = _executed or [False] - self._first_dataset: Dataset = None + self._first_dataset: Optional[Dataset] = None + self._remaining_datasets_iter: Optional[Iterator[Callable[[], Dataset]]] = None self._schema = None self._stats = DatasetPipelineStats() @@ -1154,18 +1155,15 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: # If the first dataset has already been executed (via a peek operation), then # we don't re-execute the first dataset when iterating through the pipeline. - # We re-use the saved _first_dataset. + # We re-use the saved _first_dataset and _remaining_dataset_iter. if self._first_dataset is not None: class _IterableWrapper(Iterable): """Wrapper that takes an iterator and converts it to an iterable with the first dataset skipped.""" - def __init__(self, base_iterable): - # Skip the first dataset since it's already been peeked. - self.base_iterator: Iterator = itertools.islice( - base_iterable, 1, None - ) + def __init__(self, base_iterator): + self.base_iterator = base_iterator def __iter__(self): return self.base_iterator @@ -1173,10 +1171,11 @@ def __iter__(self): # Update the base iterable to skip the first dataset. # It is ok to update the base iterable here since # the pipeline can never be executed again. - self._base_iterable = _IterableWrapper(self._base_iterable) + self._base_iterable = _IterableWrapper(self._remaining_datasets_iter) iter = itertools.chain([self._first_dataset], PipelineExecutor(self)) self._first_dataset = None + self._remaining_datasets_iter = None return iter else: return PipelineExecutor(self) @@ -1294,6 +1293,7 @@ def _peek(self) -> Dataset[T]: ) # Cache the executed _first_dataset. self._first_dataset = next(peek_pipe.iter_datasets()) + self._remaining_datasets_iter = dataset_iter # Store the stats from the peek pipeline. self._stats.add_pipeline_stats(peek_pipe._stats) From 9c0161d3cc6e3c7a6644bdc5c0a08ceb9d35a3da Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 11 Jan 2023 19:54:29 -0800 Subject: [PATCH 13/14] comment Signed-off-by: amogkam --- python/ray/data/dataset_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index cb10107b9489..6ef06bc9812e 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1160,7 +1160,7 @@ def iter_datasets(self) -> Iterator[Dataset[T]]: class _IterableWrapper(Iterable): """Wrapper that takes an iterator and converts it to an - iterable with the first dataset skipped.""" + iterable.""" def __init__(self, base_iterator): self.base_iterator = base_iterator From 8a648ef4cd1a23cb226f51b573b73c701b87a5e7 Mon Sep 17 00:00:00 2001 From: amogkam Date: Thu, 12 Jan 2023 17:18:17 -0800 Subject: [PATCH 14/14] comment Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index daa037db07af..28cc7d2041ee 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -450,7 +450,11 @@ def add(self, stats: DatasetStats) -> None: self.count += 1 def add_pipeline_stats(self, other_stats: "DatasetPipelineStats"): - """Add the provided pipeline stats to the current stats.""" + """Add the provided pipeline stats to the current stats. + + `other_stats` should cover a disjoint set of windows than + the current stats. + """ for _, dataset_stats in other_stats.history_buffer: self.add(dataset_stats)