Skip to content

Commit

Permalink
ADD: greater_is_better property (#921)
Browse files Browse the repository at this point in the history
  • Loading branch information
martins0n authored Sep 7, 2022
1 parent 7bff536 commit 18a064a
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
-
- Add `greater_is_better` property for Metric ([#921](https://github.com/tinkoff-ai/etna/pull/921))
-
-
-
Expand Down
44 changes: 43 additions & 1 deletion etna/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from abc import ABC
from abc import abstractmethod
from enum import Enum
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Union

import numpy as np
Expand All @@ -24,7 +27,46 @@ def _missing_(cls, value):
)


class Metric(BaseMixin):
class AbstractMetric(ABC):
"""Abstract class for metric."""

@abstractmethod
def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[str, float]]:
"""
Compute metric's value with ``y_true`` and ``y_pred``.
Notes
-----
Note that if ``y_true`` and ``y_pred`` are not sorted Metric will sort it anyway
Parameters
----------
y_true:
dataset with true time series values
y_pred:
dataset with predicted time series values
Returns
-------
:
metric's value aggregated over segments or not (depends on mode)
"""
pass

@property
@abstractmethod
def name(self) -> str:
"""Metric name."""
pass

@property
@abstractmethod
def greater_is_better(self) -> Optional[bool]:
"""Whether higher metric value is better."""
pass


class Metric(AbstractMetric, BaseMixin):
"""
Base class for all the multi-segment metrics.
Expand Down
10 changes: 10 additions & 0 deletions etna/metrics/intervals_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
metrics = self._aggregate_metrics(metrics_per_segment)
return metrics

@property
def greater_is_better(self) -> None:
"""Whether higher metric value is better."""
return None


class Width(Metric, _QuantileMetricMixin):
"""Mean width of prediction intervals.
Expand Down Expand Up @@ -148,5 +153,10 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
metrics = self._aggregate_metrics(metrics_per_segment)
return metrics

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


__all__ = ["Coverage", "Width"]
40 changes: 40 additions & 0 deletions etna/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""
super().__init__(mode=mode, metric_fn=mae, **kwargs)

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class MSE(Metric):
"""Mean squared error metric with multi-segment computation support.
Expand All @@ -57,6 +62,11 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""
super().__init__(mode=mode, metric_fn=mse, **kwargs)

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class R2(Metric):
"""Coefficient of determination metric with multi-segment computation support.
Expand All @@ -80,6 +90,11 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""
super().__init__(mode=mode, metric_fn=r2_score, **kwargs)

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return True


class MAPE(Metric):
"""Mean absolute percentage error metric with multi-segment computation support.
Expand All @@ -104,6 +119,11 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""
super().__init__(mode=mode, metric_fn=mape, **kwargs)

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class SMAPE(Metric):
"""Symmetric mean absolute percentage error metric with multi-segment computation support.
Expand All @@ -128,6 +148,11 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""
super().__init__(mode=mode, metric_fn=smape, **kwargs)

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class MedAE(Metric):
"""Median absolute error metric with multi-segment computation support.
Expand All @@ -152,6 +177,11 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""
super().__init__(mode=mode, metric_fn=medae, **kwargs)

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class MSLE(Metric):
"""Mean squared logarithmic error metric with multi-segment computation support.
Expand All @@ -177,6 +207,11 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""
super().__init__(mode=mode, metric_fn=msle, **kwargs)

@property
def greater_is_better(self) -> bool:
"""Whether higher metric value is better."""
return False


class Sign(Metric):
"""Sign error metric with multi-segment computation support.
Expand All @@ -201,5 +236,10 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs):
"""
super().__init__(mode=mode, metric_fn=sign, **kwargs)

@property
def greater_is_better(self) -> None:
"""Whether higher metric value is better."""
return None


__all__ = ["MAE", "MSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign"]
7 changes: 7 additions & 0 deletions tests/test_metrics/test_intervals_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,10 @@ def test_using_not_presented_quantiles(metric, tsdataset_with_zero_width_quantil
ts_train, ts_test = tsdataset_with_zero_width_quantiles
with pytest.raises(AssertionError, match="Quantile .* is not presented in tsdataset."):
_ = metric(ts_train, ts_test)


@pytest.mark.parametrize(
"metric, greater_is_better", ((Coverage(quantiles=(0.1, 0.3)), None), (Width(quantiles=(0.1, 0.3)), False))
)
def test_metrics_greater_is_better(metric, greater_is_better):
assert metric.greater_is_better == greater_is_better
18 changes: 18 additions & 0 deletions tests/test_metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,24 @@ def test_metrics_values(metric_class, metric_fn, train_test_dfs):
assert value == true_metric_value


@pytest.mark.parametrize(
"metric, greater_is_better",
(
(MAE(), False),
(MSE(), False),
(MedAE(), False),
(MSLE(), False),
(MAPE(), False),
(SMAPE(), False),
(R2(), True),
(Sign(), None),
(DummyMetric(), False),
),
)
def test_metrics_greater_is_better(metric, greater_is_better):
assert metric.greater_is_better == greater_is_better


def test_multiple_calls():
"""Check that metric works correctly in case of multiple call."""
timerange = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", periods=10, freq="1D")})
Expand Down
4 changes: 4 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ def __init__(self, mode: str = MetricAggregationMode.per_segment, alpha: float =
@property
def name(self) -> str:
return self.__repr__()

@property
def greater_is_better(self) -> bool:
return False

1 comment on commit 18a064a

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.