Skip to content

Commit

Permalink
Add native prediction intervals for TFTModel (#770)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Jun 24, 2022
1 parent e3445c9 commit 5487dc5
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
- Make native prediction intervals for DeepAR ([#761](https://github.com/tinkoff-ai/etna/pull/761))
-
- Make native prediction intervals for TFTModel ([#770](https://github.com/tinkoff-ai/etna/pull/770))
-
-
### Fixed
Expand Down
13 changes: 11 additions & 2 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytorch_lightning as pl
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.metrics import DistributionLoss
from pytorch_forecasting.metrics import NormalDistributionLoss
from pytorch_forecasting.models import DeepAR
from pytorch_lightning import LightningModule

Expand All @@ -45,8 +46,9 @@ def __init__(
hidden_size: int = 10,
rnn_layers: int = 2,
dropout: float = 0.1,
loss: Optional[DistributionLoss] = None,
loss: Optional["DistributionLoss"] = None,
trainer_kwargs: Optional[Dict[str, Any]] = None,
quantiles_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Initialize DeepAR wrapper.
Expand Down Expand Up @@ -79,7 +81,11 @@ def __init__(
Defaults to :py:class:`pytorch_forecasting.metrics.NormalDistributionLoss`.
trainer_kwargs:
Additional arguments for pytorch_lightning Trainer.
quantiles_kwargs:
Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss.
"""
if loss is None:
loss = NormalDistributionLoss()
self.max_epochs = max_epochs
self.gpus = gpus
self.gradient_clip_val = gradient_clip_val
Expand All @@ -92,6 +98,7 @@ def __init__(
self.dropout = dropout
self.loss = loss
self.trainer_kwargs = trainer_kwargs if trainer_kwargs is not None else dict()
self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict()
self.model: Optional[Union[LightningModule, DeepAR]] = None
self.trainer: Optional[pl.Trainer] = None

Expand Down Expand Up @@ -194,7 +201,9 @@ def forecast(

if prediction_interval:
quantiles_predicts = self.model.predict( # type: ignore
prediction_dataloader, mode="quantiles", mode_kwargs={"quantiles": quantiles}
prediction_dataloader,
mode="quantiles",
mode_kwargs={"quantiles": quantiles, **self.quantiles_kwargs},
).numpy()
# shape (segments, encoder_length, len(quantiles))
quantiles_predicts = quantiles_predicts.transpose((1, 0, 2))
Expand Down
77 changes: 73 additions & 4 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import warnings
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Union

import pandas as pd
Expand All @@ -10,18 +12,21 @@
from etna.datasets.tsdataset import TSDataset
from etna.loggers import tslogger
from etna.models.base import Model
from etna.models.base import PredictIntervalAbstractModel
from etna.models.base import log_decorator
from etna.models.nn.utils import _DeepCopyMixin
from etna.transforms import PytorchForecastingTransform

if SETTINGS.torch_required:
import pytorch_lightning as pl
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.metrics import MultiHorizonMetric
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_lightning import LightningModule


class TFTModel(Model, _DeepCopyMixin):
class TFTModel(Model, PredictIntervalAbstractModel, _DeepCopyMixin):
"""Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`.
Notes
Expand All @@ -43,7 +48,9 @@ def __init__(
attention_head_size: int = 4,
dropout: float = 0.1,
hidden_continuous_size: int = 8,
loss: "MultiHorizonMetric" = None,
trainer_kwargs: Optional[Dict[str, Any]] = None,
quantiles_kwargs: Optional[Dict[str, Any]] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -74,9 +81,16 @@ def __init__(
Dropout rate.
hidden_continuous_size:
Hidden size for processing continuous variables.
loss:
Loss function taking prediction and targets.
Defaults to :py:class:`pytorch_forecasting.metrics.QuantileLoss`.
trainer_kwargs:
Additional arguments for pytorch_lightning Trainer.
quantiles_kwargs:
Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss.
"""
if loss is None:
loss = QuantileLoss()
self.max_epochs = max_epochs
self.gpus = gpus
self.gradient_clip_val = gradient_clip_val
Expand All @@ -89,7 +103,9 @@ def __init__(
self.attention_head_size = attention_head_size
self.dropout = dropout
self.hidden_continuous_size = hidden_continuous_size
self.loss = loss
self.trainer_kwargs = trainer_kwargs if trainer_kwargs is not None else dict()
self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict()
self.model: Optional[Union[LightningModule, TemporalFusionTransformer]] = None
self.trainer: Optional[pl.Trainer] = None

Expand All @@ -109,6 +125,7 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule:
attention_head_size=self.attention_head_size,
dropout=self.dropout,
hidden_continuous_size=self.hidden_continuous_size,
loss=self.loss,
)

@staticmethod
Expand Down Expand Up @@ -156,14 +173,20 @@ def fit(self, ts: TSDataset) -> "TFTModel":
return self

@log_decorator
def forecast(self, ts: TSDataset) -> TSDataset:
def forecast(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
) -> TSDataset:
"""
Predict future.
Parameters
----------
ts:
TSDataset to forecast.
Dataset with features
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval
Returns
-------
Expand All @@ -181,7 +204,53 @@ def forecast(self, ts: TSDataset) -> TSDataset:

predicts = self.model.predict(prediction_dataloader).numpy() # type: ignore
# shape (segments, encoder_length)

ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :]

if prediction_interval:
if not isinstance(self.loss, QuantileLoss):
warnings.warn(
"Quantiles can't be computed because TFTModel supports this only if QunatileLoss is chosen"
)
else:
quantiles_predicts = self.model.predict( # type: ignore
prediction_dataloader,
mode="quantiles",
mode_kwargs={"quantiles": quantiles, **self.quantiles_kwargs},
).numpy()
# shape (segments, encoder_length, len(quantiles))

loss_quantiles = self.loss.quantiles
computed_quantiles_indices = []
computed_quantiles = []
not_computed_quantiles = []
for quantile in quantiles:
if quantile in loss_quantiles:
computed_quantiles.append(quantile)
computed_quantiles_indices.append(loss_quantiles.index(quantile))
else:
not_computed_quantiles.append(quantile)

if not_computed_quantiles:
warnings.warn(
f"Quantiles: {not_computed_quantiles} can't be computed because loss wasn't fitted on them"
)

quantiles_predicts = quantiles_predicts[:, :, computed_quantiles_indices]
quantiles = computed_quantiles

quantiles_predicts = quantiles_predicts.transpose((1, 0, 2))
# shape (encoder_length, segments, len(quantiles))
quantiles_predicts = quantiles_predicts.reshape(quantiles_predicts.shape[0], -1)
# shape (encoder_length, segments * len(quantiles))

df = ts.df
segments = ts.segments
quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles]
columns = pd.MultiIndex.from_product([segments, quantile_columns])
quantiles_df = pd.DataFrame(quantiles_predicts, columns=columns, index=df.index)
df = pd.concat((df, quantiles_df), axis=1)
df = df.sort_index(axis=1)
ts.df = df

ts.inverse_transform()
return ts
2 changes: 1 addition & 1 deletion tests/test_models/nn/test_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_prediction_interval_run_infuture(example_tsds):
example_tsds.fit_transform([transform])
model = DeepARModel(max_epochs=2, learning_rate=[0.01], gpus=0, batch_size=64)
model.fit(example_tsds)
future = example_tsds.make_future(10)
future = example_tsds.make_future(horizon)
forecast = model.forecast(future, prediction_interval=True, quantiles=[0.025, 0.975])
for segment in forecast.segments:
segment_slice = forecast[:, segment, :][segment]
Expand Down
60 changes: 60 additions & 0 deletions tests/test_models/nn/test_tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,63 @@ def test_forecast_without_make_future(weekly_period_df):
model.fit(ts)
with pytest.raises(ValueError, match="The future is not generated!"):
_ = model.forecast(ts=ts)


def _get_default_transform(horizon: int):
return PytorchForecastingTransform(
max_encoder_length=21,
min_encoder_length=21,
max_prediction_length=horizon,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
static_categoricals=["segment"],
target_normalizer=None,
)


def test_prediction_interval_run_infuture(example_tsds):
horizon = 10
transform = _get_default_transform(horizon)
example_tsds.fit_transform([transform])
model = TFTModel(max_epochs=8, learning_rate=[0.1], gpus=0, batch_size=64)
model.fit(example_tsds)
future = example_tsds.make_future(horizon)
forecast = model.forecast(future, prediction_interval=True, quantiles=[0.02, 0.98])
for segment in forecast.segments:
segment_slice = forecast[:, segment, :][segment]
assert {"target_0.02", "target_0.98", "target"}.issubset(segment_slice.columns)
assert (segment_slice["target_0.98"] - segment_slice["target_0.02"] >= 0).all()
assert (segment_slice["target"] - segment_slice["target_0.02"] >= 0).all()
assert (segment_slice["target_0.98"] - segment_slice["target"] >= 0).all()


def test_prediction_interval_run_infuture_warning_not_found_quantiles(example_tsds):
horizon = 10
transform = _get_default_transform(horizon)
example_tsds.fit_transform([transform])
model = TFTModel(max_epochs=2, learning_rate=[0.1], gpus=0, batch_size=64)
model.fit(example_tsds)
future = example_tsds.make_future(horizon)
with pytest.warns(UserWarning, match="Quantiles: \[0.4\] can't be computed"):
forecast = model.forecast(future, prediction_interval=True, quantiles=[0.02, 0.4, 0.98])
for segment in forecast.segments:
segment_slice = forecast[:, segment, :][segment]
assert {"target_0.02", "target_0.98", "target"}.issubset(segment_slice.columns)
assert {"target_0.4"}.isdisjoint(segment_slice.columns)


def test_prediction_interval_run_infuture_warning_loss(example_tsds):
from pytorch_forecasting.metrics import MAE as MAEPF

horizon = 10
transform = _get_default_transform(horizon)
example_tsds.fit_transform([transform])
model = TFTModel(max_epochs=2, learning_rate=[0.1], gpus=0, batch_size=64, loss=MAEPF())
model.fit(example_tsds)
future = example_tsds.make_future(horizon)
with pytest.warns(UserWarning, match="Quantiles can't be computed"):
forecast = model.forecast(future, prediction_interval=True, quantiles=[0.02, 0.98])
for segment in forecast.segments:
segment_slice = forecast[:, segment, :][segment]
assert {"target"}.issubset(segment_slice.columns)
assert {"target_0.02", "target_0.98"}.isdisjoint(segment_slice.columns)

1 comment on commit 5487dc5

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.