Skip to content

Commit

Permalink
[Data] Don't drop first dataset when peeking DatasetPipeline (ray-p…
Browse files Browse the repository at this point in the history
…roject#31513)

Signed-off-by: amogkam [email protected]

Closes ray-project#31505.

When peeking a DatasetPipeline via .schema() for example, the first dataset in the base iterator is consumed. Then when chaining new operations on the pipeline, such as a map_batches, the dataset that was peeked is lost.

In this PR, we change the implementation of peek to not consume the base iterable, but rather create a new iterable consisting of just the first dataset.

Signed-off-by: Andrea Pisoni <[email protected]>
  • Loading branch information
amogkam authored and andreapiso committed Jan 22, 2023
1 parent f025fd1 commit b5796ec
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 26 deletions.
38 changes: 31 additions & 7 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,13 +644,23 @@ def __init__(self, *, max_history: int = 3):
self.wait_time_s = []

# Iteration stats, filled out if the user iterates over the pipeline.
self.iter_ds_wait_s: Timer = Timer()
self.iter_wait_s: Timer = Timer()
self.iter_get_s: Timer = Timer()
self.iter_next_batch_s: Timer = Timer()
self.iter_format_batch_s: Timer = Timer()
self.iter_user_s: Timer = Timer()
self.iter_total_s: Timer = Timer()
self._iter_stats = {
"iter_ds_wait_s": Timer(),
"iter_wait_s": Timer(),
"iter_get_s": Timer(),
"iter_next_batch_s": Timer(),
"iter_format_batch_s": Timer(),
"iter_user_s": Timer(),
"iter_total_s": Timer(),
}

# Make iteration stats also accessible via attributes.
def __getattr__(self, name):
if name == "_iter_stats":
raise AttributeError
if name in self._iter_stats:
return self._iter_stats[name]
raise AttributeError

def add(self, stats: DatasetStats) -> None:
"""Called to add stats for a newly computed window."""
Expand All @@ -659,6 +669,20 @@ def add(self, stats: DatasetStats) -> None:
self.history_buffer.pop(0)
self.count += 1

def add_pipeline_stats(self, other_stats: "DatasetPipelineStats"):
"""Add the provided pipeline stats to the current stats.
`other_stats` should cover a disjoint set of windows than
the current stats.
"""
for _, dataset_stats in other_stats.history_buffer:
self.add(dataset_stats)

self.wait_time_s.extend(other_stats.wait_time_s)

for stat_name, timer in self._iter_stats.items():
timer.add(other_stats._iter_stats[stat_name].get())

def _summarize_iter(self) -> str:
out = ""
if (
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3622,7 +3622,7 @@ def __init__(self, blocks):
self._blocks = blocks
self._i = 0

def __next__(self) -> "Dataset[T]":
def __next__(self) -> Callable[[], "Dataset[T]"]:
if times and self._i >= times:
raise StopIteration
epoch = self._i
Expand Down
70 changes: 53 additions & 17 deletions python/ray/data/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self,
base_iterable: Iterable[Callable[[], Dataset[T]]],
stages: List[Callable[[Dataset[Any]], Dataset[Any]]] = None,
length: int = None,
length: Optional[int] = None,
progress_bars: bool = progress_bar._enabled,
_executed: List[bool] = None,
):
Expand All @@ -101,8 +101,8 @@ def __init__(
# Whether the pipeline execution has started.
# This variable is shared across all pipelines descending from this.
self._executed = _executed or [False]
self._dataset_iter = None
self._first_dataset = None
self._first_dataset: Optional[Dataset] = None
self._remaining_datasets_iter: Optional[Iterator[Callable[[], Dataset]]] = None
self._schema = None
self._stats = DatasetPipelineStats()

Expand Down Expand Up @@ -504,7 +504,7 @@ def __init__(self, original_iter):
# This is calculated later.
self._max_i = None

def __next__(self) -> Dataset[T]:
def __next__(self) -> Callable[[], Dataset[T]]:
# Still going through the original pipeline.
if self._original_iter:
try:
Expand Down Expand Up @@ -540,11 +540,11 @@ def gen():
raise StopIteration

class RepeatIterable:
def __init__(self, original_iter):
self._original_iter = original_iter
def __init__(self, original_iterable):
self._original_iterable = original_iterable

def __iter__(self):
return RepeatIterator(self._original_iter)
return RepeatIterator(iter(self._original_iterable))

if not times:
length = float("inf")
Expand All @@ -554,7 +554,7 @@ def __iter__(self):
length = None

return DatasetPipeline(
RepeatIterable(iter(self._base_iterable)),
RepeatIterable(self._base_iterable),
stages=self._stages.copy(),
length=length,
)
Expand Down Expand Up @@ -1180,12 +1180,35 @@ def iter_datasets(self) -> Iterator[Dataset[T]]:
if self._executed[0]:
raise RuntimeError("Pipeline cannot be read multiple times.")
self._executed[0] = True
if self._first_dataset is None:
self._peek()
iter = itertools.chain([self._first_dataset], self._dataset_iter)
self._first_dataset = None
self._dataset_iter = None
return iter

self._optimize_stages()

# If the first dataset has already been executed (via a peek operation), then
# we don't re-execute the first dataset when iterating through the pipeline.
# We re-use the saved _first_dataset and _remaining_dataset_iter.
if self._first_dataset is not None:

class _IterableWrapper(Iterable):
"""Wrapper that takes an iterator and converts it to an
iterable."""

def __init__(self, base_iterator):
self.base_iterator = base_iterator

def __iter__(self):
return self.base_iterator

# Update the base iterable to skip the first dataset.
# It is ok to update the base iterable here since
# the pipeline can never be executed again.
self._base_iterable = _IterableWrapper(self._remaining_datasets_iter)

iter = itertools.chain([self._first_dataset], PipelineExecutor(self))
self._first_dataset = None
self._remaining_datasets_iter = None
return iter
else:
return PipelineExecutor(self)

@DeveloperAPI
def foreach_window(
Expand All @@ -1201,6 +1224,7 @@ def foreach_window(
"""
if self._executed[0]:
raise RuntimeError("Pipeline cannot be read multiple times.")

return DatasetPipeline(
self._base_iterable,
self._stages + [fn],
Expand Down Expand Up @@ -1289,9 +1313,21 @@ def add_stage(ds, stage):

def _peek(self) -> Dataset[T]:
if self._first_dataset is None:
self._optimize_stages()
self._dataset_iter = PipelineExecutor(self)
self._first_dataset = next(self._dataset_iter)
dataset_iter = iter(self._base_iterable)
first_dataset_gen = next(dataset_iter)
peek_pipe = DatasetPipeline(
base_iterable=[first_dataset_gen],
stages=self._stages.copy(),
length=1,
progress_bars=True,
)
# Cache the executed _first_dataset.
self._first_dataset = next(peek_pipe.iter_datasets())
self._remaining_datasets_iter = dataset_iter

# Store the stats from the peek pipeline.
self._stats.add_pipeline_stats(peek_pipe._stats)

return self._first_dataset

def _write_each_dataset(self, write_fn: Callable[[Dataset[T]], None]) -> None:
Expand Down
24 changes: 23 additions & 1 deletion python/ray/data/tests/test_dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,24 @@ def test_schema_peek(ray_start_regular_shared):
assert pipe.schema() is None


def test_schema_after_repeat(ray_start_regular_shared):
pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2)
assert pipe.schema() == int
output = []
for ds in pipe.iter_datasets():
output.extend(ds.take())
assert sorted(output) == sorted(list(range(6)) * 2)

pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2)
assert pipe.schema() == int
# Test that operations still work after peek.
pipe = pipe.map_batches(lambda batch: batch)
output = []
for ds in pipe.iter_datasets():
output.extend(ds.take())
assert sorted(output) == sorted(list(range(6)) * 2)


def test_split(ray_start_regular_shared):
pipe = ray.data.range(3).map(lambda x: x + 1).repeat(10)

Expand Down Expand Up @@ -757,7 +775,11 @@ def __next__(self):
return lambda: ds

p1 = ray.data.range(10).repeat()
p2 = DatasetPipeline.from_iterable(Iterable(p1.iter_datasets()))
# Start the pipeline.
data_iter = p1.iter_datasets()
next(data_iter)

p2 = DatasetPipeline.from_iterable(Iterable(data_iter))
with pytest.raises(RuntimeError) as error:
p2.split(2)
assert "PipelineExecutor is not serializable once it has started" in str(error)
Expand Down

0 comments on commit b5796ec

Please sign in to comment.