From 5f825c7dea81563b0556ffe0720db3770ce25183 Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 21 May 2024 12:56:13 -0700 Subject: [PATCH] store state only for dp rank Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- ...taset_checkpoint.py => test_checkpoint.py} | 20 +++++------ torchtitan/checkpoint.py | 33 ++----------------- torchtitan/datasets/hf_datasets.py | 32 ++++++++++++++++-- 3 files changed, 41 insertions(+), 44 deletions(-) rename test/datasets/{test_dataset_checkpoint.py => test_checkpoint.py} (74%) diff --git a/test/datasets/test_dataset_checkpoint.py b/test/datasets/test_checkpoint.py similarity index 74% rename from test/datasets/test_dataset_checkpoint.py rename to test/datasets/test_checkpoint.py index e4be71d6..6f04dd23 100644 --- a/test/datasets/test_dataset_checkpoint.py +++ b/test/datasets/test_checkpoint.py @@ -5,12 +5,11 @@ # LICENSE file in the root directory of this source tree. import torch -from torchtitan.checkpoint import DataLoaderWrapper from torchtitan.datasets.hf_datasets import build_hf_data_loader from torchtitan.datasets.tokenizer import create_tokenizer -class TestDatasetCheckpoint: +class TestCheckpoint: def test_c4_resumption(self): dataset_name = "c4_mini" dataset_path = "./torchtitan/datasets/c4_mini" @@ -19,32 +18,32 @@ def test_c4_resumption(self): world_size = 4 rank = 0 - dl_wrapper = self._create_dataloader_wrapper( + dl = self._build_dataloader( dataset_name, dataset_path, batch_size, seq_len, world_size, rank ) - it = iter(dl_wrapper.dataloader) + it = iter(dl) for _ in range(250): next(it) - state = dl_wrapper.state_dict() + state = dl.state_dict() expected_input_ids, expected_labels = next(it) # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above - dl_wrapper = self._create_dataloader_wrapper( + dl = self._build_dataloader( dataset_name, dataset_path, batch_size, seq_len, world_size, rank ) - dl_wrapper.load_state_dict(state) - input_ids, labels = next(iter(dl_wrapper.dataloader)) + dl.load_state_dict(state) + input_ids, labels = next(iter(dl)) assert torch.equal(input_ids, expected_input_ids) assert torch.equal(labels, expected_labels) - def _create_dataloader_wrapper( + def _build_dataloader( self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank ): tokenizer_type = "tiktoken" tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") - dataloader = build_hf_data_loader( + return build_hf_data_loader( dataset_name=dataset_name, dataset_path=dataset_path, tokenizer=tokenizer, @@ -53,4 +52,3 @@ def _create_dataloader_wrapper( world_size=4, rank=0, ) - return DataLoaderWrapper(dataloader) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 61dce0c2..692a3183 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -6,7 +6,6 @@ import enum import os -import pickle import re import time from multiprocessing import get_context @@ -23,9 +22,8 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful -from torch.utils.data import DataLoader -from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torch.utils.data import DataLoader from torchtitan.logging_utils import init_logger, logger @@ -63,33 +61,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict) -class DataLoaderWrapper(Stateful): - def __init__(self, dataloader: DataLoader) -> None: - self.dataloader = dataloader - # Use global rank for now even though dataloader state could be same across dp groups - self.rank_id = str( - dist.get_rank() if (dist.is_available() and dist.is_initialized()) else 0 - ) - - def state_dict(self) -> Dict[str, Any]: - if isinstance(self.dataloader, StatefulDataLoader): - return {self.rank_id: pickle.dumps(self.dataloader.state_dict())} - return {} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - if isinstance(self.dataloader, StatefulDataLoader): - # State is empty - if not state_dict: - return - - if self.rank_id not in state_dict: - logger.warning(f"DataLoader state is empty for rank {self.rank_id}. ") - return - - # Load state for the current rank - self.dataloader.load_state_dict(pickle.loads(state_dict[self.rank_id])) - - class Terminate: pass @@ -149,7 +120,7 @@ def __init__( "model": ModelWrapper(model), "optimizer": OptimizerWrapper(model, optimizer), "lr_scheduler": lr_scheduler, - "dataloader": DataLoaderWrapper(dataloader), + "dataloader": dataloader, } ) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index a98c9467..c8e115c5 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +import pickle +from typing import Any, Dict, List, Optional import torch from torch.distributed.checkpoint.stateful import Stateful @@ -159,6 +160,33 @@ def state_dict(self): return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} +class DpAwareDataLoader(StatefulDataLoader): + """ + A wrapper around the StatefulDataLoader that ensures that the state is stored only once for DP ranks. + """ + + def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int): + super().__init__(hf_ds, batch_size) + self._dp_rank = dp_rank + self._rank_id = f"dp_rank_{dp_rank}" + + def state_dict(self) -> Dict[str, Any]: + # Store state only for dp rank to avoid replicating the same state across other dimensions + return {self._rank_id: pickle.dumps(super().state_dict())} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # State being empty is valid, don't log a warning + if not state_dict: + return + + if self._rank_id not in state_dict: + logger.warning( + f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}." + ) + return + super().load_state_dict(pickle.loads(state_dict[self._rank_id])) + + def build_hf_data_loader( dataset_name: str, dataset_path: Optional[str], @@ -173,4 +201,4 @@ def build_hf_data_loader( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return StatefulDataLoader(hf_ds, batch_size=batch_size) + return DpAwareDataLoader(rank, hf_ds, batch_size=batch_size)