diff --git a/CHANGELOG.md b/CHANGELOG.md index 054a7ce4d..37ebad246 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `DeadlineMovingAverageModel` ([#827](https://github.com/tinkoff-ai/etna/pull/827)) - `DirectEnsemble` ([#824](https://github.com/tinkoff-ai/etna/pull/824)) - -- +- Add `predict` method to pipelines ([#954](https://github.com/tinkoff-ai/etna/pull/954)) - Implement predict method in `SARIMAXModel`, `AutoARIMAModel`, `SeasonalMovingAverageModel`, `DeadlineMovingAverageModel` ([#948](https://github.com/tinkoff-ai/etna/pull/948)) - Make `SeasonalMovingAverageModel` and `DeadlineMovingAverageModel` to work with context ([#917](https://github.com/tinkoff-ai/etna/pull/917)) - diff --git a/etna/pipeline/base.py b/etna/pipeline/base.py index b37a23c68..a0db147f8 100644 --- a/etna/pipeline/base.py +++ b/etna/pipeline/base.py @@ -150,6 +150,48 @@ def forecast( """ pass + @abstractmethod + def predict( + self, + start_timestamp: Optional[pd.Timestamp] = None, + end_timestamp: Optional[pd.Timestamp] = None, + prediction_interval: bool = False, + quantiles: Sequence[float] = (0.025, 0.975), + ) -> TSDataset: + """Make in-sample predictions in a given range. + + Currently, in situation when segments start with different timestamps + we only guarantee to work with ``start_timestamp`` >= beginning of all segments. + + Parameters + ---------- + start_timestamp: + First timestamp of prediction range to return, should be >= than first timestamp in ``self.ts``; + expected that beginning of each segment <= ``start_timestamp``; + if isn't set the first timestamp where each segment began is taken. + end_timestamp: + Last timestamp of prediction range to return; if isn't set the last timestamp of ``self.ts`` is taken. + Expected that value is <= ``self.ts``. + prediction_interval: + If True returns prediction interval for forecast. + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval. + + Returns + ------- + : + Dataset with predictions in ``[start_timestamp, end_timestamp]`` range. + + Raises + ------ + ValueError: + Value of ``end_timestamp`` is less than ``start_timestamp``. + ValueError: + Value of ``start_timestamp`` goes before point where each segment started. + ValueError: + Value of ``end_timestamp`` goes after the last timestamp. + """ + @abstractmethod def backtest( self, @@ -277,6 +319,93 @@ def forecast( ) return predictions + def _predict( + self, + start_timestamp: pd.Timestamp, + end_timestamp: pd.Timestamp, + prediction_interval: bool, + quantiles: Sequence[float], + ) -> TSDataset: + raise NotImplementedError() + + def predict( + self, + start_timestamp: Optional[pd.Timestamp] = None, + end_timestamp: Optional[pd.Timestamp] = None, + prediction_interval: bool = False, + quantiles: Sequence[float] = (0.025, 0.975), + ) -> TSDataset: + """Make in-sample predictions in a given range. + + Currently, in situation when segments start with different timestamps + we only guarantee to work with ``start_timestamp`` >= beginning of all segments. + + Parameters + ---------- + start_timestamp: + First timestamp of prediction range to return, should be >= than first timestamp in ``self.ts``; + expected that beginning of each segment <= ``start_timestamp``; + if isn't set the first timestamp where each segment began is taken. + end_timestamp: + Last timestamp of prediction range to return; if isn't set the last timestamp of ``self.ts`` is taken. + Expected that value is <= ``self.ts``. + prediction_interval: + If True returns prediction interval for forecast. + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval. + + Returns + ------- + : + Dataset with predictions in ``[start_timestamp, end_timestamp]`` range. + + Raises + ------ + ValueError: + Pipeline wasn't fitted. + ValueError: + Value of ``end_timestamp`` is less than ``start_timestamp``. + ValueError: + Value of ``start_timestamp`` goes before point where each segment started. + ValueError: + Value of ``end_timestamp`` goes after the last timestamp. + """ + # check presence dataset + if self.ts is None: + raise ValueError( + f"{self.__class__.__name__} is not fitted! Fit the {self.__class__.__name__} " + f"before calling predict method." + ) + + # check timestamps + min_timestamp = self.ts.describe()["start_timestamp"].max() + max_timestamp = self.ts.index[-1] + + if start_timestamp is None: + start_timestamp = min_timestamp + if end_timestamp is None: + end_timestamp = max_timestamp + + if start_timestamp < min_timestamp: + raise ValueError("Value of start_timestamp is less than beginning of some segments!") + if end_timestamp > max_timestamp: + raise ValueError("Value of end_timestamp is more than ending of dataset!") + + if start_timestamp > end_timestamp: + raise ValueError("Value of end_timestamp is less than start_timestamp!") + + # check quantiles + self._validate_quantiles(quantiles=quantiles) + + # make prediction + prediction = self._predict( + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + prediction_interval=prediction_interval, + quantiles=quantiles, + ) + return prediction + def _init_backtest(self): self._folds: Optional[Dict[int, Any]] = None self._fold_column = "fold_number" diff --git a/tests/test_pipeline/test_base.py b/tests/test_pipeline/test_base.py new file mode 100644 index 000000000..078cfb68a --- /dev/null +++ b/tests/test_pipeline/test_base.py @@ -0,0 +1,144 @@ +from typing import Sequence +from unittest.mock import MagicMock + +import pandas as pd +import pytest + +from etna.datasets import TSDataset +from etna.datasets import generate_ar_df +from etna.pipeline.base import BasePipeline + + +class DummyPipeline(BasePipeline): + def __init__(self, horizon: int): + super().__init__(horizon=horizon) + + def fit(self, ts: TSDataset): + self.ts = ts + return self + + def _forecast(self): + return None + + def _predict( + self, + start_timestamp: pd.Timestamp, + end_timestamp: pd.Timestamp, + prediction_interval: bool, + quantiles: Sequence[float], + ) -> TSDataset: + return self.ts + + +@pytest.mark.parametrize("quantiles", [(0.025,), (0.975,), (0.025, 0.975)]) +@pytest.mark.parametrize("prediction_interval", [False, True]) +@pytest.mark.parametrize( + "start_timestamp, end_timestamp", + [ + (pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-05")), + (pd.Timestamp("2020-01-10"), pd.Timestamp("2020-01-15")), + ], +) +@pytest.mark.parametrize( + "ts", [TSDataset(df=TSDataset.to_dataset(generate_ar_df(start_time="2020-01-01", periods=50)), freq="D")] +) +def test_predict_pass_params(ts, start_timestamp, end_timestamp, prediction_interval, quantiles): + pipeline = DummyPipeline(horizon=5) + mock = MagicMock() + pipeline._predict = mock + + pipeline.fit(ts) + _ = pipeline.predict( + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + prediction_interval=prediction_interval, + quantiles=quantiles, + ) + + mock.assert_called_once_with( + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + prediction_interval=prediction_interval, + quantiles=quantiles, + ) + + +def test_predict_fail_not_fitted(): + pipeline = DummyPipeline(horizon=5) + with pytest.raises(ValueError, match="DummyPipeline is not fitted"): + _ = pipeline.predict() + + +@pytest.mark.parametrize("ts_name", ["example_tsds", "ts_with_different_series_length"]) +def test_predict_use_ts_timestamps(ts_name, request): + ts = request.getfixturevalue(ts_name) + pipeline = DummyPipeline(horizon=5) + mock = MagicMock() + pipeline._predict = mock + + pipeline.fit(ts) + _ = pipeline.predict() + + expected_start_timestamp = ts.describe()["start_timestamp"].max() + expected_end_timestamp = ts.index.max() + + mock.assert_called_once_with( + start_timestamp=expected_start_timestamp, + end_timestamp=expected_end_timestamp, + prediction_interval=False, + quantiles=(0.025, 0.975), + ) + + +def test_predict_fail_early_start(example_tsds): + pipeline = DummyPipeline(horizon=5) + pipeline.fit(example_tsds) + start_timestamp = example_tsds.index[0] - pd.DateOffset(days=5) + + with pytest.raises(ValueError, match="Value of start_timestamp is less than beginning of some segments"): + _ = pipeline.predict(start_timestamp=start_timestamp) + + +def test_predict_fail_late_end(example_tsds): + pipeline = DummyPipeline(horizon=5) + + pipeline.fit(example_tsds) + end_timestamp = example_tsds.index[-1] + pd.DateOffset(days=5) + + with pytest.raises(ValueError, match="Value of end_timestamp is more than ending of dataset"): + _ = pipeline.predict(end_timestamp=end_timestamp) + + +def test_predict_fail_start_later_than_end(example_tsds): + pipeline = DummyPipeline(horizon=5) + + pipeline.fit(example_tsds) + start_timestamp = example_tsds.index[2] + end_timestamp = example_tsds.index[0] + + with pytest.raises(ValueError, match="Value of end_timestamp is less than start_timestamp"): + _ = pipeline.predict(start_timestamp=start_timestamp, end_timestamp=end_timestamp) + + +@pytest.mark.parametrize("quantiles", [(0.025,), (0.975,), (0.025, 0.975)]) +def test_predict_validate_quantiles(quantiles, example_tsds): + pipeline = DummyPipeline(horizon=5) + mock = MagicMock() + pipeline._validate_quantiles = mock + + pipeline.fit(example_tsds) + _ = pipeline.predict(prediction_interval=True, quantiles=quantiles) + + mock.assert_called_once_with(quantiles=quantiles) + + +def test_predict_return_private_predict(example_tsds): + pipeline = DummyPipeline(horizon=5) + mock = MagicMock() + pipeline._predict = mock + expected_result = mock.return_value + + pipeline.fit(example_tsds) + returned_result = pipeline.predict() + + assert returned_result == expected_result