Skip to content

Change default aggregation method in metric #684

Merged
merged 2 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Add missed `forecast_params` in forecast CLI method ([#671](https://github.com/tinkoff-ai/etna/pull/671))
-
- Add `_per_segment_average` method to the Metric class ([#684](https://github.com/tinkoff-ai/etna/pull/684))
-
-
-
Expand Down
23 changes: 17 additions & 6 deletions etna/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ def _missing_(cls, value):
)


def identity(x):
return x


class Metric(BaseMixin):
"""
Base class for all the multi-segment metrics.
Expand Down Expand Up @@ -64,7 +60,7 @@ def __init__(self, metric_fn: Callable[..., float], mode: str = MetricAggregatio
if MetricAggregationMode(mode) == MetricAggregationMode.macro:
self._aggregate_metrics = self._macro_average
elif MetricAggregationMode(mode) == MetricAggregationMode.per_segment:
self._aggregate_metrics = identity
self._aggregate_metrics = self._per_segment_average
self.mode = mode

@property
Expand Down Expand Up @@ -134,7 +130,7 @@ def _validate_timestamp_columns(timestamp_true: pd.Series, timestamp_pred: pd.Se
raise ValueError("y_true and y_pred have different timestamps")

@staticmethod
def _macro_average(metrics_per_segments: Dict[str, float]) -> float:
def _macro_average(metrics_per_segments: Dict[str, float]) -> Union[float, Dict[str, float]]:
julia-shenshina marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute macro averaging of metrics over segment.

Expand All @@ -148,6 +144,21 @@ def _macro_average(metrics_per_segments: Dict[str, float]) -> float:
"""
return np.mean(list(metrics_per_segments.values())).item()

@staticmethod
def _per_segment_average(metrics_per_segments: Dict[str, float]) -> Union[float, Dict[str, float]]:
"""
Compute per-segment averaging of metrics over segment.

Parameters
----------
metrics_per_segments: dict of {segment: metric_value} for segments to aggregate

Returns
-------
aggregated dict of metric
"""
return metrics_per_segments

def _log_start(self):
"""Log metric computation."""
tslogger.log(f"Metric {self.__repr__()} is calculated on dataset")
Expand Down