From 4f873d099be5fca843f4a8d0b01413cac43a2538 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 24 Nov 2023 11:57:53 -0800 Subject: [PATCH 1/3] Add ability to restart on new epoch You can set the epoch via the option `--epoch=[INTEGER]`. This automatically handles changing the data order each epoch by setting the data seed to `seed + epoch`. So `--epoch` is the only flag you need to set when restarting on a new epoch. Everything else in the config can stay the same. Note that we count epochs starting from 0. So to start the 2nd epoch you would add the flag `--epoch=1`. --- olmo/config.py | 5 +++++ olmo/data/__init__.py | 2 +- olmo/train.py | 43 ++++++++++++++++++++++++++++--------------- scripts/train.py | 1 + 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index 4bed7b7bb..6f4abb50e 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -659,6 +659,11 @@ class TrainConfig(BaseConfig): Used to seed all initial RNG states. """ + epoch: int = 0 + """ + Increment this when starting a new epoch. + """ + dry_run: bool = False """ If ``True``, don't actually train. diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index 1c4a5a884..c78fd4b48 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -93,7 +93,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader: IterableDataset( dataset, # type: ignore train_config.global_train_batch_size, - seed=train_config.seed, + seed=train_config.seed + train_config.epoch, shuffle=True, drop_last=train_config.data.drop_last, max_examples=train_config.global_train_batch_size * train_config.max_duration, diff --git a/olmo/train.py b/olmo/train.py index a68cd5599..4524ac7a3 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -101,12 +101,15 @@ class Trainer: train_loader: DataLoader device: torch.device evaluators: List[Evaluator] + epoch: int = 0 global_step: int = 0 global_data_step: int = 0 """This is now redundant since adding 'global_train_examples_seen'.""" + global_train_examples_seen_this_epoch: int = 0 + """Tracks the global number of training examples seen in the current epoch for the purpose of restoring + the data loader position on restarts.""" global_train_examples_seen: int = 0 - """Tracks the global number of training examples seen for the purpose of restoring the dataset - position on restarts.""" + """Tracks the global number of training examples throughout training.""" global_train_tokens_seen: int = 0 """Tracks the global total number of tokens trained on.""" checkpoints: List[Path] = field(default_factory=list) @@ -118,8 +121,10 @@ class Trainer: def trainer_state_dict(self) -> Dict[str, Any]: return { + "epoch": self.epoch, "global_step": self.global_step, "global_data_step": self.global_data_step, + "global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch, "global_train_examples_seen": self.global_train_examples_seen, "global_train_tokens_seen": self.global_train_tokens_seen, "world_size": get_world_size(), @@ -147,40 +152,47 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: ] # Dataset / dataloader position. + checkpoint_epoch = state_dict.get("epoch", 0) self.global_step = state_dict["global_step"] self.global_data_step = state_dict["global_data_step"] self.global_train_examples_seen = state_dict.get( # newer addition "global_train_examples_seen", self.global_data_step * self.cfg.global_train_batch_size ) + self.global_train_examples_seen_this_epoch = state_dict.get( + "global_train_examples_seen_this_epoch", + self.global_train_examples_seen, + ) self.global_train_tokens_seen = state_dict.get( # newer addition "global_train_tokens_seen", self.global_data_step * self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length, ) + if not self.cfg.restore_dataloader: + self.epoch = 0 self.global_data_step = 0 self.global_train_examples_seen = 0 self.global_train_tokens_seen = 0 - elif self.cfg.fast_forward_batches: + self.global_train_examples_seen_this_epoch = 0 + elif checkpoint_epoch != self.epoch: + log.info(f"Starting new epoch (epoch = {self.epoch})") + self.global_train_examples_seen_this_epoch = 0 + + if self.cfg.fast_forward_batches: + log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps") self.global_data_step += self.cfg.fast_forward_batches # Technically we don't "see" these batches that we fast-forward through, but we use # this variable to update the position of the dataset so we need to include them here. self.global_train_examples_seen += self.cfg.fast_forward_batches * self.cfg.global_train_batch_size + self.global_train_examples_seen_this_epoch += ( + self.cfg.fast_forward_batches * self.cfg.global_train_batch_size + ) # NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because # that variable is meant to track the actual number of tokens trained on. - if self.global_data_step > 0: - if self.global_data_step > self.global_step: - log.info( - f"Fast-forwarding data loader to step {self.global_step:,d}+{self.global_data_step-self.global_step:,d} " - f"({self.global_train_examples_seen:,d} examples)" - ) - else: - log.info( - f"Fast-forwarding data loader to step {self.global_data_step:,d} " - f"({self.global_train_examples_seen:,d} examples)" - ) + if self.global_train_examples_seen_this_epoch > 0: assert isinstance(self.train_loader.dataset, IterableDataset) - self.train_loader.dataset.start_index = self.global_train_examples_seen + log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}") + self.train_loader.dataset.start_index = self.global_train_examples_seen_this_epoch # Reset learning rate and weight decay to the values from the config, not the checkpoint. log.info("Resetting learning rate...") @@ -790,6 +802,7 @@ def on_trace_ready(p): global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks self.global_step += 1 self.global_data_step += 1 + self.global_train_examples_seen_this_epoch += global_batch_size self.global_train_examples_seen += global_batch_size self.global_train_tokens_seen += global_batch_size * seq_len speed_monitor.batch_start( diff --git a/scripts/train.py b/scripts/train.py index 951664388..6d3b3652f 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -165,6 +165,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: # Consolidate components into `Trainer` object. with Trainer( cfg=cfg, + epoch=cfg.epoch, model=olmo_model, fsdp_model=fsdp_model, optim=optim, From 4d6e61ca00f05673b8c639a3c50cd8a5fe777a10 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 27 Nov 2023 10:31:31 -0800 Subject: [PATCH 2/3] Remove redundant `Trainer` fields `global_train_examples_seen` and `global_data_step` no longer needed --- olmo/train.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index 4524ac7a3..fb1f57dbe 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -103,13 +103,9 @@ class Trainer: evaluators: List[Evaluator] epoch: int = 0 global_step: int = 0 - global_data_step: int = 0 - """This is now redundant since adding 'global_train_examples_seen'.""" global_train_examples_seen_this_epoch: int = 0 """Tracks the global number of training examples seen in the current epoch for the purpose of restoring the data loader position on restarts.""" - global_train_examples_seen: int = 0 - """Tracks the global number of training examples throughout training.""" global_train_tokens_seen: int = 0 """Tracks the global total number of tokens trained on.""" checkpoints: List[Path] = field(default_factory=list) @@ -123,9 +119,7 @@ def trainer_state_dict(self) -> Dict[str, Any]: return { "epoch": self.epoch, "global_step": self.global_step, - "global_data_step": self.global_data_step, "global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch, - "global_train_examples_seen": self.global_train_examples_seen, "global_train_tokens_seen": self.global_train_tokens_seen, "world_size": get_world_size(), "checkpoints": self.checkpoints, @@ -154,23 +148,22 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: # Dataset / dataloader position. checkpoint_epoch = state_dict.get("epoch", 0) self.global_step = state_dict["global_step"] - self.global_data_step = state_dict["global_data_step"] - self.global_train_examples_seen = state_dict.get( # newer addition - "global_train_examples_seen", self.global_data_step * self.cfg.global_train_batch_size - ) self.global_train_examples_seen_this_epoch = state_dict.get( "global_train_examples_seen_this_epoch", - self.global_train_examples_seen, + state_dict.get( # for backwards compatibility + "global_train_examples_seen", + state_dict.get("global_data_step", 0) * self.cfg.global_train_batch_size, + ), ) - self.global_train_tokens_seen = state_dict.get( # newer addition + self.global_train_tokens_seen = state_dict.get( "global_train_tokens_seen", - self.global_data_step * self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length, + state_dict.get("global_data_step", 0) # for backwards compatibility + * self.cfg.global_train_batch_size + * self.cfg.model.max_sequence_length, ) if not self.cfg.restore_dataloader: self.epoch = 0 - self.global_data_step = 0 - self.global_train_examples_seen = 0 self.global_train_tokens_seen = 0 self.global_train_examples_seen_this_epoch = 0 elif checkpoint_epoch != self.epoch: @@ -179,10 +172,8 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.cfg.fast_forward_batches: log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps") - self.global_data_step += self.cfg.fast_forward_batches # Technically we don't "see" these batches that we fast-forward through, but we use # this variable to update the position of the dataset so we need to include them here. - self.global_train_examples_seen += self.cfg.fast_forward_batches * self.cfg.global_train_batch_size self.global_train_examples_seen_this_epoch += ( self.cfg.fast_forward_batches * self.cfg.global_train_batch_size ) @@ -801,9 +792,7 @@ def on_trace_ready(p): assert batch_size == self.cfg.device_train_batch_size global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks self.global_step += 1 - self.global_data_step += 1 self.global_train_examples_seen_this_epoch += global_batch_size - self.global_train_examples_seen += global_batch_size self.global_train_tokens_seen += global_batch_size * seq_len speed_monitor.batch_start( self.global_train_tokens_seen, From b8ca94d1ed3472e434bb42a0e8d9774d828c5d1c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 27 Nov 2023 12:57:15 -0800 Subject: [PATCH 3/3] default to global_step --- olmo/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index fb1f57dbe..f2db74dbe 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -152,12 +152,12 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: "global_train_examples_seen_this_epoch", state_dict.get( # for backwards compatibility "global_train_examples_seen", - state_dict.get("global_data_step", 0) * self.cfg.global_train_batch_size, + state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size, ), ) self.global_train_tokens_seen = state_dict.get( "global_train_tokens_seen", - state_dict.get("global_data_step", 0) # for backwards compatibility + state_dict.get("global_data_step", self.global_step) # for backwards compatibility * self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length, )