Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Don't drop first dataset when peeking DatasetPipeline #31513

Merged
merged 16 commits into from
Jan 18, 2023
34 changes: 27 additions & 7 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,23 @@ def __init__(self, *, max_history: int = 3):
self.wait_time_s = []

# Iteration stats, filled out if the user iterates over the pipeline.
self.iter_ds_wait_s: Timer = Timer()
self.iter_wait_s: Timer = Timer()
self.iter_get_s: Timer = Timer()
self.iter_next_batch_s: Timer = Timer()
self.iter_format_batch_s: Timer = Timer()
self.iter_user_s: Timer = Timer()
self.iter_total_s: Timer = Timer()
self._iter_stats = {
"iter_ds_wait_s": Timer(),
"iter_wait_s": Timer(),
"iter_get_s": Timer(),
"iter_next_batch_s": Timer(),
"iter_format_batch_s": Timer(),
"iter_user_s": Timer(),
"iter_total_s": Timer(),
}

# Make iteration stats also accessible via attributes.
def __getattr__(self, name):
if name == "_iter_stats":
raise AttributeError
if name in self._iter_stats:
return self._iter_stats[name]
raise AttributeError

def add(self, stats: DatasetStats) -> None:
"""Called to add stats for a newly computed window."""
Expand All @@ -439,6 +449,16 @@ def add(self, stats: DatasetStats) -> None:
self.history_buffer.pop(0)
self.count += 1

def add_pipeline_stats(self, other_stats: "DatasetPipelineStats"):
amogkam marked this conversation as resolved.
Show resolved Hide resolved
"""Add the provided pipeline stats to the current stats."""
for _, dataset_stats in other_stats.history_buffer:
self.add(dataset_stats)

self.wait_time_s.extend(other_stats.wait_time_s)

for stat_name, timer in self._iter_stats.items():
timer.add(other_stats._iter_stats[stat_name].get())

def _summarize_iter(self) -> str:
out = ""
if (
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3589,7 +3589,7 @@ def __init__(self, blocks):
self._blocks = blocks
self._i = 0

def __next__(self) -> "Dataset[T]":
def __next__(self) -> Callable[[], "Dataset[T]"]:
if times and self._i >= times:
raise StopIteration
epoch = self._i
Expand Down
70 changes: 53 additions & 17 deletions python/ray/data/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
self,
base_iterable: Iterable[Callable[[], Dataset[T]]],
stages: List[Callable[[Dataset[Any]], Dataset[Any]]] = None,
length: int = None,
length: Optional[int] = None,
progress_bars: bool = progress_bar._enabled,
_executed: List[bool] = None,
):
Expand All @@ -98,8 +98,8 @@ def __init__(
# Whether the pipeline execution has started.
# This variable is shared across all pipelines descending from this.
self._executed = _executed or [False]
self._dataset_iter = None
self._first_dataset = None
self._first_dataset: Optional[Dataset] = None
self._remaining_datasets_iter: Optional[Iterator[Callable[[], Dataset]]] = None
self._schema = None
self._stats = DatasetPipelineStats()

Expand Down Expand Up @@ -481,7 +481,7 @@ def __init__(self, original_iter):
# This is calculated later.
self._max_i = None

def __next__(self) -> Dataset[T]:
def __next__(self) -> Callable[[], Dataset[T]]:
# Still going through the original pipeline.
if self._original_iter:
try:
Expand Down Expand Up @@ -517,11 +517,11 @@ def gen():
raise StopIteration

class RepeatIterable:
def __init__(self, original_iter):
self._original_iter = original_iter
def __init__(self, original_iterable):
self._original_iterable = original_iterable

def __iter__(self):
return RepeatIterator(self._original_iter)
return RepeatIterator(iter(self._original_iterable))

if not times:
length = float("inf")
Expand All @@ -531,7 +531,7 @@ def __iter__(self):
length = None

return DatasetPipeline(
RepeatIterable(iter(self._base_iterable)),
RepeatIterable(self._base_iterable),
stages=self._stages.copy(),
length=length,
)
Expand Down Expand Up @@ -1150,12 +1150,35 @@ def iter_datasets(self) -> Iterator[Dataset[T]]:
if self._executed[0]:
raise RuntimeError("Pipeline cannot be read multiple times.")
self._executed[0] = True
if self._first_dataset is None:
self._peek()
iter = itertools.chain([self._first_dataset], self._dataset_iter)
self._first_dataset = None
self._dataset_iter = None
return iter

self._optimize_stages()

# If the first dataset has already been executed (via a peek operation), then
# we don't re-execute the first dataset when iterating through the pipeline.
# We re-use the saved _first_dataset and _remaining_dataset_iter.
if self._first_dataset is not None:

class _IterableWrapper(Iterable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this wrapping needed since iterator itself has iter to return itself?

"""Wrapper that takes an iterator and converts it to an
iterable."""

def __init__(self, base_iterator):
self.base_iterator = base_iterator

def __iter__(self):
return self.base_iterator

# Update the base iterable to skip the first dataset.
# It is ok to update the base iterable here since
# the pipeline can never be executed again.
self._base_iterable = _IterableWrapper(self._remaining_datasets_iter)

iter = itertools.chain([self._first_dataset], PipelineExecutor(self))
self._first_dataset = None
self._remaining_datasets_iter = None
return iter
else:
return PipelineExecutor(self)

@DeveloperAPI
def foreach_window(
Expand All @@ -1171,6 +1194,7 @@ def foreach_window(
"""
if self._executed[0]:
raise RuntimeError("Pipeline cannot be read multiple times.")

return DatasetPipeline(
self._base_iterable,
self._stages + [fn],
Expand Down Expand Up @@ -1259,9 +1283,21 @@ def add_stage(ds, stage):

def _peek(self) -> Dataset[T]:
if self._first_dataset is None:
self._optimize_stages()
self._dataset_iter = PipelineExecutor(self)
self._first_dataset = next(self._dataset_iter)
dataset_iter = iter(self._base_iterable)
first_dataset_gen = next(dataset_iter)
peek_pipe = DatasetPipeline(
base_iterable=[first_dataset_gen],
stages=self._stages.copy(),
length=1,
progress_bars=True,
)
amogkam marked this conversation as resolved.
Show resolved Hide resolved
# Cache the executed _first_dataset.
self._first_dataset = next(peek_pipe.iter_datasets())
self._remaining_datasets_iter = dataset_iter

# Store the stats from the peek pipeline.
self._stats.add_pipeline_stats(peek_pipe._stats)

return self._first_dataset

def _write_each_dataset(self, write_fn: Callable[[Dataset[T]], None]) -> None:
Expand Down
24 changes: 23 additions & 1 deletion python/ray/data/tests/test_dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,24 @@ def test_schema_peek(ray_start_regular_shared):
assert pipe.schema() is None


def test_schema_after_repeat(ray_start_regular_shared):
pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2)
assert pipe.schema() == int
output = []
for ds in pipe.iter_datasets():
output.extend(ds.take())
assert sorted(output) == sorted(list(range(6)) * 2)

pipe = ray.data.range(6, parallelism=6).window(blocks_per_window=2).repeat(2)
assert pipe.schema() == int
# Test that operations still work after peek.
pipe = pipe.map_batches(lambda batch: batch)
output = []
for ds in pipe.iter_datasets():
output.extend(ds.take())
assert sorted(output) == sorted(list(range(6)) * 2)


def test_split(ray_start_regular_shared):
pipe = ray.data.range(3).map(lambda x: x + 1).repeat(10)

Expand Down Expand Up @@ -757,7 +775,11 @@ def __next__(self):
return lambda: ds

p1 = ray.data.range(10).repeat()
p2 = DatasetPipeline.from_iterable(Iterable(p1.iter_datasets()))
# Start the pipeline.
data_iter = p1.iter_datasets()
next(data_iter)

p2 = DatasetPipeline.from_iterable(Iterable(data_iter))
with pytest.raises(RuntimeError) as error:
p2.split(2)
assert "PipelineExecutor is not serializable once it has started" in str(error)
Expand Down