From 6a8e478a7fd5897b0347f85aef2595373355954e Mon Sep 17 00:00:00 2001 From: Maxim Zherelo <60392282+brsnw250@users.noreply.github.com> Date: Wed, 29 Mar 2023 09:29:42 +0300 Subject: [PATCH] Implement in-sample predictions in BATS/TBATS (#1181) --- CHANGELOG.md | 12 +- etna/models/tbats.py | 122 ++++++- .../test_inference/test_predict.py | 12 +- tests/test_models/test_tbats.py | 323 +++++++++++------- 4 files changed, 336 insertions(+), 133 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f21c643d3..b96f1f4b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,16 +19,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ChangePointsLevelTransform` and base classes `PerIntervalModel`, `BaseChangePointsModelAdapter` for per-interval transforms ([#998](https://github.com/tinkoff-ai/etna/pull/998)) - Method `set_params` to change parameters of ETNA objects ([#1102](https://github.com/tinkoff-ai/etna/pull/1102)) - Function `plot_forecast_decomposition` ([#1129](https://github.com/tinkoff-ai/etna/pull/1129)) -- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` ([#1125](https://github.com/tinkoff-ai/etna/issues/1125)) -- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` ([#1135](https://github.com/tinkoff-ai/etna/issues/1135)) -- Methods `forecast_components` and `predict_components` for forecast decomposition in `_HoltWintersAdapter ` ([#1146](https://github.com/tinkoff-ai/etna/issues/1146)) -- Methods `predict_components` for forecast decomposition in `_ProphetAdapter` ([#1161](https://github.com/tinkoff-ai/etna/issues/1161)) -- Methods `forecast_components` and `predict_components` for forecast decomposition in `_SARIMAXAdapter` and `_AutoARIMAAdapter` ([#1149](https://github.com/tinkoff-ai/etna/issues/1149)) +- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` ([#1133](https://github.com/tinkoff-ai/etna/pull/1133)) +- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` ([#1148](https://github.com/tinkoff-ai/etna/pull/1148)) +- Methods `forecast_components` and `predict_components` for forecast decomposition in `_HoltWintersAdapter ` ([#1162](https://github.com/tinkoff-ai/etna/pull/1162)) +- Method `predict_components` for forecast decomposition in `_ProphetAdapter` ([#1172](https://github.com/tinkoff-ai/etna/pull/1172)) +- Methods `forecast_components` and `predict_components` for forecast decomposition in `_SARIMAXAdapter` and `_AutoARIMAAdapter` ([#1174](https://github.com/tinkoff-ai/etna/pull/1174)) - Add `refit` parameter into `backtest` ([#1159](https://github.com/tinkoff-ai/etna/pull/1159)) - Add `stride` parameter into `backtest` ([#1165](https://github.com/tinkoff-ai/etna/pull/1165)) - Add optional parameter `ts` into `forecast` method of pipelines ([#1071](https://github.com/tinkoff-ai/etna/pull/1071)) - Add tests on `transform` method of transforms on subset of segments, on new segments, on future with gap ([#1094](https://github.com/tinkoff-ai/etna/pull/1094)) - Add tests on `inverse_transform` method of transforms on subset of segments, on new segments, on future with gap ([#1127](https://github.com/tinkoff-ai/etna/pull/1127)) +- In-sample prediction for `BATSModel` and `TBATSModel` ([#1181](https://github.com/tinkoff-ai/etna/pull/1181)) +- Method `predict_components` for forecast decomposition in `_TBATSAdapter` ([#1181](https://github.com/tinkoff-ai/etna/pull/1181)) - ### Changed - Add optional `features` parameter in the signature of `TSDataset.to_pandas`, `TSDataset.to_flatten` ([#809](https://github.com/tinkoff-ai/etna/pull/809)) diff --git a/etna/models/tbats.py b/etna/models/tbats.py index 32977a2a9..a65449aa9 100644 --- a/etna/models/tbats.py +++ b/etna/models/tbats.py @@ -22,6 +22,7 @@ class _TBATSAdapter(BaseAdapter): def __init__(self, model: Estimator): self._model = model self._fitted_model: Optional[Model] = None + self._first_train_timestamp = None self._last_train_timestamp = None self._freq = None @@ -32,6 +33,7 @@ def fit(self, df: pd.DataFrame, regressors: Iterable[str]): target = df["target"] self._fitted_model = self._model.fit(target) + self._first_train_timestamp = df["timestamp"].min() self._last_train_timestamp = df["timestamp"].max() self._freq = freq @@ -65,7 +67,37 @@ def forecast(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Itera return y_pred def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Iterable[float]) -> pd.DataFrame: - raise NotImplementedError("Method predict isn't currently implemented!") + if self._fitted_model is None or self._freq is None: + raise ValueError("Model is not fitted! Fit the model before calling predict method!") + + train_timestamp = pd.date_range( + start=str(self._first_train_timestamp), end=str(self._last_train_timestamp), freq=self._freq + ) + + if not (set(df["timestamp"]) <= set(train_timestamp)): + raise NotImplementedError("Method predict isn't currently implemented for out-of-sample prediction!") + + y_pred = pd.DataFrame() + y_pred["target"] = self._fitted_model.y_hat + y_pred["timestamp"] = train_timestamp + + if prediction_interval: + for quantile in quantiles: + confidence_intervals = self._fitted_model._calculate_confidence_intervals( + y_pred["target"].values, quantile + ) + + if quantile < 1 / 2: + y_pred[f"target_{quantile:.4g}"] = confidence_intervals["lower_bound"] + else: + y_pred[f"target_{quantile:.4g}"] = confidence_intervals["upper_bound"] + + # selecting time points from provided dataframe + y_pred.set_index("timestamp", inplace=True) + y_pred = y_pred.loc[df["timestamp"]] + y_pred.reset_index(drop=True, inplace=True) + + return y_pred def get_model(self) -> Model: """Get internal :py:class:`tbats.tbats.Model` model that was fitted inside etna class. @@ -114,7 +146,31 @@ def predict_components(self, df: pd.DataFrame) -> pd.DataFrame: : dataframe with prediction components """ - raise NotImplementedError("Prediction decomposition isn't currently implemented!") + if self._fitted_model is None or self._freq is None: + raise ValueError("Model is not fitted! Fit the model before estimating forecast components!") + + train_timestamp = pd.date_range( + start=str(self._first_train_timestamp), end=str(self._last_train_timestamp), freq=self._freq + ) + + if not (set(df["timestamp"]) <= set(train_timestamp)): + raise NotImplementedError( + "Method predict_components isn't currently implemented for out-of-sample prediction!" + ) + + self._check_components() + + raw_components = self._decompose_predict() + components = self._process_components(raw_components=raw_components) + + # selecting time points from provided dataframe + components["timestamp"] = train_timestamp + + components.set_index("timestamp", inplace=True) + components = components.loc[df["timestamp"]] + components.reset_index(drop=True, inplace=True) + + return components def _get_steps_to_forecast(self, df: pd.DataFrame) -> int: if self._freq is None: @@ -157,6 +213,16 @@ def _check_components(self): if len(not_fitted_components) > 0: warn(f"Following components are not fitted: {', '.join(not_fitted_components)}!") + def _rescale_components(self, raw_components: np.ndarray) -> np.ndarray: + """Rescale components when Box-Cox transform used.""" + if self._fitted_model is None: + raise ValueError("Fitted model is not set!") + + transformed_pred = np.sum(raw_components, axis=1) + pred = self._fitted_model._inv_boxcox(transformed_pred) + components = raw_components * pred[..., np.newaxis] / transformed_pred[..., np.newaxis] + return components + def _decompose_forecast(self, horizon: int) -> np.ndarray: """Estimate raw forecast components.""" if self._fitted_model is None: @@ -175,9 +241,33 @@ def _decompose_forecast(self, horizon: int) -> np.ndarray: raw_components = np.stack(components, axis=0) if model.params.components.use_box_cox: - transformed_pred = np.sum(raw_components, axis=1) - pred = model._inv_boxcox(transformed_pred) - raw_components = raw_components * pred[..., np.newaxis] / transformed_pred[..., np.newaxis] + raw_components = self._rescale_components(raw_components) + + return raw_components + + def _decompose_predict(self) -> np.ndarray: + """Estimate raw prediction components.""" + if self._fitted_model is None: + raise ValueError("Fitted model is not set!") + + model = self._fitted_model + state_matrix = model.matrix.make_F_matrix() + component_weights = model.matrix.make_w_vector() + error_weights = model.matrix.make_g_vector() + + steps = len(model.y) + state = model.params.x0 + weighted_error = model.resid_boxcox[..., np.newaxis] * error_weights[np.newaxis] + + components = [] + for t in range(steps): + components.append(component_weights * state) + state = state_matrix @ state + weighted_error[t] + + raw_components = np.stack(components, axis=0) + + if model.params.components.use_box_cox: + raw_components = self._rescale_components(raw_components) return raw_components @@ -223,13 +313,21 @@ def _process_components(self, raw_components: np.ndarray) -> pd.DataFrame: raw_components[:, component_idx : component_idx + p + q], axis=1 ) - return pd.DataFrame(data=named_components) + return pd.DataFrame(data=named_components).add_prefix("target_component_") class BATSModel( PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel ): - """Class for holding segment interval BATS model.""" + """Class for holding segment interval BATS model. + + Notes + ----- + This model supports in-sample and out-of-sample prediction decomposition. + Prediction components for BATS model are: local level, trend, seasonality and ARMA component. + In-sample and out-of-sample decompositions components are estimated directly from the fitted model parameters. + Box-Cox transform supported with components proportional rescaling. + """ def __init__( self, @@ -298,7 +396,15 @@ def __init__( class TBATSModel( PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel ): - """Class for holding segment interval TBATS model.""" + """Class for holding segment interval TBATS model. + + Notes + ----- + This model supports in-sample and out-of-sample prediction decomposition. + Prediction components for TBATS model are: local level, trend, seasonality and ARMA component. + In-sample and out-of-sample decompositions components are estimated directly from the fitted model parameters. + Box-Cox transform supported with components proportional rescaling. + """ def __init__( self, diff --git a/tests/test_models/test_inference/test_predict.py b/tests/test_models/test_inference/test_predict.py index 1a91de593..c0232d232 100644 --- a/tests/test_models/test_inference/test_predict.py +++ b/tests/test_models/test_inference/test_predict.py @@ -60,6 +60,8 @@ class TestPredictInSampleFull: (HoltModel(), []), (HoltWintersModel(), []), (SimpleExpSmoothingModel(), []), + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ], ) def test_predict_in_sample_full(self, model, transforms, example_tsds): @@ -95,8 +97,6 @@ def test_predict_in_sample_full_failed_not_enough_context(self, model, transform @pytest.mark.parametrize( "model, transforms", [ - (BATSModel(use_trend=True), []), - (TBATSModel(use_trend=True), []), ( DeepARModel( dataset_builder=PytorchForecastingDatasetBuilder( @@ -171,6 +171,8 @@ class TestPredictInSampleSuffix: (NaiveModel(lag=3), []), (SeasonalMovingAverageModel(), []), (DeadlineMovingAverageModel(window=1), []), + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ], ) def test_predict_in_sample_suffix(self, model, transforms, example_tsds): @@ -180,8 +182,6 @@ def test_predict_in_sample_suffix(self, model, transforms, example_tsds): @pytest.mark.parametrize( "model, transforms", [ - (BATSModel(use_trend=True), []), - (TBATSModel(use_trend=True), []), ( DeepARModel( dataset_builder=PytorchForecastingDatasetBuilder( @@ -714,6 +714,8 @@ def _test_predict_subset_segments(self, ts, model, transforms, segments, num_ski (SeasonalMovingAverageModel(), []), (NaiveModel(lag=3), []), (DeadlineMovingAverageModel(window=1), []), + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ], ) def test_predict_subset_segments(self, model, transforms, example_tsds): @@ -723,8 +725,6 @@ def test_predict_subset_segments(self, model, transforms, example_tsds): @pytest.mark.parametrize( "model, transforms", [ - (BATSModel(use_trend=True), []), - (TBATSModel(use_trend=True), []), ( DeepARModel( dataset_builder=PytorchForecastingDatasetBuilder( diff --git a/tests/test_models/test_tbats.py b/tests/test_models/test_tbats.py index f4e06463d..e6495ff48 100644 --- a/tests/test_models/test_tbats.py +++ b/tests/test_models/test_tbats.py @@ -1,3 +1,6 @@ +from copy import deepcopy +from unittest.mock import Mock + import numpy as np import pandas as pd import pytest @@ -5,10 +8,10 @@ from etna.datasets import TSDataset from etna.metrics import MAE from etna.models.tbats import BATS +from etna.models.tbats import TBATS from etna.models.tbats import BATSModel from etna.models.tbats import TBATSModel from etna.models.tbats import _TBATSAdapter -from etna.transforms import LagTransform from tests.test_models.test_linear_model import linear_segments_by_parameters from tests.test_models.utils import assert_model_equals_loaded_original @@ -45,9 +48,8 @@ def sinusoid_ts(): @pytest.fixture() -def periodic_ts(): - horizon = 14 - periods = 100 +def periodic_dfs(): + periods = 50 t = np.arange(periods) # data from https://pypi.org/project/tbats/ @@ -58,20 +60,29 @@ def periodic_ts(): + 20 ) - ts_1 = pd.DataFrame( + df = pd.DataFrame( { - "segment": ["segment_1"] * periods, "timestamp": pd.date_range(start="1/1/2018", periods=periods), "target": y, } ) - ts_2 = pd.DataFrame( - { - "segment": ["segment_2"] * periods, - "timestamp": pd.date_range(start="1/1/2018", periods=periods), - "target": 2 * y, - } - ) + + return df.iloc[:40], df.iloc[40:] + + +@pytest.fixture() +def periodic_ts(periodic_dfs): + horizon = 10 + + df = pd.concat(periodic_dfs, axis=0).reset_index(drop=True) + + ts_1 = df.copy() + ts_1["segment"] = "segment_1" + + ts_2 = df.copy() + ts_2["segment"] = "segment_2" + ts_2["target"] *= 2 + df = pd.concat((ts_1, ts_2)) df = TSDataset.to_dataset(df) ts = TSDataset(df, freq="D") @@ -121,45 +132,48 @@ def test_repr(model_class, model_class_repr): @pytest.mark.parametrize("model", (TBATSModel(), BATSModel())) -def test_not_fitted(model, linear_segments_ts_unique): - train, test = linear_segments_ts_unique - to_forecast = train.make_future(3) +@pytest.mark.parametrize("method", ("forecast", "predict")) +def test_not_fitted(model, method, linear_segments_ts_unique): + method_to_call = getattr(model, method) with pytest.raises(ValueError, match="model is not fitted!"): - model.forecast(to_forecast) + method_to_call(ts=Mock()) @pytest.mark.long_2 @pytest.mark.parametrize("model", [TBATSModel(), BATSModel()]) -def test_format(model, new_format_df): - df = new_format_df - ts = TSDataset(df, "1d") - lags = LagTransform(lags=[3, 4, 5], in_column="target") - ts.fit_transform([lags]) - model.fit(ts) - future_ts = ts.make_future(3, transforms=[lags]) - model.forecast(future_ts) - future_ts.inverse_transform([lags]) - assert not future_ts.isnull().values.any() - - -@pytest.mark.long_2 -@pytest.mark.parametrize("model", [TBATSModel(), BATSModel()]) -def test_dummy(model, sinusoid_ts): +@pytest.mark.parametrize("method, use_future", (("predict", False), ("forecast", True))) +def test_dummy(model, method, use_future, sinusoid_ts): train, test = sinusoid_ts model.fit(train) - future_ts = train.make_future(14) - y_pred = model.forecast(future_ts) + + if use_future: + pred_ts = train.make_future(14) + y_true = test + else: + pred_ts = deepcopy(train) + y_true = train + + method_to_call = getattr(model, method) + y_pred = method_to_call(ts=pred_ts) + metric = MAE("macro") - value_metric = metric(y_pred, test) + value_metric = metric(y_true, y_pred) assert value_metric < 0.33 @pytest.mark.long_2 @pytest.mark.parametrize("model", [TBATSModel(), BATSModel()]) -def test_prediction_interval(model, example_tsds): +@pytest.mark.parametrize("method, use_future", (("predict", False), ("forecast", True))) +def test_prediction_interval(model, method, use_future, example_tsds): model.fit(example_tsds) - future_ts = example_tsds.make_future(3) - forecast = model.forecast(future_ts, prediction_interval=True, quantiles=[0.025, 0.975]) + if use_future: + pred_ts = example_tsds.make_future(3) + else: + pred_ts = deepcopy(example_tsds) + + method_to_call = getattr(model, method) + forecast = method_to_call(ts=pred_ts, prediction_interval=True, quantiles=[0.025, 0.975]) + for segment in forecast.segments: segment_slice = forecast[:, segment, :][segment] assert {"target_0.025", "target_0.975", "target"}.issubset(segment_slice.columns) @@ -172,141 +186,194 @@ def test_save_load(model, example_tsds): assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3) -def test_forecast_decompose_not_fitted(small_periodic_ts): +@pytest.mark.parametrize("method", ("predict_components", "forecast_components")) +def test_decompose_not_fitted(small_periodic_ts, method): model = _TBATSAdapter(model=BATS()) + method_to_call = getattr(model, method) with pytest.raises(ValueError, match="Model is not fitted!"): - model.forecast_components(df=small_periodic_ts.df) + method_to_call(df=small_periodic_ts.df) -def test_predict_components_not_implemented(small_periodic_ts): - model = _TBATSAdapter(model=BATS()) +@pytest.mark.parametrize( + "estimator", + (BATS, TBATS), +) +def test_decompose_forecast_output_format(periodic_dfs, estimator): + _, train = periodic_dfs + horizon = 3 - with pytest.raises(NotImplementedError, match="Prediction decomposition isn't currently implemented!"): - model.predict_components(df=small_periodic_ts.df) + model = _TBATSAdapter(model=estimator()) + model.fit(train, []) + + components = model._decompose_forecast(horizon=horizon) + assert isinstance(components, np.ndarray) + assert components.shape[0] == horizon @pytest.mark.parametrize( "estimator", ( - BATSModel, - TBATSModel, + BATS, + TBATS, ), ) -def test_decompose_forecast_output_format(small_periodic_ts, estimator): - horizon = 3 - model = estimator() - model.fit(small_periodic_ts) +def test_decompose_predict_output_format(periodic_dfs, estimator): + _, train = periodic_dfs + model = _TBATSAdapter(model=estimator()) + model.fit(train, []) - components = model._models["segment_1"]._decompose_forecast(horizon=horizon) + components = model._decompose_predict() assert isinstance(components, np.ndarray) - assert components.shape[0] == horizon + assert components.shape[0] == len(train) @pytest.mark.parametrize( "estimator", ( - BATSModel, - TBATSModel, + BATS, + TBATS, ), ) -def test_named_components_output_format(small_periodic_ts, estimator): +def test_named_components_output_format(periodic_dfs, estimator): + _, train = periodic_dfs horizon = 3 - model = estimator() - model.fit(small_periodic_ts) - segment_model = model._models["segment_1"] - components = segment_model._decompose_forecast(horizon=horizon) - components = segment_model._process_components(raw_components=components) + model = _TBATSAdapter(model=estimator()) + model.fit(train, []) + + components = model._decompose_forecast(horizon=horizon) + components = model._process_components(raw_components=components) assert isinstance(components, pd.DataFrame) assert len(components) == horizon +@pytest.mark.parametrize( + "train_slice,decompose_slice", ((slice(5, 20), slice(None, 20)), (slice(5, 10), slice(10, 20))) +) +def test_predict_components_out_of_sample_error(periodic_dfs, train_slice, decompose_slice): + train, _ = periodic_dfs + + model = _TBATSAdapter(model=BATS()) + model.fit(train.iloc[train_slice], []) + with pytest.raises(NotImplementedError, match="isn't currently implemented for out-of-sample prediction"): + model.predict_components(df=train.iloc[decompose_slice]) + + @pytest.mark.long_1 @pytest.mark.parametrize( "estimator,params,components_names", ( ( - BATSModel, - {"use_box_cox": False, "use_trend": True, "use_arma_errors": True, "seasonal_periods": [7, 14]}, - {"local_level", "trend", "arma(p=1,q=1)", "seasonal(s=7)", "seasonal(s=14)"}, + BATS, + {"use_box_cox": True, "use_trend": True, "use_arma_errors": True, "seasonal_periods": [7, 14]}, + { + "target_component_local_level", + "target_component_trend", + "target_component_arma(p=1,q=1)", + "target_component_seasonal(s=7)", + "target_component_seasonal(s=14)", + }, ), ( - TBATSModel, + TBATS, {"use_box_cox": False, "use_trend": True, "use_arma_errors": False, "seasonal_periods": [7, 14]}, - {"local_level", "trend", "seasonal(s=7.0)", "seasonal(s=14.0)"}, + { + "target_component_local_level", + "target_component_trend", + "target_component_seasonal(s=7.0)", + "target_component_seasonal(s=14.0)", + }, ), ), ) -def test_components_names(periodic_ts, estimator, params, components_names): - train, test = periodic_ts - model = estimator(**params) - model.fit(train) +@pytest.mark.parametrize( + "method,use_future", + ( + ("predict_components", False), + ("forecast_components", True), + ), +) +def test_components_names(periodic_dfs, estimator, params, components_names, method, use_future): + train, test = periodic_dfs + model = _TBATSAdapter(model=estimator(**params)) + model.fit(train, []) - future = train.make_future(3).to_pandas(flatten=True) + pred_df = test if use_future else train - for segment in test.columns.get_level_values("segment"): - components_df = model._models[segment].forecast_components(df=future) - assert set(components_df.columns) == components_names + method_to_call = getattr(model, method) + components_df = method_to_call(df=pred_df) + assert set(components_df.columns) == components_names @pytest.mark.parametrize( "estimator", ( - BATSModel, - TBATSModel, + BATS, + TBATS, ), ) -def test_seasonal_components_not_fitted(small_periodic_ts, estimator): - model = estimator(seasonal_periods=[7, 14], use_arma_errors=False) - model.fit(small_periodic_ts) +@pytest.mark.parametrize("method,use_future", (("predict_components", False), ("forecast_components", True))) +def test_seasonal_components_not_fitted(periodic_dfs, estimator, method, use_future): + train, test = periodic_dfs + model = _TBATSAdapter(model=estimator(seasonal_periods=[7, 14], use_arma_errors=False)) + model.fit(train, []) - future = small_periodic_ts.make_future(3).to_pandas(flatten=True) - segment_model = model._models["segment_1"] - segment_model._fitted_model.params.components.seasonal_periods = [] + model._fitted_model.params.components.seasonal_periods = [] + pred_df = test if use_future else train + + method_to_call = getattr(model, method) with pytest.warns(Warning, match=f"Following components are not fitted: Seasonal!"): - segment_model.forecast_components(df=future) + method_to_call(df=pred_df) @pytest.mark.parametrize( "estimator", ( - BATSModel, - TBATSModel, + BATS, + TBATS, ), ) -def test_arma_component_not_fitted(small_periodic_ts, estimator): - model = estimator(use_arma_errors=True, seasonal_periods=[]) - model.fit(small_periodic_ts) +@pytest.mark.parametrize("method,use_future", (("predict_components", False), ("forecast_components", True))) +def test_arma_component_not_fitted(periodic_dfs, estimator, method, use_future): + train, test = periodic_dfs + + model = _TBATSAdapter(model=estimator(use_arma_errors=True)) + model.fit(train, []) - future = small_periodic_ts.make_future(3).to_pandas(flatten=True) - segment_model = model._models["segment_1"] - segment_model._fitted_model.params.components.use_arma_errors = False + model._fitted_model.params.components.use_arma_errors = False + pred_df = test if use_future else train + + method_to_call = getattr(model, method) with pytest.warns(Warning, match=f"Following components are not fitted: ARMA!"): - segment_model.forecast_components(df=future) + method_to_call(df=pred_df) @pytest.mark.parametrize( "estimator", ( - BATSModel, - TBATSModel, + BATS, + TBATS, ), ) -def test_arma_w_seasonal_components_not_fitted(small_periodic_ts, estimator): - model = estimator(use_arma_errors=True, seasonal_periods=[2, 3]) - model.fit(small_periodic_ts) +@pytest.mark.parametrize("method,use_future", (("predict_components", False), ("forecast_components", True))) +def test_arma_with_seasonal_components_not_fitted(periodic_dfs, estimator, method, use_future): + train, test = periodic_dfs + + model = _TBATSAdapter(model=estimator(use_arma_errors=True, seasonal_periods=[2, 3], use_box_cox=False)) + model.fit(train, []) - future = small_periodic_ts.make_future(3).to_pandas(flatten=True) - segment_model = model._models["segment_1"] - segment_model._fitted_model.params.components.use_arma_errors = False - segment_model._fitted_model.params.components.seasonal_periods = [] + model._fitted_model.params.components.use_arma_errors = False + model._fitted_model.params.components.seasonal_periods = [] + pred_df = test if use_future else train + + method_to_call = getattr(model, method) with pytest.warns(Warning, match=f"Following components are not fitted: Seasonal, ARMA!"): - segment_model.forecast_components(df=future) + method_to_call(df=pred_df) @pytest.mark.long_1 @@ -314,8 +381,8 @@ def test_arma_w_seasonal_components_not_fitted(small_periodic_ts, estimator): @pytest.mark.parametrize( "estimator", ( - BATSModel, - TBATSModel, + BATS, + TBATS, ), ) @pytest.mark.parametrize( @@ -342,16 +409,44 @@ def test_arma_w_seasonal_components_not_fitted(small_periodic_ts, estimator): }, ), ) -def test_forecast_decompose_sum_up_to_target(periodic_ts, estimator, params): - train, test = periodic_ts +@pytest.mark.parametrize("method,use_future", (("predict_components", False), ("forecast_components", True))) +def test_forecast_decompose_sum_up_to_target(periodic_dfs, estimator, params, method, use_future): + train, test = periodic_dfs - horizon = 14 - model = estimator(**params) - model.fit(train) - future_ts = train.make_future(horizon) - y_pred = model.forecast(future_ts) + model = _TBATSAdapter(model=estimator(**params)) + model.fit(train, []) + + if use_future: + pred_df = test + y_pred = model.forecast(test, prediction_interval=False, quantiles=[]) + + else: + pred_df = train + y_pred = model.predict(train, prediction_interval=False, quantiles=[]) + + method_to_call = getattr(model, method) + components = method_to_call(df=pred_df) + + y_hat_pred = np.sum(components.values, axis=1) + np.testing.assert_allclose(y_hat_pred, np.squeeze(y_pred.values)) + + +@pytest.mark.parametrize( + "estimator", + ( + BATS, + TBATS, + ), +) +def test_predict_decompose_on_subset(periodic_dfs, estimator): + train, _ = periodic_dfs + sub_train = train.iloc[5:] + + model = _TBATSAdapter(model=estimator()) + model.fit(train, []) + + y_pred = model.predict(df=sub_train, prediction_interval=False, quantiles=[]) + components = model.predict_components(df=sub_train) - for segment in y_pred.columns.get_level_values("segment"): - components = model._models[segment].forecast_components(df=future_ts.to_pandas(flatten=True)) - y_hat_pred = np.sum(components.values, axis=1) - np.testing.assert_allclose(y_hat_pred, y_pred[:, segment, "target"].values) + y_hat_pred = np.sum(components.values, axis=1) + np.testing.assert_allclose(y_hat_pred, np.squeeze(y_pred.values))