Skip to content

Commit

Permalink
fix: multiclass prediction mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Feb 3, 2020
1 parent 123932a commit 2317c5c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 12 deletions.
32 changes: 26 additions & 6 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,45 @@
") "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Deprecated : best model is automatically loaded at end of fit\n",
"# clf.load_best_model()\n",
"# To get final results you may need to use a mapping for classes \n",
"# as you are allowed to use targets like [\"yes\", \"no\", \"maybe\", \"I don't know\"]\n",
"\n",
"preds = clf.predict_proba(X_test)\n",
"\n",
"y_true = y_test\n",
"preds_mapper = { idx : class_name for idx, class_name in enumerate(clf.classes_)}\n",
"\n",
"test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_true)\n",
"preds = clf.predict_proba(X_test)\n",
"y_pred = np.vectorize(preds_mapper.get)(preds[:,1])\n",
"test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
"\n",
"print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# or you can simply use the predict method\n",
"\n",
"y_pred = clf.predict(X_test)\n",
"test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
29 changes: 27 additions & 2 deletions forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,47 @@
") "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Predictions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# To get final results you may need to use a mapping for classes \n",
"# as you are allowed to use targets like [\"yes\", \"no\", \"maybe\", \"I don't know\"]\n",
"\n",
"preds_mapper = { idx : class_name for idx, class_name in enumerate(clf.classes_)}\n",
"\n",
"preds = clf.predict_proba(X_test)\n",
"\n",
"y_true = y_test\n",
"y_pred = np.vectorize(preds_mapper.get)(np.argmax(preds, axis=1))\n",
"\n",
"test_acc = accuracy_score(y_pred=np.argmax(preds, axis=1), y_true=y_true)\n",
"test_acc = accuracy_score(y_pred=y_pred, y_true=y_test)\n",
"\n",
"print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_acc}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# or you can simply use the predict method\n",
"\n",
"y_pred = clf.predict(X_test)\n",
"test_acc = accuracy_score(y_pred=y_pred, y_true=y_test)\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_acc}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
9 changes: 5 additions & 4 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ def update_fit_params(self, X_train, y_train, X_valid, y_valid, loss_fn,
self.classes_ = train_labels
self.target_mapper = {class_label: index
for index, class_label in enumerate(self.classes_)}

self.preds_mapper = {index: class_label
for index, class_label in enumerate(self.classes_)}
self.weights = weights
self.updated_weights = self.weight_updater(self.weights)

Expand Down Expand Up @@ -619,7 +620,7 @@ def predict(self, X):
Returns
-------
predictions: np.array
Predictions of the regression problem or the last class
Predictions of the most probable class
"""
self.network.eval()
dataloader = DataLoader(PredictDataset(X),
Expand All @@ -636,7 +637,7 @@ def predict(self, X):
else:
res = np.hstack([res, predictions])

return res
return np.vectorize(self.preds_mapper.get)(res)

def predict_proba(self, X):
"""
Expand Down Expand Up @@ -881,7 +882,7 @@ def predict(self, X):
Returns
-------
predictions: np.array
Predictions of the regression problem or the last class
Predictions of the regression problem
"""
self.network.eval()
dataloader = DataLoader(PredictDataset(X),
Expand Down

0 comments on commit 2317c5c

Please sign in to comment.