From b949f1927c867169293b40cd0d1f1b6b82e7bd38 Mon Sep 17 00:00:00 2001 From: Gokul Date: Tue, 21 May 2024 16:05:08 -0700 Subject: [PATCH] Use stateful dataloader to checkpoint data iteration order and token buffer (#279) Summary: Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426. Also make sure the dataloader state has a different key per rank. Test Plan: Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run. Reviewers: @tianyu-l Subscribers: @andrewkho Tasks: Tags: --- .ci/docker/requirements.txt | 2 +- .../workflows/integration_test_periodic.yaml | 1 + .github/workflows/unit_test_4gpu.yaml | 1 + .github/workflows/unit_test_cpu.yaml | 1 + README.md | 1 + pyproject.toml | 2 +- test/__init__.py | 5 ++ test/datasets/__init__.py | 5 ++ test/datasets/test_checkpoint.py | 54 +++++++++++++ torchtitan/checkpoint.py | 3 + torchtitan/datasets/hf_datasets.py | 81 ++++++++++++++++--- train.py | 1 + 12 files changed, 145 insertions(+), 12 deletions(-) create mode 100644 test/datasets/__init__.py create mode 100644 test/datasets/test_checkpoint.py diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index b82120a6..bb21293b 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -1,5 +1,5 @@ torch >= 2.2.0.dev -datasets +datasets >= 2.19.0 tomli >= 1.1.0 ; python_version < "3.11" tensorboard sentencepiece diff --git a/.github/workflows/integration_test_periodic.yaml b/.github/workflows/integration_test_periodic.yaml index bc717cd1..488fc4da 100644 --- a/.github/workflows/integration_test_periodic.yaml +++ b/.github/workflows/integration_test_periodic.yaml @@ -34,6 +34,7 @@ jobs: - name: Install dependencies run: | pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt - name: Run test_runner.py diff --git a/.github/workflows/unit_test_4gpu.yaml b/.github/workflows/unit_test_4gpu.yaml index 5759349d..6f052868 100644 --- a/.github/workflows/unit_test_4gpu.yaml +++ b/.github/workflows/unit_test_4gpu.yaml @@ -31,5 +31,6 @@ jobs: pip config --user set global.progress_bar off python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded diff --git a/.github/workflows/unit_test_cpu.yaml b/.github/workflows/unit_test_cpu.yaml index dd318dbb..2482bd51 100644 --- a/.github/workflows/unit_test_cpu.yaml +++ b/.github/workflows/unit_test_cpu.yaml @@ -25,4 +25,5 @@ jobs: pip config --user set global.progress_bar off pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly pytest test --cov=. --cov-report=xml --durations=20 -vv diff --git a/README.md b/README.md index 21634d0b..a8d1fcc4 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ git clone https://github.com/pytorch/torchtitan cd torchtitan pip install -r requirements.txt pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 +pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly ``` ### Downloading a tokenizer diff --git a/pyproject.toml b/pyproject.toml index 2a8f9557..a5c1b72f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ keywords = ["pytorch", "training", "llm"] dependencies = [ # Hugging Face integrations - "datasets", + "datasets>=2.19.0", # Tokenization "blobfile", diff --git a/test/__init__.py b/test/__init__.py index e69de29b..2e41cd71 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/datasets/__init__.py b/test/datasets/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/test/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/datasets/test_checkpoint.py b/test/datasets/test_checkpoint.py new file mode 100644 index 00000000..6f04dd23 --- /dev/null +++ b/test/datasets/test_checkpoint.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtitan.datasets.hf_datasets import build_hf_data_loader +from torchtitan.datasets.tokenizer import create_tokenizer + + +class TestCheckpoint: + def test_c4_resumption(self): + dataset_name = "c4_mini" + dataset_path = "./torchtitan/datasets/c4_mini" + batch_size = 1 + seq_len = 1024 + world_size = 4 + rank = 0 + + dl = self._build_dataloader( + dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ) + + it = iter(dl) + for _ in range(250): + next(it) + state = dl.state_dict() + expected_input_ids, expected_labels = next(it) + + # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above + dl = self._build_dataloader( + dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ) + dl.load_state_dict(state) + input_ids, labels = next(iter(dl)) + + assert torch.equal(input_ids, expected_input_ids) + assert torch.equal(labels, expected_labels) + + def _build_dataloader( + self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ): + tokenizer_type = "tiktoken" + tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") + return build_hf_data_loader( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + batch_size=1, + seq_len=1024, + world_size=4, + rank=0, + ) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 81bdf592..fb7c41c8 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -22,6 +22,7 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import DataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import init_logger, logger @@ -103,6 +104,7 @@ def __init__( model: nn.Module, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + dataloader: DataLoader, states: Dict[str, Any], job_config: JobConfig, ) -> None: @@ -118,6 +120,7 @@ def __init__( "model": ModelWrapper(model), "optimizer": OptimizerWrapper(model, optimizer), "lr_scheduler": lr_scheduler, + "dataloader": dataloader, } ) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index f6d09faa..d0306663 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -4,10 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +import pickle +from typing import Any, Dict, List, Optional import torch -from torch.utils.data import DataLoader, IterableDataset +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset +from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.datasets.tokenizer import Tokenizer from torchtitan.logging_utils import logger @@ -23,7 +26,7 @@ } -class HuggingFaceDataset(IterableDataset): +class HuggingFaceDataset(IterableDataset, Stateful): """PyTorch Representation of the HuggingFace Dataset. Args: @@ -99,32 +102,90 @@ def __init__( self.seq_len = seq_len self.infinite = infinite + # variables for checkpointing + self._sample_idx = 0 + self._all_tokens: List[int] = [] + def __iter__(self): max_buffer_token_len = 1 + self.seq_len - all_tokens: List[int] = [] while True: - for sample in iter(self._data): + for sample in self._get_data_iter(): sample_text = sample["text"] sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) - all_tokens.extend(sample_tokens) + self._all_tokens.extend(sample_tokens) + self._sample_idx += 1 - while len(all_tokens) >= max_buffer_token_len: - x = torch.LongTensor(all_tokens[:max_buffer_token_len]) + while len(self._all_tokens) >= max_buffer_token_len: + x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) # update tokens to the remaining tokens - all_tokens = all_tokens[max_buffer_token_len:] + self._all_tokens = self._all_tokens[max_buffer_token_len:] input = x[:-1] label = x[1:] yield input, label + if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data.") break else: + # Reset offset for the next iteration + self._sample_idx = 0 logger.warning( f"Dataset {self.dataset_name} is being re-looped. " "Loss related metrics might be misleading." ) + def _get_data_iter(self): + if self._sample_idx == 0: + return iter(self._data) + + # Skip samples + if isinstance(self._data, IterableDataset): + it = iter(self._data) + # Naively iterate through the samples as skip may not be supported + for _ in range(self._sample_idx): + next(it) + return it + + # As skipping to the end throws an error in case of map-style dataset, return an empty iterator + if self._sample_idx == len(self._data): + return iter([]) + return iter(self._data.skip(self._sample_idx)) + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + self._all_tokens = state_dict["token_buffer"] + + def state_dict(self): + return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} + + +class DPAwareDataLoader(StatefulDataLoader, Stateful): + """ + A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank. + """ + + def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int): + super().__init__(hf_ds, batch_size) + self._dp_rank = dp_rank + self._rank_id = f"dp_rank_{dp_rank}" + + def state_dict(self) -> Dict[str, Any]: + # Store state only for dp rank to avoid replicating the same state across other dimensions + return {self._rank_id: pickle.dumps(super().state_dict())} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # State being empty is valid, don't log a warning + if not state_dict: + return + + if self._rank_id not in state_dict: + logger.warning( + f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}." + ) + return + super().load_state_dict(pickle.loads(state_dict[self._rank_id])) + def build_hf_data_loader( dataset_name: str, @@ -140,4 +201,4 @@ def build_hf_data_loader( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return DataLoader(hf_ds, batch_size=batch_size) + return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size) diff --git a/train.py b/train.py index 318c7174..a0bb337e 100644 --- a/train.py +++ b/train.py @@ -245,6 +245,7 @@ def loss_fn(pred, labels): model=model, optimizer=optimizer, lr_scheduler=scheduler, + dataloader=data_loader, states={"train_state": train_state}, job_config=job_config, )