diff --git a/census_example.ipynb b/census_example.ipynb index d820e97f..ed6472ec 100755 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -262,6 +262,60 @@ "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save and load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save state dict\n", + "clf.save_model('test.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define new model with basic parameters and load state dict weights\n", + "loaded_clf = TabNetClassifier(input_dim=14,\n", + " output_dim=2,\n", + " cat_idxs=cat_idxs,\n", + " cat_dims=cat_dims,\n", + " cat_emb_dim=1,\n", + " mask_type='entmax')\n", + "loaded_clf.load_model('test.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_preds = loaded_clf.predict_proba(X_test)\n", + "loaded_test_auc = roc_auc_score(y_score=loaded_preds[:,1], y_true=y_test)\n", + "\n", + "print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_auc}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert(test_auc == loaded_test_auc)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -385,7 +439,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.7.5" } }, "nbformat": 4, diff --git a/forest_example.ipynb b/forest_example.ipynb index 81f7412e..aea5ccf9 100644 --- a/forest_example.ipynb +++ b/forest_example.ipynb @@ -314,6 +314,63 @@ "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_acc}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save and load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save state dict\n", + "clf.save_model('test.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define new model with basic parameters and load state dict weights\n", + "loaded_clf = TabNetClassifier(input_dim=54,\n", + " output_dim=7,\n", + " n_d=64, n_a=64, n_steps=5,\n", + " gamma=1.5, n_independent=2, n_shared=2,\n", + " cat_idxs=cat_idxs,\n", + " cat_dims=cat_dims,\n", + " cat_emb_dim=1)\n", + "loaded_clf.load_model('test.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_preds = loaded_clf.predict_proba(X_test)\n", + "loaded_y_pred = np.vectorize(preds_mapper.get)(np.argmax(loaded_preds, axis=1))\n", + "\n", + "loaded_test_acc = accuracy_score(y_pred=loaded_y_pred, y_true=y_test)\n", + "\n", + "print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_acc}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert(test_acc == loaded_test_acc)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 0471fad8..c3c0aa4f 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -12,7 +12,7 @@ create_explain_matrix) from sklearn.base import BaseEstimator from torch.utils.data import DataLoader -import copy +from copy import deepcopy class TabModel(BaseEstimator): @@ -23,8 +23,9 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2), scheduler_params=None, scheduler_fn=None, - device_name='auto', - mask_type="sparsemax"): + mask_type="sparsemax", + input_dim=None, output_dim=None, + device_name='auto'): """ Class for TabNet model Parameters @@ -53,6 +54,11 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], self.scheduler_params = scheduler_params self.scheduler_fn = scheduler_fn self.mask_type = mask_type + self.input_dim = input_dim + self.output_dim = output_dim + + self.batch_size = 1024 + self.seed = seed torch.manual_seed(self.seed) # Defining device @@ -76,6 +82,49 @@ def construct_loaders(self, X_train, y_train, X_valid, y_valid, """ raise NotImplementedError('users must define construct_loaders to use this base class') + def init_network( + self, + input_dim, + output_dim, + n_d, + n_a, + n_steps, + gamma, + cat_idxs, + cat_dims, + cat_emb_dim, + n_independent, + n_shared, + epsilon, + virtual_batch_size, + momentum, + device_name, + mask_type + ): + self.network = tab_network.TabNet( + input_dim, + output_dim, + n_d=n_d, + n_a=n_a, + n_steps=n_steps, + gamma=gamma, + cat_idxs=cat_idxs, + cat_dims=cat_dims, + cat_emb_dim=cat_emb_dim, + n_independent=n_independent, + n_shared=n_shared, + epsilon=epsilon, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device_name=device_name, + mask_type=mask_type).to(self.device) + + self.reducing_matrix = create_explain_matrix( + self.network.input_dim, + self.network.cat_emb_dim, + self.network.cat_idxs, + self.network.post_embed_dim) + def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, weights=0, max_epochs=100, patience=10, batch_size=1024, virtual_batch_size=128, num_workers=0, drop_last=False): @@ -125,22 +174,24 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, self.num_workers, self.drop_last) - self.network = tab_network.TabNet(self.input_dim, self.output_dim, - n_d=self.n_d, n_a=self.n_d, - n_steps=self.n_steps, gamma=self.gamma, - cat_idxs=self.cat_idxs, cat_dims=self.cat_dims, - cat_emb_dim=self.cat_emb_dim, - n_independent=self.n_independent, n_shared=self.n_shared, - epsilon=self.epsilon, - virtual_batch_size=self.virtual_batch_size, - momentum=self.momentum, - device_name=self.device_name, - mask_type=self.mask_type).to(self.device) - - self.reducing_matrix = create_explain_matrix(self.network.input_dim, - self.network.cat_emb_dim, - self.network.cat_idxs, - self.network.post_embed_dim) + self.init_network( + input_dim=self.input_dim, + output_dim=self.output_dim, + n_d=self.n_d, + n_a=self.n_a, + n_steps=self.n_steps, + gamma=self.gamma, + cat_idxs=self.cat_idxs, + cat_dims=self.cat_dims, + cat_emb_dim=self.cat_emb_dim, + n_independent=self.n_independent, + n_shared=self.n_shared, + epsilon=self.epsilon, + virtual_batch_size=self.virtual_batch_size, + momentum=self.momentum, + device_name=self.device_name, + mask_type=self.mask_type + ) self.optimizer = self.optimizer_fn(self.network.parameters(), **self.optimizer_params) @@ -183,7 +234,7 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, self.best_cost = stopping_loss self.patience_counter = 0 # Saving model - self.best_network = copy.deepcopy(self.network) + self.best_network = deepcopy(self.network) else: self.patience_counter += 1 @@ -217,6 +268,31 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, # compute feature importance once the best model is defined self._compute_feature_importances(train_dataloader) + def save_model(self, path): + torch.save(self.network.state_dict(), path) + + def load_model(self, path): + self.init_network( + input_dim=self.input_dim, + output_dim=self.output_dim, + n_d=self.n_d, + n_a=self.n_a, + n_steps=self.n_steps, + gamma=self.gamma, + cat_idxs=self.cat_idxs, + cat_dims=self.cat_dims, + cat_emb_dim=self.cat_emb_dim, + n_independent=self.n_independent, + n_shared=self.n_shared, + epsilon=self.epsilon, + virtual_batch_size=1024, + momentum=self.momentum, + device_name=self.device_name, + mask_type=self.mask_type + ) + self.network.load_state_dict(torch.load(path)) + self.network.eval() + def fit_epoch(self, train_dataloader, valid_dataloader): """ Evaluates and updates network for one epoch.