From d6fbf9012ad2a60f0ac4e2b801d258a16250d74c Mon Sep 17 00:00:00 2001 From: Optimox Date: Fri, 12 Jun 2020 17:57:51 +0200 Subject: [PATCH] fix: verbosity with schedulers --- census_example.ipynb | 15 ++++++++++++++- pytorch_tabnet/tab_model.py | 31 +++++++++++++++++-------------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/census_example.ipynb b/census_example.ipynb index 6bea6ff4..fcaa1c24 100755 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -151,7 +151,10 @@ " cat_dims=cat_dims,\n", " cat_emb_dim=1,\n", " optimizer_fn=torch.optim.Adam,\n", - " optimizer_params=dict(lr=2e-2))" + " optimizer_params=dict(lr=2e-2),\n", + " scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n", + " \"gamma\":0.9},\n", + " scheduler_fn=torch.optim.lr_scheduler.StepLR)" ] }, { @@ -227,6 +230,16 @@ "plt.plot([-x for x in clf.history['valid']['metric']])" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot learning rates\n", + "plt.plot([x for x in clf.history['train']['lr']])" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index cef61eb8..0abcbe04 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -148,11 +148,11 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, else: self.scheduler = None - losses_train = [] - losses_valid = [] - - metrics_train = [] - metrics_valid = [] + self.losses_train = [] + self.losses_valid = [] + self.learning_rates = [] + self.metrics_train = [] + self.metrics_valid = [] if self.verbose > 0: print("Will train until validation stopping metric", @@ -165,13 +165,16 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, while (self.epoch < self.max_epochs and self.patience_counter < self.patience): starting_time = time.time() + # updates learning rate history + self.learning_rates.append(self.optimizer.param_groups[-1]["lr"]) + fit_metrics = self.fit_epoch(train_dataloader, valid_dataloader) # leaving it here, may be used for callbacks later - losses_train.append(fit_metrics['train']['loss_avg']) - losses_valid.append(fit_metrics['valid']['total_loss']) - metrics_train.append(fit_metrics['train']['stopping_loss']) - metrics_valid.append(fit_metrics['valid']['stopping_loss']) + self.losses_train.append(fit_metrics['train']['loss_avg']) + self.losses_valid.append(fit_metrics['valid']['total_loss']) + self.metrics_train.append(fit_metrics['train']['stopping_loss']) + self.metrics_valid.append(fit_metrics['valid']['stopping_loss']) stopping_loss = fit_metrics['valid']['stopping_loss'] if stopping_loss < self.best_cost: @@ -201,10 +204,11 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, print(f"Training done in {total_time:.3f} seconds.") print('---------------------------------------') - self.history = {"train": {"loss": losses_train, - "metric": metrics_train}, - "valid": {"loss": losses_valid, - "metric": metrics_valid}} + self.history = {"train": {"loss": self.losses_train, + "metric": self.metrics_train, + "lr": self.learning_rates}, + "valid": {"loss": self.losses_valid, + "metric": self.metrics_valid}} # load best models post training self.load_best_model() @@ -767,7 +771,6 @@ def train_epoch(self, train_loader): if self.scheduler is not None: self.scheduler.step() - print("Current learning rate: ", self.optimizer.param_groups[-1]["lr"]) return epoch_metrics def train_batch(self, data, targets):