Skip to content

Add method for determining the size of the context #855

Merged
merged 2 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions etna/models/autoarima.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class AutoARIMAModel(PerSegmentPredictionIntervalModel):
We use :py:class:`pmdarima.arima.arima.ARIMA`.
"""

context_size = 0

def __init__(
self,
**kwargs,
Expand Down
11 changes: 11 additions & 0 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def _forecast_segment(model, segment: Union[str, List[str]], ts: TSDataset) -> p
class FitAbstractModel(ABC):
"""Interface for model with fit method."""

@property
@abstractmethod
def context_size(self) -> int:
"""Upper bound to context size of the model."""
pass

@abstractmethod
def fit(self, ts: TSDataset) -> "FitAbstractModel":
"""Fit model.
Expand Down Expand Up @@ -672,6 +678,11 @@ def __init__(
self.trainer_params = {} if trainer_params is None else trainer_params
self.split_params = {} if split_params is None else split_params

@property
def context_size(self) -> int:
"""Context size of the model."""
return self.encoder_length

@log_decorator
def fit(self, ts: TSDataset) -> "DeepBaseModel":
"""Fit model.
Expand Down
8 changes: 8 additions & 0 deletions etna/models/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class CatBoostPerSegmentModel(PerSegmentModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
iterations: Optional[int] = None,
Expand Down Expand Up @@ -253,6 +255,8 @@ class CatBoostMultiSegmentModel(MultiSegmentModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
iterations: Optional[int] = None,
Expand Down Expand Up @@ -383,6 +387,8 @@ class CatBoostModelPerSegment(CatBoostPerSegmentModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
iterations: Optional[int] = None,
Expand Down Expand Up @@ -512,6 +518,8 @@ class CatBoostModelMultiSegment(CatBoostMultiSegmentModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
iterations: Optional[int] = None,
Expand Down
27 changes: 27 additions & 0 deletions etna/models/deadline_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, window: int = 3, seasonality: str = "month"):
self.window = window
self.seasonality = SeasonalityMode(seasonality)
self.freqs_available = {"H", "D"}
self._freq = None

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_DeadlineMovingAverageModel":
"""
Expand Down Expand Up @@ -92,6 +93,7 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_DeadlineMovingAverag
self.series = targets.loc[timestamps >= first_index]
self.timestamps = timestamps.loc[timestamps >= first_index]
self.shift = len(self.series)
self._freq = freq

return self

Expand Down Expand Up @@ -134,6 +136,24 @@ def predict(self, df: pd.DataFrame) -> np.ndarray:

return res[-len(df) :]

@property
def context_size(self) -> int:
"""Upper bound to context size of the model."""
cur_value = None
if self.seasonality is SeasonalityMode.year:
cur_value = 366
elif self.seasonality is SeasonalityMode.month:
cur_value = 31

if self._freq is None:
raise ValueError("Model is not fitted! Fit the model before trying the find out context size!")
if self._freq == "H":
cur_value *= 24

cur_value *= self.window

return cur_value


class DeadlineMovingAverageModel(PerSegmentModel):
"""Moving average model that uses exact previous dates to predict."""
Expand All @@ -155,6 +175,13 @@ def __init__(self, window: int = 3, seasonality: str = "month"):
base_model=_DeadlineMovingAverageModel(window=window, seasonality=seasonality)
)

@property
def context_size(self) -> int:
"""Upper bound to context size of the model."""
models = self.get_model()
model = next(iter(models.values()))
return model.context_size

def get_model(self) -> Dict[str, "DeadlineMovingAverageModel"]:
"""Get internal model.

Expand Down
6 changes: 6 additions & 0 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ class HoltWintersModel(PerSegmentModel):
We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package.
"""

context_size = 0

def __init__(
self,
trend: Optional[str] = None,
Expand Down Expand Up @@ -482,6 +484,8 @@ class HoltModel(HoltWintersModel):
as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model.
"""

context_size = 0

def __init__(
self,
exponential: bool = False,
Expand Down Expand Up @@ -579,6 +583,8 @@ class SimpleExpSmoothingModel(HoltWintersModel):
as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model.
"""

context_size = 0

def __init__(
self,
initialization_method: str = "estimated",
Expand Down
2 changes: 2 additions & 0 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class DeepARModel(MultiSegmentPredictionIntervalModel, _DeepCopyMixin):
It`s not right pattern of using Transforms and TSDataset.
"""

context_size = 0

def __init__(
self,
batch_size: int = 64,
Expand Down
2 changes: 2 additions & 0 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class TFTModel(MultiSegmentPredictionIntervalModel, _DeepCopyMixin):
It`s not right pattern of using Transforms and TSDataset.
"""

context_size = 0

def __init__(
self,
max_epochs: int = 10,
Expand Down
2 changes: 2 additions & 0 deletions etna/models/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class ProphetModel(PerSegmentPredictionIntervalModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
growth: str = "linear",
Expand Down
2 changes: 2 additions & 0 deletions etna/models/sarimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ class SARIMAXModel(PerSegmentPredictionIntervalModel):
future.
"""

context_size = 0

def __init__(
self,
order: Tuple[int, int, int] = (2, 1, 0),
Expand Down
5 changes: 5 additions & 0 deletions etna/models/seasonal_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def __init__(self, window: int = 5, seasonality: int = 7):
base_model=_SeasonalMovingAverageModel(window=window, seasonality=seasonality)
)

@property
def context_size(self) -> int:
"""Context size of the model."""
return self.window * self.seasonality

def get_model(self) -> Dict[str, "SeasonalMovingAverageModel"]:
"""Get internal model.

Expand Down
4 changes: 4 additions & 0 deletions etna/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def get_model(self) -> RegressorMixin:
class SklearnPerSegmentModel(PerSegmentModel):
"""Class for holding per segment Sklearn model."""

context_size = 0

def __init__(self, regressor: RegressorMixin):
"""
Create instance of SklearnPerSegmentModel with given parameters.
Expand All @@ -90,6 +92,8 @@ def __init__(self, regressor: RegressorMixin):
class SklearnMultiSegmentModel(MultiSegmentModel):
"""Class for holding Sklearn model for all segments."""

context_size = 0

def __init__(self, regressor: RegressorMixin):
"""
Create instance of SklearnMultiSegmentModel with given parameters.
Expand Down
4 changes: 4 additions & 0 deletions etna/models/tbats.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def get_model(self) -> Estimator:
class BATSModel(PerSegmentPredictionIntervalModel):
"""Class for holding segment interval BATS model."""

context_size = 0

def __init__(
self,
use_box_cox: Optional[bool] = None,
Expand Down Expand Up @@ -142,6 +144,8 @@ def __init__(
class TBATSModel(PerSegmentPredictionIntervalModel):
"""Class for holding segment interval TBATS model."""

context_size = 0

def __init__(
self,
use_box_cox: Optional[bool] = None,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_models/nn/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,14 @@ def test_rnn_make_samples(example_df):
assert first_sample["decoder_target"].shape == (decoder_length, 1)
np.testing.assert_equal(example_df[["target"]].iloc[: encoder_length - 1], first_sample["encoder_real"])
np.testing.assert_equal(example_df[["target"]].iloc[1:encoder_length], second_sample["encoder_real"])


@pytest.mark.parametrize("encoder_length", [1, 2, 10])
def test_context_size(encoder_length):
encoder_length = encoder_length
decoder_length = encoder_length
model = RNNModel(
input_size=1, encoder_length=encoder_length, decoder_length=decoder_length, trainer_params=dict(max_epochs=100)
)

assert model.context_size == encoder_length
47 changes: 47 additions & 0 deletions tests/test_models/test_simple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,53 @@ def test_deadline_moving_average_forecaster_correct(df):
assert np.all(res.values == answer.values)


@pytest.mark.parametrize(
"model",
[
SeasonalMovingAverageModel(window=1, seasonality=1),
SeasonalMovingAverageModel(window=3, seasonality=1),
SeasonalMovingAverageModel(window=1, seasonality=3),
SeasonalMovingAverageModel(window=3, seasonality=7),
MovingAverageModel(window=3),
NaiveModel(lag=5),
],
)
def test_context_size_seasonal_ma(model):
expected_context_size = model.window * model.seasonality
assert model.context_size == expected_context_size


@pytest.mark.parametrize(
"model, freq, expected_context_size",
[
(DeadlineMovingAverageModel(window=1, seasonality="month"), "D", 31),
(DeadlineMovingAverageModel(window=3, seasonality="month"), "D", 3 * 31),
(DeadlineMovingAverageModel(window=1, seasonality="year"), "D", 366),
(DeadlineMovingAverageModel(window=3, seasonality="year"), "D", 3 * 366),
(DeadlineMovingAverageModel(window=1, seasonality="month"), "H", 31 * 24),
(DeadlineMovingAverageModel(window=3, seasonality="month"), "H", 3 * 31 * 24),
(DeadlineMovingAverageModel(window=1, seasonality="year"), "H", 366 * 24),
(DeadlineMovingAverageModel(window=3, seasonality="year"), "H", 3 * 366 * 24),
],
)
def test_context_size_deadline_ma(model, freq, expected_context_size):
# create dataframe
df = pd.DataFrame(
{
"timestamp": pd.date_range(start="2020-01-01", periods=expected_context_size + 10, freq=freq),
"segment": "segment_0",
"target": 1,
}
)
ts = TSDataset(df=TSDataset.to_dataset(df), freq=freq)

# fit model
model.fit(ts)

# check result
assert model.context_size == expected_context_size


@pytest.mark.parametrize(
"etna_model_class",
(SeasonalMovingAverageModel, MovingAverageModel, NaiveModel, DeadlineMovingAverageModel),
Expand Down