Skip to content

Commit

Permalink
Merge pull request #449 from ACEsuit/multihead_validation_print
Browse files Browse the repository at this point in the history
validation loss printed during training is not actually for the specific head
  • Loading branch information
ilyes319 authored Jun 6, 2024
2 parents a92dba2 + 9616a41 commit 14ec4ec
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def train(
)
valid_loss += valid_loss_head
valid_err_log(
valid_loss, eval_metrics, logger, log_errors, None, valid_loader_name
valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name
)

while epoch < max_num_epochs:
Expand Down Expand Up @@ -224,7 +224,7 @@ def train(
)
valid_loss += valid_loss_head
valid_err_log(
valid_loss,
valid_loss_head,
eval_metrics,
logger,
log_errors,
Expand Down

0 comments on commit 14ec4ec

Please sign in to comment.