Skip to content

Commit

Permalink
feat: update notebooks for new model format
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp committed Dec 3, 2019
1 parent b9d14c4 commit 43e2693
Show file tree
Hide file tree
Showing 4 changed files with 548 additions and 201 deletions.
124 changes: 25 additions & 99 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pytorch_tabnet import tab_network\n",
"from pytorch_tabnet.tab_model import Model\n",
"from pytorch_tabnet.tab_model import TabNetClassifier\n",
"\n",
"import torch\n",
"from sklearn.preprocessing import LabelEncoder\n",
Expand Down Expand Up @@ -126,9 +125,7 @@
"\n",
"cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n",
"\n",
"cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n",
"\n",
"train[target] = train[target].astype(int)"
"cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n"
]
},
{
Expand All @@ -144,65 +141,7 @@
"metadata": {},
"outputs": [],
"source": [
"num_workers= 5\n",
"LR = 2e-2\n",
"batch_size = 1024 #64\n",
"mini_batch_size = 128\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"network_params = {\"input_dim\" : len(features),\n",
" \"n_d\" : 8,\n",
" \"n_a\" : 8,\n",
" \"n_independent\": 2,\n",
" \"n_shared\": 2,\n",
" \"n_steps\": 3,\n",
" \"gamma\": 1.3,\n",
" \"output_dim\" : 2,\n",
" \"momentum\": 0.1,\n",
" \"cat_idxs\":cat_idxs,\n",
" \"cat_dims\": cat_dims,\n",
" \"cat_emb_dim\": 1,\n",
" \"virtual_batch_size\": mini_batch_size,\n",
"}\n",
"\n",
"description = f\"test_TabNet_LR_{LR}_BS_{batch_size}_DS_{dataset_name}\"\n",
"description += f\"_miniBS_{mini_batch_size}\"\n",
"description += f\"_nd_{network_params['n_d']}\"\n",
"description += f\"_na_{network_params['n_a']}\"\n",
"description += f\"_nsteps_{network_params['n_steps']}\"\n",
"description += f\"_gamma_{network_params['gamma']}\"\n",
"description += f\"_momentum_{network_params['momentum']}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"my_scheduler = torch.optim.lr_scheduler.StepLR\n",
"scheduler_params = {\"gamma\": 0.9,\n",
" \"step_size\": 20}\n",
"\n",
"training_params = {\"model_name\": description,\n",
" \"lambda_sparse\": 1e-3,\n",
" \"lr\":LR,\n",
" \"patience\": 200,\n",
" \"optimizer_fn\":torch.optim.Adam,\n",
" \"scheduler_fn\": my_scheduler,\n",
" \"scheduler_params\":scheduler_params,\n",
" \"max_epochs\": 1000,\n",
" \"batch_size\": batch_size,\n",
" \"clip_value\": 0.5,\n",
" \"device\":device\n",
" }"
"clf = TabNetClassifier()"
]
},
{
Expand All @@ -218,36 +157,29 @@
"metadata": {},
"outputs": [],
"source": [
"X_train = train.iloc[train_indices][features].values\n",
"y_train = train.iloc[train_indices][target].values\n",
"X_train = train[features].values[train_indices]\n",
"y_train = train[target].values[train_indices]\n",
"\n",
"X_valid = train.iloc[valid_indices][features].values\n",
"y_valid = train.iloc[valid_indices][target].values\n",
"X_valid = train[features].values[valid_indices]\n",
"y_valid = train[target].values[valid_indices]\n",
"\n",
"X_test = train.iloc[test_indices][features].values\n",
"y_test = train.iloc[test_indices][target].values"
"X_test = train[features].values[test_indices]\n",
"y_test = train[target].values[test_indices]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
"scrolled": true
},
"outputs": [],
"source": [
"network = tab_network.TabNet\n",
"model = Model()\n",
"\n",
"\n",
"model.def_network(network, **network_params)\n",
"model.set_params(**training_params)\n",
"\n",
"model.fit(\n",
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" X_valid=X_valid, y_valid=y_valid,\n",
" balanced=False, #True,\n",
" weights=None, #{0: 1, 1:10}\n",
" max_epochs=1000, patience=50,\n",
" batch_size=1024, virtual_batch_size=128\n",
") "
]
},
Expand All @@ -257,15 +189,16 @@
"metadata": {},
"outputs": [],
"source": [
"model.load_best_model()\n",
"# Deprecated : best model is automatically loaded at end of fit\n",
"# clf.load_best_model()\n",
"\n",
"preds = model.predict_proba(X_test)\n",
"preds = clf.predict_proba(X_test)\n",
"\n",
"y_true = y_test\n",
"\n",
"test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_true)\n",
"\n",
"print(f\"BEST VALID SCORE FOR {dataset_name} : {model.best_cost}\")\n",
"print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
]
},
Expand All @@ -282,7 +215,7 @@
"metadata": {},
"outputs": [],
"source": [
"explain_matrix, masks = model.explain(X_test)"
"explain_matrix, masks = clf.explain(X_test)"
]
},
{
Expand All @@ -301,9 +234,9 @@
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(1, network_params['n_steps'])\n",
"fig, axs = plt.subplots(1, 3, figsize=(20,20))\n",
"\n",
"for i in range(network_params['n_steps']):\n",
"for i in range(3):\n",
" axs[i].imshow(masks[i][:50])\n",
" axs[i].set_title(f\"mask {i}\")\n"
]
Expand All @@ -325,7 +258,7 @@
"source": [
"from xgboost import XGBClassifier\n",
"\n",
"clf = XGBClassifier(max_depth=8,\n",
"clf_xgb = XGBClassifier(max_depth=8,\n",
" learning_rate=0.1,\n",
" n_estimators=1000,\n",
" verbosity=0,\n",
Expand All @@ -348,7 +281,7 @@
" random_state=0,\n",
" seed=None,)\n",
"\n",
"clf.fit(X_train, y_train,\n",
"clf_xgb.fit(X_train, y_train,\n",
" eval_set=[(X_valid, y_valid)],\n",
" early_stopping_rounds=40,\n",
" verbose=10)"
Expand All @@ -360,21 +293,14 @@
"metadata": {},
"outputs": [],
"source": [
"preds = np.array(clf.predict_proba(X_valid))\n",
"preds = np.array(clf_xgb.predict_proba(X_valid))\n",
"valid_auc = roc_auc_score(y_score=preds[:,1], y_true=y_valid)\n",
"print(valid_auc)\n",
"\n",
"preds = np.array(clf.predict_proba(X_test))\n",
"preds = np.array(clf_xgb.predict_proba(X_test))\n",
"test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
"print(test_auc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -393,7 +319,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 43e2693

Please sign in to comment.