diff --git a/mace/tools/train.py b/mace/tools/train.py index 87a958f9..32231acf 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -244,12 +244,12 @@ def train( if valid_loss >= lowest_loss: patience_counter += 1 - if patience_counter >= patience and (swa.start is not None and epoch < swa.start): + if patience_counter >= patience and (swa is not None and epoch < swa.start): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" ) epoch = swa.start - elif patience_counter >= patience and (swa.start is None or epoch >= swa.start): + elif patience_counter >= patience and (swa is None or epoch >= swa.start): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement" )