Skip to content

Commit

Permalink
Fix uneven batches in distributed dataloading (#237)
Browse files Browse the repository at this point in the history
This was a big one!
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
awaelchli authored Jul 19, 2024
1 parent 7bad26f commit c58b673
Show file tree
Hide file tree
Showing 10 changed files with 502 additions and 308 deletions.
2 changes: 1 addition & 1 deletion src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints

os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
shutil.move(local_filepath, output_filepath)
shutil.copy(local_filepath, output_filepath)
else:
raise ValueError(f"The provided {output_dir.path} isn't supported.")

Expand Down
11 changes: 11 additions & 0 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def set_shuffle(self, shuffle: bool) -> None:
for dataset in self._datasets:
dataset.set_shuffle(shuffle)

def set_batch_size(self, batch_size: int) -> None:
"""Set the current batch size to the datasets."""
self.batch_size = batch_size
for dataset in self._datasets:
dataset.set_batch_size(batch_size)

def set_num_workers(self, num_workers: int) -> None:
"""Set the current number of workers to the datasets."""
for dataset in self._datasets:
dataset.set_num_workers(num_workers)

def set_drop_last(self, drop_last: bool) -> None:
"""Set the current drop_last to the datasets."""
for dataset in self._datasets:
Expand Down
5 changes: 4 additions & 1 deletion src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def __init__(
profile_dir: Optional[str] = None,
prefetch_factor: Optional[int] = None,
shuffle: Optional[bool] = None,
drop_last: Optional[bool] = False,
drop_last: Optional[bool] = None,
collate_fn: Optional[Callable] = None,
**kwargs: Any,
) -> None: # pyright: ignore
Expand All @@ -571,6 +571,9 @@ def __init__(
if drop_last is not None:
dataset.set_drop_last(drop_last)

dataset.set_batch_size(batch_size)
dataset.set_num_workers(num_workers)

shuffle = None

if profile_batches and not _VIZ_TRACKER_AVAILABLE:
Expand Down
123 changes: 64 additions & 59 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset
from litdata.utilities.encryption import Encryption
from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion
from litdata.utilities.shuffle import (
_find_chunks_per_workers_on_which_to_skip_deletion,
_map_node_worker_rank_to_chunk_indexes_to_not_delete,
)

logger = Logger(__name__)

Expand Down Expand Up @@ -120,8 +123,10 @@ def __init__(
self.shuffler: Optional[Shuffle] = None
self.serializers = serializers
self._state_dict: Optional[Dict[str, Any]] = None
self.num_workers: Optional[int] = None
self.batch_size: Optional[int] = None
# Has slightly different meaning in the context of the dataset
# We consider `num_workers = 0` from `torch.utils.DataLoader` still as 1 worker (the main process)
self.num_workers: int = 1
self.batch_size: int = 1
self._encryption = encryption

def set_shuffle(self, shuffle: bool) -> None:
Expand Down Expand Up @@ -179,7 +184,13 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)

def __len__(self) -> int:
return self.get_len(1, 1)
return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)

def set_batch_size(self, batch_size: int) -> None:
self.batch_size = batch_size

def set_num_workers(self, num_workers: int) -> None:
self.num_workers = num_workers or 1

def get_len(self, num_workers: int, batch_size: int) -> int:
self.num_workers = num_workers
Expand All @@ -205,35 +216,46 @@ def __iter__(self) -> "StreamingDataset":
state: Dict[str, Any] = self._state_dict
self.current_epoch = state["current_epoch"]

chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch
workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_workers(
self.distributed_env, self.worker_env.world_size, self.batch_size, self.current_epoch
)
chunks_replica = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
intervals_replica = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]

worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
self.worker_chunks = workers_chunks[worker_rank]
self.worker_intervals = workers_intervals[worker_rank]

# The max number of samples to return from `__next__` (in worker)
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)

# Handle restart
if self._state_dict:
self._resume(chunks_replica, intervals_replica)
self._resume(workers_chunks, workers_intervals)
else:
# Find the chunks shared across multiple ranks.
# For each shared chunk, find the rank to use the chunk last and prevent deletion
# for the other ranks.
chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion(
self.worker_env.world_size, chunks_per_replica, intervals_per_replica
# Find the chunks shared across all workers of the current node.
# For each shared chunk, find the rank and worker to use the chunk last and prevent
# premature deletion for the other workers.
node_size = self.distributed_env.world_size // self.distributed_env.num_nodes
first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size
num_workers_per_node = node_size * self.num_workers
worker_start = first_rank_this_node * num_workers_per_node
worker_end = worker_start + num_workers_per_node
local_rank = self.distributed_env.global_rank % node_size

chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion(
self.num_workers,
self.batch_size,
workers_chunks[worker_start:worker_end],
workers_intervals[worker_start:worker_end],
)
if self.distributed_env.global_rank in chunks_indexes_skip_deletion:
self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[
self.distributed_env.global_rank
]

workers_chunks, workers_intervals = _associate_chunks_to_workers(
self.worker_env,
chunks_per_replica[self.distributed_env.global_rank],
intervals_per_replica[self.distributed_env.global_rank],
worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete(
chunks_indexes_skip_deletion
)

self.worker_chunks = workers_chunks[self.worker_env.rank]
self.worker_intervals = workers_intervals[self.worker_env.rank]
worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank
if worker_rank_local_node in worker_node_rank_to_chunk_indexes:
self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[
worker_rank_local_node
]

self.num_chunks = len(self.worker_chunks)
self.current_indexes = []
Expand All @@ -246,7 +268,7 @@ def __iter__(self) -> "StreamingDataset":

return self

def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> None:
def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) -> None:
assert self._state_dict
assert self.worker_env
assert self.shuffler
Expand All @@ -259,17 +281,22 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
# TODO: Implement elastic sampling where the number of workers, ranks can change.
num_samples_yielded = self._state_dict["num_samples_yielded"]

worker_start = self.distributed_env.global_rank * num_workers
worker_end = worker_start + num_workers

# replay sampling from each worker / chunks using the batch size
workers_chunks, workers_intervals = _associate_chunks_to_workers(
self.worker_env, chunks_replica, intervals_replica
)
indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers)
chunks_index, indexes = _replay_chunks_sampling(workers_intervals, indexes)
chunks_index, indexes = _replay_chunks_sampling(
workers_intervals={i: workers_intervals[i] for i in range(worker_start, worker_end)},
indexes=indexes,
)

# select the chunks and intervals associated to this worker
worker_rank = self.worker_env.rank
worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
worker_local_rank = self.worker_env.rank

self.num_chunks = len(workers_intervals[worker_rank])
self.chunk_index = chunks_index[worker_rank]
self.chunk_index = chunks_index[worker_local_rank]
self.worker_chunks = workers_chunks[worker_rank]
self.worker_intervals = workers_intervals[worker_rank]

Expand All @@ -281,10 +308,10 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index)

# skip any indexes already consumed
current_indexes = current_indexes[indexes[worker_rank] :]
current_indexes = current_indexes[indexes[worker_local_rank] :]
self.current_indexes = current_indexes

self.global_index = num_samples_yielded
self.global_index = indexes[worker_local_rank]

# bump the chunk_index
self.chunk_index += 1
Expand All @@ -305,7 +332,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:

def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.global_index >= len(self):
if self.global_index >= self.stop_length:
self.current_epoch += 1
raise StopIteration

Expand Down Expand Up @@ -454,8 +481,8 @@ def reset(self) -> None:
"random_state": None,
"shuffler": None,
"_state_dict": None,
"num_workers": None,
"batch_size": None,
"num_workers": 1,
"batch_size": 1,
}

for prop, value in default_properties.items():
Expand All @@ -470,28 +497,6 @@ def is_integer(value: str) -> bool:
return False


def _associate_chunks_to_workers(
worker_env: _WorkerEnv, chunks_replica: List[int], intervals_replica: List[Any]
) -> Any:
workers_chunks = {}
workers_intervals = {}

for worker_idx in range(worker_env.world_size):
worker_chunks = []
worker_intervals = []
for i, (chunk_index, chunk_interval) in enumerate(zip(chunks_replica, intervals_replica)):
if i % worker_env.world_size != worker_idx:
continue

worker_chunks.append(chunk_index)
worker_intervals.append(chunk_interval)

workers_chunks[worker_idx] = worker_chunks
workers_intervals[worker_idx] = worker_intervals

return workers_chunks, workers_intervals


def _replay_sampling(num_samples_yielded: int, batch_size: int, num_workers: int) -> Dict[int, int]:
"""This function replays the sampling from the dataloader."""
divisible_num_batches_yielded = num_samples_yielded // (num_workers * batch_size)
Expand Down
46 changes: 22 additions & 24 deletions src/litdata/streaming/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

from litdata.streaming import Cache
from litdata.utilities.env import _DistributedEnv
from litdata.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle
from litdata.utilities.shuffle import (
_associate_chunks_and_intervals_to_workers,
_intra_node_chunk_shuffle,
)


class Shuffle(ABC):
Expand All @@ -32,23 +35,19 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool):

@lru_cache(maxsize=10)
def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int:
_, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks(
_, workers_intervals = self.get_chunks_and_intervals_per_workers(
distributed_env, num_workers, batch_size, current_epoch
)

if self.drop_last:
items_per_process = [
sum((interval[2] - interval[1]) for interval in intervals) for intervals in intervals_per_ranks
]
# Validate each processes gets the exact number of elements
if len(items_per_process) > 1:
assert all(items_per_process[0] == items_to_process for items_to_process in items_per_process[:1])
return items_per_process[0]

return sum((interval[2] - interval[1]) for interval in intervals_per_ranks[distributed_env.global_rank])
worker_start = distributed_env.global_rank * num_workers
worker_end = worker_start + num_workers
return sum(
(interval[2] - interval[1])
for intervals in workers_intervals[worker_start:worker_end]
for interval in intervals
)

@abstractmethod
def get_chunks_and_intervals_per_ranks(
def get_chunks_and_intervals_per_workers(
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
) -> Any:
pass
Expand All @@ -63,19 +62,18 @@ class NoShuffle(Shuffle):
is True."""

@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_ranks(
def get_chunks_and_intervals_per_workers(
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
) -> Any:
# 1. Get the intervals
chunk_intervals = self.cache.get_chunk_intervals()
indexes = range(len(chunk_intervals))

# 2. Compute the items budget of each rank
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size
)

return chunks_per_ranks, intervals_per_ranks
return workers_chunks, workers_intervals

def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
return array.tolist()
Expand All @@ -100,7 +98,7 @@ class FullShuffle(Shuffle):
"""

@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_ranks(
def get_chunks_and_intervals_per_workers(
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
) -> Any:
# 1. Get the intervals
Expand All @@ -120,24 +118,24 @@ def get_chunks_and_intervals_per_ranks(
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()

# 3. Compute the items budget of each rank
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size
)

# For the first epoch, no need of further shuffling
if current_epoch == 1 or distributed_env.num_nodes == 1:
return chunks_per_ranks, intervals_per_ranks
return workers_chunks, workers_intervals

# Perform shuffle within the nodes to avoid cache miss.
# Note: It is possible for the overlapping chunks to change due to the changing order.
shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, chunks_per_ranks, self.seed, current_epoch)
shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, workers_chunks, self.seed, current_epoch)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()

chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size
)

return chunks_per_ranks, intervals_per_ranks
return workers_chunks, workers_intervals

def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist()
Loading

0 comments on commit c58b673

Please sign in to comment.