Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed Jan 29, 2022
1 parent eb8adc6 commit a564aac
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
15 changes: 13 additions & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2289,7 +2289,7 @@ def __next__(self) -> "Dataset[T]":
raise StopIteration
self._ds._set_epoch(self._i)
self._i += 1
return lambda: self._ds
return lambda: self._ds.force_reads()

class Iterable:
def __init__(self, ds: "Dataset[T]"):
Expand Down Expand Up @@ -2366,7 +2366,8 @@ def __next__(self) -> "Dataset[T]":
blocks = self._splits.pop(0)

def gen():
return Dataset(blocks, self._epoch, outer_stats)
ds = Dataset(blocks, self._epoch, outer_stats)
return ds.force_reads()

return gen

Expand Down Expand Up @@ -2395,6 +2396,16 @@ def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
"""
return self._blocks.get_blocks()

@DeveloperAPI
def force_reads(self) -> "Dataset[T]":
"""Force full evaluation of the blocks of this dataset.
This can be used to read all blocks into memory. By default, Datasets
doesn't read blocks from the datasource until the first transform.
"""
self.get_internal_block_refs()
return self

@DeveloperAPI
def stats(self) -> str:
"""Returns a string containing execution timing information."""
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def _read_stream(self, f: "pa.NativeFile", path: str, **reader_args):
@pytest.mark.parametrize("pipelined", [False, True])
def test_basic_actors(shutdown_only, pipelined):
ray.init(num_cpus=2)
ds = ray.data.range(5)
ds = ray.data.range(20)
ds = maybe_pipeline(ds, pipelined)
assert sorted(ds.map(lambda x: x + 1,
compute="actors").take()) == [1, 2, 3, 4, 5]
assert sorted(ds.map(lambda x: x + 1, compute="actors").take()) == range(
1, 20)


@pytest.mark.parametrize("pipelined", [False, True])
Expand Down

0 comments on commit a564aac

Please sign in to comment.