Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IterableDataset sharding logic needs improvement #6594

Open
rwightman opened this issue Jan 15, 2024 · 1 comment
Open

IterableDataset sharding logic needs improvement #6594

rwightman opened this issue Jan 15, 2024 · 1 comment

Comments

@rwightman
Copy link

rwightman commented Jan 15, 2024

Describe the bug

The sharding of IterableDatasets with respect to distributed and dataloader worker processes appears problematic with significant performance traps and inconsistencies wrt to distributed train processes vs worker processes.

Splitting across num_workers (per train process loader processes) and world_size (distributed training processes) appears inconsistent.

  • worker split:
    if self._is_main_process() and ex_iterable.n_shards < worker_info.num_workers:
    logger.warning(
    f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={ex_iterable.n_shards}). "
    f"Stopping {worker_info.num_workers - ex_iterable.n_shards} dataloader workers."
    )
    logger.info(
    f"To parallelize data loading, we give each process some shards (or data sources) to process. "
    f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={ex_iterable.n_shards}. "
    f"To enable more parallelism, please split the dataset in more files than {ex_iterable.n_shards}."
    )
    # split workload
    _log_prefix = f"node#{self._distributed.rank} " if self._distributed else ""
    shards_indices = self._ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers)
    if shards_indices:
    logger.debug(
    f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards."
    )
    ex_iterable = ex_iterable.shard_data_sources(worker_id=worker_info.id, num_workers=worker_info.num_workers)
  • distributed split:
    if self._distributed:
    rank = self._distributed.rank
    world_size = self._distributed.world_size
    if ex_iterable.n_shards % world_size == 0:
    if self._is_main_process():
    n_shards_per_node = ex_iterable.n_shards // world_size
    plural = "s" if n_shards_per_node > 1 else ""
    logger.info(
    f"Assigning {n_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node."
    )
    ex_iterable = ex_iterable.shard_data_sources(rank, world_size)
    else:
    if self._is_main_process():
    logger.info(
    f"Assigning 1 out of {world_size} examples of the dataset to each node. The others are skipped during the iteration."
    )
    logger.info(
    f"It is more optimized to distribute the dataset shards (or data sources) across nodes. "
    f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. "
    f"The current dataset has {ex_iterable.n_shards} which is not a factor of {world_size}"
    )
    ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank)

In the case of the distributed split, there is a modulus check that flips between two very different behaviours, why is this different than splitting across the data loader workers? For IterableDatasets the DataLoaders worker processes are independent, so whether it's workers within one train process or across a distributed world the shards should be distributed the same, across world_size * num_worker independent workers in either case...

Further, the fallback case when the n_shards % world_size == 0 check fails is a rather extreme change. I argue it is not desirable to do that implicitly, it should be an explicit case for specific scenarios (ie reliable validation). A train scenario would likely be much better handled with improved wrapping / stopping behaviour to eg also fix #6437. Changing from stepping shards to stepping samples means that every single process reads ALL of the shards. This was never an intended default for sharded training, shards gain their performance advantage in large scale distributed training by explicitly avoiding the need to have every process overlapping in the data they read, by default, only the data allocated to each process via their assigned shards should be read in each pass of the dataset.

Using a large scale CLIP example, some of the larger datasets have 10-20k shards across 100+TB of data. Training with 1000 GPUs we are switching between reading 100 terabytes per epoch to 100 petabytes if say change 20k % 1000 and drop one gpu-node to 20k % 992.

The 'step over samples' case might be worth the overhead in specific validation scenarios where gaurantees of at least/most once samples seen are more important and do not make up a significant portion of train time or are done in smaller world sizes outside of train.

Steps to reproduce the bug

N/A

Expected behavior

We have an iterable dataset with N shards, to split across workers

  • shuffle shards (same seed across all train processes)
  • step shard iterator across distributed processes
  • step shard iterator across dataloader worker processes
  • shuffle samples in every worker via shuffle buffer (different seed in each worker, but ideally controllable (based on base seed + worker id + epoch).
  • end up with (possibly uneven) number of shards per worker but each shard only ever accessed by 1 worker per pass (epoch)

Environment info

N/A

@JohnHerry
Copy link

I do not know is it the same probelm as mine. I think the num_workers should a value of process number for one dataloader mapped to one card, or the total number of processes for all multiple cards.
but when I set the num_workers larger then the count of training split files, it will report num_workers > n_shards and kill all workers over. as a result, only n_shards workers left, where n_shard = total files count / total cards
Is that means the num_workers should be the process number on one card? ok, I changed the num_workers lower, to view it as the number of loader process for one card, but this time, the data loading is still very slow, it seems that only num_workers dataloader process are working, not the num_workers * n_cards as I thought.
So how to set a good parameter to make good dataloading?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants