Skip to content

Commit

Permalink
feat: save and load tabnet models
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Jun 28, 2020
1 parent 65b0b88 commit 9d2d8ae
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 21 deletions.
56 changes: 55 additions & 1 deletion census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,60 @@
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Save and load Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save state dict\n",
"clf.save_model('test.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define new model with basic parameters and load state dict weights\n",
"loaded_clf = TabNetClassifier(input_dim=14,\n",
" output_dim=2,\n",
" cat_idxs=cat_idxs,\n",
" cat_dims=cat_dims,\n",
" cat_emb_dim=1,\n",
" mask_type='entmax')\n",
"loaded_clf.load_model('test.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loaded_preds = loaded_clf.predict_proba(X_test)\n",
"loaded_test_auc = roc_auc_score(y_score=loaded_preds[:,1], y_true=y_test)\n",
"\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_auc}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert(test_auc == loaded_test_auc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -385,7 +439,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
57 changes: 57 additions & 0 deletions forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,63 @@
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_acc}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Save and load Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save state dict\n",
"clf.save_model('test.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define new model with basic parameters and load state dict weights\n",
"loaded_clf = TabNetClassifier(input_dim=54,\n",
" output_dim=7,\n",
" n_d=64, n_a=64, n_steps=5,\n",
" gamma=1.5, n_independent=2, n_shared=2,\n",
" cat_idxs=cat_idxs,\n",
" cat_dims=cat_dims,\n",
" cat_emb_dim=1)\n",
"loaded_clf.load_model('test.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loaded_preds = loaded_clf.predict_proba(X_test)\n",
"loaded_y_pred = np.vectorize(preds_mapper.get)(np.argmax(loaded_preds, axis=1))\n",
"\n",
"loaded_test_acc = accuracy_score(y_pred=loaded_y_pred, y_true=y_test)\n",
"\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_acc}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert(test_acc == loaded_test_acc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
116 changes: 96 additions & 20 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
create_explain_matrix)
from sklearn.base import BaseEstimator
from torch.utils.data import DataLoader
import copy
from copy import deepcopy


class TabModel(BaseEstimator):
Expand All @@ -23,8 +23,9 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[],
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_params=None, scheduler_fn=None,
device_name='auto',
mask_type="sparsemax"):
mask_type="sparsemax",
input_dim=None, output_dim=None,
device_name='auto'):
""" Class for TabNet model
Parameters
Expand Down Expand Up @@ -53,6 +54,11 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[],
self.scheduler_params = scheduler_params
self.scheduler_fn = scheduler_fn
self.mask_type = mask_type
self.input_dim = input_dim
self.output_dim = output_dim

self.batch_size = 1024

self.seed = seed
torch.manual_seed(self.seed)
# Defining device
Expand All @@ -76,6 +82,49 @@ def construct_loaders(self, X_train, y_train, X_valid, y_valid,
"""
raise NotImplementedError('users must define construct_loaders to use this base class')

def init_network(
self,
input_dim,
output_dim,
n_d,
n_a,
n_steps,
gamma,
cat_idxs,
cat_dims,
cat_emb_dim,
n_independent,
n_shared,
epsilon,
virtual_batch_size,
momentum,
device_name,
mask_type
):
self.network = tab_network.TabNet(
input_dim,
output_dim,
n_d=n_d,
n_a=n_a,
n_steps=n_steps,
gamma=gamma,
cat_idxs=cat_idxs,
cat_dims=cat_dims,
cat_emb_dim=cat_emb_dim,
n_independent=n_independent,
n_shared=n_shared,
epsilon=epsilon,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device_name=device_name,
mask_type=mask_type).to(self.device)

self.reducing_matrix = create_explain_matrix(
self.network.input_dim,
self.network.cat_emb_dim,
self.network.cat_idxs,
self.network.post_embed_dim)

def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
weights=0, max_epochs=100, patience=10, batch_size=1024,
virtual_batch_size=128, num_workers=0, drop_last=False):
Expand Down Expand Up @@ -125,22 +174,24 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
self.num_workers,
self.drop_last)

self.network = tab_network.TabNet(self.input_dim, self.output_dim,
n_d=self.n_d, n_a=self.n_d,
n_steps=self.n_steps, gamma=self.gamma,
cat_idxs=self.cat_idxs, cat_dims=self.cat_dims,
cat_emb_dim=self.cat_emb_dim,
n_independent=self.n_independent, n_shared=self.n_shared,
epsilon=self.epsilon,
virtual_batch_size=self.virtual_batch_size,
momentum=self.momentum,
device_name=self.device_name,
mask_type=self.mask_type).to(self.device)

self.reducing_matrix = create_explain_matrix(self.network.input_dim,
self.network.cat_emb_dim,
self.network.cat_idxs,
self.network.post_embed_dim)
self.init_network(
input_dim=self.input_dim,
output_dim=self.output_dim,
n_d=self.n_d,
n_a=self.n_a,
n_steps=self.n_steps,
gamma=self.gamma,
cat_idxs=self.cat_idxs,
cat_dims=self.cat_dims,
cat_emb_dim=self.cat_emb_dim,
n_independent=self.n_independent,
n_shared=self.n_shared,
epsilon=self.epsilon,
virtual_batch_size=self.virtual_batch_size,
momentum=self.momentum,
device_name=self.device_name,
mask_type=self.mask_type
)

self.optimizer = self.optimizer_fn(self.network.parameters(),
**self.optimizer_params)
Expand Down Expand Up @@ -183,7 +234,7 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
self.best_cost = stopping_loss
self.patience_counter = 0
# Saving model
self.best_network = copy.deepcopy(self.network)
self.best_network = deepcopy(self.network)
else:
self.patience_counter += 1

Expand Down Expand Up @@ -217,6 +268,31 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
# compute feature importance once the best model is defined
self._compute_feature_importances(train_dataloader)

def save_model(self, path):
torch.save(self.network.state_dict(), path)

def load_model(self, path):
self.init_network(
input_dim=self.input_dim,
output_dim=self.output_dim,
n_d=self.n_d,
n_a=self.n_a,
n_steps=self.n_steps,
gamma=self.gamma,
cat_idxs=self.cat_idxs,
cat_dims=self.cat_dims,
cat_emb_dim=self.cat_emb_dim,
n_independent=self.n_independent,
n_shared=self.n_shared,
epsilon=self.epsilon,
virtual_batch_size=1024,
momentum=self.momentum,
device_name=self.device_name,
mask_type=self.mask_type
)
self.network.load_state_dict(torch.load(path))
self.network.eval()

def fit_epoch(self, train_dataloader, valid_dataloader):
"""
Evaluates and updates network for one epoch.
Expand Down

0 comments on commit 9d2d8ae

Please sign in to comment.