Skip to content

Add predict method to pipelines #954

Merged
merged 4 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DeadlineMovingAverageModel` ([#827](https://github.com/tinkoff-ai/etna/pull/827))
- `DirectEnsemble` ([#824](https://github.com/tinkoff-ai/etna/pull/824))
-
-
- Add `predict` method to pipelines ([#954](https://github.com/tinkoff-ai/etna/pull/954))
- Implement predict method in `SARIMAXModel`, `AutoARIMAModel`, `SeasonalMovingAverageModel`, `DeadlineMovingAverageModel` ([#948](https://github.com/tinkoff-ai/etna/pull/948))
- Make `SeasonalMovingAverageModel` and `DeadlineMovingAverageModel` to work with context ([#917](https://github.com/tinkoff-ai/etna/pull/917))
-
Expand Down
129 changes: 129 additions & 0 deletions etna/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,48 @@ def forecast(
"""
pass

@abstractmethod
def predict(
self,
start_timestamp: Optional[pd.Timestamp] = None,
end_timestamp: Optional[pd.Timestamp] = None,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
) -> TSDataset:
"""Make in-sample predictions in a given range.

Currently, in situation when segments start with different timestamps
we only guarantee to work with ``start_timestamp`` >= beginning of all segments.

Parameters
----------
start_timestamp:
First timestamp of prediction range to return, should be >= than first timestamp in ``self.ts``;
expected that beginning of each segment <= ``start_timestamp``;
if isn't set the first timestamp where each segment began is taken.
end_timestamp:
Last timestamp of prediction range to return; if isn't set the last timestamp of ``self.ts`` is taken.
Expected that value is <= ``self.ts``.
prediction_interval:
If True returns prediction interval for forecast.
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval.

Returns
-------
:
Dataset with predictions in ``[start_timestamp, end_timestamp]`` range.

Raises
------
ValueError:
Value of ``end_timestamp`` is less than ``start_timestamp``.
ValueError:
Value of ``start_timestamp`` goes before point where each segment started.
ValueError:
Value of ``end_timestamp`` goes after the last timestamp.
"""

@abstractmethod
def backtest(
self,
Expand Down Expand Up @@ -277,6 +319,93 @@ def forecast(
)
return predictions

def _predict(
self,
start_timestamp: pd.Timestamp,
end_timestamp: pd.Timestamp,
prediction_interval: bool,
quantiles: Sequence[float],
) -> TSDataset:
raise NotImplementedError()

def predict(
self,
start_timestamp: Optional[pd.Timestamp] = None,
end_timestamp: Optional[pd.Timestamp] = None,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
) -> TSDataset:
"""Make in-sample predictions in a given range.

Currently, in situation when segments start with different timestamps
we only guarantee to work with ``start_timestamp`` >= beginning of all segments.

Parameters
----------
start_timestamp:
First timestamp of prediction range to return, should be >= than first timestamp in ``self.ts``;
expected that beginning of each segment <= ``start_timestamp``;
if isn't set the first timestamp where each segment began is taken.
end_timestamp:
Last timestamp of prediction range to return; if isn't set the last timestamp of ``self.ts`` is taken.
Expected that value is <= ``self.ts``.
prediction_interval:
If True returns prediction interval for forecast.
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval.

Returns
-------
:
Dataset with predictions in ``[start_timestamp, end_timestamp]`` range.

Raises
------
ValueError:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
Pipeline wasn't fitted.
ValueError:
Value of ``end_timestamp`` is less than ``start_timestamp``.
ValueError:
Value of ``start_timestamp`` goes before point where each segment started.
ValueError:
Value of ``end_timestamp`` goes after the last timestamp.
"""
# check presence dataset
if self.ts is None:
raise ValueError(
f"{self.__class__.__name__} is not fitted! Fit the {self.__class__.__name__} "
f"before calling predict method."
)

# check timestamps
min_timestamp = self.ts.describe()["start_timestamp"].max()
max_timestamp = self.ts.index[-1]

if start_timestamp is None:
start_timestamp = min_timestamp
if end_timestamp is None:
end_timestamp = max_timestamp

if start_timestamp < min_timestamp:
raise ValueError("Value of start_timestamp is less than beginning of some segments!")
if end_timestamp > max_timestamp:
raise ValueError("Value of end_timestamp is more than ending of dataset!")

if start_timestamp > end_timestamp:
raise ValueError("Value of end_timestamp is less than start_timestamp!")

# check quantiles
self._validate_quantiles(quantiles=quantiles)

# make prediction
prediction = self._predict(
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
prediction_interval=prediction_interval,
quantiles=quantiles,
)
return prediction

def _init_backtest(self):
self._folds: Optional[Dict[int, Any]] = None
self._fold_column = "fold_number"
Expand Down
152 changes: 152 additions & 0 deletions tests/test_pipeline/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Sequence
from unittest.mock import MagicMock

import numpy as np
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
from etna.pipeline.base import BasePipeline


@pytest.fixture
def ts_with_different_beginnings(example_tsds):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
df = example_tsds.to_pandas()
df.iloc[:5, 0] = np.NaN
return TSDataset(df=df, freq="D")


class DummyPipeline(BasePipeline):
def __init__(self, horizon: int):
super().__init__(horizon=horizon)

def fit(self, ts: TSDataset):
self.ts = ts
return self

def _forecast(self):
return None

def _predict(
self,
start_timestamp: pd.Timestamp,
end_timestamp: pd.Timestamp,
prediction_interval: bool,
quantiles: Sequence[float],
) -> TSDataset:
return self.ts


@pytest.mark.parametrize("quantiles", [(0.025,), (0.975,), (0.025, 0.975)])
@pytest.mark.parametrize("prediction_interval", [False, True])
@pytest.mark.parametrize(
"start_timestamp, end_timestamp",
[
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-05")),
(pd.Timestamp("2020-01-10"), pd.Timestamp("2020-01-15")),
],
)
@pytest.mark.parametrize(
"ts", [TSDataset(df=TSDataset.to_dataset(generate_ar_df(start_time="2020-01-01", periods=50)), freq="D")]
)
def test_predict_pass_params(ts, start_timestamp, end_timestamp, prediction_interval, quantiles):
pipeline = DummyPipeline(horizon=5)
mock = MagicMock()
pipeline._predict = mock

pipeline.fit(ts)
_ = pipeline.predict(
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
prediction_interval=prediction_interval,
quantiles=quantiles,
)

mock.assert_called_once_with(
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
prediction_interval=prediction_interval,
quantiles=quantiles,
)


def test_predict_fail_not_fitted():
pipeline = DummyPipeline(horizon=5)
with pytest.raises(ValueError, match="DummyPipeline is not fitted"):
_ = pipeline.predict()


@pytest.mark.parametrize("ts_name", ["example_tsds", "ts_with_different_beginnings"])
def test_predict_use_ts_timestamps(ts_name, request):
ts = request.getfixturevalue(ts_name)
pipeline = DummyPipeline(horizon=5)
mock = MagicMock()
pipeline._predict = mock

pipeline.fit(ts)
_ = pipeline.predict()

expected_start_timestamp = ts.describe()["start_timestamp"].max()
expected_end_timestamp = ts.index.max()

mock.assert_called_once_with(
start_timestamp=expected_start_timestamp,
end_timestamp=expected_end_timestamp,
prediction_interval=False,
quantiles=(0.025, 0.975),
)


def test_predict_fail_early_start(example_tsds):
pipeline = DummyPipeline(horizon=5)
pipeline.fit(example_tsds)
start_timestamp = example_tsds.index[0] - pd.DateOffset(days=5)

with pytest.raises(ValueError, match="Value of start_timestamp is less than beginning of some segments"):
_ = pipeline.predict(start_timestamp=start_timestamp)


def test_predict_fail_late_end(example_tsds):
pipeline = DummyPipeline(horizon=5)

pipeline.fit(example_tsds)
end_timestamp = example_tsds.index[-1] + pd.DateOffset(days=5)

with pytest.raises(ValueError, match="Value of end_timestamp is more than ending of dataset"):
_ = pipeline.predict(end_timestamp=end_timestamp)


def test_predict_fail_start_later_than_end(example_tsds):
pipeline = DummyPipeline(horizon=5)

pipeline.fit(example_tsds)
start_timestamp = example_tsds.index[2]
end_timestamp = example_tsds.index[0]

with pytest.raises(ValueError, match="Value of end_timestamp is less than start_timestamp"):
_ = pipeline.predict(start_timestamp=start_timestamp, end_timestamp=end_timestamp)


@pytest.mark.parametrize("quantiles", [(0.025,), (0.975,), (0.025, 0.975)])
def test_predict_validate_quantiles(quantiles, example_tsds):
pipeline = DummyPipeline(horizon=5)
mock = MagicMock()
pipeline._validate_quantiles = mock

pipeline.fit(example_tsds)
_ = pipeline.predict(prediction_interval=True, quantiles=quantiles)

mock.assert_called_once_with(quantiles=quantiles)


def test_predict_return_private_predict(example_tsds):
pipeline = DummyPipeline(horizon=5)
mock = MagicMock()
pipeline._predict = mock
expected_result = mock.return_value

pipeline.fit(example_tsds)
returned_result = pipeline.predict()

assert returned_result == expected_result