diff --git a/src/train_egnn.py b/src/train_egnn.py index 9a76193..b23ccab 100644 --- a/src/train_egnn.py +++ b/src/train_egnn.py @@ -909,8 +909,9 @@ def train_model(model, train_loader, val_loader, num_epochs, learning_rate, devi train_loss = train_one_epoch(model, train_loader, optimizer, device, use_pointnet, log_interval, beta) # Validate every few epochs (e.g., every 5 epochs) - if (epoch + 1) % 5 == 0: - val_loss = validate(model, val_loader, device, use_pointnet) + if (epoch + 1) % 1 == 0: + val_loss, val_pose_loss, val_corr_loss = validate(model, val_loader, device, use_pointnet) + print(val_loss) print(f'Epoch {epoch + 1}/{num_epochs} - Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}') else: print(f'Epoch {epoch + 1}/{num_epochs} - Training Loss: {train_loss:.4f}')