Skip to content

Commit

Permalink
Fix consumed_samples which was off by one batch
Browse files Browse the repository at this point in the history
Signed-off-by: Olivier Delalleau <[email protected]>
  • Loading branch information
odelalleau committed Aug 12, 2023
1 parent ada4fe5 commit e6fda0f
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,10 @@ def compute_consumed_samples(self, steps_since_resume=0):
)
return int(consumed_samples)

def _compute_consumed_samples_after_training_step(self):
# Add +1 to account for the current batch, which is not counted yet in `trainer.global_step`.
return self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step)

def _extract_consumed_samples_from_ckpt(self, ckpt_path):
try:
init_consumed_samples = int(float(re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", ckpt_path)[0]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,7 @@ def training_step(self, dataloader_iter, batch_idx):
self.log('lr', lr, batch_size=1)
self.log('global_step', self.trainer.global_step, prog_bar=True, batch_size=1)
self.log(
'consumed_samples',
self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
prog_bar=True,
batch_size=1,
'consumed_samples', self._compute_consumed_samples_after_training_step(), prog_bar=True, batch_size=1,
)

return loss_mean[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def training_step(self, dataloader_iter, batch_idx):
'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1,
)

consumed_samples = self.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
consumed_samples = self._compute_consumed_samples_after_training_step()
# TODO: make sure compute_consumed_samples works for pipeline parallelism
self.log(
'consumed_samples', consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def training_step(self, dataloader_iter, batch_idx):
# TODO: make sure compute_consumed_samples works for pipeline parallelism
self.log(
'consumed_samples',
self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
self._compute_consumed_samples_after_training_step(),
prog_bar=True,
rank_zero_only=True,
batch_size=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,7 @@ def training_step(self, batch, batch_idx):
self.log('lr', lr, batch_size=1)
self.log('global_step', self.trainer.global_step, prog_bar=True, batch_size=1)
self.log(
'consumed_samples',
self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
prog_bar=True,
batch_size=1,
'consumed_samples', self._compute_consumed_samples_after_training_step(), prog_bar=True, batch_size=1,
)
self._reduced_loss_buffer = []
return lm_loss
Expand Down

0 comments on commit e6fda0f

Please sign in to comment.