Skip to content

Commit

Permalink
Add params_to_tune for SeasonalMovingAverageModel, `MovingAverage…
Browse files Browse the repository at this point in the history
…Model`, `NaiveModel` and `DeadlineMovingAverageModel` (#1208)
  • Loading branch information
Mr-Geekman authored Apr 6, 2023
1 parent 1e51920 commit db8cd3a
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add default `params_to_tune` for `ProphetModel` ([#1203](https://github.com/tinkoff-ai/etna/pull/1203))
- Add default `params_to_tune` for `SARIMAXModel`, change default parameters for the model ([#1206](https://github.com/tinkoff-ai/etna/pull/1206))
- Add default `params_to_tune` for linear models ([#1204](https://github.com/tinkoff-ai/etna/pull/1204))
- Add default `params_to_tune` for `SeasonalMovingAverageModel`, `MovingAverageModel`, `NaiveModel` and `DeadlineMovingAverageModel` ([#1208](https://github.com/tinkoff-ai/etna/pull/1208))
### Fixed
- Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
- `ProphetModel` fails with additional seasonality set ([#1157](https://github.com/tinkoff-ai/etna/pull/1157))
Expand Down
2 changes: 1 addition & 1 deletion etna/clustering/hierarchical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from etna.datasets import TSDataset


class ClusteringLinkageMode(Enum):
class ClusteringLinkageMode(str, Enum):
"""Modes allowed for clustering distance computation."""

ward = "ward"
Expand Down
34 changes: 29 additions & 5 deletions etna/models/deadline_ma.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import warnings
from enum import Enum
from typing import Dict
from typing import Optional

import numpy as np
import pandas as pd
from typing_extensions import assert_never

from etna import SETTINGS
from etna.datasets import TSDataset
from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel

if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution
from optuna.distributions import IntUniformDistribution

class SeasonalityMode(Enum):

class SeasonalityMode(str, Enum):
"""Enum for seasonality mode for DeadlineMovingAverageModel."""

month = "month"
Expand Down Expand Up @@ -37,7 +44,7 @@ def __init__(self, window: int = 3, seasonality: str = "month"):
window:
Number of values taken for forecast for each point.
seasonality:
Only allowed monthly or annual seasonality.
Only allowed values are "month" and "year".
"""
self.window = window
self.seasonality = SeasonalityMode(seasonality)
Expand All @@ -59,6 +66,8 @@ def context_size(self) -> int:
cur_value = 366
elif self.seasonality is SeasonalityMode.month:
cur_value = 31
else:
assert_never(self.seasonality)

if self._freq == "H":
cur_value *= 24
Expand Down Expand Up @@ -145,9 +154,10 @@ def _get_context_beginning(

if seasonality is SeasonalityMode.month:
first_index = future_timestamps[0] - pd.DateOffset(months=window)

elif seasonality is SeasonalityMode.year:
first_index = future_timestamps[0] - pd.DateOffset(years=window)
else:
assert_never(seasonality)

if first_index < history_timestamps[0]:
raise ValueError(
Expand All @@ -165,10 +175,12 @@ def _make_predictions(
end_idx = len(result_template)
for i in range(start_idx, end_idx):
for w in range(1, self.window + 1):
if self.seasonality == SeasonalityMode.month:
if self.seasonality is SeasonalityMode.month:
prev_date = result_template.index[i] - pd.DateOffset(months=w)
elif self.seasonality == SeasonalityMode.year:
elif self.seasonality is SeasonalityMode.year:
prev_date = result_template.index[i] - pd.DateOffset(years=w)
else:
assert_never(self.seasonality)

result_template.loc[index[i]] += context.loc[prev_date]

Expand Down Expand Up @@ -301,5 +313,17 @@ def predict(self, ts: TSDataset, prediction_size: int, return_components: bool =
ts.df = new_df
return ts

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.
This grid doesn't tune ``seasonality`` parameter. It expected to be set by the user.
Returns
-------
:
Grid to tune.
"""
return {"window": IntUniformDistribution(low=1, high=10, step=1)}


__all__ = ["DeadlineMovingAverageModel"]
18 changes: 18 additions & 0 deletions etna/models/naive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import Dict

from etna import SETTINGS
from etna.models.seasonal_ma import SeasonalMovingAverageModel

if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution


class NaiveModel(SeasonalMovingAverageModel):
"""Naive model predicts t-th value of series with its (t - lag) value.
Expand All @@ -22,5 +28,17 @@ def __init__(self, lag: int = 1):
self.lag = lag
super().__init__(window=1, seasonality=lag)

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.
This grid is empty.
Returns
-------
:
Grid to tune.
"""
return {}


__all__ = ["NaiveModel"]
18 changes: 18 additions & 0 deletions etna/models/seasonal_ma.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import warnings
from typing import Dict

import numpy as np
import pandas as pd

from etna import SETTINGS
from etna.datasets import TSDataset
from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel

if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution
from optuna.distributions import IntUniformDistribution


class SeasonalMovingAverageModel(
NonPredictionIntervalContextRequiredAbstractModel,
Expand Down Expand Up @@ -191,5 +197,17 @@ def predict(self, ts: TSDataset, prediction_size: int, return_components: bool =
ts.df = new_df
return ts

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.
This grid doesn't tune ``seasonality`` parameter. It expected to be set by the user.
Returns
-------
:
Grid to tune.
"""
return {"window": IntUniformDistribution(low=1, high=10, step=1)}


__all__ = ["SeasonalMovingAverageModel"]
2 changes: 1 addition & 1 deletion etna/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Timestamp = Union[str, pd.Timestamp]


class CrossValidationMode(Enum):
class CrossValidationMode(str, Enum):
"""Enum for different cross-validation modes."""

expand = "expand"
Expand Down
21 changes: 21 additions & 0 deletions tests/test_models/test_simple_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import pytest
from optuna.samplers import RandomSampler

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
Expand Down Expand Up @@ -729,3 +730,23 @@ def test_deadline_model_forecast_correct_with_big_horizons(two_month_ts):
)
def test_save_load(model, example_tsds):
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3)


@pytest.mark.parametrize(
"model, expected_length",
[
(NaiveModel(), 0),
(MovingAverageModel(), 1),
(SeasonalMovingAverageModel(), 1),
(DeadlineMovingAverageModel(), 1),
],
)
def test_params_to_tune(model, expected_length):
grid = model.params_to_tune()
# we need sampler to get a value from distribution
sampler = RandomSampler()

assert len(grid) == expected_length
for name, distribution in grid.items():
value = sampler.sample_independent(study=None, trial=None, param_name=name, param_distribution=distribution)
_ = model.set_params(**{name: value})

0 comments on commit db8cd3a

Please sign in to comment.