Skip to content

[BUG] Raise errors in models.nn if they can't make in-sample and some cases out-sample predictions #813

Merged
merged 6 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `known_future` parameter to CLI ([#758](https://github.com/tinkoff-ai/etna/pull/758))
- FutureWarning: The frame.append method is deprecated. Use pandas.concat instead ([#764](https://github.com/tinkoff-ai/etna/pull/764))
- Correct ordering if multi-index in backtest ([#771](https://github.com/tinkoff-ai/etna/pull/771))
-
- Raise errors in models.nn if they can't make in-sample and some cases out-sample predictions ([#813](https://github.com/tinkoff-ai/etna/pull/813))
-
-
-
Expand Down
21 changes: 19 additions & 2 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def __init__(
self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict()
self.model: Optional[Union[LightningModule, DeepAR]] = None
self.trainer: Optional[pl.Trainer] = None
self._last_train_timestamp = None
self._freq: Optional[str] = None

def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule:
"""
Expand Down Expand Up @@ -145,6 +147,8 @@ def fit(self, ts: TSDataset) -> "DeepARModel":
-------
DeepARModel
"""
self._last_train_timestamp = ts.df.index[-1]
self._freq = ts.freq
pf_transform = self._get_pf_transform(ts)
self.model = self._from_dataset(pf_transform.pf_dataset_train)

Expand Down Expand Up @@ -186,6 +190,19 @@ def forecast(
TSDataset
TSDataset with predictions.
"""
if ts.index[0] <= self._last_train_timestamp:
raise NotImplementedError(
"It is not possible to make in-sample predictions with DeepAR model! "
"In-sample predictions aren't supported by current implementation."
)
elif ts.index[0] != pd.date_range(self._last_train_timestamp, periods=2, freq=self._freq)[-1]:
raise NotImplementedError(
"You can only forecast from the next point after the last one in the training dataset: "
f"last train timestamp: {self._last_train_timestamp}, first test timestamp is {ts.index[0]}"
)
else:
pass

pf_transform = self._get_pf_transform(ts)
if pf_transform.pf_dataset_predict is None:
raise ValueError(
Expand All @@ -197,7 +214,7 @@ def forecast(

predicts = self.model.predict(prediction_dataloader).numpy() # type: ignore
# shape (segments, encoder_length)
ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :]
ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[: len(ts.df)]

if prediction_interval:
quantiles_predicts = self.model.predict( # type: ignore
Expand All @@ -215,7 +232,7 @@ def forecast(
segments = ts.segments
quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles]
columns = pd.MultiIndex.from_product([segments, quantile_columns])
quantiles_df = pd.DataFrame(quantiles_predicts, columns=columns, index=df.index)
quantiles_df = pd.DataFrame(quantiles_predicts[: len(df)], columns=columns, index=df.index)
df = pd.concat((df, quantiles_df), axis=1)
df = df.sort_index(axis=1)
ts.df = df
Expand Down
21 changes: 19 additions & 2 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(
self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict()
self.model: Optional[Union[LightningModule, TemporalFusionTransformer]] = None
self.trainer: Optional[pl.Trainer] = None
self._last_train_timestamp = None
self._freq: Optional[str] = None

def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule:
"""
Expand Down Expand Up @@ -152,6 +154,8 @@ def fit(self, ts: TSDataset) -> "TFTModel":
-------
TFTModel
"""
self._last_train_timestamp = ts.df.index[-1]
self._freq = ts.freq
pf_transform = self._get_pf_transform(ts)
self.model = self._from_dataset(pf_transform.pf_dataset_train)

Expand Down Expand Up @@ -193,6 +197,19 @@ def forecast(
TSDataset
TSDataset with predictions.
"""
if ts.index[0] <= self._last_train_timestamp:
raise NotImplementedError(
"It is not possible to make in-sample predictions with TFT model! "
"In-sample predictions aren't supported by current implementation."
)
elif ts.index[0] != pd.date_range(self._last_train_timestamp, periods=2, freq=self._freq)[-1]:
raise NotImplementedError(
"You can only forecast from the next point after the last one in the training dataset: "
f"last train timestamp: {self._last_train_timestamp}, first test timestamp is {ts.index[0]}"
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this else block?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to have all possible switches explicitly

pass

pf_transform = self._get_pf_transform(ts)
if pf_transform.pf_dataset_predict is None:
raise ValueError(
Expand All @@ -204,7 +221,7 @@ def forecast(

predicts = self.model.predict(prediction_dataloader).numpy() # type: ignore
# shape (segments, encoder_length)
ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :]
ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[: len(ts.df)]

if prediction_interval:
if not isinstance(self.loss, QuantileLoss):
Expand Down Expand Up @@ -247,7 +264,7 @@ def forecast(
segments = ts.segments
quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles]
columns = pd.MultiIndex.from_product([segments, quantile_columns])
quantiles_df = pd.DataFrame(quantiles_predicts, columns=columns, index=df.index)
quantiles_df = pd.DataFrame(quantiles_predicts[: len(df)], columns=columns, index=df.index)
df = pd.concat((df, quantiles_df), axis=1)
df = df.sort_index(axis=1)
ts.df = df
Expand Down
1 change: 1 addition & 0 deletions tests/test_models/nn/test_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_forecast_without_make_future(weekly_period_df):

model = DeepARModel(max_epochs=1)
model.fit(ts)
ts.df.index = ts.df.index + pd.Timedelta(days=len(ts.df))
with pytest.raises(ValueError, match="The future is not generated!"):
_ = model.forecast(ts=ts)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_models/nn/test_tft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pandas as pd
import pytest

from etna.datasets.tsdataset import TSDataset
Expand Down Expand Up @@ -114,6 +115,7 @@ def test_forecast_without_make_future(weekly_period_df):

model = TFTModel(max_epochs=1)
model.fit(ts)
ts.df.index = ts.df.index + pd.Timedelta(days=len(ts.df))
with pytest.raises(ValueError, match="The future is not generated!"):
_ = model.forecast(ts=ts)

Expand Down
116 changes: 56 additions & 60 deletions tests/test_models/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,20 @@ def _test_forecast_out_sample_prefix(ts, model, transforms):
# fitting
ts.fit_transform(transforms)
model.fit(ts)

# forecasting full
forecast_full_ts = ts.make_future(5)

import torch # TODO: remove after fix at issue-802

torch.manual_seed(11)

model.forecast(forecast_full_ts)

# forecasting only prefix
forecast_prefix_ts = ts.make_future(5)
forecast_prefix_ts.df = forecast_prefix_ts.df.iloc[:-2]

torch.manual_seed(11) # TODO: remove after fix at issue-802
model.forecast(forecast_prefix_ts)

# checking
Expand Down Expand Up @@ -168,6 +174,17 @@ def test_forecast_in_sample_full(model, transforms, example_tsds):
(LinearMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(ElasticPerSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
],
)
def test_forecast_in_sample_full_failed(model, transforms, example_tsds):
_test_forecast_in_sample_full(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
(
DeepARModel(max_epochs=1, learning_rate=[0.01]),
[
Expand Down Expand Up @@ -196,17 +213,6 @@ def test_forecast_in_sample_full(model, transforms, example_tsds):
),
],
)
def test_forecast_in_sample_full_failed(model, transforms, example_tsds):
_test_forecast_in_sample_full(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_in_sample_full_not_implemented(model, transforms, example_tsds):
with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"):
_test_forecast_in_sample_full(example_tsds, model, transforms)
Expand Down Expand Up @@ -236,10 +242,11 @@ def test_forecast_in_sample_suffix(model, transforms, example_tsds):
_test_forecast_in_sample_suffix(example_tsds, model, transforms)


@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
(
DeepARModel(max_epochs=1, learning_rate=[0.01]),
[
Expand Down Expand Up @@ -268,17 +275,6 @@ def test_forecast_in_sample_suffix(model, transforms, example_tsds):
),
],
)
def test_forecast_in_sample_suffix_failed(model, transforms, example_tsds):
_test_forecast_in_sample_suffix(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_in_sample_suffix_not_implemented(model, transforms, example_tsds):
with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"):
_test_forecast_in_sample_suffix(example_tsds, model, transforms)
Expand All @@ -304,16 +300,6 @@ def test_forecast_in_sample_suffix_not_implemented(model, transforms, example_ts
(NaiveModel(lag=3), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_out_sample_prefix(model, transforms, example_tsds):
_test_forecast_out_sample_prefix(example_tsds, model, transforms)


@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize(
"model, transforms",
[
(
DeepARModel(max_epochs=5, learning_rate=[0.01]),
[
Expand Down Expand Up @@ -342,7 +328,7 @@ def test_forecast_out_sample_prefix(model, transforms, example_tsds):
),
],
)
def test_forecast_out_sample_prefix_failed(model, transforms, example_tsds):
def test_forecast_out_sample_prefix(model, transforms, example_tsds):
_test_forecast_out_sample_prefix(example_tsds, model, transforms)


Expand All @@ -362,6 +348,15 @@ def test_forecast_out_sample_prefix_failed(model, transforms, example_tsds):
(SimpleExpSmoothingModel(), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_out_sample_suffix(model, transforms, example_tsds):
_test_forecast_out_sample_suffix(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(
TFTModel(max_epochs=1, learning_rate=[0.01]),
[
Expand All @@ -376,20 +371,6 @@ def test_forecast_out_sample_prefix_failed(model, transforms, example_tsds):
)
],
),
],
)
def test_forecast_out_sample_suffix(model, transforms, example_tsds):
_test_forecast_out_sample_suffix(example_tsds, model, transforms)


@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize(
"model, transforms",
[
(AutoARIMAModel(), []),
(MovingAverageModel(window=3), []),
(SeasonalMovingAverageModel(), []),
(NaiveModel(lag=3), []),
(
DeepARModel(max_epochs=5, learning_rate=[0.01]),
[
Expand All @@ -404,6 +385,21 @@ def test_forecast_out_sample_suffix(model, transforms, example_tsds):
),
],
)
def test_forecast_out_sample_suffix_not_implemented(model, transforms, example_tsds):
with pytest.raises(NotImplementedError, match="You can only forecast from the next point after the last one"):
_test_forecast_out_sample_suffix(example_tsds, model, transforms)


@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize(
"model, transforms",
[
(AutoARIMAModel(), []),
(MovingAverageModel(window=3), []),
(SeasonalMovingAverageModel(), []),
(NaiveModel(lag=3), []),
],
)
def test_forecast_out_sample_suffix_failed(model, transforms, example_tsds):
_test_forecast_out_sample_suffix(example_tsds, model, transforms)

Expand Down Expand Up @@ -433,6 +429,17 @@ def test_forecast_mixed_in_out_sample(model, transforms, example_tsds):
[
(SARIMAXModel(), []),
(AutoARIMAModel(), []),
],
)
def test_forecast_mixed_in_out_sample_failed(model, transforms, example_tsds):
_test_forecast_mixed_in_out_sample(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
(
DeepARModel(max_epochs=5, learning_rate=[0.01]),
[
Expand Down Expand Up @@ -461,17 +468,6 @@ def test_forecast_mixed_in_out_sample(model, transforms, example_tsds):
),
],
)
def test_forecast_mixed_in_out_sample_failed(model, transforms, example_tsds):
_test_forecast_mixed_in_out_sample(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_mixed_in_out_sample_not_implemented(model, transforms, example_tsds):
with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"):
_test_forecast_mixed_in_out_sample(example_tsds, model, transforms)