Skip to content

Commit

Permalink
fix schduler steps without optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname authored and Benedikt Fuchs committed Aug 7, 2023
1 parent 91861bd commit abcc759
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 4 additions & 1 deletion flair/trainers/plugins/functional/linear_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ def before_training_epoch(self, **kw):
self.previous_learning_rate = self.current_learning_rate

@TrainerPlugin.hook
def after_training_batch(self, **kw):
def after_training_batch(self, optimizer_was_run: bool, **kw):
"""Do the scheduler step if one-cycle or linear decay.
:param kw:
:return:
"""
# skip if no optimization has happened.
if not optimizer_was_run:
return
self.scheduler.step()
self.store_learning_rate()

Expand Down
3 changes: 3 additions & 0 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,11 @@ def train_custom(
# do the optimizer step
scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
scale_before = scaler.get_scale()
scaler.step(self.optimizer)
scaler.update()
scale_after = scaler.get_scale()
batch_kw["optimizer_was_run"] = scale_before <= scale_after

if batch_train_samples > 0:
train_loss = batch_train_loss / batch_train_samples
Expand Down

0 comments on commit abcc759

Please sign in to comment.