Skip to content

Commit

Permalink
[FIX] ROC AUC for multi class classification (#482)
Browse files Browse the repository at this point in the history
* fixed cut mix

* remove unnecessary comment

* change all_supported_metrics

* fix roc_auc for multiclass

* remove unnecessary code
  • Loading branch information
ravinkohli authored Oct 17, 2022
1 parent d29d11b commit 873df9a
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
5 changes: 2 additions & 3 deletions autoPyTorch/pipeline/components/setup/network/base_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent:

self.network = torch.nn.Sequential(X['network_embedding'], X['network_backbone'], X['network_head'])

if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS:
self.network = torch.nn.Sequential(self.network, nn.Softmax(dim=1))
# Properly set the network training device
if self.device is None:
self.device = get_device_from_fit_dictionary(X)

self.to(self.device)

if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS:
self.final_activation = nn.Softmax(dim=1)

self.is_fitted_ = True

return self
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/pipeline/components/training/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __call__(
Score function applied to prediction of estimator on X.
"""
y_type = type_of_target(y_true)
if y_type not in ("binary", "multilabel-indicator"):
if y_type not in ("binary", "multilabel-indicator") and self.name != 'roc_auc':
raise ValueError("{0} format is not supported".format(y_type))

if y_type == "binary":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@


# Score functions that need decision values
roc_auc = make_metric('roc_auc', sklearn.metrics.roc_auc_score, needs_threshold=True)
roc_auc = make_metric('roc_auc', sklearn.metrics.roc_auc_score, needs_threshold=True, multi_class= 'ovo')
average_precision = make_metric('average_precision',
sklearn.metrics.average_precision_score,
needs_threshold=True)
Expand Down
4 changes: 2 additions & 2 deletions autoPyTorch/pipeline/components/training/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def get_metrics(dataset_properties: Dict[str, Any],
if names is not None:
for name in names:
if name not in supported_metrics.keys():
raise ValueError("Invalid name entered for task {}, currently "
"supported metrics for task include {}".format(dataset_properties['task_type'],
raise ValueError("Invalid name {} entered for task {}, currently "
"supported metrics for task include {}".format(name, dataset_properties['task_type'],
list(supported_metrics.keys())))
else:
metric = supported_metrics[name]
Expand Down

0 comments on commit 873df9a

Please sign in to comment.