From 8ab6972df1d0a54e810b1822ae61b489b0cbb440 Mon Sep 17 00:00:00 2001 From: rawanmahdi Date: Wed, 28 Jun 2023 10:13:24 -0400 Subject: [PATCH 1/3] feat: add sklearn classification report as default supported metric --- pytorch_tabnet/metrics.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/pytorch_tabnet/metrics.py b/pytorch_tabnet/metrics.py index e8ad8181..d3b0fcc4 100644 --- a/pytorch_tabnet/metrics.py +++ b/pytorch_tabnet/metrics.py @@ -9,6 +9,7 @@ log_loss, balanced_accuracy_score, mean_squared_log_error, + classification_report ) import torch @@ -402,6 +403,32 @@ def __call__(self, y_true, y_score): y_score = np.clip(y_score, a_min=0, a_max=None) return np.sqrt(mean_squared_log_error(y_true, y_score)) +class ClassificationReport(Metric): + """ + Classification Report: Precision, Recall and F1 scores. + """ + + def __init__(self): + self._name = "classification_report" + self._maximize = False + + def __call__(self, y_true, y_score): + """ + Compute precision, recall and F1 scores of predictions for each target class. + + Parameters + ---------- + y_true : np.ndarray + Target matrix or vector + y_score : np.ndarray + Score matrix or vector + + Returns + ------- + float + AUC of predictions vs targets. + """ + return classification_report(y_true, y_score[:, 1]) class UnsupervisedMetric(Metric): """ From 8b96d460842c20b46a25817bb725fc6e9d8dc832 Mon Sep 17 00:00:00 2001 From: rawanmahdi Date: Wed, 28 Jun 2023 10:28:07 -0400 Subject: [PATCH 2/3] feat: update docs for lassifcationReport class --- pytorch_tabnet/metrics.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_tabnet/metrics.py b/pytorch_tabnet/metrics.py index d3b0fcc4..93bb542a 100644 --- a/pytorch_tabnet/metrics.py +++ b/pytorch_tabnet/metrics.py @@ -406,6 +406,8 @@ def __call__(self, y_true, y_score): class ClassificationReport(Metric): """ Classification Report: Precision, Recall and F1 scores. + Scikit-implementation: + https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html """ def __init__(self): @@ -425,10 +427,10 @@ def __call__(self, y_true, y_score): Returns ------- - float - AUC of predictions vs targets. + str + table of precision, recall, and f1 score as well as supports """ - return classification_report(y_true, y_score[:, 1]) + return classification_report(y_true, y_score) class UnsupervisedMetric(Metric): """ From 0a2d975c3d42b1bece5ad589920c927281fcce43 Mon Sep 17 00:00:00 2001 From: rawanmahdi Date: Thu, 6 Jul 2023 15:24:37 -0400 Subject: [PATCH 3/3] feat: fixing spacing to pass lint --- pytorch_tabnet/metrics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_tabnet/metrics.py b/pytorch_tabnet/metrics.py index 93bb542a..c9a263ec 100644 --- a/pytorch_tabnet/metrics.py +++ b/pytorch_tabnet/metrics.py @@ -403,6 +403,7 @@ def __call__(self, y_true, y_score): y_score = np.clip(y_score, a_min=0, a_max=None) return np.sqrt(mean_squared_log_error(y_true, y_score)) + class ClassificationReport(Metric): """ Classification Report: Precision, Recall and F1 scores. @@ -432,6 +433,7 @@ def __call__(self, y_true, y_score): """ return classification_report(y_true, y_score) + class UnsupervisedMetric(Metric): """ Unsupervised metric