Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] randomize_block_order() not compatible with stage fusion #26090

Merged
merged 1 commit into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def _optimize(self) -> Tuple[BlockList, DatasetStats, List[Stage]]:
"""
context = DatasetContext.get_current()
blocks, stats, stages = self._get_source_blocks_and_stages()
if context.optimize_reorder_stages:
stages = _reorder_stages(stages)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericl If you rewrite the read stage before you reorder stages, won't the read->map_batches fusion work without the lazy vs defer_execution hack?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because if you call .randomize_blocks() eagerly it will materialize the read blocks, so it's too late. We have to force it to be lazy as a special case.

I also need this for auto_repartition() anyway, so I think this is a sensible thing... and if we move to lazy by default it will go away.

if context.optimize_fuse_stages:
if context.optimize_fuse_read_stages:
# If using a lazy datasource, rewrite read stage into one-to-one stage
Expand Down Expand Up @@ -692,6 +694,18 @@ def __call__(
return blocks, stage_info


class RandomizeBlocksStage(AllToAllStage):
def __init__(self, seed: Optional[int]):
def do_randomize_block_order(block_list, *_):
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}
randomized_block_list = block_list.randomize_block_order(seed)
return randomized_block_list, {}

super().__init__("randomize_block_order", None, do_randomize_block_order)


def _rewrite_read_stages(
blocks: BlockList,
stats: DatasetStats,
Expand Down Expand Up @@ -741,6 +755,36 @@ def block_fn(read_fn: Callable[[], Iterator[Block]]) -> Iterator[Block]:
return block_list, stats, stage


def _reorder_stages(stages: List[Stage]) -> List[Stage]:
"""Reorder randomize stages to the end to enable better stage fusion.

This applies to RandomizeBlockOrder stages specifically (issue #26057).

Args:
stages: Stages to try to reorder.

Returns:
Reordered stages.
"""

output: List[Stage] = []
reorder_buf: List[RandomizeBlocksStage] = []

for s in stages:
if isinstance(s, RandomizeBlocksStage):
# Buffer it for later reordering.
reorder_buf.append(s)
else:
# Barrier: flush the reorder buffer.
if isinstance(s, AllToAllStage):
output.extend(reorder_buf)
reorder_buf = []
output.append(s)

output.extend(reorder_buf)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is a good one-off.



def _fuse_one_to_one_stages(stages: List[Stage]) -> List[Stage]:
"""Fuses compatible one-to-one stages.

Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
# Whether to enable stage-fusion optimizations for dataset pipelines.
DEFAULT_OPTIMIZE_FUSE_STAGES = True

# Whether to enable stage-reorder optimizations for dataset pipelines.
DEFAULT_OPTIMIZE_REORDER_STAGES = True

# Whether to furthermore fuse read stages. When this is enabled, data will also be
# re-read from the base dataset in each repetition of a DatasetPipeline.
DEFAULT_OPTIMIZE_FUSE_READ_STAGES = True
Expand Down Expand Up @@ -62,6 +65,7 @@ def __init__(
optimize_fuse_stages: bool,
optimize_fuse_read_stages: bool,
optimize_fuse_shuffle_stages: bool,
optimize_reorder_stages: bool,
actor_prefetcher_enabled: bool,
use_push_based_shuffle: bool,
pipeline_push_based_shuffle_reduce_tasks: bool,
Expand All @@ -76,6 +80,7 @@ def __init__(
self.optimize_fuse_stages = optimize_fuse_stages
self.optimize_fuse_read_stages = optimize_fuse_read_stages
self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
self.optimize_reorder_stages = optimize_reorder_stages
self.actor_prefetcher_enabled = actor_prefetcher_enabled
self.use_push_based_shuffle = use_push_based_shuffle
self.pipeline_push_based_shuffle_reduce_tasks = (
Expand Down Expand Up @@ -104,6 +109,7 @@ def get_current() -> "DatasetContext":
optimize_fuse_stages=DEFAULT_OPTIMIZE_FUSE_STAGES,
optimize_fuse_read_stages=DEFAULT_OPTIMIZE_FUSE_READ_STAGES,
optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES,
optimize_reorder_stages=DEFAULT_OPTIMIZE_REORDER_STAGES,
actor_prefetcher_enabled=DEFAULT_ACTOR_PREFETCHER_ENABLED,
use_push_based_shuffle=DEFAULT_USE_PUSH_BASED_SHUFFLE,
# NOTE(swang): We have to pipeline reduce tasks right now
Expand Down
32 changes: 12 additions & 20 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
from ray.data._internal.fast_repartition import fast_repartition
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.plan import AllToAllStage, ExecutionPlan, OneToOneStage
from ray.data._internal.plan import (
AllToAllStage,
ExecutionPlan,
OneToOneStage,
RandomizeBlocksStage,
)
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.shuffle_and_partition import (
Expand Down Expand Up @@ -169,6 +174,8 @@ def __init__(
plan: ExecutionPlan,
epoch: int,
lazy: bool,
*,
defer_execution: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had hard time understand why there's a defer_execution when there is already a lazy parameter? Seems to muddy the execution semantics even more.

):
"""Construct a Dataset (internal API).

Expand All @@ -183,7 +190,7 @@ def __init__(
self._epoch = epoch
self._lazy = lazy

if not lazy:
if not lazy and not defer_execution:
self._plan.execute(allow_clear_input_blocks=False)

@staticmethod
Expand Down Expand Up @@ -808,26 +815,11 @@ def randomize_block_order(
based on system randomness.

Returns:
The shuffled dataset.
The block-shuffled dataset.
"""

def do_randomize_block_order(block_list, *_):
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}

randomized_block_list = block_list.randomize_block_order(seed)

return randomized_block_list, {}

plan = self._plan.with_stage(
AllToAllStage(
"randomize_block_order",
None,
do_randomize_block_order,
)
)
return Dataset(plan, self._epoch, self._lazy)
plan = self._plan.with_stage(RandomizeBlocksStage(seed))
return Dataset(plan, self._epoch, self._lazy, defer_execution=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm having both lazy and defer_execution is a bit odd, why can't this follow the self._lazy semantics? Is there an issue with executing this stage eagerly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I can't remember why I did this... removed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it breaks lazy read->map_batches() fusion, that's why. I reverted this since it broke a unit test.

Copy link
Contributor

@clarkzinzow clarkzinzow Jun 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm do you know why it breaks that?


def random_sample(
self, fraction: float, *, seed: Optional[int] = None
Expand Down
36 changes: 35 additions & 1 deletion python/ray/data/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ray
from ray._private.internal_api import memory_summary
from ray.data import Dataset
from ray.data.block import BlockMetadata
from ray.data.context import DatasetContext
from ray.data.datasource import Datasource, ReadTask
Expand Down Expand Up @@ -48,7 +49,14 @@ def expect_stages(pipe, num_stages_expected, stage_names):
for name in stage_names:
name = " " + name + ":"
assert name in stats, (name, stats)
assert len(pipe._optimized_stages) == num_stages_expected, pipe._optimized_stages
if isinstance(pipe, Dataset):
assert (
len(pipe._plan._stages_before_snapshot) == num_stages_expected
), pipe._plan._stages_before_snapshot
else:
assert (
len(pipe._optimized_stages) == num_stages_expected
), pipe._optimized_stages


def test_memory_sanity(shutdown_only):
Expand Down Expand Up @@ -294,6 +302,32 @@ def test_stage_linking(ray_start_regular_shared):
_assert_has_stages(ds._plan._last_optimized_stages, ["read->map"])


def test_optimize_reorder(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = True
context.optimize_reorder_stages = True

ds = ray.data.range(10).randomize_block_order().map_batches(lambda x: x)
expect_stages(
ds,
2,
["read->map_batches", "randomize_block_order"],
)

ds2 = (
ray.data.range(10)
.randomize_block_order()
.repartition(10)
.map_batches(lambda x: x)
)
expect_stages(
ds2,
3,
["read", "randomize_block_order", "repartition", "map_batches"],
)


def test_optimize_fuse(ray_start_regular_shared):
context = DatasetContext.get_current()

Expand Down