Skip to content

Commit

Permalink
Add objects GC in dataset iterator (ray-project#34030)
Browse files Browse the repository at this point in the history
* Revert "[Datasets] Revert "Enable streaming executor by default (ray-project#32493)" (ray-project#33485)"

This reverts commit 5c79954.

* Add objects GC in dataset iterator

* test it

* more tests

* fix comment

* add a little more memory as it's close to the limit and may make test flaky

* feedback

Signed-off-by: elliottower <[email protected]>
  • Loading branch information
jianoaix authored and elliottower committed Apr 22, 2023
1 parent 7382073 commit d015068
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def __repr__(self) -> str:
def _to_block_iterator(
self,
) -> Tuple[
Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]
Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats], bool
]:
ds = self._base_dataset
block_iterator, stats, executor = ds._plan.execute_to_iterator()
ds._current_executor = executor
return block_iterator, stats
return block_iterator, stats, False

def stats(self) -> str:
return self._base_dataset.stats()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,28 @@ def _get_next_dataset(self) -> "DatasetPipeline":
def _to_block_iterator(
self,
) -> Tuple[
Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]
Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats], bool
]:
epoch_pipeline = self._get_next_dataset()

# Peek the first dataset from the pipeline to see if blocks are owned
# by consumer. If so, the blocks are safe to be eagerly cleared after use
# because memories are not shared across different consumers. This will
# improve the memory efficiency.
if epoch_pipeline._first_dataset is not None:
blocks_owned_by_consumer = (
epoch_pipeline._first_dataset._plan.execute()._owned_by_consumer
)
else:
blocks_owned_by_consumer = (
epoch_pipeline._peek()._plan.execute()._owned_by_consumer
)

def block_iter():
for ds in epoch_pipeline.iter_datasets():
yield from ds._plan.execute().iter_blocks_with_metadata()

return block_iter(), None
return block_iter(), None, blocks_owned_by_consumer

def iter_batches(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
def _to_block_iterator(
self,
) -> Tuple[
Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]
Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats], bool
]:
def gen_blocks() -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]:
cur_epoch = ray.get(
Expand All @@ -100,7 +100,7 @@ def gen_blocks() -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]:
)
yield block_ref

return gen_blocks(), None
return gen_blocks(), None, False

def stats(self) -> str:
"""Implements DatasetIterator."""
Expand Down
8 changes: 6 additions & 2 deletions python/ray/data/dataset_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ class DatasetIterator(abc.ABC):
def _to_block_iterator(
self,
) -> Tuple[
Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]
Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats], bool
]:
"""Returns the iterator to use for `iter_batches`.
Returns:
A tuple. The first item of the tuple is an iterator over pairs of Block
object references and their corresponding metadata. The second item of the
tuple is a DatasetStats object used for recording stats during iteration.
The third item is a boolean indicating if the blocks can be safely cleared
after use.
"""
raise NotImplementedError

Expand Down Expand Up @@ -153,7 +155,7 @@ def iter_batches(

time_start = time.perf_counter()

block_iterator, stats = self._to_block_iterator()
block_iterator, stats, blocks_owned_by_consumer = self._to_block_iterator()
if use_legacy:
# Legacy iter_batches does not use metadata.
def drop_metadata(block_iterator):
Expand All @@ -164,6 +166,7 @@ def drop_metadata(block_iterator):
drop_metadata(block_iterator),
stats=stats,
prefetch_blocks=prefetch_blocks,
clear_block_after_read=blocks_owned_by_consumer,
batch_size=batch_size,
batch_format=batch_format,
drop_last=drop_last,
Expand All @@ -175,6 +178,7 @@ def drop_metadata(block_iterator):
yield from iter_batches(
block_iterator,
stats=stats,
clear_block_after_read=blocks_owned_by_consumer,
batch_size=batch_size,
batch_format=batch_format,
drop_last=drop_last,
Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/tests/test_dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ def test_iter_batches_basic(ray_start_regular_shared):
assert all(len(e) == 1 for e in batches)


def test_to_torch(ray_start_regular_shared):
pipe = ray.data.range(10, parallelism=10).window(blocks_per_window=2)
batches = list(pipe.to_torch(batch_size=None))
assert len(batches) == 10


def test_iter_batches_batch_across_windows(ray_start_regular_shared):
# 3 windows, each containing 3 blocks, each containing 3 rows.
pipe = ray.data.range(27, parallelism=9).window(blocks_per_window=3)
Expand Down
78 changes: 78 additions & 0 deletions python/ray/data/tests/test_object_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,52 @@ def check_no_spill(ctx, pipe):
assert "Spilled" not in meminfo, meminfo


def check_to_torch_no_spill(ctx, pipe):
# Run up to 10 epochs of the pipeline to stress test that
# no spilling will happen.
max_epoch = 10
for p in pipe.iter_epochs(max_epoch):
for _ in p.to_torch(batch_size=None):
pass
meminfo = memory_summary(ctx.address_info["address"], stats_only=True)
assert "Spilled" not in meminfo, meminfo


def check_iter_torch_batches_no_spill(ctx, pipe):
# Run up to 10 epochs of the pipeline to stress test that
# no spilling will happen.
max_epoch = 10
for p in pipe.iter_epochs(max_epoch):
for _ in p.iter_torch_batches(batch_size=None):
pass
meminfo = memory_summary(ctx.address_info["address"], stats_only=True)
assert "Spilled" not in meminfo, meminfo


def check_to_tf_no_spill(ctx, pipe):
# Run up to 10 epochs of the pipeline to stress test that
# no spilling will happen.
max_epoch = 10
for p in pipe.iter_epochs(max_epoch):
for _ in p.to_tf(
feature_columns="__value__", label_columns="label", batch_size=None
):
pass
meminfo = memory_summary(ctx.address_info["address"], stats_only=True)
assert "Spilled" not in meminfo, meminfo


def check_iter_tf_batches_no_spill(ctx, pipe):
# Run up to 10 epochs of the pipeline to stress test that
# no spilling will happen.
max_epoch = 10
for p in pipe.iter_epochs(max_epoch):
for _ in p.iter_tf_batches():
pass
meminfo = memory_summary(ctx.address_info["address"], stats_only=True)
assert "Spilled" not in meminfo, meminfo


def test_iter_batches_no_spilling_upon_no_transformation(shutdown_only):
# The object store is about 300MB.
ctx = ray.init(num_cpus=1, object_store_memory=300e6)
Expand All @@ -26,6 +72,38 @@ def test_iter_batches_no_spilling_upon_no_transformation(shutdown_only):
check_no_spill(ctx, ds.window(blocks_per_window=20))


def test_torch_iteration(shutdown_only):
# The object store is about 400MB.
ctx = ray.init(num_cpus=1, object_store_memory=400e6)
# The size of dataset is 500*(80*80*4)*8B, about 100MB.
ds = ray.data.range_tensor(500, shape=(80, 80, 4), parallelism=100)

# to_torch
check_to_torch_no_spill(ctx, ds.repeat())
check_to_torch_no_spill(ctx, ds.window(blocks_per_window=20))
# iter_torch_batches
check_iter_torch_batches_no_spill(ctx, ds.repeat())
check_iter_torch_batches_no_spill(ctx, ds.window(blocks_per_window=20))


def test_tf_iteration(shutdown_only):
# The object store is about 800MB.
ctx = ray.init(num_cpus=1, object_store_memory=800e6)
# The size of dataset is 500*(80*80*4)*8B, about 100MB.
ds = ray.data.range_tensor(500, shape=(80, 80, 4), parallelism=100).add_column(
"label", lambda x: 1
)

# to_tf
check_to_tf_no_spill(ctx, ds.repeat().map(lambda x: x))
check_to_tf_no_spill(ctx, ds.window(blocks_per_window=20).map(lambda x: x))
# iter_tf_batches
check_iter_tf_batches_no_spill(ctx, ds.repeat().map(lambda x: x))
check_iter_tf_batches_no_spill(
ctx, ds.window(blocks_per_window=20).map(lambda x: x)
)


def test_iter_batches_no_spilling_upon_rewindow(shutdown_only):
# The object store is about 300MB.
ctx = ray.init(num_cpus=1, object_store_memory=300e6)
Expand Down

0 comments on commit d015068

Please sign in to comment.