Skip to content

Commit

Permalink
Change default metrics of CategoricalOutput to retrieval-metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Jun 30, 2023
1 parent 8c225e0 commit 4281058
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 23 deletions.
70 changes: 50 additions & 20 deletions merlin/models/torch/outputs/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Type, Union

import torch
import torchmetrics as tm
from torch import nn
from torchmetrics import AUROC, Accuracy, AveragePrecision, Metric, Precision, Recall

import merlin.dtypes as md
from merlin.models.torch import schema
Expand All @@ -40,13 +40,13 @@ class BinaryOutput(ModelOutput):
"""

DEFAULT_LOSS_CLS = nn.BCEWithLogitsLoss
DEFAULT_METRICS_CLS = (Accuracy, AUROC, Precision, Recall)
DEFAULT_METRICS_CLS = (tm.Accuracy, tm.AUROC, tm.Precision, tm.Recall)

def __init__(
self,
schema: Optional[ColumnSchema] = None,
loss: Optional[nn.Module] = None,
metrics: Sequence[Metric] = (),
metrics: Sequence[tm.Metric] = (),
):
"""Initializes a BinaryOutput object."""
super().__init__(
Expand Down Expand Up @@ -113,16 +113,25 @@ class CategoricalOutput(ModelOutput):
by default 1.0
"""

DEFAULT_LOSS_CLS = nn.CrossEntropyLoss
DEFAULT_METRICS_CLS = (
tm.RetrievalHitRate,
tm.RetrievalNormalizedDCG,
tm.RetrievalPrecision,
tm.RetrievalRecall,
)
DEFAULT_K = (5,)

def __init__(
self,
schema: Optional[Union[ColumnSchema, Schema]] = None,
loss: nn.Module = nn.CrossEntropyLoss(),
metrics: Optional[Sequence[Metric]] = None,
loss: Optional[nn.Module] = None,
metrics: Optional[Sequence[tm.Metric]] = None,
logits_temperature: float = 1.0,
):
super().__init__(
loss=loss,
metrics=metrics or [],
loss=loss or self.DEFAULT_LOSS_CLS(),
metrics=metrics or create_retrieval_metrics(self.DEFAULT_METRICS_CLS, self.DEFAULT_K),
logits_temperature=logits_temperature,
)

Expand All @@ -135,7 +144,7 @@ def with_weight_tying(
block: nn.Module,
selection: Optional[schema.Selection] = None,
loss: nn.Module = nn.CrossEntropyLoss(),
metrics: Optional[Sequence[Metric]] = None,
metrics: Optional[Sequence[tm.Metric]] = None,
logits_temperature: float = 1.0,
) -> "CategoricalOutput":
self = cls(loss=loss, metrics=metrics, logits_temperature=logits_temperature)
Expand Down Expand Up @@ -177,17 +186,6 @@ def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
to_call = CategoricalTarget(target)
self.num_classes = to_call.num_classes
self.prepend(to_call)
if not self.metrics:
self.metrics = self.default_metrics(self.num_classes)

@classmethod
def default_metrics(cls, num_classes: int) -> List[Metric]:
"""Returns the default metrics used for multi-class classification."""
return [
AveragePrecision(task="multiclass", num_classes=num_classes),
Precision(task="multiclass", num_classes=num_classes),
Recall(task="multiclass", num_classes=num_classes),
]

@classmethod
def schema_selection(cls, schema: Schema) -> Schema:
Expand Down Expand Up @@ -413,3 +411,35 @@ def categorical_output_schema(target: ColumnSchema, num_classes: int) -> Schema:
)

return Schema([_target])


def create_retrieval_metrics(
metrics: Sequence[Type[tm.Metric]], ks: Sequence[int]
) -> List[tm.Metric]:
"""
Create a list of retrieval metrics given metric types and a list of integers.
For each integer in `ks`, a metric is created for each type in `metrics`.
Parameters
----------
metrics : Sequence[Type[tm.Metric]]
The types of metrics to create. Each type should be a callable that
accepts a single integer parameter `k` to instantiate a new metric.
ks : Sequence[int]
A list of integers to use as the `k` parameter when creating each metric.
Returns
-------
List[tm.Metric]
A list of metrics. The length of the list is equal to the product of
the lengths of `metrics` and `ks`. The metrics are ordered first by
the values in `ks`, then by the order in `metrics`.
"""

outputs = []

for k in ks:
for metric in metrics:
outputs.append(metric(k=k))

return outputs
7 changes: 4 additions & 3 deletions tests/unit/torch/outputs/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def test_init(self):
assert isinstance(categorical_output, mm.CategoricalOutput)
assert isinstance(categorical_output.loss, nn.CrossEntropyLoss)
assert sorted(m.__class__.__name__ for m in categorical_output.metrics) == [
"MulticlassAveragePrecision",
"MulticlassPrecision",
"MulticlassRecall",
"RetrievalHitRate",
"RetrievalNormalizedDCG",
"RetrievalPrecision",
"RetrievalRecall",
]
output_schema = categorical_output[0].output_schema.first
assert output_schema.dtype == md.float32
Expand Down

0 comments on commit 4281058

Please sign in to comment.