From ffed41996e67f645893aa58f686dedeb7d13b1e0 Mon Sep 17 00:00:00 2001 From: martins0n <33594071+martins0n@users.noreply.github.com> Date: Fri, 22 Jul 2022 13:35:50 +0300 Subject: [PATCH 1/6] [BUG] Teach `models.nn` to work with in-sample predictions and out-sample predictions correclty Fixes #787 --- etna/models/nn/deepar.py | 17 +++++++++++++++++ etna/models/nn/tft.py | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 67f0e47ea..777b803fe 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -101,6 +101,8 @@ def __init__( self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict() self.model: Optional[Union[LightningModule, DeepAR]] = None self.trainer: Optional[pl.Trainer] = None + self._last_train_timestamp = None + self._freq = None def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule: """ @@ -145,6 +147,8 @@ def fit(self, ts: TSDataset) -> "DeepARModel": ------- DeepARModel """ + self._last_train_timestamp = ts.df.index[-1] + self._freq = ts.freq pf_transform = self._get_pf_transform(ts) self.model = self._from_dataset(pf_transform.pf_dataset_train) @@ -186,6 +190,19 @@ def forecast( TSDataset TSDataset with predictions. """ + if ts.index[0] < self._last_train_timestamp: + raise NotImplementedError( + "It is not possible to make in-sample predictions with DeepAR model! " + "In-sample predictions aren't supported by current implementation." + ) + elif ts.index[0] != pd.date_range(self._last_train_timestamp, periods=2, freq=self._freq)[-1]: + raise NotImplementedError( + "You can only forecast from the next point after the last one in the training dataset: " + f"last train timestamp: {self._last_train_timestamp}, first test timestamp is {ts.index[0]}" + ) + else: + pass + pf_transform = self._get_pf_transform(ts) if pf_transform.pf_dataset_predict is None: raise ValueError( diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index 0f82ea4fb..24a83bf75 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -108,6 +108,8 @@ def __init__( self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict() self.model: Optional[Union[LightningModule, TemporalFusionTransformer]] = None self.trainer: Optional[pl.Trainer] = None + self._last_train_timestamp = None + self._freq = None def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule: """ @@ -152,6 +154,8 @@ def fit(self, ts: TSDataset) -> "TFTModel": ------- TFTModel """ + self._last_train_timestamp = ts.df.index[-1] + self._freq = ts.freq pf_transform = self._get_pf_transform(ts) self.model = self._from_dataset(pf_transform.pf_dataset_train) @@ -193,6 +197,20 @@ def forecast( TSDataset TSDataset with predictions. """ + + if ts.index[0] < self._last_train_timestamp: + raise NotImplementedError( + "It is not possible to make in-sample predictions with TFT model! " + "In-sample predictions aren't supported by current implementation." + ) + elif ts.index[0] != pd.date_range(self._last_train_timestamp, periods=2, freq=self._freq)[-1]: + raise NotImplementedError( + "You can only forecast from the next point after the last one in the training dataset: " + f"last train timestamp: {self._last_train_timestamp}, first test timestamp is {ts.index[0]}" + ) + else: + pass + pf_transform = self._get_pf_transform(ts) if pf_transform.pf_dataset_predict is None: raise ValueError( From f6fa731edccbdb23e70d2f2f7f18e1e0ad3e6668 Mon Sep 17 00:00:00 2001 From: martins0n <33594071+martins0n@users.noreply.github.com> Date: Fri, 22 Jul 2022 15:27:05 +0300 Subject: [PATCH 2/6] FIX: tests --- etna/models/nn/deepar.py | 8 +-- etna/models/nn/tft.py | 9 ++- tests/test_models/nn/test_tft.py | 2 + tests/test_models/test_inference.py | 103 +++++++++++++++------------- 4 files changed, 64 insertions(+), 58 deletions(-) diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 777b803fe..5930236ca 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -102,7 +102,7 @@ def __init__( self.model: Optional[Union[LightningModule, DeepAR]] = None self.trainer: Optional[pl.Trainer] = None self._last_train_timestamp = None - self._freq = None + self._freq: Optional[str] = None def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule: """ @@ -202,7 +202,7 @@ def forecast( ) else: pass - + pf_transform = self._get_pf_transform(ts) if pf_transform.pf_dataset_predict is None: raise ValueError( @@ -214,7 +214,7 @@ def forecast( predicts = self.model.predict(prediction_dataloader).numpy() # type: ignore # shape (segments, encoder_length) - ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :] + ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[: len(ts.df)] if prediction_interval: quantiles_predicts = self.model.predict( # type: ignore @@ -232,7 +232,7 @@ def forecast( segments = ts.segments quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles] columns = pd.MultiIndex.from_product([segments, quantile_columns]) - quantiles_df = pd.DataFrame(quantiles_predicts, columns=columns, index=df.index) + quantiles_df = pd.DataFrame(quantiles_predicts[: len(df)], columns=columns, index=df.index) df = pd.concat((df, quantiles_df), axis=1) df = df.sort_index(axis=1) ts.df = df diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index 24a83bf75..9c626b7e1 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -109,7 +109,7 @@ def __init__( self.model: Optional[Union[LightningModule, TemporalFusionTransformer]] = None self.trainer: Optional[pl.Trainer] = None self._last_train_timestamp = None - self._freq = None + self._freq: Optional[str] = None def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule: """ @@ -197,7 +197,6 @@ def forecast( TSDataset TSDataset with predictions. """ - if ts.index[0] < self._last_train_timestamp: raise NotImplementedError( "It is not possible to make in-sample predictions with TFT model! " @@ -210,7 +209,7 @@ def forecast( ) else: pass - + pf_transform = self._get_pf_transform(ts) if pf_transform.pf_dataset_predict is None: raise ValueError( @@ -222,7 +221,7 @@ def forecast( predicts = self.model.predict(prediction_dataloader).numpy() # type: ignore # shape (segments, encoder_length) - ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :] + ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[: len(ts.df)] if prediction_interval: if not isinstance(self.loss, QuantileLoss): @@ -265,7 +264,7 @@ def forecast( segments = ts.segments quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles] columns = pd.MultiIndex.from_product([segments, quantile_columns]) - quantiles_df = pd.DataFrame(quantiles_predicts, columns=columns, index=df.index) + quantiles_df = pd.DataFrame(quantiles_predicts[: len(df)], columns=columns, index=df.index) df = pd.concat((df, quantiles_df), axis=1) df = df.sort_index(axis=1) ts.df = df diff --git a/tests/test_models/nn/test_tft.py b/tests/test_models/nn/test_tft.py index 4e31b17c7..5b6ee4601 100644 --- a/tests/test_models/nn/test_tft.py +++ b/tests/test_models/nn/test_tft.py @@ -1,3 +1,4 @@ +import pandas as pd import pytest from etna.datasets.tsdataset import TSDataset @@ -114,6 +115,7 @@ def test_forecast_without_make_future(weekly_period_df): model = TFTModel(max_epochs=1) model.fit(ts) + ts.df.index = ts.df.index + pd.Timedelta(days=len(ts.df)) with pytest.raises(ValueError, match="The future is not generated!"): _ = model.forecast(ts=ts) diff --git a/tests/test_models/test_inference.py b/tests/test_models/test_inference.py index 99e162d8b..87c453a4b 100644 --- a/tests/test_models/test_inference.py +++ b/tests/test_models/test_inference.py @@ -69,14 +69,20 @@ def _test_forecast_out_sample_prefix(ts, model, transforms): # fitting ts.fit_transform(transforms) model.fit(ts) - # forecasting full forecast_full_ts = ts.make_future(5) + + import torch # TODO: remove after fix at issue-802 + + torch.manual_seed(11) + model.forecast(forecast_full_ts) # forecasting only prefix forecast_prefix_ts = ts.make_future(5) forecast_prefix_ts.df = forecast_prefix_ts.df.iloc[:-2] + + torch.manual_seed(11) # TODO: remove after fix at issue-802 model.forecast(forecast_prefix_ts) # checking @@ -168,6 +174,17 @@ def test_forecast_in_sample_full(model, transforms, example_tsds): (LinearMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]), (ElasticPerSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]), (ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]), + ], +) +def test_forecast_in_sample_full_failed(model, transforms, example_tsds): + _test_forecast_in_sample_full(example_tsds, model, transforms) + + +@pytest.mark.parametrize( + "model, transforms", + [ + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ( DeepARModel(max_epochs=1, learning_rate=[0.01]), [ @@ -196,17 +213,6 @@ def test_forecast_in_sample_full(model, transforms, example_tsds): ), ], ) -def test_forecast_in_sample_full_failed(model, transforms, example_tsds): - _test_forecast_in_sample_full(example_tsds, model, transforms) - - -@pytest.mark.parametrize( - "model, transforms", - [ - (BATSModel(use_trend=True), []), - (TBATSModel(use_trend=True), []), - ], -) def test_forecast_in_sample_full_not_implemented(model, transforms, example_tsds): with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"): _test_forecast_in_sample_full(example_tsds, model, transforms) @@ -236,7 +242,6 @@ def test_forecast_in_sample_suffix(model, transforms, example_tsds): _test_forecast_in_sample_suffix(example_tsds, model, transforms) -@pytest.mark.xfail(strict=True) @pytest.mark.parametrize( "model, transforms", [ @@ -304,16 +309,6 @@ def test_forecast_in_sample_suffix_not_implemented(model, transforms, example_ts (NaiveModel(lag=3), []), (BATSModel(use_trend=True), []), (TBATSModel(use_trend=True), []), - ], -) -def test_forecast_out_sample_prefix(model, transforms, example_tsds): - _test_forecast_out_sample_prefix(example_tsds, model, transforms) - - -@pytest.mark.xfail(strict=True) -@pytest.mark.parametrize( - "model, transforms", - [ ( DeepARModel(max_epochs=5, learning_rate=[0.01]), [ @@ -342,7 +337,7 @@ def test_forecast_out_sample_prefix(model, transforms, example_tsds): ), ], ) -def test_forecast_out_sample_prefix_failed(model, transforms, example_tsds): +def test_forecast_out_sample_prefix(model, transforms, example_tsds): _test_forecast_out_sample_prefix(example_tsds, model, transforms) @@ -362,6 +357,15 @@ def test_forecast_out_sample_prefix_failed(model, transforms, example_tsds): (SimpleExpSmoothingModel(), []), (BATSModel(use_trend=True), []), (TBATSModel(use_trend=True), []), + ], +) +def test_forecast_out_sample_suffix(model, transforms, example_tsds): + _test_forecast_out_sample_suffix(example_tsds, model, transforms) + + +@pytest.mark.parametrize( + "model, transforms", + [ ( TFTModel(max_epochs=1, learning_rate=[0.01]), [ @@ -376,20 +380,6 @@ def test_forecast_out_sample_prefix_failed(model, transforms, example_tsds): ) ], ), - ], -) -def test_forecast_out_sample_suffix(model, transforms, example_tsds): - _test_forecast_out_sample_suffix(example_tsds, model, transforms) - - -@pytest.mark.xfail(strict=True) -@pytest.mark.parametrize( - "model, transforms", - [ - (AutoARIMAModel(), []), - (MovingAverageModel(window=3), []), - (SeasonalMovingAverageModel(), []), - (NaiveModel(lag=3), []), ( DeepARModel(max_epochs=5, learning_rate=[0.01]), [ @@ -404,6 +394,21 @@ def test_forecast_out_sample_suffix(model, transforms, example_tsds): ), ], ) +def test_forecast_out_sample_suffix_not_implemented(model, transforms, example_tsds): + with pytest.raises(NotImplementedError, match="You can only forecast from the next point after the last one"): + _test_forecast_out_sample_suffix(example_tsds, model, transforms) + + +@pytest.mark.xfail(strict=True) +@pytest.mark.parametrize( + "model, transforms", + [ + (AutoARIMAModel(), []), + (MovingAverageModel(window=3), []), + (SeasonalMovingAverageModel(), []), + (NaiveModel(lag=3), []), + ], +) def test_forecast_out_sample_suffix_failed(model, transforms, example_tsds): _test_forecast_out_sample_suffix(example_tsds, model, transforms) @@ -433,6 +438,17 @@ def test_forecast_mixed_in_out_sample(model, transforms, example_tsds): [ (SARIMAXModel(), []), (AutoARIMAModel(), []), + ], +) +def test_forecast_mixed_in_out_sample_failed(model, transforms, example_tsds): + _test_forecast_mixed_in_out_sample(example_tsds, model, transforms) + + +@pytest.mark.parametrize( + "model, transforms", + [ + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ( DeepARModel(max_epochs=5, learning_rate=[0.01]), [ @@ -461,17 +477,6 @@ def test_forecast_mixed_in_out_sample(model, transforms, example_tsds): ), ], ) -def test_forecast_mixed_in_out_sample_failed(model, transforms, example_tsds): - _test_forecast_mixed_in_out_sample(example_tsds, model, transforms) - - -@pytest.mark.parametrize( - "model, transforms", - [ - (BATSModel(use_trend=True), []), - (TBATSModel(use_trend=True), []), - ], -) def test_forecast_mixed_in_out_sample_not_implemented(model, transforms, example_tsds): with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"): _test_forecast_mixed_in_out_sample(example_tsds, model, transforms) From 882f2b9fea9824e766f5fbf42802ea43238a7c92 Mon Sep 17 00:00:00 2001 From: martins0n <33594071+martins0n@users.noreply.github.com> Date: Fri, 22 Jul 2022 15:31:00 +0300 Subject: [PATCH 3/6] FIX: changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 00a791d37..242200f85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `known_future` parameter to CLI ([#758](https://github.com/tinkoff-ai/etna/pull/758)) - FutureWarning: The frame.append method is deprecated. Use pandas.concat instead ([#764](https://github.com/tinkoff-ai/etna/pull/764)) - Correct ordering if multi-index in backtest ([#771](https://github.com/tinkoff-ai/etna/pull/771)) -- +- Raise errors in models.nn if they can't make in-sample and some cases out-sample predictions ([#813](https://github.com/tinkoff-ai/etna/pull/813)) - - - From 2cc0ddd9e933829432a4cb4de68b8e3f27f95dd8 Mon Sep 17 00:00:00 2001 From: martins0n <33594071+martins0n@users.noreply.github.com> Date: Fri, 22 Jul 2022 15:46:02 +0300 Subject: [PATCH 4/6] FIX: test --- tests/test_models/test_inference.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/test_models/test_inference.py b/tests/test_models/test_inference.py index 87c453a4b..d1f3a979b 100644 --- a/tests/test_models/test_inference.py +++ b/tests/test_models/test_inference.py @@ -245,6 +245,8 @@ def test_forecast_in_sample_suffix(model, transforms, example_tsds): @pytest.mark.parametrize( "model, transforms", [ + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ( DeepARModel(max_epochs=1, learning_rate=[0.01]), [ @@ -273,17 +275,6 @@ def test_forecast_in_sample_suffix(model, transforms, example_tsds): ), ], ) -def test_forecast_in_sample_suffix_failed(model, transforms, example_tsds): - _test_forecast_in_sample_suffix(example_tsds, model, transforms) - - -@pytest.mark.parametrize( - "model, transforms", - [ - (BATSModel(use_trend=True), []), - (TBATSModel(use_trend=True), []), - ], -) def test_forecast_in_sample_suffix_not_implemented(model, transforms, example_tsds): with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"): _test_forecast_in_sample_suffix(example_tsds, model, transforms) From f7e97aad8446f4ee3bf82d71a8846db13ed1e39b Mon Sep 17 00:00:00 2001 From: martins0n <33594071+martins0n@users.noreply.github.com> Date: Fri, 22 Jul 2022 16:13:37 +0300 Subject: [PATCH 5/6] FIX: deepar test --- tests/test_models/nn/test_deepar.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_models/nn/test_deepar.py b/tests/test_models/nn/test_deepar.py index 04acb90ce..63eb2a823 100644 --- a/tests/test_models/nn/test_deepar.py +++ b/tests/test_models/nn/test_deepar.py @@ -119,6 +119,7 @@ def test_forecast_without_make_future(weekly_period_df): model = DeepARModel(max_epochs=1) model.fit(ts) + ts.df.index = ts.df.index + pd.Timedelta(days=len(ts.df)) with pytest.raises(ValueError, match="The future is not generated!"): _ = model.forecast(ts=ts) From c549ec09d2d1192d3b8c2500f2d7ad007cae8012 Mon Sep 17 00:00:00 2001 From: martins0n <33594071+martins0n@users.noreply.github.com> Date: Fri, 22 Jul 2022 17:54:43 +0300 Subject: [PATCH 6/6] FIX: <= --- etna/models/nn/deepar.py | 2 +- etna/models/nn/tft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 5930236ca..0f04380d7 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -190,7 +190,7 @@ def forecast( TSDataset TSDataset with predictions. """ - if ts.index[0] < self._last_train_timestamp: + if ts.index[0] <= self._last_train_timestamp: raise NotImplementedError( "It is not possible to make in-sample predictions with DeepAR model! " "In-sample predictions aren't supported by current implementation." diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index 9c626b7e1..b3ff52b09 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -197,7 +197,7 @@ def forecast( TSDataset TSDataset with predictions. """ - if ts.index[0] < self._last_train_timestamp: + if ts.index[0] <= self._last_train_timestamp: raise NotImplementedError( "It is not possible to make in-sample predictions with TFT model! " "In-sample predictions aren't supported by current implementation."