Skip to content

Add predict method to pipelines #954

Merged
merged 4 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DeadlineMovingAverageModel` ([#827](https://github.com/tinkoff-ai/etna/pull/827))
- `DirectEnsemble` ([#824](https://github.com/tinkoff-ai/etna/pull/824))
-
-
- Add `predict` method to pipelines ([#954](https://github.com/tinkoff-ai/etna/pull/954))
- Implement predict method in `SARIMAXModel`, `AutoARIMAModel`, `SeasonalMovingAverageModel`, `DeadlineMovingAverageModel` ([#948](https://github.com/tinkoff-ai/etna/pull/948))
- Make `SeasonalMovingAverageModel` and `DeadlineMovingAverageModel` to work with context ([#917](https://github.com/tinkoff-ai/etna/pull/917))
-
Expand Down
129 changes: 129 additions & 0 deletions etna/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,48 @@ def forecast(
"""
pass

@abstractmethod
def predict(
self,
start_timestamp: Optional[pd.Timestamp] = None,
end_timestamp: Optional[pd.Timestamp] = None,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
) -> TSDataset:
"""Make in-sample predictions in a given range.

Currently, in situation when segments start with different timestamps
we only guarantee to work with ``start_timestamp`` >= beginning of all segments.

Parameters
----------
start_timestamp:
First timestamp of prediction range to return, should be >= than first timestamp in ``self.ts``;
expected that beginning of each segment <= ``start_timestamp``;
if isn't set the first timestamp where each segment began is taken.
end_timestamp:
Last timestamp of prediction range to return; if isn't set the last timestamp of ``self.ts`` is taken.
Expected that value is <= ``self.ts``.
prediction_interval:
If True returns prediction interval for forecast.
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval.

Returns
-------
:
Dataset with predictions in ``[start_timestamp, end_timestamp]`` range.

Raises
------
ValueError:
Value of ``end_timestamp`` is less than ``start_timestamp``.
ValueError:
Value of ``start_timestamp`` goes before point where each segment started.
ValueError:
Value of ``end_timestamp`` goes after the last timestamp.
"""

@abstractmethod
def backtest(
self,
Expand Down Expand Up @@ -277,6 +319,93 @@ def forecast(
)
return predictions

def _predict(
self,
start_timestamp: pd.Timestamp,
end_timestamp: pd.Timestamp,
prediction_interval: bool,
quantiles: Sequence[float],
) -> TSDataset:
raise NotImplementedError()

def predict(
self,
start_timestamp: Optional[pd.Timestamp] = None,
end_timestamp: Optional[pd.Timestamp] = None,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
) -> TSDataset:
"""Make in-sample predictions in a given range.

Currently, in situation when segments start with different timestamps
we only guarantee to work with ``start_timestamp`` >= beginning of all segments.

Parameters
----------
start_timestamp:
First timestamp of prediction range to return, should be >= than first timestamp in ``self.ts``;
expected that beginning of each segment <= ``start_timestamp``;
if isn't set the first timestamp where each segment began is taken.
end_timestamp:
Last timestamp of prediction range to return; if isn't set the last timestamp of ``self.ts`` is taken.
Expected that value is <= ``self.ts``.
prediction_interval:
If True returns prediction interval for forecast.
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval.

Returns
-------
:
Dataset with predictions in ``[start_timestamp, end_timestamp]`` range.

Raises
------
ValueError:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
Pipeline wasn't fitted.
ValueError:
Value of ``end_timestamp`` is less than ``start_timestamp``.
ValueError:
Value of ``start_timestamp`` goes before point where each segment started.
ValueError:
Value of ``end_timestamp`` goes after the last timestamp.
"""
# check presence dataset
if self.ts is None:
raise ValueError(
f"{self.__class__.__name__} is not fitted! Fit the {self.__class__.__name__} "
f"before calling predict method."
)

# check timestamps
min_timestamp = self.ts.describe()["start_timestamp"].max()
max_timestamp = self.ts.index[-1]

if start_timestamp is None:
start_timestamp = min_timestamp
if end_timestamp is None:
end_timestamp = max_timestamp

if start_timestamp < min_timestamp:
raise ValueError("Value of start_timestamp is less than beginning of some segments!")
if end_timestamp > max_timestamp:
raise ValueError("Value of end_timestamp is more than ending of dataset!")

if start_timestamp > end_timestamp:
raise ValueError("Value of end_timestamp is less than start_timestamp!")

# check quantiles
self._validate_quantiles(quantiles=quantiles)

# make prediction
prediction = self._predict(
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
prediction_interval=prediction_interval,
quantiles=quantiles,
)
return prediction

def _init_backtest(self):
self._folds: Optional[Dict[int, Any]] = None
self._fold_column = "fold_number"
Expand Down
144 changes: 144 additions & 0 deletions tests/test_pipeline/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Sequence
from unittest.mock import MagicMock

import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
from etna.pipeline.base import BasePipeline


class DummyPipeline(BasePipeline):
def __init__(self, horizon: int):
super().__init__(horizon=horizon)

def fit(self, ts: TSDataset):
self.ts = ts
return self

def _forecast(self):
return None

def _predict(
self,
start_timestamp: pd.Timestamp,
end_timestamp: pd.Timestamp,
prediction_interval: bool,
quantiles: Sequence[float],
) -> TSDataset:
return self.ts


@pytest.mark.parametrize("quantiles", [(0.025,), (0.975,), (0.025, 0.975)])
@pytest.mark.parametrize("prediction_interval", [False, True])
@pytest.mark.parametrize(
"start_timestamp, end_timestamp",
[
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-05")),
(pd.Timestamp("2020-01-10"), pd.Timestamp("2020-01-15")),
],
)
@pytest.mark.parametrize(
"ts", [TSDataset(df=TSDataset.to_dataset(generate_ar_df(start_time="2020-01-01", periods=50)), freq="D")]
)
def test_predict_pass_params(ts, start_timestamp, end_timestamp, prediction_interval, quantiles):
pipeline = DummyPipeline(horizon=5)
mock = MagicMock()
pipeline._predict = mock

pipeline.fit(ts)
_ = pipeline.predict(
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
prediction_interval=prediction_interval,
quantiles=quantiles,
)

mock.assert_called_once_with(
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
prediction_interval=prediction_interval,
quantiles=quantiles,
)


def test_predict_fail_not_fitted():
pipeline = DummyPipeline(horizon=5)
with pytest.raises(ValueError, match="DummyPipeline is not fitted"):
_ = pipeline.predict()


@pytest.mark.parametrize("ts_name", ["example_tsds", "ts_with_different_series_length"])
def test_predict_use_ts_timestamps(ts_name, request):
ts = request.getfixturevalue(ts_name)
pipeline = DummyPipeline(horizon=5)
mock = MagicMock()
pipeline._predict = mock

pipeline.fit(ts)
_ = pipeline.predict()

expected_start_timestamp = ts.describe()["start_timestamp"].max()
expected_end_timestamp = ts.index.max()

mock.assert_called_once_with(
start_timestamp=expected_start_timestamp,
end_timestamp=expected_end_timestamp,
prediction_interval=False,
quantiles=(0.025, 0.975),
)


def test_predict_fail_early_start(example_tsds):
pipeline = DummyPipeline(horizon=5)
pipeline.fit(example_tsds)
start_timestamp = example_tsds.index[0] - pd.DateOffset(days=5)

with pytest.raises(ValueError, match="Value of start_timestamp is less than beginning of some segments"):
_ = pipeline.predict(start_timestamp=start_timestamp)


def test_predict_fail_late_end(example_tsds):
pipeline = DummyPipeline(horizon=5)

pipeline.fit(example_tsds)
end_timestamp = example_tsds.index[-1] + pd.DateOffset(days=5)

with pytest.raises(ValueError, match="Value of end_timestamp is more than ending of dataset"):
_ = pipeline.predict(end_timestamp=end_timestamp)


def test_predict_fail_start_later_than_end(example_tsds):
pipeline = DummyPipeline(horizon=5)

pipeline.fit(example_tsds)
start_timestamp = example_tsds.index[2]
end_timestamp = example_tsds.index[0]

with pytest.raises(ValueError, match="Value of end_timestamp is less than start_timestamp"):
_ = pipeline.predict(start_timestamp=start_timestamp, end_timestamp=end_timestamp)


@pytest.mark.parametrize("quantiles", [(0.025,), (0.975,), (0.025, 0.975)])
def test_predict_validate_quantiles(quantiles, example_tsds):
pipeline = DummyPipeline(horizon=5)
mock = MagicMock()
pipeline._validate_quantiles = mock

pipeline.fit(example_tsds)
_ = pipeline.predict(prediction_interval=True, quantiles=quantiles)

mock.assert_called_once_with(quantiles=quantiles)


def test_predict_return_private_predict(example_tsds):
pipeline = DummyPipeline(horizon=5)
mock = MagicMock()
pipeline._predict = mock
expected_result = mock.return_value

pipeline.fit(example_tsds)
returned_result = pipeline.predict()

assert returned_result == expected_result