Skip to content

Add params_to_tune for SARIMAXModel model #1206

Merged
merged 6 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Remove version python-3.7 from `pyproject.toml`, update lock ([#1183](https://github.com/tinkoff-ai/etna/pull/1183))
- Add default `params_to_tune` for catboost models ([#1185](https://github.com/tinkoff-ai/etna/pull/1185))
- Add default `params_to_tune` for `ProphetModel` ([#1203](https://github.com/tinkoff-ai/etna/pull/1203))
- Add default `params_to_tune` for `SARIMAXModel`, change default parameters for the model ([#1206](https://github.com/tinkoff-ai/etna/pull/1206))
- Add default `params_to_tune` for linear models ([#1204](https://github.com/tinkoff-ai/etna/pull/1204))
### Fixed
- Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
Expand Down
49 changes: 43 additions & 6 deletions etna/models/sarimax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from abc import abstractmethod
from datetime import datetime
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
Expand All @@ -13,13 +14,19 @@
from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper
from statsmodels.tsa.statespace.simulation_smoother import SimulationSmoother

from etna import SETTINGS
from etna.libs.pmdarima_utils import seasonal_prediction_with_confidence
from etna.models.base import BaseAdapter
from etna.models.base import PredictionIntervalContextIgnorantAbstractModel
from etna.models.mixins import PerSegmentModelMixin
from etna.models.mixins import PredictionIntervalContextIgnorantModelMixin
from etna.models.utils import determine_num_steps

if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalDistribution
from optuna.distributions import IntUniformDistribution

warnings.filterwarnings(
message="No frequency information was provided, so inferred frequency .* will be used",
action="ignore",
Expand Down Expand Up @@ -374,9 +381,9 @@ class _SARIMAXAdapter(_SARIMAXBaseAdapter):

def __init__(
self,
order: Tuple[int, int, int] = (2, 1, 0),
seasonal_order: Tuple[int, int, int, int] = (1, 1, 0, 12),
trend: Optional[str] = "c",
order: Tuple[int, int, int] = (1, 0, 0),
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
seasonal_order: Tuple[int, int, int, int] = (0, 0, 0, 0),
trend: Optional[str] = None,
measurement_error: bool = False,
time_varying_regression: bool = False,
mle_regression: bool = True,
Expand Down Expand Up @@ -552,9 +559,9 @@ class SARIMAXModel(

def __init__(
self,
order: Tuple[int, int, int] = (2, 1, 0),
seasonal_order: Tuple[int, int, int, int] = (1, 1, 0, 12),
trend: Optional[str] = "c",
order: Tuple[int, int, int] = (1, 0, 0),
seasonal_order: Tuple[int, int, int, int] = (0, 0, 0, 0),
trend: Optional[str] = None,
measurement_error: bool = False,
time_varying_regression: bool = False,
mle_regression: bool = True,
Expand Down Expand Up @@ -698,3 +705,33 @@ def __init__(
**self.kwargs,
)
)

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.

This grid doesn't tune ``seasonal_order.s`` parameter that determines number of periods in a season.
This parameter is expected to be set by the user.

Returns
-------
:
Grid to tune.
"""
num_periods = self.seasonal_order[3]
if num_periods == 0:
return {
"order.0": IntUniformDistribution(low=1, high=6, step=1),
"order.1": IntUniformDistribution(low=1, high=2, step=1),
"order.2": IntUniformDistribution(low=1, high=6, step=1),
"trend": CategoricalDistribution(["n", "c", "t", "ct"]),
}
else:
return {
"order.0": IntUniformDistribution(low=1, high=num_periods - 1, step=1),
"order.1": IntUniformDistribution(low=1, high=2, step=1),
"order.2": IntUniformDistribution(low=1, high=num_periods - 1, step=1),
"seasonal_order.0": IntUniformDistribution(low=0, high=2, step=1),
"seasonal_order.1": IntUniformDistribution(low=0, high=1, step=1),
"seasonal_order.2": IntUniformDistribution(low=0, high=1, step=1),
"trend": CategoricalDistribution(["n", "c", "t", "ct"]),
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def test_get_anomalies_prediction_interval_interface(outliers_tsds, model, in_co
0.95,
{"1": [np.datetime64("2021-01-11")], "2": [np.datetime64("2021-01-09"), np.datetime64("2021-01-27")]},
),
(SARIMAXModel, 0.999, {"1": [], "2": [np.datetime64("2021-01-27")]}),
(
SARIMAXModel,
0.999,
{"1": [np.datetime64("2021-01-11")], "2": [np.datetime64("2021-01-09"), np.datetime64("2021-01-27")]},
),
),
)
def test_get_anomalies_prediction_interval_values(outliers_tsds, model, interval_width, true_anomalies, in_column):
Expand Down
17 changes: 16 additions & 1 deletion tests/test_models/test_sarimax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
from optuna.samplers import RandomSampler
from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper

from etna.models import SARIMAXModel
Expand Down Expand Up @@ -169,7 +170,7 @@ def test_decomposition_hamiltonian_repr_error(dfs_w_exog, components_method_name
)
@pytest.mark.parametrize("trend", (None, "t"))
def test_components_names(dfs_w_exog, regressors, regressors_components, trend, components_method_name, in_sample):
expected_components = regressors_components + ["target_component_sarima"]
expected_components = regressors_components + ["target_component_arima"]
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

train, test = dfs_w_exog
pred_df = train if in_sample else test
Expand Down Expand Up @@ -236,3 +237,17 @@ def test_components_sum_up_to_target(
components = components_method(df=pred_df)

np.testing.assert_allclose(np.sum(components.values, axis=1), np.squeeze(pred))


@pytest.mark.parametrize(
"model", [SARIMAXModel(seasonal_order=(0, 0, 0, 0)), SARIMAXModel(seasonal_order=(0, 0, 0, 7))]
)
def test_params_to_tune(model):
grid = model.params_to_tune()
# we need sampler to get a value from distribution
sampler = RandomSampler()

assert len(grid) > 0
for name, distribution in grid.items():
value = sampler.sample_independent(study=None, trial=None, param_name=name, param_distribution=distribution)
_ = model.set_params(**{name: value})