Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IterableDataset.shard() #7252

Merged
merged 8 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/about_mapstyle_vs_iterable.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ But using a shuffle buffer is not enough to provide a satisfactory shuffling for
```python
# Stream from the internet
my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True)
my_iterable_dataset.n_shards # 39
my_iterable_dataset.num_shards # 39

# Stream from local files
data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024

# From a generator function
def my_generator(n, sources):
Expand All @@ -154,7 +154,7 @@ def my_generator(n, sources):

gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024
```

## Speed differences
Expand Down Expand Up @@ -242,5 +242,5 @@ my_iterable_dataset = my_dataset.to_iterable_dataset()
If you want to shuffle your dataset or [use it with a PyTorch DataLoader](./use_with_pytorch#stream-data), we recommend generating a sharded [`IterableDataset`]:
```python
my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024
```
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
- batch
- skip
- take
- shard
- load_state_dict
- state_dict
- info
Expand Down
29 changes: 29 additions & 0 deletions docs/source/stream.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,35 @@ You can split your dataset one of two ways:

<a id='interleave_datasets'></a>


### Shard

🤗 Datasets supports sharding to divide a very large dataset into a predefined number of chunks. Specify the `num_shards` parameter in [`~IterableDataset.shard`] to determine the number of shards to split the dataset into. You'll also need to provide the shard you want to return with the `index` parameter.

For example, the [amazon_polarity](https://huggingface.co/datasets/amazon_polarity) dataset has 4 shards (in this case they are 4 Parquet files):

```py
>>> from datasets import load_dataset
>>> dataset = load_dataset("amazon_polarity", split="train", streaming=True)
>>> print(dataset)
IterableDataset({
features: ['label', 'title', 'content'],
num_shards: 4
})
```

After sharding the dataset into two chunks, the first one will only have 2 shards:

```py
>>> dataset.shard(num_shards=2, index=0)
IterableDataset({
features: ['label', 'title', 'content'],
num_shards: 2
})
```

If your dataset has `dataset.num_shards==1`, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.

## Interleave

[`interleave_datasets`] can combine an [`IterableDataset`] with other datasets. The combined dataset returns alternating examples from each of the original datasets.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi

```py
>>> my_iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train")
>>> my_iterable_dataset.n_shards
>>> my_iterable_dataset.num_shards
39
>>> dataloader = DataLoader(my_iterable_dataset, batch_size=32, num_workers=4)
```
Expand Down Expand Up @@ -259,7 +259,7 @@ Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of t

For iterable datasets:

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.

Expand Down
24 changes: 12 additions & 12 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _get_output_signature(
tf_dtype = tf.float32
np_dtype = np.float32
elif np_arrays[0].dtype.kind == "U": # Unicode strings
np_dtype = np.unicode_
np_dtype = np.str_
tf_dtype = tf.string
else:
raise RuntimeError(
Expand Down Expand Up @@ -4630,40 +4630,40 @@ def shard(
self,
num_shards: int,
index: int,
contiguous: bool = False,
contiguous: bool = True,
keep_in_memory: bool = False,
indices_cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
) -> "Dataset":
"""Return the `index`-nth shard from dataset split into `num_shards` pieces.

This shards deterministically. `dset.shard(n, i)` will contain all elements of dset whose
index mod `n = i`.
This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks,
so it can be easily concatenated back together after processing. If `len(dataset) % n == l`, then the
first `l` dataset each have length `(len(dataset) // n) + 1`, and the remaining dataset have length `(len(dataset) // n)`.
`datasets.concatenate_datasets([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original.

`dset.shard(n, i, contiguous=True)` will instead split dset into contiguous chunks,
so it can be easily concatenated back together after processing. If `n % i == l`, then the
first `l` shards will have length `(n // i) + 1`, and the remaining shards will have length `(n // i)`.
`datasets.concatenate([dset.shard(n, i, contiguous=True) for i in range(n)])` will return
a dataset with the same order as the original.
Note: n should be less or equal to the number of elements in the dataset `len(dataset)`.

On the other hand, `dataset.shard(n, i, contiguous=False)` contains all elements of the dataset whose index mod `n = i`.

Be sure to shard before using any randomizing operator (such as `shuffle`).
It is best if the shard operator is used early in the dataset pipeline.


Args:
num_shards (`int`):
How many shards to split the dataset into.
index (`int`):
Which shard to select and return.
contiguous: (`bool`, defaults to `False`):
contiguous: (`bool`, defaults to `True`):
Whether to select contiguous blocks of indices for shards.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
indices_cache_file_name (`str`, *optional*):
Provide the name of a path for the cache file. It is used to store the
indices of each shard instead of the automatically generated cache file name.
writer_batch_size (`int`, defaults to `1000`):
Number of rows per write operation for the cache file writer.
This only concerns the indices mapping.
Number of indices per write operation for the cache file writer.
This value is a good trade-off between memory usage during the processing, and processing speed.
Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running `map`.

Expand Down
2 changes: 1 addition & 1 deletion src/datasets/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D

For iterable datasets:

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.

Expand Down
Loading
Loading