From 5487dc5e3eab2c1ee87044e2639fea38f30a6627 Mon Sep 17 00:00:00 2001 From: Mr-Geekman <36005824+Mr-Geekman@users.noreply.github.com> Date: Fri, 24 Jun 2022 19:58:53 +0300 Subject: [PATCH] Add native prediction intervals for `TFTModel` (#770) --- CHANGELOG.md | 2 +- etna/models/nn/deepar.py | 13 ++++- etna/models/nn/tft.py | 77 +++++++++++++++++++++++++++-- tests/test_models/nn/test_deepar.py | 2 +- tests/test_models/nn/test_tft.py | 60 ++++++++++++++++++++++ 5 files changed, 146 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec5931814..bbf8edff8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - Make native prediction intervals for DeepAR ([#761](https://github.com/tinkoff-ai/etna/pull/761)) -- +- Make native prediction intervals for TFTModel ([#770](https://github.com/tinkoff-ai/etna/pull/770)) - - ### Fixed diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index e2276b2aa..67f0e47ea 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.metrics import DistributionLoss + from pytorch_forecasting.metrics import NormalDistributionLoss from pytorch_forecasting.models import DeepAR from pytorch_lightning import LightningModule @@ -45,8 +46,9 @@ def __init__( hidden_size: int = 10, rnn_layers: int = 2, dropout: float = 0.1, - loss: Optional[DistributionLoss] = None, + loss: Optional["DistributionLoss"] = None, trainer_kwargs: Optional[Dict[str, Any]] = None, + quantiles_kwargs: Optional[Dict[str, Any]] = None, ): """ Initialize DeepAR wrapper. @@ -79,7 +81,11 @@ def __init__( Defaults to :py:class:`pytorch_forecasting.metrics.NormalDistributionLoss`. trainer_kwargs: Additional arguments for pytorch_lightning Trainer. + quantiles_kwargs: + Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss. """ + if loss is None: + loss = NormalDistributionLoss() self.max_epochs = max_epochs self.gpus = gpus self.gradient_clip_val = gradient_clip_val @@ -92,6 +98,7 @@ def __init__( self.dropout = dropout self.loss = loss self.trainer_kwargs = trainer_kwargs if trainer_kwargs is not None else dict() + self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict() self.model: Optional[Union[LightningModule, DeepAR]] = None self.trainer: Optional[pl.Trainer] = None @@ -194,7 +201,9 @@ def forecast( if prediction_interval: quantiles_predicts = self.model.predict( # type: ignore - prediction_dataloader, mode="quantiles", mode_kwargs={"quantiles": quantiles} + prediction_dataloader, + mode="quantiles", + mode_kwargs={"quantiles": quantiles, **self.quantiles_kwargs}, ).numpy() # shape (segments, encoder_length, len(quantiles)) quantiles_predicts = quantiles_predicts.transpose((1, 0, 2)) diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index 283e98c5e..0f82ea4fb 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -1,7 +1,9 @@ +import warnings from typing import Any from typing import Dict from typing import List from typing import Optional +from typing import Sequence from typing import Union import pandas as pd @@ -10,6 +12,7 @@ from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger from etna.models.base import Model +from etna.models.base import PredictIntervalAbstractModel from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -17,11 +20,13 @@ if SETTINGS.torch_required: import pytorch_lightning as pl from pytorch_forecasting.data import TimeSeriesDataSet + from pytorch_forecasting.metrics import MultiHorizonMetric + from pytorch_forecasting.metrics import QuantileLoss from pytorch_forecasting.models import TemporalFusionTransformer from pytorch_lightning import LightningModule -class TFTModel(Model, _DeepCopyMixin): +class TFTModel(Model, PredictIntervalAbstractModel, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. Notes @@ -43,7 +48,9 @@ def __init__( attention_head_size: int = 4, dropout: float = 0.1, hidden_continuous_size: int = 8, + loss: "MultiHorizonMetric" = None, trainer_kwargs: Optional[Dict[str, Any]] = None, + quantiles_kwargs: Optional[Dict[str, Any]] = None, *args, **kwargs, ): @@ -74,9 +81,16 @@ def __init__( Dropout rate. hidden_continuous_size: Hidden size for processing continuous variables. + loss: + Loss function taking prediction and targets. + Defaults to :py:class:`pytorch_forecasting.metrics.QuantileLoss`. trainer_kwargs: Additional arguments for pytorch_lightning Trainer. + quantiles_kwargs: + Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss. """ + if loss is None: + loss = QuantileLoss() self.max_epochs = max_epochs self.gpus = gpus self.gradient_clip_val = gradient_clip_val @@ -89,7 +103,9 @@ def __init__( self.attention_head_size = attention_head_size self.dropout = dropout self.hidden_continuous_size = hidden_continuous_size + self.loss = loss self.trainer_kwargs = trainer_kwargs if trainer_kwargs is not None else dict() + self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict() self.model: Optional[Union[LightningModule, TemporalFusionTransformer]] = None self.trainer: Optional[pl.Trainer] = None @@ -109,6 +125,7 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule: attention_head_size=self.attention_head_size, dropout=self.dropout, hidden_continuous_size=self.hidden_continuous_size, + loss=self.loss, ) @staticmethod @@ -156,14 +173,20 @@ def fit(self, ts: TSDataset) -> "TFTModel": return self @log_decorator - def forecast(self, ts: TSDataset) -> TSDataset: + def forecast( + self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) + ) -> TSDataset: """ Predict future. Parameters ---------- ts: - TSDataset to forecast. + Dataset with features + prediction_interval: + If True returns prediction interval for forecast + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval Returns ------- @@ -181,7 +204,53 @@ def forecast(self, ts: TSDataset) -> TSDataset: predicts = self.model.predict(prediction_dataloader).numpy() # type: ignore # shape (segments, encoder_length) - ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :] + + if prediction_interval: + if not isinstance(self.loss, QuantileLoss): + warnings.warn( + "Quantiles can't be computed because TFTModel supports this only if QunatileLoss is chosen" + ) + else: + quantiles_predicts = self.model.predict( # type: ignore + prediction_dataloader, + mode="quantiles", + mode_kwargs={"quantiles": quantiles, **self.quantiles_kwargs}, + ).numpy() + # shape (segments, encoder_length, len(quantiles)) + + loss_quantiles = self.loss.quantiles + computed_quantiles_indices = [] + computed_quantiles = [] + not_computed_quantiles = [] + for quantile in quantiles: + if quantile in loss_quantiles: + computed_quantiles.append(quantile) + computed_quantiles_indices.append(loss_quantiles.index(quantile)) + else: + not_computed_quantiles.append(quantile) + + if not_computed_quantiles: + warnings.warn( + f"Quantiles: {not_computed_quantiles} can't be computed because loss wasn't fitted on them" + ) + + quantiles_predicts = quantiles_predicts[:, :, computed_quantiles_indices] + quantiles = computed_quantiles + + quantiles_predicts = quantiles_predicts.transpose((1, 0, 2)) + # shape (encoder_length, segments, len(quantiles)) + quantiles_predicts = quantiles_predicts.reshape(quantiles_predicts.shape[0], -1) + # shape (encoder_length, segments * len(quantiles)) + + df = ts.df + segments = ts.segments + quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles] + columns = pd.MultiIndex.from_product([segments, quantile_columns]) + quantiles_df = pd.DataFrame(quantiles_predicts, columns=columns, index=df.index) + df = pd.concat((df, quantiles_df), axis=1) + df = df.sort_index(axis=1) + ts.df = df + ts.inverse_transform() return ts diff --git a/tests/test_models/nn/test_deepar.py b/tests/test_models/nn/test_deepar.py index 6906f8235..9a0ef1588 100644 --- a/tests/test_models/nn/test_deepar.py +++ b/tests/test_models/nn/test_deepar.py @@ -166,7 +166,7 @@ def test_prediction_interval_run_infuture(example_tsds): example_tsds.fit_transform([transform]) model = DeepARModel(max_epochs=2, learning_rate=[0.01], gpus=0, batch_size=64) model.fit(example_tsds) - future = example_tsds.make_future(10) + future = example_tsds.make_future(horizon) forecast = model.forecast(future, prediction_interval=True, quantiles=[0.025, 0.975]) for segment in forecast.segments: segment_slice = forecast[:, segment, :][segment] diff --git a/tests/test_models/nn/test_tft.py b/tests/test_models/nn/test_tft.py index 14c8f309a..d1949b987 100644 --- a/tests/test_models/nn/test_tft.py +++ b/tests/test_models/nn/test_tft.py @@ -128,3 +128,63 @@ def test_forecast_without_make_future(weekly_period_df): model.fit(ts) with pytest.raises(ValueError, match="The future is not generated!"): _ = model.forecast(ts=ts) + + +def _get_default_transform(horizon: int): + return PytorchForecastingTransform( + max_encoder_length=21, + min_encoder_length=21, + max_prediction_length=horizon, + time_varying_known_reals=["time_idx"], + time_varying_unknown_reals=["target"], + static_categoricals=["segment"], + target_normalizer=None, + ) + + +def test_prediction_interval_run_infuture(example_tsds): + horizon = 10 + transform = _get_default_transform(horizon) + example_tsds.fit_transform([transform]) + model = TFTModel(max_epochs=8, learning_rate=[0.1], gpus=0, batch_size=64) + model.fit(example_tsds) + future = example_tsds.make_future(horizon) + forecast = model.forecast(future, prediction_interval=True, quantiles=[0.02, 0.98]) + for segment in forecast.segments: + segment_slice = forecast[:, segment, :][segment] + assert {"target_0.02", "target_0.98", "target"}.issubset(segment_slice.columns) + assert (segment_slice["target_0.98"] - segment_slice["target_0.02"] >= 0).all() + assert (segment_slice["target"] - segment_slice["target_0.02"] >= 0).all() + assert (segment_slice["target_0.98"] - segment_slice["target"] >= 0).all() + + +def test_prediction_interval_run_infuture_warning_not_found_quantiles(example_tsds): + horizon = 10 + transform = _get_default_transform(horizon) + example_tsds.fit_transform([transform]) + model = TFTModel(max_epochs=2, learning_rate=[0.1], gpus=0, batch_size=64) + model.fit(example_tsds) + future = example_tsds.make_future(horizon) + with pytest.warns(UserWarning, match="Quantiles: \[0.4\] can't be computed"): + forecast = model.forecast(future, prediction_interval=True, quantiles=[0.02, 0.4, 0.98]) + for segment in forecast.segments: + segment_slice = forecast[:, segment, :][segment] + assert {"target_0.02", "target_0.98", "target"}.issubset(segment_slice.columns) + assert {"target_0.4"}.isdisjoint(segment_slice.columns) + + +def test_prediction_interval_run_infuture_warning_loss(example_tsds): + from pytorch_forecasting.metrics import MAE as MAEPF + + horizon = 10 + transform = _get_default_transform(horizon) + example_tsds.fit_transform([transform]) + model = TFTModel(max_epochs=2, learning_rate=[0.1], gpus=0, batch_size=64, loss=MAEPF()) + model.fit(example_tsds) + future = example_tsds.make_future(horizon) + with pytest.warns(UserWarning, match="Quantiles can't be computed"): + forecast = model.forecast(future, prediction_interval=True, quantiles=[0.02, 0.98]) + for segment in forecast.segments: + segment_slice = forecast[:, segment, :][segment] + assert {"target"}.issubset(segment_slice.columns) + assert {"target_0.02", "target_0.98"}.isdisjoint(segment_slice.columns)