Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Provide more efficient + intuitive block clearing semantics for different execution modes #24127

Merged
merged 2 commits into from
Apr 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ def __init__(
self._lazy = lazy

if not lazy:
# TODO(ekl) we should clear inputs once we have full lineage recorded.
self._plan.execute(clear_input_blocks=False)
self._plan.execute(allow_clear_input_blocks=False)

@staticmethod
def copy(dataset: "Dataset[T]") -> "Dataset[T]":
Expand Down Expand Up @@ -2667,7 +2666,7 @@ def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]":

ctx = DatasetContext.get_current()
if self._plan.is_read_stage() and ctx.optimize_fuse_read_stages:
blocks, _ = self._plan._get_source_blocks()
blocks, _, _ = self._plan._get_source_blocks_and_stages()
blocks.clear()
blocks, outer_stats, read_stage = _rewrite_read_stage(blocks)
else:
Expand Down Expand Up @@ -2786,7 +2785,7 @@ def window(

ctx = DatasetContext.get_current()
if self._plan.is_read_stage() and ctx.optimize_fuse_read_stages:
blocks, _ = self._plan._get_source_blocks()
blocks, _, _ = self._plan._get_source_blocks_and_stages()
blocks.clear()
blocks, outer_stats, read_stage = _rewrite_read_stage(blocks)
else:
Expand All @@ -2807,7 +2806,7 @@ def __next__(self) -> "Dataset[T]":

def gen():
ds = Dataset(
ExecutionPlan(blocks, outer_stats), self._epoch, lazy=False
ExecutionPlan(blocks, outer_stats), self._epoch, lazy=True
)
return ds

Expand Down Expand Up @@ -2861,19 +2860,27 @@ def __iter__(self):
)
return pipe

def fully_executed(self) -> "Dataset[T]":
def fully_executed(self, preserve_original: bool = True) -> "Dataset[T]":
"""Force full evaluation of the blocks of this dataset.

This can be used to read all blocks into memory. By default, Datasets
doesn't read blocks from the datasource until the first transform.

Args:
preserve_original: Whether the original unexecuted dataset should be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "original dataset" for a dataset? (I think users may have same question if they read this API)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I thought that when calling d2 = ds.fully_executed(preserve_original=False), original referring to ds would be obvious. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I was having a chain of operations in mind, in that case it's unclear which one of them should be thought as "original" (maybe the first one? maybe each intermediate ones? or others?).

preserved. If False, this function will mutate the original dataset,
which can more efficiently reclaim memory.

Returns:
A Dataset with all blocks fully materialized in memory.
"""
plan = self._plan.deep_copy(preserve_uuid=True)
plan.execute(force_read=True)
ds = Dataset(plan, self._epoch, lazy=False)
ds._set_uuid(self._get_uuid())
if preserve_original:
plan = self._plan.deep_copy(preserve_uuid=True)
ds = Dataset(plan, self._epoch, self._lazy)
ds._set_uuid(self._get_uuid())
else:
ds = self
ds._plan.execute(force_read=True)
return ds

def is_fully_executed(self) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/impl/block_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ def clear(self) -> None:
"""Erase references to the tasks tracked by the BlockList."""
self._blocks = None

def is_cleared(self) -> bool:
"""Whether this BlockList has been cleared."""
return self._blocks is None

def _check_if_cleared(self) -> None:
"""Raise an error if this BlockList has been previously cleared."""
if self._blocks is None:
if self.is_cleared():
raise ValueError(
"This Dataset's blocks have been moved, which means that you "
"can no longer use this Dataset."
Expand Down
4 changes: 3 additions & 1 deletion python/ray/data/impl/lazy_block_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def clear(self):
]
self._cached_metadata = [None for _ in self._cached_metadata]

def is_cleared(self) -> bool:
return all(ref is None for ref in self._block_partition_refs)

def _check_if_cleared(self):
pass # LazyBlockList can always be re-computed.

Expand All @@ -158,7 +161,6 @@ def split(self, split_size: int) -> List["LazyBlockList"]:

# Note: does not force execution prior to splitting.
def split_by_bytes(self, bytes_per_split: int) -> List["BlockList"]:
self._check_if_cleared()
output = []
cur_tasks, cur_blocks, cur_blocks_meta = [], [], []
cur_size = 0
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/impl/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def pipeline_stage(fn: Callable[[], Dataset[T]]) -> Dataset[T]:
# Force eager evaluation of all blocks in the pipeline stage. This
# prevents resource deadlocks due to overlapping stage execution (e.g.,
# task -> actor stage).
return fn().fully_executed()
return fn().fully_executed(preserve_original=False)


class PipelineExecutor:
Expand Down
89 changes: 70 additions & 19 deletions python/ray/data/impl/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def fuse(self, other: "Stage") -> "Stage":
"""Fuse this stage with a compatible stage."""
raise NotImplementedError

def __repr__(self):
return f'{type(self).__name__}("{self.name}")'

def __str__(self):
return repr(self)


class ExecutionPlan:
"""A lazy execution plan for a Dataset."""
Expand Down Expand Up @@ -224,21 +230,29 @@ def meta_count(self) -> Optional[int]:
return None

def execute(
self, clear_input_blocks: bool = True, force_read: bool = False
self,
allow_clear_input_blocks: bool = True,
force_read: bool = False,
) -> BlockList:
"""Execute this plan.

Args:
clear_input_blocks: Whether to assume ownership of the input blocks,
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
allowing them to be dropped from memory during execution.
allow_clear_input_blocks: Whether we should try to clear the input blocks
for each stage.
force_read: Whether to force the read stage to fully execute.

Returns:
The blocks of the output dataset.
"""
if not self.has_computed_output():
blocks, stats, stages = self._optimize()
for stage in stages:
for stage_idx, stage in enumerate(stages):
if allow_clear_input_blocks:
clear_input_blocks = self._should_clear_input_blocks(
blocks, stage_idx
)
else:
clear_input_blocks = False
stats_builder = stats.child_builder(stage.name)
blocks, stage_info = stage(blocks, clear_input_blocks)
if stage_info:
Expand Down Expand Up @@ -275,13 +289,34 @@ def stats(self) -> DatasetStats:
self.execute()
return self._snapshot_stats

def _should_clear_input_blocks(
self,
blocks: BlockList,
stage_idx: int,
):
"""Whether the provided blocks should be cleared when passed into the stage.

Args:
blocks: The blocks that we may want to clear.
stage_idx: The position of the stage in the optimized after-snapshot chain.
"""
if stage_idx != 0 or self._stages_before_snapshot:
# Not the first stage, always clear stage input blocks.
return True
elif isinstance(blocks, LazyBlockList):
# Always clear lazy input blocks since they can be recomputed.
return True
else:
# Otherwise, we have non-lazy input blocks that's the source of this
# execution plan, so we don't clear these.
return False

def _optimize(self) -> Tuple[BlockList, DatasetStats, List[Stage]]:
"""Apply stage fusion optimizations, returning an updated source block list and
associated stats, and a set of optimized stages.
"""
context = DatasetContext.get_current()
blocks, stats = self._get_source_blocks()
stages = self._stages_after_snapshot.copy()
blocks, stats, stages = self._get_source_blocks_and_stages()
if context.optimize_fuse_stages:
if context.optimize_fuse_read_stages:
# If using a lazy datasource, rewrite read stage into one-to-one stage
Expand All @@ -293,20 +328,32 @@ def _optimize(self) -> Tuple[BlockList, DatasetStats, List[Stage]]:
self._last_optimized_stages = stages
return blocks, stats, stages

def _get_source_blocks(self) -> Tuple[BlockList, DatasetStats]:
"""Get the source blocks (and corresponding stats) for plan execution.
def _get_source_blocks_and_stages(
self,
) -> Tuple[BlockList, DatasetStats, List[Stage]]:
"""Get the source blocks, corresponding stats, and the stages for plan
execution.

If a computed snapshot exists, return the snapshot blocks and stats; otherwise,
return the input blocks and stats that the plan was created with.
If a computed snapshot exists and has not been cleared, return the snapshot
jianoaix marked this conversation as resolved.
Show resolved Hide resolved
blocks and stats; otherwise, return the input blocks and stats that the plan was
created with.
"""
stages = self._stages_after_snapshot.copy()
if self._snapshot_blocks is not None:
# If snapshot exists, we only have to execute the plan from the
# snapshot.
blocks = self._snapshot_blocks
stats = self._snapshot_stats
# Unlink the snapshot blocks from the plan so we can eagerly reclaim the
# snapshot block memory after the first stage is done executing.
self._snapshot_blocks = None
if not self._snapshot_blocks.is_cleared():
# If snapshot exists, we only have to execute the plan from the
# snapshot.
blocks = self._snapshot_blocks
stats = self._snapshot_stats
# Unlink the snapshot blocks from the plan so we can eagerly reclaim the
# snapshot block memory after the first stage is done executing.
self._snapshot_blocks = None
else:
# Snapshot exists but has been cleared, so we need to recompute from the
# source (input blocks).
blocks = self._in_blocks
stats = self._in_stats
stages = self._stages_before_snapshot + self._stages_after_snapshot
else:
# If no snapshot exists, we have to execute the full plan from the
# beginning.
Expand All @@ -317,7 +364,7 @@ def _get_source_blocks(self) -> Tuple[BlockList, DatasetStats]:
# can eagerly reclaim the input block memory after the first stage is
# done executing.
self._in_blocks = None
return blocks, stats
return blocks, stats, stages

def has_lazy_input(self) -> bool:
"""Return whether this plan has lazy input blocks."""
Expand All @@ -335,7 +382,11 @@ def has_computed_output(self) -> bool:
"""Whether this plan has a computed snapshot for the final stage, i.e. for the
output of this plan.
"""
return self._snapshot_blocks is not None and not self._stages_after_snapshot
return (
self._snapshot_blocks is not None
and not self._stages_after_snapshot
and not self._snapshot_blocks.is_cleared()
)


class OneToOneStage(Stage):
Expand Down
Loading