Skip to content

Commit

Permalink
[data] Capture the context when the dataset is first created (ray-pro…
Browse files Browse the repository at this point in the history
…ject#35239)

Signed-off-by: e428265 <[email protected]>
  • Loading branch information
ericl authored and arvind-chandra committed Aug 31, 2023
1 parent 455f655 commit 4b07da4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 7 deletions.
2 changes: 0 additions & 2 deletions python/ray/data/_internal/iterator/iterator_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from ray.types import ObjectRef
from ray.data.block import Block, BlockMetadata
from ray.data.context import DataContext
from ray.data.iterator import DataIterator
from ray.data._internal.stats import DatasetStats

Expand All @@ -17,7 +16,6 @@ def __init__(
base_dataset: "Dataset",
):
self._base_dataset = base_dataset
self._base_context = DataContext.get_current()

def __repr__(self) -> str:
return f"DataIterator({self._base_dataset})"
Expand Down
19 changes: 14 additions & 5 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def __init__(

self._run_by_consumer = run_by_consumer

# Snapshot the current context, so that the config of Datasets is always
# determined by the config at the time it was created.
self._context = copy.deepcopy(DataContext.get_current())

def __repr__(self) -> str:
return (
f"ExecutionPlan("
Expand Down Expand Up @@ -483,7 +487,9 @@ def execute_to_iterator(
Tuple of iterator over output blocks and the executor.
"""

ctx = DataContext.get_current()
# Always used the saved context for execution.
ctx = self._context

if not ctx.use_streaming_executor or self.has_computed_output():
return (
self.execute(
Expand Down Expand Up @@ -532,7 +538,10 @@ def execute(
Returns:
The blocks of the output dataset.
"""
context = DataContext.get_current()

# Always used the saved context for execution.
context = self._context

if not ray.available_resources().get("CPU"):
if log_once("cpu_warning"):
logger.get_logger().warning(
Expand Down Expand Up @@ -672,7 +681,7 @@ def _optimize(self) -> Tuple[BlockList, DatasetStats, List[Stage]]:
"""Apply stage fusion optimizations, returning an updated source block list and
associated stats, and a set of optimized stages.
"""
context = DataContext.get_current()
context = self._context
blocks, stats, stages = self._get_source_blocks_and_stages()
if context.optimize_reorder_stages:
stages = _reorder_stages(stages)
Expand Down Expand Up @@ -728,7 +737,7 @@ def is_read_stage_equivalent(self) -> bool:
"""Return whether this plan can be executed as only a read stage."""
from ray.data._internal.stage_impl import RandomizeBlocksStage

context = DataContext.get_current()
context = self._context
remaining_stages = self._stages_after_snapshot
if (
context.optimize_fuse_stages
Expand Down Expand Up @@ -764,7 +773,7 @@ def _run_with_new_execution_backend(self) -> bool:
# - Read only: handle with legacy backend
# - Read->randomize_block_order: handle with new backend
# Note that both are considered read equivalent, hence this extra check.
context = DataContext.get_current()
context = self._context
trailing_randomize_block_order_stage = (
self._stages_after_snapshot
and len(self._stages_after_snapshot) == 1
Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4094,6 +4094,12 @@ def deserialize_lineage(serialized_ds: bytes) -> "Dataset":
"""
return pickle.loads(serialized_ds)

@property
@DeveloperAPI
def context(self) -> DataContext:
"""Return the DataContext used to create this Dataset."""
return self._plan._context

def _divide(self, block_idx: int) -> ("Dataset", "Dataset"):
block_list = self._plan.execute()
left, right = block_list.divide(block_idx)
Expand Down
30 changes: 30 additions & 0 deletions python/ray/data/tests/test_context_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,36 @@
from ray._private.test_utils import run_string_as_driver


def test_context_saved_when_dataset_created(ray_start_regular_shared):
ctx = DataContext.get_current()
d1 = ray.data.range(10)
d2 = ray.data.range(10)
assert ctx.eager_free
assert d1.context.eager_free
assert d2.context.eager_free

d1.context.eager_free = False
assert not d1.context.eager_free
assert d2.context.eager_free
assert ctx.eager_free

@ray.remote(num_cpus=0)
def check(d1, d2):
assert not d1.context.eager_free
assert d2.context.eager_free

ray.get(check.remote(d1, d2))

@ray.remote(num_cpus=0)
def check2(d):
d.take()

d1.context.execution_options.resource_limits.cpu = 0.1
with pytest.raises(ValueError):
ray.get(check2.remote(d1))
ray.get(check2.remote(d2))


def test_read(ray_start_regular_shared):
class CustomDatasource(Datasource):
def prepare_read(self, parallelism: int):
Expand Down

0 comments on commit 4b07da4

Please sign in to comment.