Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support iter_epochs for Datasets #19217

Merged
merged 19 commits into from
Oct 12, 2021
46 changes: 45 additions & 1 deletion doc/source/data/dataset-pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,40 @@ You can also create a DatasetPipeline from a custom iterator over dataset creato
splits = ray.data.range(1000, parallelism=200).split(20)
pipe = DatasetPipeline.from_iterable([lambda s=s: s for s in splits])
Handling Epochs
~~~~~~~~~~~~~~~

It's common in ML training to want to divide data ingest into epochs, or repetitions over the original source dataset. DatasetPipeline provides a convenient ``.iter_epochs()`` method that can be used to split up the pipeline into epoch-delimited pipeline segments. Epochs are defined by the last call to ``.repeat()`` in a pipeline, for example:

.. code-block:: python
pipe = ray.data.range(5).repeat(3).random_shuffle_each_window()
for i, epoch in enumerate(pipe.iter_epochs()):
print("Epoch {}", i)
for row in epoch.iter_rows():
print(row)
# ->
# Epoch 0
# 2
# 1
# 3
# 4
# 0
# Epoch 1
# 3
# 4
# 0
# 2
# 1
# Epoch 2
# 3
# 2
# 4
# 1
# 0
Note that while epochs commonly consist of a single window, they can also contain multiple windows if ``.window()`` is used or there are multiple ``.repeat()`` calls.

Per-Window Transformations
~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -79,12 +113,14 @@ While most Dataset operations are per-row (e.g., map, filter), some operations a
# Example of randomly shuffling each window of a pipeline.
ray.data.range(5).repeat(2).random_shuffle_each_window().show_windows()
# ->
# ----- Epoch 0 ------
# === Window 0 ===
# 4
# 3
# 1
# 0
# 2
# ----- Epoch 1 ------
# === Window 1 ===
# 2
# 1
Expand All @@ -99,12 +135,14 @@ You can also apply arbitrary transformations to each window using ``DatasetPipel
# Equivalent transformation using .foreach_window()
ray.data.range(5).repeat(2).foreach_window(lambda w: w.random_shuffle()).show_windows()
# ->
# ----- Epoch 0 ------
# === Window 0 ===
# 1
# 0
# 4
# 2
# 3
# ----- Epoch 1 ------
# === Window 1 ===
# 4
# 2
Expand Down Expand Up @@ -256,6 +294,7 @@ Sometimes, you may want to change the structure of an existing pipeline. For exa
.repeat(2) \
.show_windows()
# ->
# ------ Epoch 0 ------
# === Window 0 ===
# 0
# 1
Expand All @@ -264,6 +303,7 @@ Sometimes, you may want to change the structure of an existing pipeline. For exa
# 3
# === Window 2 ===
# 4
# ------ Epoch 1 ------
# === Window 3 ===
# 0
# 1
Expand All @@ -273,18 +313,22 @@ Sometimes, you may want to change the structure of an existing pipeline. For exa
# === Window 5 ===
# 4
# Repeat followed by window.
# Repeat followed by window. Note that epoch 1 contains some leftover
# data from the tail end of epoch 0, since re-windowing can merge windows
# across epochs.
ray.data.range(5) \
.repeat(2) \
.rewindow(blocks_per_window=2) \
.show_windows()
# ->
# ------ Epoch 0 ------
# === Window 0 ===
# 0
# 1
# === Window 1 ===
# 2
# 3
# ------ Epoch 1 ------
# === Window 2 ===
# 4
# 0
Expand Down
77 changes: 57 additions & 20 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@

logger = logging.getLogger(__name__)

# Whether we have warned of Datasets containing multiple epochs of data.
_epoch_warned = False


@PublicAPI(stability="beta")
class Dataset(Generic[T]):
Expand All @@ -63,14 +66,15 @@ class Dataset(Generic[T]):
and simple repartition, but currently not aggregations and joins.
"""

def __init__(self, blocks: BlockList[T]):
def __init__(self, blocks: BlockList[T], epoch: int):
"""Construct a Dataset (internal API).
The constructor is not part of the Dataset API. Use the ``ray.data.*``
read methods to construct a dataset.
"""
self._blocks: BlockList[T] = blocks
self._uuid = uuid4().hex
self._epoch = epoch
ericl marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(self._blocks, BlockList), self._blocks

def map(self,
Expand Down Expand Up @@ -125,7 +129,9 @@ def transform(block: Block) -> Block:

compute = get_compute(compute)

return Dataset(compute.apply(transform, ray_remote_args, self._blocks))
return Dataset(
compute.apply(transform, ray_remote_args, self._blocks),
self._epoch)

def map_batches(self,
fn: Union[CallableClass, Callable[[BatchType], BatchType]],
Expand Down Expand Up @@ -219,7 +225,9 @@ def transform(block: Block) -> Block:

compute = get_compute(compute)

return Dataset(compute.apply(transform, ray_remote_args, self._blocks))
return Dataset(
compute.apply(transform, ray_remote_args, self._blocks),
self._epoch)

def flat_map(self,
fn: Union[CallableClass, Callable[[T], Iterable[U]]],
Expand Down Expand Up @@ -257,7 +265,9 @@ def transform(block: Block) -> Block:

compute = get_compute(compute)

return Dataset(compute.apply(transform, ray_remote_args, self._blocks))
return Dataset(
compute.apply(transform, ray_remote_args, self._blocks),
self._epoch)

def filter(self,
fn: Union[CallableClass, Callable[[T], bool]],
Expand Down Expand Up @@ -295,7 +305,9 @@ def transform(block: Block) -> Block:

compute = get_compute(compute)

return Dataset(compute.apply(transform, ray_remote_args, self._blocks))
return Dataset(
compute.apply(transform, ray_remote_args, self._blocks),
self._epoch)

def repartition(self, num_blocks: int) -> "Dataset[T]":
"""Repartition the dataset into exactly this number of blocks.
Expand All @@ -316,7 +328,7 @@ def repartition(self, num_blocks: int) -> "Dataset[T]":
"""

new_blocks = simple_shuffle(self._blocks, num_blocks)
return Dataset(new_blocks)
return Dataset(new_blocks, self._epoch)

def random_shuffle(
self,
Expand Down Expand Up @@ -360,7 +372,7 @@ def random_shuffle(
random_shuffle=True,
random_seed=seed,
_spread_resource_prefix=_spread_resource_prefix)
return Dataset(new_blocks)
return Dataset(new_blocks, self._epoch)

def _move_blocks(self):
blocks = self._blocks.copy()
Expand Down Expand Up @@ -557,7 +569,8 @@ def equalize(splits: List[Dataset[T]],
return equalize([
Dataset(
BlockList(
list(blocks), [metadata_mapping[b] for b in blocks]))
list(blocks), [metadata_mapping[b]
for b in blocks]), self._epoch)
for blocks in np.array_split(block_refs, n)
if not equal or len(blocks) > 0
], n)
Expand Down Expand Up @@ -657,7 +670,7 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]:
BlockList(
allocation_per_actor[actor],
[metadata_mapping[b]
for b in allocation_per_actor[actor]]))
for b in allocation_per_actor[actor]]), self._epoch)
for actor in locality_hints
], n)

Expand Down Expand Up @@ -733,7 +746,18 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
metadata.extend(bl._metadata)
blocks.extend(bl._blocks)

return Dataset(LazyBlockList(calls, metadata, blocks))
epochs = [ds._get_epoch() for ds in datasets]
max_epoch = max(*epochs)
if len(set(epochs)) > 1:
global _epoch_warned
if not _epoch_warned:
logger.warning(
"Dataset contains data from multiple epochs: {}, "
"likely due to a `rewindow()` call. The higher epoch "
"number {} will be used. This warning will not "
"be shown again.".format(set(epochs), max_epoch))
_epoch_warned = True
return Dataset(LazyBlockList(calls, metadata, blocks), max_epoch)

def sort(self,
key: Union[None, str, List[str], Callable[[T], Any]] = None,
Expand Down Expand Up @@ -772,7 +796,7 @@ def sort(self,
# Handle empty dataset.
if self.num_blocks() == 0:
return self
return Dataset(sort_impl(self._blocks, key, descending))
return Dataset(sort_impl(self._blocks, key, descending), self._epoch)

def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]":
"""Zip this dataset with the elements of another.
Expand Down Expand Up @@ -823,7 +847,7 @@ def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):

# TODO(ekl) it might be nice to have a progress bar here.
metadata = ray.get(metadata)
return Dataset(BlockList(blocks, metadata))
return Dataset(BlockList(blocks, metadata), self._epoch)

def limit(self, limit: int) -> "Dataset[T]":
"""Limit the dataset to the first number of records specified.
Expand Down Expand Up @@ -1623,6 +1647,9 @@ def repeat(self, times: int = None) -> "DatasetPipeline[T]":
Transformations done on the returned pipeline are evaluated on each
loop of the pipeline over the base dataset.
Note that every repeat of the dataset is considered an "epoch" for
the purposes of ``DatasetPipeline.iter_epochs()``.
Examples:
>>> # Infinite pipeline of numbers [0, 5)
>>> ray.data.range(5).repeat().take()
Expand Down Expand Up @@ -1653,6 +1680,7 @@ def __init__(self, ds: "Dataset[T]"):
def __next__(self) -> "Dataset[T]":
if times and self._i >= times:
raise StopIteration
self._ds._set_epoch(self._i)
self._i += 1
return lambda: self._ds

Expand Down Expand Up @@ -1719,8 +1747,9 @@ def window(self, *, blocks_per_window: int = 10) -> "DatasetPipeline[T]":
from ray.data.dataset_pipeline import DatasetPipeline

class Iterator:
def __init__(self, splits):
def __init__(self, splits, epoch):
self._splits = splits.copy()
self._epoch = epoch

def __next__(self) -> "Dataset[T]":
if not self._splits:
Expand All @@ -1729,18 +1758,19 @@ def __next__(self) -> "Dataset[T]":
blocks = self._splits.pop(0)

def gen():
return Dataset(blocks)
return Dataset(blocks, self._epoch)

return gen

class Iterable:
def __init__(self, blocks):
def __init__(self, blocks, epoch):
self._splits = blocks.split(split_size=blocks_per_window)
self._epoch = epoch

def __iter__(self):
return Iterator(self._splits)
return Iterator(self._splits, self._epoch)

it = Iterable(self._blocks)
it = Iterable(self._blocks, self._epoch)
return DatasetPipeline(it, length=len(it._splits))

@DeveloperAPI
Expand Down Expand Up @@ -1791,16 +1821,17 @@ def _split(self, index: int,
right_metadata.append(ray.get(m1))
count += num_rows

left = Dataset(BlockList(left_blocks, left_metadata))
left = Dataset(BlockList(left_blocks, left_metadata), self._epoch)
if return_right_half:
right = Dataset(BlockList(right_blocks, right_metadata))
right = Dataset(
BlockList(right_blocks, right_metadata), self._epoch)
else:
right = None
return left, right

def _divide(self, block_idx: int) -> ("Dataset[T]", "Dataset[T]"):
left, right = self._blocks.divide(block_idx)
return Dataset(left), Dataset(right)
return Dataset(left, self._epoch), Dataset(right, self._epoch)

def __repr__(self) -> str:
schema = self.schema()
Expand Down Expand Up @@ -1840,6 +1871,12 @@ def _get_uuid(self) -> str:
def _set_uuid(self, uuid: str) -> None:
self._uuid = uuid

def _get_epoch(self) -> int:
return self._epoch

def _set_epoch(self, epoch: int) -> None:
self._epoch = epoch


def _get_num_rows(block: Block) -> int:
block = BlockAccessor.for_block(block)
Expand Down
Loading