From a564aac47df38032ee6372083355745c11e03912 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 28 Jan 2022 18:32:59 -0800 Subject: [PATCH] fix --- python/ray/data/dataset.py | 15 +++++++++++++-- python/ray/data/tests/test_dataset.py | 6 +++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 8f6a9b97f751..e220c07f3168 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2289,7 +2289,7 @@ def __next__(self) -> "Dataset[T]": raise StopIteration self._ds._set_epoch(self._i) self._i += 1 - return lambda: self._ds + return lambda: self._ds.force_reads() class Iterable: def __init__(self, ds: "Dataset[T]"): @@ -2366,7 +2366,8 @@ def __next__(self) -> "Dataset[T]": blocks = self._splits.pop(0) def gen(): - return Dataset(blocks, self._epoch, outer_stats) + ds = Dataset(blocks, self._epoch, outer_stats) + return ds.force_reads() return gen @@ -2395,6 +2396,16 @@ def get_internal_block_refs(self) -> List[ObjectRef[Block]]: """ return self._blocks.get_blocks() + @DeveloperAPI + def force_reads(self) -> "Dataset[T]": + """Force full evaluation of the blocks of this dataset. + + This can be used to read all blocks into memory. By default, Datasets + doesn't read blocks from the datasource until the first transform. + """ + self.get_internal_block_refs() + return self + @DeveloperAPI def stats(self) -> str: """Returns a string containing execution timing information.""" diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index eea7961dbcba..9d38f9b6ab46 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -76,10 +76,10 @@ def _read_stream(self, f: "pa.NativeFile", path: str, **reader_args): @pytest.mark.parametrize("pipelined", [False, True]) def test_basic_actors(shutdown_only, pipelined): ray.init(num_cpus=2) - ds = ray.data.range(5) + ds = ray.data.range(20) ds = maybe_pipeline(ds, pipelined) - assert sorted(ds.map(lambda x: x + 1, - compute="actors").take()) == [1, 2, 3, 4, 5] + assert sorted(ds.map(lambda x: x + 1, compute="actors").take()) == range( + 1, 20) @pytest.mark.parametrize("pipelined", [False, True])