From 7b81911828e98e703c841a531a9c2766cd86b20f Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 25 Aug 2022 17:38:48 +0300 Subject: [PATCH 01/15] Unify interface to work with prediction intervals --- etna/models/autoarima.py | 6 +- etna/models/base.py | 231 ++++++++++++++++++++++---------------- etna/models/nn/deepar.py | 20 +++- etna/models/nn/tft.py | 20 +++- etna/models/prophet.py | 5 +- etna/models/sarimax.py | 5 +- etna/models/tbats.py | 7 +- etna/pipeline/pipeline.py | 4 +- 8 files changed, 185 insertions(+), 113 deletions(-) diff --git a/etna/models/autoarima.py b/etna/models/autoarima.py index 2d8aa37c8..a592a9bb7 100644 --- a/etna/models/autoarima.py +++ b/etna/models/autoarima.py @@ -5,7 +5,9 @@ from statsmodels.tools.sm_exceptions import ValueWarning from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper -from etna.models.base import PerSegmentPredictionIntervalModel +# from etna.models.base import PerSegmentPredictionIntervalModel +from etna.models.base import PerSegmentModel +from etna.models.base import PredictionIntervalInterface from etna.models.sarimax import _SARIMAXBaseAdapter warnings.filterwarnings( @@ -48,7 +50,7 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResul return model.arima_res_ -class AutoARIMAModel(PerSegmentPredictionIntervalModel): +class AutoARIMAModel(PerSegmentModel, PredictionIntervalInterface): """ Class for holding auto arima model. diff --git a/etna/models/base.py b/etna/models/base.py index 1e31238dd..f7134dcd1 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -151,43 +151,42 @@ def get_model(self) -> Union[Any, Dict[str, Any]]: pass +class PredictionIntervalInterface(ABC): + """Interface that is used to select classes that support prediction intervals.""" + + pass + + class ForecastAbstractModel(ABC): """Interface for model with forecast method.""" - @abstractmethod - def forecast(self, ts: TSDataset) -> TSDataset: - """Make predictions. + def _extract_prediction_interval_params(self, **kwargs) -> Dict[str, Any]: + extracted_params = {} - Parameters - ---------- - ts: - Dataset with features + if isinstance(self, PredictionIntervalInterface): + prediction_interval = kwargs.get("prediction_interval", False) + extracted_params["prediction_interval"] = prediction_interval - Returns - ------- - : - Dataset with predictions - """ - pass + quantiles = kwargs.get("quantiles", (0.025, 0.975)) + extracted_params["quantiles"] = quantiles + else: + if "prediction_interval" in kwargs or "quantiles" in kwargs: + raise NotImplementedError(f"Model {self.__class__} doesn't support prediction intervals!") + return extracted_params -class PredictIntervalAbstractModel(ABC): - """Interface for model with forecast method that creates prediction interval.""" + def _extract_prediction_size_params(self, **kwargs): + extracted_params = {} + return extracted_params @abstractmethod - def forecast( - self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) - ) -> TSDataset: + def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: """Make predictions. Parameters ---------- ts: Dataset with features - prediction_interval: - If True returns prediction interval for forecast - quantiles: - Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval Returns ------- @@ -197,6 +196,32 @@ def forecast( pass +# class PredictIntervalAbstractModel(ABC): +# """Interface for model with forecast method that creates prediction interval.""" +# +# @abstractmethod +# def forecast( +# self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) +# ) -> TSDataset: +# """Make predictions. +# +# Parameters +# ---------- +# ts: +# Dataset with features +# prediction_interval: +# If True returns prediction interval for forecast +# quantiles: +# Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval +# +# Returns +# ------- +# : +# Dataset with predictions +# """ +# pass + + class PerSegmentBaseModel(FitAbstractModel, BaseMixin): """Base class for holding specific models for per-segment prediction.""" @@ -301,21 +326,27 @@ def __init__(self, base_model: Any): super().__init__(base_model=base_model) @log_decorator - def forecast(self, ts: TSDataset) -> TSDataset: + def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: """Make predictions. Parameters ---------- ts: Dataframe with features + Returns ------- : Dataset with predictions """ + forecast_params = {} + forecast_params.update(self._extract_prediction_interval_params(**kwargs)) + forecast_params.update(self._extract_prediction_size_params(**kwargs)) + result_list = list() for segment, model in self._get_model().items(): - segment_predict = self._forecast_segment(model=model, segment=segment, ts=ts) + segment_predict = self._forecast_segment(model=model, segment=segment, ts=ts, **forecast_params) + result_list.append(segment_predict) result_df = pd.concat(result_list, ignore_index=True) @@ -330,57 +361,57 @@ def forecast(self, ts: TSDataset) -> TSDataset: return ts -class PerSegmentPredictionIntervalModel(PerSegmentBaseModel, PredictIntervalAbstractModel): - """Class for holding specific models for per-segment prediction which are able to build prediction intervals.""" - - def __init__(self, base_model: Any): - """ - Init PerSegmentPredictionIntervalModel. - - Parameters - ---------- - base_model: - Internal model which will be used to forecast segments, expected to have fit/predict interface - """ - super().__init__(base_model=base_model) - - @log_decorator - def forecast( - self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) - ) -> TSDataset: - """Make predictions. - - Parameters - ---------- - ts: - Dataset with features - prediction_interval: - If True returns prediction interval for forecast - quantiles: - Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval - - Returns - ------- - : - Dataset with predictions - """ - result_list = list() - for segment, model in self._get_model().items(): - segment_predict = self._forecast_segment( - model=model, segment=segment, ts=ts, prediction_interval=prediction_interval, quantiles=quantiles - ) - result_list.append(segment_predict) - - result_df = pd.concat(result_list, ignore_index=True) - result_df = result_df.set_index(["timestamp", "segment"]) - df = ts.to_pandas(flatten=True) - df = df.set_index(["timestamp", "segment"]) - df = df.combine_first(result_df).reset_index() - - df = TSDataset.to_dataset(df) - ts.df = df - ts.inverse_transform() - return ts +# class PerSegmentPredictionIntervalModel(PerSegmentBaseModel, PredictIntervalAbstractModel): +# """Class for holding specific models for per-segment prediction which are able to build prediction intervals.""" +# +# def __init__(self, base_model: Any): +# """ +# Init PerSegmentPredictionIntervalModel. +# +# Parameters +# ---------- +# base_model: +# Internal model which will be used to forecast segments, expected to have fit/predict interface +# """ +# super().__init__(base_model=base_model) +# +# @log_decorator +# def forecast( +# self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) +# ) -> TSDataset: +# """Make predictions. +# +# Parameters +# ---------- +# ts: +# Dataset with features +# prediction_interval: +# If True returns prediction interval for forecast +# quantiles: +# Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval +# +# Returns +# ------- +# : +# Dataset with predictions +# """ +# result_list = list() +# for segment, model in self._get_model().items(): +# segment_predict = self._forecast_segment( +# model=model, segment=segment, ts=ts, prediction_interval=prediction_interval, quantiles=quantiles +# ) +# result_list.append(segment_predict) +# +# result_df = pd.concat(result_list, ignore_index=True) +# result_df = result_df.set_index(["timestamp", "segment"]) +# df = ts.to_pandas(flatten=True) +# df = df.set_index(["timestamp", "segment"]) +# df = df.combine_first(result_df).reset_index() +# +# df = TSDataset.to_dataset(df) +# ts.df = df +# ts.inverse_transform() +# return ts class MultiSegmentModel(FitAbstractModel, ForecastAbstractModel, BaseMixin): @@ -418,7 +449,7 @@ def fit(self, ts: TSDataset) -> "MultiSegmentModel": return self @log_decorator - def forecast(self, ts: TSDataset) -> TSDataset: + def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: """Make predictions. Parameters @@ -431,9 +462,13 @@ def forecast(self, ts: TSDataset) -> TSDataset: : Dataset with predictions """ + forecast_params = {} + forecast_params.update(self._extract_prediction_interval_params(**kwargs)) + forecast_params.update(self._extract_prediction_size_params(**kwargs)) + horizon = len(ts.df) x = ts.to_pandas(flatten=True).drop(["segment"], axis=1) - y = self._base_model.predict(x).reshape(-1, horizon).T + y = self._base_model.predict(x, **forecast_params).reshape(-1, horizon).T ts.loc[:, pd.IndexSlice[:, "target"]] = y ts.inverse_transform() return ts @@ -831,31 +866,31 @@ def get_model(self) -> "DeepBaseNet": return self.net -class MultiSegmentPredictionIntervalModel(FitAbstractModel, PredictIntervalAbstractModel, BaseMixin): - """Class for holding specific models for multi-segment prediction which are able to build prediction intervals.""" - - def __init__(self): - """Init MultiSegmentPredictionIntervalModel.""" - self.model = None - - def get_model(self) -> Any: - """Get internal model that is used inside etna class. - - Internal model is a model that is used inside etna to forecast segments, - e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`. - - Returns - ------- - : - Internal model - """ - return self.model +# class MultiSegmentPredictionIntervalModel(FitAbstractModel, PredictIntervalAbstractModel, BaseMixin): +# """Class for holding specific models for multi-segment prediction which are able to build prediction intervals.""" +# +# def __init__(self): +# """Init MultiSegmentPredictionIntervalModel.""" +# self.model = None +# +# def get_model(self) -> Any: +# """Get internal model that is used inside etna class. +# +# Internal model is a model that is used inside etna to forecast segments, +# e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`. +# +# Returns +# ------- +# : +# Internal model +# """ +# return self.model BaseModel = Union[ PerSegmentModel, - PerSegmentPredictionIntervalModel, + # PerSegmentPredictionIntervalModel, MultiSegmentModel, DeepBaseModel, - MultiSegmentPredictionIntervalModel, + # MultiSegmentPredictionIntervalModel, ] diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index c4f6f6677..a2b887d91 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -8,9 +8,12 @@ import pandas as pd from etna import SETTINGS +from etna.core.mixins import BaseMixin from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger -from etna.models.base import MultiSegmentPredictionIntervalModel +from etna.models.base import FitAbstractModel +from etna.models.base import ForecastAbstractModel +from etna.models.base import PredictionIntervalInterface from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -24,7 +27,7 @@ from pytorch_lightning import LightningModule -class DeepARModel(MultiSegmentPredictionIntervalModel, _DeepCopyMixin): +class DeepARModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalInterface, BaseMixin, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`. Notes @@ -241,3 +244,16 @@ def forecast( ts.inverse_transform() return ts + + def get_model(self) -> Any: + """Get internal model that is used inside etna class. + + Internal model is a model that is used inside etna to forecast segments, + e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`. + + Returns + ------- + : + Internal model + """ + return self.model diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index ca0d1ae79..cc55c93de 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -9,9 +9,12 @@ import pandas as pd from etna import SETTINGS +from etna.core.mixins import BaseMixin from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger -from etna.models.base import MultiSegmentPredictionIntervalModel +from etna.models.base import FitAbstractModel +from etna.models.base import ForecastAbstractModel +from etna.models.base import PredictionIntervalInterface from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -25,7 +28,7 @@ from pytorch_lightning import LightningModule -class TFTModel(MultiSegmentPredictionIntervalModel, _DeepCopyMixin): +class TFTModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalInterface, BaseMixin, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. Notes @@ -273,3 +276,16 @@ def forecast( ts.inverse_transform() return ts + + def get_model(self) -> Any: + """Get internal model that is used inside etna class. + + Internal model is a model that is used inside etna to forecast segments, + e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`. + + Returns + ------- + : + Internal model + """ + return self.model diff --git a/etna/models/prophet.py b/etna/models/prophet.py index d4c5731cc..4ce6a024e 100644 --- a/etna/models/prophet.py +++ b/etna/models/prophet.py @@ -10,7 +10,8 @@ from etna import SETTINGS from etna.models.base import BaseAdapter -from etna.models.base import PerSegmentPredictionIntervalModel +from etna.models.base import PerSegmentModel +from etna.models.base import PredictionIntervalInterface if SETTINGS.prophet_required: from prophet import Prophet @@ -153,7 +154,7 @@ def get_model(self) -> Prophet: return self.model -class ProphetModel(PerSegmentPredictionIntervalModel): +class ProphetModel(PerSegmentModel, PredictionIntervalInterface): """Class for holding Prophet model. Notes diff --git a/etna/models/sarimax.py b/etna/models/sarimax.py index 62c60e7eb..29470c2bc 100644 --- a/etna/models/sarimax.py +++ b/etna/models/sarimax.py @@ -13,7 +13,8 @@ from etna.libs.pmdarima_utils import seasonal_prediction_with_confidence from etna.models.base import BaseAdapter -from etna.models.base import PerSegmentPredictionIntervalModel +from etna.models.base import PerSegmentModel +from etna.models.base import PredictionIntervalInterface from etna.models.utils import determine_num_steps warnings.filterwarnings( @@ -352,7 +353,7 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame): return result -class SARIMAXModel(PerSegmentPredictionIntervalModel): +class SARIMAXModel(PerSegmentModel, PredictionIntervalInterface): """ Class for holding Sarimax model. diff --git a/etna/models/tbats.py b/etna/models/tbats.py index 9f1a22cd7..4b4c09d84 100644 --- a/etna/models/tbats.py +++ b/etna/models/tbats.py @@ -10,7 +10,8 @@ from tbats.tbats.Model import Model from etna.models.base import BaseAdapter -from etna.models.base import PerSegmentPredictionIntervalModel +from etna.models.base import PerSegmentModel +from etna.models.base import PredictionIntervalInterface from etna.models.utils import determine_num_steps @@ -72,7 +73,7 @@ def get_model(self) -> Estimator: return self.model -class BATSModel(PerSegmentPredictionIntervalModel): +class BATSModel(PerSegmentModel, PredictionIntervalInterface): """Class for holding segment interval BATS model.""" context_size = 0 @@ -141,7 +142,7 @@ def __init__( super().__init__(base_model=_TBATSAdapter(self.model)) -class TBATSModel(PerSegmentPredictionIntervalModel): +class TBATSModel(PerSegmentModel, PredictionIntervalInterface): """Class for holding segment interval TBATS model.""" context_size = 0 diff --git a/etna/pipeline/pipeline.py b/etna/pipeline/pipeline.py index bb22bd5e4..2c7d95670 100644 --- a/etna/pipeline/pipeline.py +++ b/etna/pipeline/pipeline.py @@ -3,7 +3,7 @@ from etna.datasets import TSDataset from etna.models.base import BaseModel from etna.models.base import DeepBaseModel -from etna.models.base import PredictIntervalAbstractModel +from etna.models.base import PredictionIntervalInterface from etna.pipeline.base import BasePipeline from etna.transforms.base import Transform @@ -89,7 +89,7 @@ def forecast( self._validate_quantiles(quantiles=quantiles) self._validate_backtest_n_folds(n_folds=n_folds) - if prediction_interval and isinstance(self.model, PredictIntervalAbstractModel): + if prediction_interval and isinstance(self.model, PredictionIntervalInterface): future = self.ts.make_future(self.horizon) predictions = self.model.forecast(ts=future, prediction_interval=prediction_interval, quantiles=quantiles) else: From d55764605931fea421979eea4f8653e75684c15a Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 25 Aug 2022 18:06:43 +0300 Subject: [PATCH 02/15] Add ContextRequiredInterface --- etna/models/base.py | 137 +++++++------------------------ etna/pipeline/pipeline.py | 12 ++- tests/test_models/nn/test_rnn.py | 2 +- 3 files changed, 38 insertions(+), 113 deletions(-) diff --git a/etna/models/base.py b/etna/models/base.py index f7134dcd1..3ee6dcb2f 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -152,7 +152,13 @@ def get_model(self) -> Union[Any, Dict[str, Any]]: class PredictionIntervalInterface(ABC): - """Interface that is used to select classes that support prediction intervals.""" + """Interface that is used to mark classes that support prediction intervals.""" + + pass + + +class ContextRequiredInterface(ABC): + """Interface that is used to mark classes that need context for prediction.""" pass @@ -177,6 +183,20 @@ def _extract_prediction_interval_params(self, **kwargs) -> Dict[str, Any]: def _extract_prediction_size_params(self, **kwargs): extracted_params = {} + + if isinstance(self, ContextRequiredInterface): + prediction_size = kwargs.get("prediction_size") + if prediction_size is None: + raise ValueError(f"Parameter prediction_size is required for {self.__class__} model!") + + if not isinstance(prediction_size, int) or prediction_size <= 0: + raise ValueError(f"Parameter prediction_size should be positive integer!") + + extracted_params = prediction_size + else: + if "prediction_size" in kwargs: + raise NotImplementedError(f"Model {self.__class__} doesn't support prediction_size parameter!") + return extracted_params @abstractmethod @@ -196,32 +216,6 @@ def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: pass -# class PredictIntervalAbstractModel(ABC): -# """Interface for model with forecast method that creates prediction interval.""" -# -# @abstractmethod -# def forecast( -# self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) -# ) -> TSDataset: -# """Make predictions. -# -# Parameters -# ---------- -# ts: -# Dataset with features -# prediction_interval: -# If True returns prediction interval for forecast -# quantiles: -# Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval -# -# Returns -# ------- -# : -# Dataset with predictions -# """ -# pass - - class PerSegmentBaseModel(FitAbstractModel, BaseMixin): """Base class for holding specific models for per-segment prediction.""" @@ -361,59 +355,6 @@ def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: return ts -# class PerSegmentPredictionIntervalModel(PerSegmentBaseModel, PredictIntervalAbstractModel): -# """Class for holding specific models for per-segment prediction which are able to build prediction intervals.""" -# -# def __init__(self, base_model: Any): -# """ -# Init PerSegmentPredictionIntervalModel. -# -# Parameters -# ---------- -# base_model: -# Internal model which will be used to forecast segments, expected to have fit/predict interface -# """ -# super().__init__(base_model=base_model) -# -# @log_decorator -# def forecast( -# self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) -# ) -> TSDataset: -# """Make predictions. -# -# Parameters -# ---------- -# ts: -# Dataset with features -# prediction_interval: -# If True returns prediction interval for forecast -# quantiles: -# Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval -# -# Returns -# ------- -# : -# Dataset with predictions -# """ -# result_list = list() -# for segment, model in self._get_model().items(): -# segment_predict = self._forecast_segment( -# model=model, segment=segment, ts=ts, prediction_interval=prediction_interval, quantiles=quantiles -# ) -# result_list.append(segment_predict) -# -# result_df = pd.concat(result_list, ignore_index=True) -# result_df = result_df.set_index(["timestamp", "segment"]) -# df = ts.to_pandas(flatten=True) -# df = df.set_index(["timestamp", "segment"]) -# df = df.combine_first(result_df).reset_index() -# -# df = TSDataset.to_dataset(df) -# ts.df = df -# ts.inverse_transform() -# return ts - - class MultiSegmentModel(FitAbstractModel, ForecastAbstractModel, BaseMixin): """Class for holding specific models for per-segment prediction.""" @@ -654,7 +595,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore return loss -class DeepBaseModel(FitAbstractModel, DeepBaseAbstractModel, BaseMixin): +class DeepBaseModel(FitAbstractModel, DeepBaseAbstractModel, ContextRequiredInterface, BaseMixin): """Class for partially implemented interfaces for holding deep models.""" def __init__( @@ -825,15 +766,16 @@ def raw_predict(self, torch_dataset: "Dataset") -> Dict[Tuple[str, str], np.ndar return predictions_dict @log_decorator - def forecast(self, ts: "TSDataset", horizon: int) -> "TSDataset": + def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset": """Make predictions. Parameters ---------- ts: Dataset with features and expected decoder length for context - horizon: - Horizon to predict for + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context. Returns ------- @@ -847,9 +789,9 @@ def forecast(self, ts: "TSDataset", horizon: int) -> "TSDataset": dropna=False, ) predictions = self.raw_predict(test_dataset) - future_ts = ts.tsdataset_idx_slice(start_idx=self.encoder_length, end_idx=self.encoder_length + horizon) + future_ts = ts.tsdataset_idx_slice(start_idx=self.encoder_length, end_idx=self.encoder_length + prediction_size) for (segment, feature_nm), value in predictions.items(): - future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]] = value[:horizon, :] + future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]] = value[:prediction_size, :] future_ts.inverse_transform() @@ -866,31 +808,8 @@ def get_model(self) -> "DeepBaseNet": return self.net -# class MultiSegmentPredictionIntervalModel(FitAbstractModel, PredictIntervalAbstractModel, BaseMixin): -# """Class for holding specific models for multi-segment prediction which are able to build prediction intervals.""" -# -# def __init__(self): -# """Init MultiSegmentPredictionIntervalModel.""" -# self.model = None -# -# def get_model(self) -> Any: -# """Get internal model that is used inside etna class. -# -# Internal model is a model that is used inside etna to forecast segments, -# e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`. -# -# Returns -# ------- -# : -# Internal model -# """ -# return self.model - - BaseModel = Union[ PerSegmentModel, - # PerSegmentPredictionIntervalModel, MultiSegmentModel, DeepBaseModel, - # MultiSegmentPredictionIntervalModel, ] diff --git a/etna/pipeline/pipeline.py b/etna/pipeline/pipeline.py index 2c7d95670..9d3540f81 100644 --- a/etna/pipeline/pipeline.py +++ b/etna/pipeline/pipeline.py @@ -2,6 +2,7 @@ from etna.datasets import TSDataset from etna.models.base import BaseModel +from etna.models.base import ContextRequiredInterface from etna.models.base import DeepBaseModel from etna.models.base import PredictionIntervalInterface from etna.pipeline.base import BasePipeline @@ -54,9 +55,14 @@ def _forecast(self) -> TSDataset: if self.ts is None: raise ValueError("Something went wrong, ts is None!") - if isinstance(self.model, DeepBaseModel): - future = self.ts.make_future(future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length) - predictions = self.model.forecast(ts=future, horizon=self.horizon) + if isinstance(self.model, ContextRequiredInterface): + if isinstance(self.model, DeepBaseModel): + future = self.ts.make_future( + future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length + ) + else: + future = self.ts.make_future(future_steps=self.horizon, tail_steps=self.model.context_size) + predictions = self.model.forecast(ts=future, prediction_size=self.horizon) else: future = self.ts.make_future(self.horizon) predictions = self.model.forecast(ts=future) diff --git a/tests/test_models/nn/test_rnn.py b/tests/test_models/nn/test_rnn.py index 85c8072c7..e97d1adfe 100644 --- a/tests/test_models/nn/test_rnn.py +++ b/tests/test_models/nn/test_rnn.py @@ -29,7 +29,7 @@ def test_rnn_model_run_weekly_overfit_with_scaler(ts_dataset_weekly_function_wit ) future = ts_train.make_future(decoder_length, encoder_length) model.fit(ts_train) - future = model.forecast(future, horizon=horizon) + future = model.forecast(future, prediction_size=horizon) mae = MAE("macro") assert mae(ts_test, future) < 0.06 From 12c901f96128fd30a1da52d4a85349b02f43192c Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 25 Aug 2022 18:54:05 +0300 Subject: [PATCH 03/15] Remove context_size from models that doesn't require it --- etna/models/autoarima.py | 6 ++---- etna/models/base.py | 22 ++++++++++------------ etna/models/catboost.py | 8 -------- etna/models/holt_winters.py | 6 ------ etna/models/nn/deepar.py | 6 ++---- etna/models/nn/tft.py | 6 ++---- etna/models/prophet.py | 6 ++---- etna/models/sarimax.py | 6 ++---- etna/models/sklearn.py | 4 ---- etna/models/tbats.py | 10 +++------- etna/pipeline/pipeline.py | 8 ++++---- tests/test_models/test_base.py | 4 ++-- 12 files changed, 29 insertions(+), 63 deletions(-) diff --git a/etna/models/autoarima.py b/etna/models/autoarima.py index a592a9bb7..ac8e718c3 100644 --- a/etna/models/autoarima.py +++ b/etna/models/autoarima.py @@ -7,7 +7,7 @@ # from etna.models.base import PerSegmentPredictionIntervalModel from etna.models.base import PerSegmentModel -from etna.models.base import PredictionIntervalInterface +from etna.models.base import PredictionIntervalAbstractModel from etna.models.sarimax import _SARIMAXBaseAdapter warnings.filterwarnings( @@ -50,7 +50,7 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResul return model.arima_res_ -class AutoARIMAModel(PerSegmentModel, PredictionIntervalInterface): +class AutoARIMAModel(PerSegmentModel, PredictionIntervalAbstractModel): """ Class for holding auto arima model. @@ -59,8 +59,6 @@ class AutoARIMAModel(PerSegmentModel, PredictionIntervalInterface): We use :py:class:`pmdarima.arima.arima.ARIMA`. """ - context_size = 0 - def __init__( self, **kwargs, diff --git a/etna/models/base.py b/etna/models/base.py index 3ee6dcb2f..ab9afa0ea 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -109,12 +109,6 @@ def _forecast_segment(model, segment: Union[str, List[str]], ts: TSDataset) -> p class FitAbstractModel(ABC): """Interface for model with fit method.""" - @property - @abstractmethod - def context_size(self) -> int: - """Upper bound to context size of the model.""" - pass - @abstractmethod def fit(self, ts: TSDataset) -> "FitAbstractModel": """Fit model. @@ -151,16 +145,20 @@ def get_model(self) -> Union[Any, Dict[str, Any]]: pass -class PredictionIntervalInterface(ABC): +class PredictionIntervalAbstractModel(ABC): """Interface that is used to mark classes that support prediction intervals.""" pass -class ContextRequiredInterface(ABC): +class ContextRequiredAbstractModel(ABC): """Interface that is used to mark classes that need context for prediction.""" - pass + @property + @abstractmethod + def context_size(self) -> int: + """Context size of the model. Determines how many history points do we ask to pass to the model.""" + pass class ForecastAbstractModel(ABC): @@ -169,7 +167,7 @@ class ForecastAbstractModel(ABC): def _extract_prediction_interval_params(self, **kwargs) -> Dict[str, Any]: extracted_params = {} - if isinstance(self, PredictionIntervalInterface): + if isinstance(self, PredictionIntervalAbstractModel): prediction_interval = kwargs.get("prediction_interval", False) extracted_params["prediction_interval"] = prediction_interval @@ -184,7 +182,7 @@ def _extract_prediction_interval_params(self, **kwargs) -> Dict[str, Any]: def _extract_prediction_size_params(self, **kwargs): extracted_params = {} - if isinstance(self, ContextRequiredInterface): + if isinstance(self, ContextRequiredAbstractModel): prediction_size = kwargs.get("prediction_size") if prediction_size is None: raise ValueError(f"Parameter prediction_size is required for {self.__class__} model!") @@ -595,7 +593,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore return loss -class DeepBaseModel(FitAbstractModel, DeepBaseAbstractModel, ContextRequiredInterface, BaseMixin): +class DeepBaseModel(FitAbstractModel, DeepBaseAbstractModel, ContextRequiredAbstractModel, BaseMixin): """Class for partially implemented interfaces for holding deep models.""" def __init__( diff --git a/etna/models/catboost.py b/etna/models/catboost.py index 5b4f10137..19337979b 100644 --- a/etna/models/catboost.py +++ b/etna/models/catboost.py @@ -131,8 +131,6 @@ class CatBoostPerSegmentModel(PerSegmentModel): 2020-04-16 8.00 6.00 2.00 0.00 """ - context_size = 0 - def __init__( self, iterations: Optional[int] = None, @@ -255,8 +253,6 @@ class CatBoostMultiSegmentModel(MultiSegmentModel): 2020-04-16 8.00 6.00 2.00 0.00 """ - context_size = 0 - def __init__( self, iterations: Optional[int] = None, @@ -387,8 +383,6 @@ class CatBoostModelPerSegment(CatBoostPerSegmentModel): 2020-04-16 8.00 6.00 2.00 0.00 """ - context_size = 0 - def __init__( self, iterations: Optional[int] = None, @@ -518,8 +512,6 @@ class CatBoostModelMultiSegment(CatBoostMultiSegmentModel): 2020-04-16 8.00 6.00 2.00 0.00 """ - context_size = 0 - def __init__( self, iterations: Optional[int] = None, diff --git a/etna/models/holt_winters.py b/etna/models/holt_winters.py index 11c520d8e..c219d610d 100644 --- a/etna/models/holt_winters.py +++ b/etna/models/holt_winters.py @@ -285,8 +285,6 @@ class HoltWintersModel(PerSegmentModel): We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package. """ - context_size = 0 - def __init__( self, trend: Optional[str] = None, @@ -484,8 +482,6 @@ class HoltModel(HoltWintersModel): as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model. """ - context_size = 0 - def __init__( self, exponential: bool = False, @@ -583,8 +579,6 @@ class SimpleExpSmoothingModel(HoltWintersModel): as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model. """ - context_size = 0 - def __init__( self, initialization_method: str = "estimated", diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index a2b887d91..f182164e1 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -13,7 +13,7 @@ from etna.loggers import tslogger from etna.models.base import FitAbstractModel from etna.models.base import ForecastAbstractModel -from etna.models.base import PredictionIntervalInterface +from etna.models.base import PredictionIntervalAbstractModel from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -27,7 +27,7 @@ from pytorch_lightning import LightningModule -class DeepARModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalInterface, BaseMixin, _DeepCopyMixin): +class DeepARModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalAbstractModel, BaseMixin, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`. Notes @@ -36,8 +36,6 @@ class DeepARModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalInt It`s not right pattern of using Transforms and TSDataset. """ - context_size = 0 - def __init__( self, batch_size: int = 64, diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index cc55c93de..c984bd546 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -14,7 +14,7 @@ from etna.loggers import tslogger from etna.models.base import FitAbstractModel from etna.models.base import ForecastAbstractModel -from etna.models.base import PredictionIntervalInterface +from etna.models.base import PredictionIntervalAbstractModel from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -28,7 +28,7 @@ from pytorch_lightning import LightningModule -class TFTModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalInterface, BaseMixin, _DeepCopyMixin): +class TFTModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalAbstractModel, BaseMixin, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. Notes @@ -37,8 +37,6 @@ class TFTModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalInterf It`s not right pattern of using Transforms and TSDataset. """ - context_size = 0 - def __init__( self, max_epochs: int = 10, diff --git a/etna/models/prophet.py b/etna/models/prophet.py index 4ce6a024e..fc2ea6d92 100644 --- a/etna/models/prophet.py +++ b/etna/models/prophet.py @@ -11,7 +11,7 @@ from etna import SETTINGS from etna.models.base import BaseAdapter from etna.models.base import PerSegmentModel -from etna.models.base import PredictionIntervalInterface +from etna.models.base import PredictionIntervalAbstractModel if SETTINGS.prophet_required: from prophet import Prophet @@ -154,7 +154,7 @@ def get_model(self) -> Prophet: return self.model -class ProphetModel(PerSegmentModel, PredictionIntervalInterface): +class ProphetModel(PerSegmentModel, PredictionIntervalAbstractModel): """Class for holding Prophet model. Notes @@ -199,8 +199,6 @@ class ProphetModel(PerSegmentModel, PredictionIntervalInterface): 2020-04-16 8.00 6.00 2.00 0.00 """ - context_size = 0 - def __init__( self, growth: str = "linear", diff --git a/etna/models/sarimax.py b/etna/models/sarimax.py index 29470c2bc..d09192af1 100644 --- a/etna/models/sarimax.py +++ b/etna/models/sarimax.py @@ -14,7 +14,7 @@ from etna.libs.pmdarima_utils import seasonal_prediction_with_confidence from etna.models.base import BaseAdapter from etna.models.base import PerSegmentModel -from etna.models.base import PredictionIntervalInterface +from etna.models.base import PredictionIntervalAbstractModel from etna.models.utils import determine_num_steps warnings.filterwarnings( @@ -353,7 +353,7 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame): return result -class SARIMAXModel(PerSegmentModel, PredictionIntervalInterface): +class SARIMAXModel(PerSegmentModel, PredictionIntervalAbstractModel): """ Class for holding Sarimax model. @@ -365,8 +365,6 @@ class SARIMAXModel(PerSegmentModel, PredictionIntervalInterface): future. """ - context_size = 0 - def __init__( self, order: Tuple[int, int, int] = (2, 1, 0), diff --git a/etna/models/sklearn.py b/etna/models/sklearn.py index ea2efc90f..04d777d01 100644 --- a/etna/models/sklearn.py +++ b/etna/models/sklearn.py @@ -75,8 +75,6 @@ def get_model(self) -> RegressorMixin: class SklearnPerSegmentModel(PerSegmentModel): """Class for holding per segment Sklearn model.""" - context_size = 0 - def __init__(self, regressor: RegressorMixin): """ Create instance of SklearnPerSegmentModel with given parameters. @@ -92,8 +90,6 @@ def __init__(self, regressor: RegressorMixin): class SklearnMultiSegmentModel(MultiSegmentModel): """Class for holding Sklearn model for all segments.""" - context_size = 0 - def __init__(self, regressor: RegressorMixin): """ Create instance of SklearnMultiSegmentModel with given parameters. diff --git a/etna/models/tbats.py b/etna/models/tbats.py index 4b4c09d84..cf70e65fa 100644 --- a/etna/models/tbats.py +++ b/etna/models/tbats.py @@ -11,7 +11,7 @@ from etna.models.base import BaseAdapter from etna.models.base import PerSegmentModel -from etna.models.base import PredictionIntervalInterface +from etna.models.base import PredictionIntervalAbstractModel from etna.models.utils import determine_num_steps @@ -73,11 +73,9 @@ def get_model(self) -> Estimator: return self.model -class BATSModel(PerSegmentModel, PredictionIntervalInterface): +class BATSModel(PerSegmentModel, PredictionIntervalAbstractModel): """Class for holding segment interval BATS model.""" - context_size = 0 - def __init__( self, use_box_cox: Optional[bool] = None, @@ -142,11 +140,9 @@ def __init__( super().__init__(base_model=_TBATSAdapter(self.model)) -class TBATSModel(PerSegmentModel, PredictionIntervalInterface): +class TBATSModel(PerSegmentModel, PredictionIntervalAbstractModel): """Class for holding segment interval TBATS model.""" - context_size = 0 - def __init__( self, use_box_cox: Optional[bool] = None, diff --git a/etna/pipeline/pipeline.py b/etna/pipeline/pipeline.py index 9d3540f81..bf6b1c697 100644 --- a/etna/pipeline/pipeline.py +++ b/etna/pipeline/pipeline.py @@ -2,9 +2,9 @@ from etna.datasets import TSDataset from etna.models.base import BaseModel -from etna.models.base import ContextRequiredInterface +from etna.models.base import ContextRequiredAbstractModel from etna.models.base import DeepBaseModel -from etna.models.base import PredictionIntervalInterface +from etna.models.base import PredictionIntervalAbstractModel from etna.pipeline.base import BasePipeline from etna.transforms.base import Transform @@ -55,7 +55,7 @@ def _forecast(self) -> TSDataset: if self.ts is None: raise ValueError("Something went wrong, ts is None!") - if isinstance(self.model, ContextRequiredInterface): + if isinstance(self.model, ContextRequiredAbstractModel): if isinstance(self.model, DeepBaseModel): future = self.ts.make_future( future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length @@ -95,7 +95,7 @@ def forecast( self._validate_quantiles(quantiles=quantiles) self._validate_backtest_n_folds(n_folds=n_folds) - if prediction_interval and isinstance(self.model, PredictionIntervalInterface): + if prediction_interval and isinstance(self.model, PredictionIntervalAbstractModel): future = self.ts.make_future(self.horizon) predictions = self.model.forecast(ts=future, prediction_interval=prediction_interval, quantiles=quantiles) else: diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index 45dc2e451..9ae93dffc 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -141,7 +141,7 @@ def test_deep_base_model_raw_predict_call(dataloader, deep_base_model_mock): def test_deep_base_model_forecast_inverse_transform_call_check(deep_base_model_mock): ts = MagicMock() horizon = 7 - DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, horizon=horizon) + DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) ts.tsdataset_idx_slice.return_value.inverse_transform.assert_called_once() @@ -156,7 +156,7 @@ def test_deep_base_model_forecast_loop(simple_df, deep_base_model_mock): ts_after_tsdataset_idx_slice.df = simple_df.df.iloc[-horizon:] ts.tsdataset_idx_slice.return_value = ts_after_tsdataset_idx_slice - future = DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, horizon=horizon) + future = DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) np.testing.assert_allclose( future.df.loc[:, pd.IndexSlice["A", "target"]], raw_predict[("A", "target")][:horizon, 0] ) From e84068af0ade500a1ba2e894682c33686aff08f2 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 25 Aug 2022 18:57:56 +0300 Subject: [PATCH 04/15] Fix class name in errors --- etna/models/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/etna/models/base.py b/etna/models/base.py index ab9afa0ea..a8127e97d 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -166,6 +166,7 @@ class ForecastAbstractModel(ABC): def _extract_prediction_interval_params(self, **kwargs) -> Dict[str, Any]: extracted_params = {} + class_name = self.__class__.__name__ if isinstance(self, PredictionIntervalAbstractModel): prediction_interval = kwargs.get("prediction_interval", False) @@ -175,17 +176,18 @@ def _extract_prediction_interval_params(self, **kwargs) -> Dict[str, Any]: extracted_params["quantiles"] = quantiles else: if "prediction_interval" in kwargs or "quantiles" in kwargs: - raise NotImplementedError(f"Model {self.__class__} doesn't support prediction intervals!") + raise NotImplementedError(f"Model {class_name} doesn't support prediction intervals!") return extracted_params def _extract_prediction_size_params(self, **kwargs): extracted_params = {} + class_name = self.__class__.__name__ if isinstance(self, ContextRequiredAbstractModel): prediction_size = kwargs.get("prediction_size") if prediction_size is None: - raise ValueError(f"Parameter prediction_size is required for {self.__class__} model!") + raise ValueError(f"Parameter prediction_size is required for {class_name} model!") if not isinstance(prediction_size, int) or prediction_size <= 0: raise ValueError(f"Parameter prediction_size should be positive integer!") @@ -193,7 +195,7 @@ def _extract_prediction_size_params(self, **kwargs): extracted_params = prediction_size else: if "prediction_size" in kwargs: - raise NotImplementedError(f"Model {self.__class__} doesn't support prediction_size parameter!") + raise NotImplementedError(f"Model {class_name} doesn't support prediction_size parameter!") return extracted_params From 0b6f22c1e5fa59de444a5a54094ddb3f91722778 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 25 Aug 2022 19:34:42 +0300 Subject: [PATCH 05/15] Remove commented import --- etna/models/autoarima.py | 1 - 1 file changed, 1 deletion(-) diff --git a/etna/models/autoarima.py b/etna/models/autoarima.py index ac8e718c3..d94e32777 100644 --- a/etna/models/autoarima.py +++ b/etna/models/autoarima.py @@ -5,7 +5,6 @@ from statsmodels.tools.sm_exceptions import ValueWarning from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper -# from etna.models.base import PerSegmentPredictionIntervalModel from etna.models.base import PerSegmentModel from etna.models.base import PredictionIntervalAbstractModel from etna.models.sarimax import _SARIMAXBaseAdapter From 2c724f4bed63f943de00ef52d40f06b110ef9e1e Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 29 Aug 2022 12:57:42 +0300 Subject: [PATCH 06/15] Change the hierarchy of models.base --- etna/models/autoarima.py | 4 +- etna/models/base.py | 207 +++++++++++++---------- etna/models/catboost.py | 5 +- etna/models/deadline_ma.py | 3 +- etna/models/holt_winters.py | 3 +- etna/models/nn/deepar.py | 9 +- etna/models/nn/tft.py | 9 +- etna/models/prophet.py | 4 +- etna/models/sarimax.py | 4 +- etna/models/seasonal_ma.py | 3 +- etna/models/sklearn.py | 5 +- etna/models/tbats.py | 6 +- etna/pipeline/autoregressive_pipeline.py | 43 +++-- etna/pipeline/pipeline.py | 16 +- tests/test_models/test_base.py | 4 +- 15 files changed, 189 insertions(+), 136 deletions(-) diff --git a/etna/models/autoarima.py b/etna/models/autoarima.py index d94e32777..15725f36f 100644 --- a/etna/models/autoarima.py +++ b/etna/models/autoarima.py @@ -6,7 +6,7 @@ from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper from etna.models.base import PerSegmentModel -from etna.models.base import PredictionIntervalAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel from etna.models.sarimax import _SARIMAXBaseAdapter warnings.filterwarnings( @@ -49,7 +49,7 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResul return model.arima_res_ -class AutoARIMAModel(PerSegmentModel, PredictionIntervalAbstractModel): +class AutoARIMAModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): """ Class for holding auto arima model. diff --git a/etna/models/base.py b/etna/models/base.py index a8127e97d..077b75540 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -106,11 +106,11 @@ def _forecast_segment(model, segment: Union[str, List[str]], ts: TSDataset) -> p return segment_predict -class FitAbstractModel(ABC): +class AbstractModel(ABC, BaseMixin): """Interface for model with fit method.""" @abstractmethod - def fit(self, ts: TSDataset) -> "FitAbstractModel": + def fit(self, ts: TSDataset) -> "AbstractModel": """Fit model. Parameters @@ -125,6 +125,10 @@ def fit(self, ts: TSDataset) -> "FitAbstractModel": """ pass + @abstractmethod + def _forecast(self, ts: TSDataset, **kwargs): + pass + @abstractmethod def get_model(self) -> Union[Any, Dict[str, Any]]: """Get internal model/models that are used inside etna class. @@ -145,14 +149,26 @@ def get_model(self) -> Union[Any, Dict[str, Any]]: pass -class PredictionIntervalAbstractModel(ABC): - """Interface that is used to mark classes that support prediction intervals.""" +class NonPredictionIntervalAbstractModel(AbstractModel): + """Interface for models that don't support prediction intervals.""" + + pass + + +class PredictionIntervalAbstractModel(AbstractModel): + """Interface for models that support prediction intervals.""" pass -class ContextRequiredAbstractModel(ABC): - """Interface that is used to mark classes that need context for prediction.""" +class ContextIgnorantAbstractModel(AbstractModel): + """Interface for models that don't need context for prediction.""" + + pass + + +class ContextRequiredAbstractModel(AbstractModel): + """Interface for models that need context for prediction.""" @property @abstractmethod @@ -161,67 +177,116 @@ def context_size(self) -> int: pass -class ForecastAbstractModel(ABC): - """Interface for model with forecast method.""" +class NonPredictionIntervalContextIgnorantAbstractModel( + NonPredictionIntervalAbstractModel, ContextIgnorantAbstractModel +): + """Interface for models that don't support prediction intervals and don't need context for prediction.""" - def _extract_prediction_interval_params(self, **kwargs) -> Dict[str, Any]: - extracted_params = {} - class_name = self.__class__.__name__ + def forecast(self, ts: TSDataset) -> TSDataset: + """Make predictions. - if isinstance(self, PredictionIntervalAbstractModel): - prediction_interval = kwargs.get("prediction_interval", False) - extracted_params["prediction_interval"] = prediction_interval + Parameters + ---------- + ts: + Dataset with features - quantiles = kwargs.get("quantiles", (0.025, 0.975)) - extracted_params["quantiles"] = quantiles - else: - if "prediction_interval" in kwargs or "quantiles" in kwargs: - raise NotImplementedError(f"Model {class_name} doesn't support prediction intervals!") + Returns + ------- + : + Dataset with predictions + """ + return self._forecast(ts=ts) - return extracted_params - def _extract_prediction_size_params(self, **kwargs): - extracted_params = {} - class_name = self.__class__.__name__ +class NonPredictionIntervalContextRequiredAbstractModel( + NonPredictionIntervalAbstractModel, ContextRequiredAbstractModel +): + """Interface for models that don't support prediction intervals and need context for prediction.""" - if isinstance(self, ContextRequiredAbstractModel): - prediction_size = kwargs.get("prediction_size") - if prediction_size is None: - raise ValueError(f"Parameter prediction_size is required for {class_name} model!") + def forecast(self, ts: TSDataset, prediction_size: int) -> TSDataset: + """Make predictions. - if not isinstance(prediction_size, int) or prediction_size <= 0: - raise ValueError(f"Parameter prediction_size should be positive integer!") + Parameters + ---------- + ts: + Dataset with features + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context for models that require it. - extracted_params = prediction_size - else: - if "prediction_size" in kwargs: - raise NotImplementedError(f"Model {class_name} doesn't support prediction_size parameter!") + Returns + ------- + : + Dataset with predictions + """ + return self._forecast(ts=ts, prediction_size=prediction_size) - return extracted_params - @abstractmethod - def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: +class PredictionIntervalContextIgnorantAbstractModel(NonPredictionIntervalAbstractModel, ContextIgnorantAbstractModel): + """Interface for models that support prediction intervals and don't need context for prediction.""" + + def forecast( + self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) + ) -> TSDataset: """Make predictions. Parameters ---------- ts: Dataset with features + prediction_interval: + If True returns prediction interval for forecast + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval Returns ------- : Dataset with predictions """ - pass + return self._forecast(ts=ts, prediction_interval=prediction_interval, quantiles=quantiles) + +class PredictionIntervalContextRequiredAbstractModel(NonPredictionIntervalAbstractModel, ContextRequiredAbstractModel): + """Interface for models that support prediction intervals and need context for prediction.""" -class PerSegmentBaseModel(FitAbstractModel, BaseMixin): + def forecast( + self, + ts: TSDataset, + prediction_size: int, + prediction_interval: bool = False, + quantiles: Sequence[float] = (0.025, 0.975), + ) -> TSDataset: + """Make predictions. + + Parameters + ---------- + ts: + Dataset with features + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context for models that require it. + prediction_interval: + If True returns prediction interval for forecast + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval + + Returns + ------- + : + Dataset with predictions + """ + return self._forecast( + ts=ts, prediction_size=prediction_size, prediction_interval=prediction_interval, quantiles=quantiles + ) + + +class PerSegmentModel(AbstractModel): """Base class for holding specific models for per-segment prediction.""" def __init__(self, base_model: Any): """ - Init PerSegmentBaseModel. + Init PerSegmentModel. Parameters ---------- @@ -232,7 +297,7 @@ def __init__(self, base_model: Any): self._models: Optional[Dict[str, Any]] = None @log_decorator - def fit(self, ts: TSDataset) -> "PerSegmentBaseModel": + def fit(self, ts: TSDataset) -> "PerSegmentModel": """Fit model. Parameters @@ -304,23 +369,8 @@ def _forecast_segment(model: Any, segment: str, ts: TSDataset, *args, **kwargs) segment_predict["timestamp"] = dates return segment_predict - -class PerSegmentModel(PerSegmentBaseModel, ForecastAbstractModel): - """Class for holding specific models for per-segment prediction.""" - - def __init__(self, base_model: Any): - """ - Init PerSegmentBaseModel. - - Parameters - ---------- - base_model: - Internal model which will be used to forecast segments, expected to have fit/predict interface - """ - super().__init__(base_model=base_model) - @log_decorator - def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: + def _forecast(self, ts: TSDataset, **kwargs) -> TSDataset: """Make predictions. Parameters @@ -333,13 +383,9 @@ def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: : Dataset with predictions """ - forecast_params = {} - forecast_params.update(self._extract_prediction_interval_params(**kwargs)) - forecast_params.update(self._extract_prediction_size_params(**kwargs)) - result_list = list() for segment, model in self._get_model().items(): - segment_predict = self._forecast_segment(model=model, segment=segment, ts=ts, **forecast_params) + segment_predict = self._forecast_segment(model=model, segment=segment, ts=ts, **kwargs) result_list.append(segment_predict) @@ -355,7 +401,7 @@ def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: return ts -class MultiSegmentModel(FitAbstractModel, ForecastAbstractModel, BaseMixin): +class MultiSegmentModel(AbstractModel): """Class for holding specific models for per-segment prediction.""" def __init__(self, base_model: Any): @@ -390,7 +436,7 @@ def fit(self, ts: TSDataset) -> "MultiSegmentModel": return self @log_decorator - def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: + def _forecast(self, ts: TSDataset, **kwargs) -> TSDataset: """Make predictions. Parameters @@ -403,13 +449,9 @@ def forecast(self, ts: TSDataset, **kwargs) -> TSDataset: : Dataset with predictions """ - forecast_params = {} - forecast_params.update(self._extract_prediction_interval_params(**kwargs)) - forecast_params.update(self._extract_prediction_size_params(**kwargs)) - horizon = len(ts.df) x = ts.to_pandas(flatten=True).drop(["segment"], axis=1) - y = self._base_model.predict(x, **forecast_params).reshape(-1, horizon).T + y = self._base_model.predict(x, **kwargs).reshape(-1, horizon).T ts.loc[:, pd.IndexSlice[:, "target"]] = y ts.inverse_transform() return ts @@ -491,24 +533,6 @@ def step(self, batch: dict, *args, **kwargs) -> Tuple["torch.Tensor", "torch.Ten class DeepBaseAbstractModel(ABC): """Interface for holding class of etna native deep models.""" - @abstractmethod - def forecast(self, ts: TSDataset, horizon: int) -> TSDataset: - """Make predictions. - - Parameters - ---------- - ts: - Dataset with features and expected decoder length for context - horizon: - Horizon to predict for - - Returns - ------- - : - Dataset with predictions - """ - pass - @abstractmethod def raw_fit(self, torch_dataset: "Dataset") -> "DeepBaseAbstractModel": """Fit model with torch like Dataset. @@ -595,7 +619,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore return loss -class DeepBaseModel(FitAbstractModel, DeepBaseAbstractModel, ContextRequiredAbstractModel, BaseMixin): +class DeepBaseModel(DeepBaseAbstractModel, NonPredictionIntervalContextRequiredAbstractModel): """Class for partially implemented interfaces for holding deep models.""" def __init__( @@ -766,7 +790,7 @@ def raw_predict(self, torch_dataset: "Dataset") -> Dict[Tuple[str, str], np.ndar return predictions_dict @log_decorator - def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset": + def _forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset": """Make predictions. Parameters @@ -809,7 +833,8 @@ def get_model(self) -> "DeepBaseNet": BaseModel = Union[ - PerSegmentModel, - MultiSegmentModel, - DeepBaseModel, + NonPredictionIntervalContextIgnorantAbstractModel, + NonPredictionIntervalContextRequiredAbstractModel, + PredictionIntervalContextIgnorantAbstractModel, + PredictionIntervalContextRequiredAbstractModel, ] diff --git a/etna/models/catboost.py b/etna/models/catboost.py index 19337979b..49146f0aa 100644 --- a/etna/models/catboost.py +++ b/etna/models/catboost.py @@ -9,6 +9,7 @@ from etna.models.base import BaseAdapter from etna.models.base import MultiSegmentModel +from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel from etna.models.base import PerSegmentModel @@ -88,7 +89,7 @@ def get_model(self) -> CatBoostRegressor: return self.model -class CatBoostPerSegmentModel(PerSegmentModel): +class CatBoostPerSegmentModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): """Class for holding per segment Catboost model. Examples @@ -210,7 +211,7 @@ def __init__( ) -class CatBoostMultiSegmentModel(MultiSegmentModel): +class CatBoostMultiSegmentModel(MultiSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): """Class for holding Catboost model for all segments. Examples diff --git a/etna/models/deadline_ma.py b/etna/models/deadline_ma.py index 989cb7b1d..16b08dd89 100644 --- a/etna/models/deadline_ma.py +++ b/etna/models/deadline_ma.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel from etna.models.base import PerSegmentModel @@ -155,7 +156,7 @@ def context_size(self) -> int: return cur_value -class DeadlineMovingAverageModel(PerSegmentModel): +class DeadlineMovingAverageModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): """Moving average model that uses exact previous dates to predict.""" def __init__(self, window: int = 3, seasonality: str = "month"): diff --git a/etna/models/holt_winters.py b/etna/models/holt_winters.py index c219d610d..b0961aea5 100644 --- a/etna/models/holt_winters.py +++ b/etna/models/holt_winters.py @@ -13,6 +13,7 @@ from statsmodels.tsa.holtwinters import HoltWintersResults from etna.models.base import BaseAdapter +from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel from etna.models.base import PerSegmentModel @@ -276,7 +277,7 @@ def get_model(self) -> ExponentialSmoothing: return self._model -class HoltWintersModel(PerSegmentModel): +class HoltWintersModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): """ Holt-Winters' etna model. diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index f182164e1..a40a1accb 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -8,12 +8,9 @@ import pandas as pd from etna import SETTINGS -from etna.core.mixins import BaseMixin from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger -from etna.models.base import FitAbstractModel -from etna.models.base import ForecastAbstractModel -from etna.models.base import PredictionIntervalAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -27,7 +24,7 @@ from pytorch_lightning import LightningModule -class DeepARModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalAbstractModel, BaseMixin, _DeepCopyMixin): +class DeepARModel(PredictionIntervalContextIgnorantAbstractModel, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`. Notes @@ -173,7 +170,7 @@ def fit(self, ts: TSDataset) -> "DeepARModel": return self @log_decorator - def forecast( + def _forecast( self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) ) -> TSDataset: """ diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index c984bd546..6e2870c71 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -9,12 +9,9 @@ import pandas as pd from etna import SETTINGS -from etna.core.mixins import BaseMixin from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger -from etna.models.base import FitAbstractModel -from etna.models.base import ForecastAbstractModel -from etna.models.base import PredictionIntervalAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -28,7 +25,7 @@ from pytorch_lightning import LightningModule -class TFTModel(FitAbstractModel, ForecastAbstractModel, PredictionIntervalAbstractModel, BaseMixin, _DeepCopyMixin): +class TFTModel(PredictionIntervalContextIgnorantAbstractModel, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. Notes @@ -180,7 +177,7 @@ def fit(self, ts: TSDataset) -> "TFTModel": return self @log_decorator - def forecast( + def _forecast( self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) ) -> TSDataset: """ diff --git a/etna/models/prophet.py b/etna/models/prophet.py index fc2ea6d92..f628aed8a 100644 --- a/etna/models/prophet.py +++ b/etna/models/prophet.py @@ -11,7 +11,7 @@ from etna import SETTINGS from etna.models.base import BaseAdapter from etna.models.base import PerSegmentModel -from etna.models.base import PredictionIntervalAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel if SETTINGS.prophet_required: from prophet import Prophet @@ -154,7 +154,7 @@ def get_model(self) -> Prophet: return self.model -class ProphetModel(PerSegmentModel, PredictionIntervalAbstractModel): +class ProphetModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): """Class for holding Prophet model. Notes diff --git a/etna/models/sarimax.py b/etna/models/sarimax.py index d09192af1..d05c6082f 100644 --- a/etna/models/sarimax.py +++ b/etna/models/sarimax.py @@ -14,7 +14,7 @@ from etna.libs.pmdarima_utils import seasonal_prediction_with_confidence from etna.models.base import BaseAdapter from etna.models.base import PerSegmentModel -from etna.models.base import PredictionIntervalAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel from etna.models.utils import determine_num_steps warnings.filterwarnings( @@ -353,7 +353,7 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame): return result -class SARIMAXModel(PerSegmentModel, PredictionIntervalAbstractModel): +class SARIMAXModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): """ Class for holding Sarimax model. diff --git a/etna/models/seasonal_ma.py b/etna/models/seasonal_ma.py index d8cea12ad..4d69553be 100644 --- a/etna/models/seasonal_ma.py +++ b/etna/models/seasonal_ma.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel from etna.models.base import PerSegmentModel @@ -92,7 +93,7 @@ def predict(self, df: pd.DataFrame) -> np.ndarray: return y_pred -class SeasonalMovingAverageModel(PerSegmentModel): +class SeasonalMovingAverageModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): """ Seasonal moving average. diff --git a/etna/models/sklearn.py b/etna/models/sklearn.py index 04d777d01..d1b902e93 100644 --- a/etna/models/sklearn.py +++ b/etna/models/sklearn.py @@ -7,6 +7,7 @@ from etna.models.base import BaseAdapter from etna.models.base import MultiSegmentModel +from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel from etna.models.base import PerSegmentModel @@ -72,7 +73,7 @@ def get_model(self) -> RegressorMixin: return self.model -class SklearnPerSegmentModel(PerSegmentModel): +class SklearnPerSegmentModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): """Class for holding per segment Sklearn model.""" def __init__(self, regressor: RegressorMixin): @@ -87,7 +88,7 @@ def __init__(self, regressor: RegressorMixin): super().__init__(base_model=_SklearnAdapter(regressor=regressor)) -class SklearnMultiSegmentModel(MultiSegmentModel): +class SklearnMultiSegmentModel(MultiSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): """Class for holding Sklearn model for all segments.""" def __init__(self, regressor: RegressorMixin): diff --git a/etna/models/tbats.py b/etna/models/tbats.py index cf70e65fa..90569f58b 100644 --- a/etna/models/tbats.py +++ b/etna/models/tbats.py @@ -11,7 +11,7 @@ from etna.models.base import BaseAdapter from etna.models.base import PerSegmentModel -from etna.models.base import PredictionIntervalAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel from etna.models.utils import determine_num_steps @@ -73,7 +73,7 @@ def get_model(self) -> Estimator: return self.model -class BATSModel(PerSegmentModel, PredictionIntervalAbstractModel): +class BATSModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): """Class for holding segment interval BATS model.""" def __init__( @@ -140,7 +140,7 @@ def __init__( super().__init__(base_model=_TBATSAdapter(self.model)) -class TBATSModel(PerSegmentModel, PredictionIntervalAbstractModel): +class TBATSModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): """Class for holding segment interval TBATS model.""" def __init__( diff --git a/etna/pipeline/autoregressive_pipeline.py b/etna/pipeline/autoregressive_pipeline.py index 002836bc1..c3ad582d9 100644 --- a/etna/pipeline/autoregressive_pipeline.py +++ b/etna/pipeline/autoregressive_pipeline.py @@ -5,6 +5,8 @@ from etna.datasets import TSDataset from etna.models.base import BaseModel +from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel +from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.pipeline.base import BasePipeline from etna.transforms import Transform @@ -121,17 +123,36 @@ def _forecast(self) -> TSDataset: ) # manually set transforms in current_ts, otherwise make_future won't know about them current_ts.transforms = self.transforms - with warnings.catch_warnings(): - warnings.filterwarnings( - message="TSDataset freq can't be inferred", - action="ignore", - ) - warnings.filterwarnings( - message="You probably set wrong freq.", - action="ignore", - ) - current_ts_forecast = current_ts.make_future(current_step) - current_ts_future = self.model.forecast(current_ts_forecast) + + if isinstance(self.model, NonPredictionIntervalContextRequiredAbstractModel) or isinstance( + self.model, PredictionIntervalContextRequiredAbstractModel + ): + with warnings.catch_warnings(): + warnings.filterwarnings( + message="TSDataset freq can't be inferred", + action="ignore", + ) + warnings.filterwarnings( + message="You probably set wrong freq.", + action="ignore", + ) + current_ts_forecast = current_ts.make_future( + future_steps=current_step, tail_steps=self.model.context_size + ) + current_ts_future = self.model.forecast(current_ts_forecast, prediction_size=current_step) + else: + with warnings.catch_warnings(): + warnings.filterwarnings( + message="TSDataset freq can't be inferred", + action="ignore", + ) + warnings.filterwarnings( + message="You probably set wrong freq.", + action="ignore", + ) + current_ts_forecast = current_ts.make_future(future_steps=current_step) + current_ts_future = self.model.forecast(current_ts_forecast) + prediction_df = prediction_df.combine_first(current_ts_future.to_pandas()[prediction_df.columns]) # construct dataset and add all features diff --git a/etna/pipeline/pipeline.py b/etna/pipeline/pipeline.py index bf6b1c697..ed4955407 100644 --- a/etna/pipeline/pipeline.py +++ b/etna/pipeline/pipeline.py @@ -2,9 +2,10 @@ from etna.datasets import TSDataset from etna.models.base import BaseModel -from etna.models.base import ContextRequiredAbstractModel from etna.models.base import DeepBaseModel -from etna.models.base import PredictionIntervalAbstractModel +from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel +from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.pipeline.base import BasePipeline from etna.transforms.base import Transform @@ -55,7 +56,9 @@ def _forecast(self) -> TSDataset: if self.ts is None: raise ValueError("Something went wrong, ts is None!") - if isinstance(self.model, ContextRequiredAbstractModel): + if isinstance(self.model, NonPredictionIntervalContextRequiredAbstractModel) or isinstance( + self.model, PredictionIntervalContextRequiredAbstractModel + ): if isinstance(self.model, DeepBaseModel): future = self.ts.make_future( future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length @@ -95,9 +98,14 @@ def forecast( self._validate_quantiles(quantiles=quantiles) self._validate_backtest_n_folds(n_folds=n_folds) - if prediction_interval and isinstance(self.model, PredictionIntervalAbstractModel): + if prediction_interval and isinstance(self.model, PredictionIntervalContextIgnorantAbstractModel): future = self.ts.make_future(self.horizon) predictions = self.model.forecast(ts=future, prediction_interval=prediction_interval, quantiles=quantiles) + elif prediction_interval and isinstance(self.model, PredictionIntervalContextRequiredAbstractModel): + future = self.ts.make_future(future_steps=self.horizon, tail_steps=self.model.context_size) + predictions = self.model.forecast( + ts=future, prediction_size=self.horizon, prediction_interval=prediction_interval, quantiles=quantiles + ) else: predictions = super().forecast( prediction_interval=prediction_interval, quantiles=quantiles, n_folds=n_folds diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index 9ae93dffc..d920eaa86 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -141,7 +141,7 @@ def test_deep_base_model_raw_predict_call(dataloader, deep_base_model_mock): def test_deep_base_model_forecast_inverse_transform_call_check(deep_base_model_mock): ts = MagicMock() horizon = 7 - DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) + DeepBaseModel._forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) ts.tsdataset_idx_slice.return_value.inverse_transform.assert_called_once() @@ -156,7 +156,7 @@ def test_deep_base_model_forecast_loop(simple_df, deep_base_model_mock): ts_after_tsdataset_idx_slice.df = simple_df.df.iloc[-horizon:] ts.tsdataset_idx_slice.return_value = ts_after_tsdataset_idx_slice - future = DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) + future = DeepBaseModel._forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) np.testing.assert_allclose( future.df.loc[:, pd.IndexSlice["A", "target"]], raw_predict[("A", "target")][:horizon, 0] ) From 46828f50b569c8367439d024a8b9fa7e736e5299 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Wed, 31 Aug 2022 13:16:12 +0300 Subject: [PATCH 07/15] Add mixins, rearrange code --- etna/models/__init__.py | 2 +- etna/models/autoarima.py | 7 +- etna/models/base.py | 160 +++++++++++++++++++++++++-------- etna/models/catboost.py | 17 +++- etna/models/deadline_ma.py | 11 ++- etna/models/holt_winters.py | 9 +- etna/models/nn/deepar.py | 6 +- etna/models/nn/tft.py | 8 +- etna/models/prophet.py | 7 +- etna/models/sarimax.py | 7 +- etna/models/seasonal_ma.py | 11 ++- etna/models/sklearn.py | 17 +++- etna/models/tbats.py | 11 ++- tests/test_models/test_base.py | 4 +- 14 files changed, 206 insertions(+), 71 deletions(-) diff --git a/etna/models/__init__.py b/etna/models/__init__.py index 17a804329..2b2dcb9b4 100644 --- a/etna/models/__init__.py +++ b/etna/models/__init__.py @@ -3,7 +3,7 @@ from etna.models.base import BaseAdapter from etna.models.base import BaseModel from etna.models.base import Model -from etna.models.base import PerSegmentModel +from etna.models.base import PerSegmentModelMixin from etna.models.catboost import CatBoostModelMultiSegment from etna.models.catboost import CatBoostModelPerSegment from etna.models.catboost import CatBoostMultiSegmentModel diff --git a/etna/models/autoarima.py b/etna/models/autoarima.py index 15725f36f..97270d687 100644 --- a/etna/models/autoarima.py +++ b/etna/models/autoarima.py @@ -5,8 +5,9 @@ from statsmodels.tools.sm_exceptions import ValueWarning from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper -from etna.models.base import PerSegmentModel +from etna.models.base import PerSegmentModelMixin from etna.models.base import PredictionIntervalContextIgnorantAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantModelMixin from etna.models.sarimax import _SARIMAXBaseAdapter warnings.filterwarnings( @@ -49,7 +50,9 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResul return model.arima_res_ -class AutoARIMAModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): +class AutoARIMAModel( + PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel +): """ Class for holding auto arima model. diff --git a/etna/models/base.py b/etna/models/base.py index 077b75540..c0d727a6d 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -109,6 +109,12 @@ def _forecast_segment(model, segment: Union[str, List[str]], ts: TSDataset) -> p class AbstractModel(ABC, BaseMixin): """Interface for model with fit method.""" + @property + @abstractmethod + def context_size(self) -> int: + """Context size of the model. Determines how many history points do we ask to pass to the model.""" + pass + @abstractmethod def fit(self, ts: TSDataset) -> "AbstractModel": """Fit model. @@ -125,10 +131,6 @@ def fit(self, ts: TSDataset) -> "AbstractModel": """ pass - @abstractmethod - def _forecast(self, ts: TSDataset, **kwargs): - pass - @abstractmethod def get_model(self) -> Union[Any, Dict[str, Any]]: """Get internal model/models that are used inside etna class. @@ -149,38 +151,122 @@ def get_model(self) -> Union[Any, Dict[str, Any]]: pass -class NonPredictionIntervalAbstractModel(AbstractModel): - """Interface for models that don't support prediction intervals.""" +class NonPredictionIntervalContextIgnorantAbstractModel(AbstractModel): + """Interface for models that don't support prediction intervals and don't need context for prediction.""" - pass + context_size: int = 0 + @abstractmethod + def forecast(self, ts: TSDataset) -> TSDataset: + """Make predictions. -class PredictionIntervalAbstractModel(AbstractModel): - """Interface for models that support prediction intervals.""" + Parameters + ---------- + ts: + Dataset with features + + Returns + ------- + : + Dataset with predictions + """ + pass + + +class NonPredictionIntervalContextRequiredAbstractModel(AbstractModel): + """Interface for models that don't support prediction intervals and need context for prediction.""" - pass + @abstractmethod + def forecast(self, ts: TSDataset, prediction_size: int) -> TSDataset: + """Make predictions. + Parameters + ---------- + ts: + Dataset with features + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context for models that require it. -class ContextIgnorantAbstractModel(AbstractModel): - """Interface for models that don't need context for prediction.""" + Returns + ------- + : + Dataset with predictions + """ + pass - pass +class PredictionIntervalContextIgnorantAbstractModel(AbstractModel): + """Interface for models that support prediction intervals and don't need context for prediction.""" -class ContextRequiredAbstractModel(AbstractModel): - """Interface for models that need context for prediction.""" + context_size: int = 0 - @property @abstractmethod - def context_size(self) -> int: - """Context size of the model. Determines how many history points do we ask to pass to the model.""" + def forecast( + self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) + ) -> TSDataset: + """Make predictions. + + Parameters + ---------- + ts: + Dataset with features + prediction_interval: + If True returns prediction interval for forecast + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval + + Returns + ------- + : + Dataset with predictions + """ pass -class NonPredictionIntervalContextIgnorantAbstractModel( - NonPredictionIntervalAbstractModel, ContextIgnorantAbstractModel -): - """Interface for models that don't support prediction intervals and don't need context for prediction.""" +class PredictionIntervalContextRequiredAbstractModel(AbstractModel): + """Interface for models that support prediction intervals and need context for prediction.""" + + @abstractmethod + def forecast( + self, + ts: TSDataset, + prediction_size: int, + prediction_interval: bool = False, + quantiles: Sequence[float] = (0.025, 0.975), + ) -> TSDataset: + """Make predictions. + + Parameters + ---------- + ts: + Dataset with features + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context for models that require it. + prediction_interval: + If True returns prediction interval for forecast + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval + + Returns + ------- + : + Dataset with predictions + """ + pass + + +class ModelForecastMixin(ABC): + """Base for mixins.""" + + @abstractmethod + def _forecast(self, **kwargs) -> TSDataset: + pass + + +class NonPredictionIntervalContextIgnorantModelMixin(ModelForecastMixin): + """Mixin for models that don't support prediction intervals and don't need context for prediction.""" def forecast(self, ts: TSDataset) -> TSDataset: """Make predictions. @@ -198,10 +284,8 @@ def forecast(self, ts: TSDataset) -> TSDataset: return self._forecast(ts=ts) -class NonPredictionIntervalContextRequiredAbstractModel( - NonPredictionIntervalAbstractModel, ContextRequiredAbstractModel -): - """Interface for models that don't support prediction intervals and need context for prediction.""" +class NonPredictionIntervalContextRequiredModelMixin(ModelForecastMixin): + """Mixin for models that don't support prediction intervals and need context for prediction.""" def forecast(self, ts: TSDataset, prediction_size: int) -> TSDataset: """Make predictions. @@ -222,8 +306,8 @@ def forecast(self, ts: TSDataset, prediction_size: int) -> TSDataset: return self._forecast(ts=ts, prediction_size=prediction_size) -class PredictionIntervalContextIgnorantAbstractModel(NonPredictionIntervalAbstractModel, ContextIgnorantAbstractModel): - """Interface for models that support prediction intervals and don't need context for prediction.""" +class PredictionIntervalContextIgnorantModelMixin(ModelForecastMixin): + """Mixin for models that support prediction intervals and don't need context for prediction.""" def forecast( self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) @@ -247,8 +331,8 @@ def forecast( return self._forecast(ts=ts, prediction_interval=prediction_interval, quantiles=quantiles) -class PredictionIntervalContextRequiredAbstractModel(NonPredictionIntervalAbstractModel, ContextRequiredAbstractModel): - """Interface for models that support prediction intervals and need context for prediction.""" +class PredictionIntervalContextRequiredModelMixin(ModelForecastMixin): + """Mixin for models that support prediction intervals and need context for prediction.""" def forecast( self, @@ -281,12 +365,12 @@ def forecast( ) -class PerSegmentModel(AbstractModel): - """Base class for holding specific models for per-segment prediction.""" +class PerSegmentModelMixin: + """Mixin for holding methods for per-segment prediction.""" def __init__(self, base_model: Any): """ - Init PerSegmentModel. + Init PerSegmentModelMixin. Parameters ---------- @@ -297,7 +381,7 @@ def __init__(self, base_model: Any): self._models: Optional[Dict[str, Any]] = None @log_decorator - def fit(self, ts: TSDataset) -> "PerSegmentModel": + def fit(self, ts: TSDataset) -> "PerSegmentModelMixin": """Fit model. Parameters @@ -401,8 +485,8 @@ def _forecast(self, ts: TSDataset, **kwargs) -> TSDataset: return ts -class MultiSegmentModel(AbstractModel): - """Class for holding specific models for per-segment prediction.""" +class MultiSegmentModelMixin: + """Mixin for holding methods for multi-segment prediction.""" def __init__(self, base_model: Any): """ @@ -416,7 +500,7 @@ def __init__(self, base_model: Any): self._base_model = base_model @log_decorator - def fit(self, ts: TSDataset) -> "MultiSegmentModel": + def fit(self, ts: TSDataset) -> "MultiSegmentModelMixin": """Fit model. Parameters @@ -790,7 +874,7 @@ def raw_predict(self, torch_dataset: "Dataset") -> Dict[Tuple[str, str], np.ndar return predictions_dict @log_decorator - def _forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset": + def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset": """Make predictions. Parameters diff --git a/etna/models/catboost.py b/etna/models/catboost.py index 49146f0aa..2e2dd6050 100644 --- a/etna/models/catboost.py +++ b/etna/models/catboost.py @@ -8,9 +8,10 @@ from deprecated import deprecated from etna.models.base import BaseAdapter -from etna.models.base import MultiSegmentModel +from etna.models.base import MultiSegmentModelMixin from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel -from etna.models.base import PerSegmentModel +from etna.models.base import NonPredictionIntervalContextIgnorantModelMixin +from etna.models.base import PerSegmentModelMixin class _CatBoostAdapter(BaseAdapter): @@ -89,7 +90,11 @@ def get_model(self) -> CatBoostRegressor: return self.model -class CatBoostPerSegmentModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): +class CatBoostPerSegmentModel( + PerSegmentModelMixin, + NonPredictionIntervalContextIgnorantModelMixin, + NonPredictionIntervalContextIgnorantAbstractModel, +): """Class for holding per segment Catboost model. Examples @@ -211,7 +216,11 @@ def __init__( ) -class CatBoostMultiSegmentModel(MultiSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): +class CatBoostMultiSegmentModel( + MultiSegmentModelMixin, + NonPredictionIntervalContextIgnorantModelMixin, + NonPredictionIntervalContextIgnorantAbstractModel, +): """Class for holding Catboost model for all segments. Examples diff --git a/etna/models/deadline_ma.py b/etna/models/deadline_ma.py index 16b08dd89..fe29f174d 100644 --- a/etna/models/deadline_ma.py +++ b/etna/models/deadline_ma.py @@ -7,7 +7,8 @@ import pandas as pd from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel -from etna.models.base import PerSegmentModel +from etna.models.base import NonPredictionIntervalContextIgnorantModelMixin +from etna.models.base import PerSegmentModelMixin class SeasonalityMode(Enum): @@ -156,7 +157,11 @@ def context_size(self) -> int: return cur_value -class DeadlineMovingAverageModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): +class DeadlineMovingAverageModel( + PerSegmentModelMixin, + NonPredictionIntervalContextIgnorantModelMixin, + NonPredictionIntervalContextIgnorantAbstractModel, +): """Moving average model that uses exact previous dates to predict.""" def __init__(self, window: int = 3, seasonality: str = "month"): @@ -177,7 +182,7 @@ def __init__(self, window: int = 3, seasonality: str = "month"): ) @property - def context_size(self) -> int: + def context_size(self) -> int: # type: ignore """Upper bound to context size of the model.""" models = self.get_model() model = next(iter(models.values())) diff --git a/etna/models/holt_winters.py b/etna/models/holt_winters.py index b0961aea5..2405b4080 100644 --- a/etna/models/holt_winters.py +++ b/etna/models/holt_winters.py @@ -14,7 +14,8 @@ from etna.models.base import BaseAdapter from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel -from etna.models.base import PerSegmentModel +from etna.models.base import NonPredictionIntervalContextIgnorantModelMixin +from etna.models.base import PerSegmentModelMixin class _HoltWintersAdapter(BaseAdapter): @@ -277,7 +278,11 @@ def get_model(self) -> ExponentialSmoothing: return self._model -class HoltWintersModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): +class HoltWintersModel( + PerSegmentModelMixin, + NonPredictionIntervalContextIgnorantModelMixin, + NonPredictionIntervalContextIgnorantAbstractModel, +): """ Holt-Winters' etna model. diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index a40a1accb..5f4fa2897 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -24,7 +24,7 @@ from pytorch_lightning import LightningModule -class DeepARModel(PredictionIntervalContextIgnorantAbstractModel, _DeepCopyMixin): +class DeepARModel(_DeepCopyMixin, PredictionIntervalContextIgnorantAbstractModel): """Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`. Notes @@ -33,6 +33,8 @@ class DeepARModel(PredictionIntervalContextIgnorantAbstractModel, _DeepCopyMixin It`s not right pattern of using Transforms and TSDataset. """ + context_size = 0 + def __init__( self, batch_size: int = 64, @@ -170,7 +172,7 @@ def fit(self, ts: TSDataset) -> "DeepARModel": return self @log_decorator - def _forecast( + def forecast( self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) ) -> TSDataset: """ diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index 6e2870c71..da45494b3 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -25,7 +25,7 @@ from pytorch_lightning import LightningModule -class TFTModel(PredictionIntervalContextIgnorantAbstractModel, _DeepCopyMixin): +class TFTModel(_DeepCopyMixin, PredictionIntervalContextIgnorantAbstractModel): """Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. Notes @@ -34,6 +34,8 @@ class TFTModel(PredictionIntervalContextIgnorantAbstractModel, _DeepCopyMixin): It`s not right pattern of using Transforms and TSDataset. """ + context_size = 0 + def __init__( self, max_epochs: int = 10, @@ -88,7 +90,7 @@ def __init__( quantiles_kwargs: Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss. """ - super().__init__() + # super().__init__() if loss is None: loss = QuantileLoss() self.max_epochs = max_epochs @@ -177,7 +179,7 @@ def fit(self, ts: TSDataset) -> "TFTModel": return self @log_decorator - def _forecast( + def forecast( self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975) ) -> TSDataset: """ diff --git a/etna/models/prophet.py b/etna/models/prophet.py index f628aed8a..2354bda28 100644 --- a/etna/models/prophet.py +++ b/etna/models/prophet.py @@ -10,8 +10,9 @@ from etna import SETTINGS from etna.models.base import BaseAdapter -from etna.models.base import PerSegmentModel +from etna.models.base import PerSegmentModelMixin from etna.models.base import PredictionIntervalContextIgnorantAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantModelMixin if SETTINGS.prophet_required: from prophet import Prophet @@ -154,7 +155,9 @@ def get_model(self) -> Prophet: return self.model -class ProphetModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): +class ProphetModel( + PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel +): """Class for holding Prophet model. Notes diff --git a/etna/models/sarimax.py b/etna/models/sarimax.py index d05c6082f..7fc5c6cec 100644 --- a/etna/models/sarimax.py +++ b/etna/models/sarimax.py @@ -13,8 +13,9 @@ from etna.libs.pmdarima_utils import seasonal_prediction_with_confidence from etna.models.base import BaseAdapter -from etna.models.base import PerSegmentModel +from etna.models.base import PerSegmentModelMixin from etna.models.base import PredictionIntervalContextIgnorantAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantModelMixin from etna.models.utils import determine_num_steps warnings.filterwarnings( @@ -353,7 +354,9 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame): return result -class SARIMAXModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): +class SARIMAXModel( + PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel +): """ Class for holding Sarimax model. diff --git a/etna/models/seasonal_ma.py b/etna/models/seasonal_ma.py index 4d69553be..4bea401d4 100644 --- a/etna/models/seasonal_ma.py +++ b/etna/models/seasonal_ma.py @@ -6,7 +6,8 @@ import pandas as pd from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel -from etna.models.base import PerSegmentModel +from etna.models.base import NonPredictionIntervalContextIgnorantModelMixin +from etna.models.base import PerSegmentModelMixin class _SeasonalMovingAverageModel: @@ -93,7 +94,11 @@ def predict(self, df: pd.DataFrame) -> np.ndarray: return y_pred -class SeasonalMovingAverageModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): +class SeasonalMovingAverageModel( + PerSegmentModelMixin, + NonPredictionIntervalContextIgnorantModelMixin, + NonPredictionIntervalContextIgnorantAbstractModel, +): """ Seasonal moving average. @@ -123,7 +128,7 @@ def __init__(self, window: int = 5, seasonality: int = 7): ) @property - def context_size(self) -> int: + def context_size(self) -> int: # type: ignore """Context size of the model.""" return self.window * self.seasonality diff --git a/etna/models/sklearn.py b/etna/models/sklearn.py index d1b902e93..6c2207da9 100644 --- a/etna/models/sklearn.py +++ b/etna/models/sklearn.py @@ -6,9 +6,10 @@ from sklearn.base import RegressorMixin from etna.models.base import BaseAdapter -from etna.models.base import MultiSegmentModel +from etna.models.base import MultiSegmentModelMixin from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel -from etna.models.base import PerSegmentModel +from etna.models.base import NonPredictionIntervalContextIgnorantModelMixin +from etna.models.base import PerSegmentModelMixin class _SklearnAdapter(BaseAdapter): @@ -73,7 +74,11 @@ def get_model(self) -> RegressorMixin: return self.model -class SklearnPerSegmentModel(PerSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): +class SklearnPerSegmentModel( + PerSegmentModelMixin, + NonPredictionIntervalContextIgnorantModelMixin, + NonPredictionIntervalContextIgnorantAbstractModel, +): """Class for holding per segment Sklearn model.""" def __init__(self, regressor: RegressorMixin): @@ -88,7 +93,11 @@ def __init__(self, regressor: RegressorMixin): super().__init__(base_model=_SklearnAdapter(regressor=regressor)) -class SklearnMultiSegmentModel(MultiSegmentModel, NonPredictionIntervalContextIgnorantAbstractModel): +class SklearnMultiSegmentModel( + MultiSegmentModelMixin, + NonPredictionIntervalContextIgnorantModelMixin, + NonPredictionIntervalContextIgnorantAbstractModel, +): """Class for holding Sklearn model for all segments.""" def __init__(self, regressor: RegressorMixin): diff --git a/etna/models/tbats.py b/etna/models/tbats.py index 90569f58b..cc4c1e15d 100644 --- a/etna/models/tbats.py +++ b/etna/models/tbats.py @@ -10,8 +10,9 @@ from tbats.tbats.Model import Model from etna.models.base import BaseAdapter -from etna.models.base import PerSegmentModel +from etna.models.base import PerSegmentModelMixin from etna.models.base import PredictionIntervalContextIgnorantAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantModelMixin from etna.models.utils import determine_num_steps @@ -73,7 +74,9 @@ def get_model(self) -> Estimator: return self.model -class BATSModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): +class BATSModel( + PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel +): """Class for holding segment interval BATS model.""" def __init__( @@ -140,7 +143,9 @@ def __init__( super().__init__(base_model=_TBATSAdapter(self.model)) -class TBATSModel(PerSegmentModel, PredictionIntervalContextIgnorantAbstractModel): +class TBATSModel( + PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel +): """Class for holding segment interval TBATS model.""" def __init__( diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index d920eaa86..9ae93dffc 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -141,7 +141,7 @@ def test_deep_base_model_raw_predict_call(dataloader, deep_base_model_mock): def test_deep_base_model_forecast_inverse_transform_call_check(deep_base_model_mock): ts = MagicMock() horizon = 7 - DeepBaseModel._forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) + DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) ts.tsdataset_idx_slice.return_value.inverse_transform.assert_called_once() @@ -156,7 +156,7 @@ def test_deep_base_model_forecast_loop(simple_df, deep_base_model_mock): ts_after_tsdataset_idx_slice.df = simple_df.df.iloc[-horizon:] ts.tsdataset_idx_slice.return_value = ts_after_tsdataset_idx_slice - future = DeepBaseModel._forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) + future = DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon) np.testing.assert_allclose( future.df.loc[:, pd.IndexSlice["A", "target"]], raw_predict[("A", "target")][:horizon, 0] ) From 42420f3349eeb13decedfca544328fdc27bac0f8 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Wed, 31 Aug 2022 13:27:33 +0300 Subject: [PATCH 08/15] Update inheritance of mixins --- etna/models/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/etna/models/base.py b/etna/models/base.py index c0d727a6d..f64b09227 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -258,7 +258,7 @@ def forecast( class ModelForecastMixin(ABC): - """Base for mixins.""" + """Base class for model mixins.""" @abstractmethod def _forecast(self, **kwargs) -> TSDataset: @@ -365,7 +365,7 @@ def forecast( ) -class PerSegmentModelMixin: +class PerSegmentModelMixin(ModelForecastMixin): """Mixin for holding methods for per-segment prediction.""" def __init__(self, base_model: Any): @@ -485,7 +485,7 @@ def _forecast(self, ts: TSDataset, **kwargs) -> TSDataset: return ts -class MultiSegmentModelMixin: +class MultiSegmentModelMixin(ModelForecastMixin): """Mixin for holding methods for multi-segment prediction.""" def __init__(self, base_model: Any): From 755d6a8126f581f6ad8b2a782183e459cbab5d24 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Wed, 31 Aug 2022 19:14:59 +0300 Subject: [PATCH 09/15] Fix bug with rnn --- etna/models/nn/rnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/models/nn/rnn.py b/etna/models/nn/rnn.py index 61c322606..3b72324f7 100644 --- a/etna/models/nn/rnn.py +++ b/etna/models/nn/rnn.py @@ -2,10 +2,10 @@ from typing import Dict from typing import Iterator from typing import Optional -from typing import TypedDict import numpy as np import pandas as pd +from typing_extensions import TypedDict from etna import SETTINGS From d22d7ff530e6b174b07c226f1083bf044ba5c8e7 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 1 Sep 2022 11:49:38 +0300 Subject: [PATCH 10/15] Fix comments on PR --- etna/models/__init__.py | 5 ++- etna/models/base.py | 12 +++++- etna/models/nn/tft.py | 2 +- etna/pipeline/assembling_pipelines.py | 4 +- etna/pipeline/autoregressive_pipeline.py | 50 +++++++++++------------- etna/pipeline/pipeline.py | 16 +++++--- 6 files changed, 50 insertions(+), 39 deletions(-) diff --git a/etna/models/__init__.py b/etna/models/__init__.py index 2b2dcb9b4..0e3198cb5 100644 --- a/etna/models/__init__.py +++ b/etna/models/__init__.py @@ -1,8 +1,11 @@ from etna import SETTINGS from etna.models.autoarima import AutoARIMAModel +from etna.models.base import AbstractModel from etna.models.base import BaseAdapter -from etna.models.base import BaseModel +from etna.models.base import ContextIgnorantModelType +from etna.models.base import ContextRequiredModelType from etna.models.base import Model +from etna.models.base import ModelType from etna.models.base import PerSegmentModelMixin from etna.models.catboost import CatBoostModelMultiSegment from etna.models.catboost import CatBoostModelPerSegment diff --git a/etna/models/base.py b/etna/models/base.py index f64b09227..722ae9b0e 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -916,9 +916,19 @@ def get_model(self) -> "DeepBaseNet": return self.net -BaseModel = Union[ +ModelType = Union[ NonPredictionIntervalContextIgnorantAbstractModel, NonPredictionIntervalContextRequiredAbstractModel, PredictionIntervalContextIgnorantAbstractModel, PredictionIntervalContextRequiredAbstractModel, ] + +ContextRequiredModelType = Union[ + NonPredictionIntervalContextRequiredAbstractModel, + PredictionIntervalContextRequiredAbstractModel, +] + +ContextIgnorantModelType = Union[ + NonPredictionIntervalContextIgnorantAbstractModel, + PredictionIntervalContextIgnorantAbstractModel, +] diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index da45494b3..bcfeff750 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -90,7 +90,7 @@ def __init__( quantiles_kwargs: Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss. """ - # super().__init__() + super().__init__() if loss is None: loss = QuantileLoss() self.max_epochs = max_epochs diff --git a/etna/pipeline/assembling_pipelines.py b/etna/pipeline/assembling_pipelines.py index 2a7250787..ce93dd1e0 100644 --- a/etna/pipeline/assembling_pipelines.py +++ b/etna/pipeline/assembling_pipelines.py @@ -5,13 +5,13 @@ from typing import Sequence from typing import Union -from etna.models.base import BaseModel +from etna.models.base import ModelType from etna.pipeline.pipeline import Pipeline from etna.transforms import Transform def assemble_pipelines( - models: Union[BaseModel, Sequence[BaseModel]], + models: Union[ModelType, Sequence[ModelType]], transforms: Sequence[Union[Transform, Sequence[Optional[Transform]]]], horizons: Union[int, Sequence[int]], ) -> List[Pipeline]: diff --git a/etna/pipeline/autoregressive_pipeline.py b/etna/pipeline/autoregressive_pipeline.py index c3ad582d9..8debe689f 100644 --- a/etna/pipeline/autoregressive_pipeline.py +++ b/etna/pipeline/autoregressive_pipeline.py @@ -1,12 +1,14 @@ import warnings from typing import Sequence +from typing import cast import pandas as pd +from typing_extensions import get_args from etna.datasets import TSDataset -from etna.models.base import BaseModel -from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel -from etna.models.base import PredictionIntervalContextRequiredAbstractModel +from etna.models.base import ContextIgnorantModelType +from etna.models.base import ContextRequiredModelType +from etna.models.base import ModelType from etna.pipeline.base import BasePipeline from etna.transforms import Transform @@ -51,7 +53,7 @@ class AutoRegressivePipeline(BasePipeline): 2020-04-16 8.00 6.00 2.00 0.00 """ - def __init__(self, model: BaseModel, horizon: int, transforms: Sequence[Transform] = (), step: int = 1): + def __init__(self, model: ModelType, horizon: int, transforms: Sequence[Transform] = (), step: int = 1): """ Create instance of AutoRegressivePipeline with given parameters. @@ -124,34 +126,26 @@ def _forecast(self) -> TSDataset: # manually set transforms in current_ts, otherwise make_future won't know about them current_ts.transforms = self.transforms - if isinstance(self.model, NonPredictionIntervalContextRequiredAbstractModel) or isinstance( - self.model, PredictionIntervalContextRequiredAbstractModel - ): - with warnings.catch_warnings(): - warnings.filterwarnings( - message="TSDataset freq can't be inferred", - action="ignore", - ) - warnings.filterwarnings( - message="You probably set wrong freq.", - action="ignore", - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + message="TSDataset freq can't be inferred", + action="ignore", + ) + warnings.filterwarnings( + message="You probably set wrong freq.", + action="ignore", + ) + + if isinstance(self.model, get_args(ContextRequiredModelType)): + self.model = cast(ContextRequiredModelType, self.model) current_ts_forecast = current_ts.make_future( future_steps=current_step, tail_steps=self.model.context_size ) - current_ts_future = self.model.forecast(current_ts_forecast, prediction_size=current_step) - else: - with warnings.catch_warnings(): - warnings.filterwarnings( - message="TSDataset freq can't be inferred", - action="ignore", - ) - warnings.filterwarnings( - message="You probably set wrong freq.", - action="ignore", - ) + current_ts_future = self.model.forecast(current_ts_forecast, prediction_size=current_step) + else: + self.model = cast(ContextIgnorantModelType, self.model) current_ts_forecast = current_ts.make_future(future_steps=current_step) - current_ts_future = self.model.forecast(current_ts_forecast) + current_ts_future = self.model.forecast(current_ts_forecast) prediction_df = prediction_df.combine_first(current_ts_future.to_pandas()[prediction_df.columns]) diff --git a/etna/pipeline/pipeline.py b/etna/pipeline/pipeline.py index ed4955407..adb09f972 100644 --- a/etna/pipeline/pipeline.py +++ b/etna/pipeline/pipeline.py @@ -1,9 +1,13 @@ from typing import Sequence +from typing import cast + +from typing_extensions import get_args from etna.datasets import TSDataset -from etna.models.base import BaseModel +from etna.models.base import ContextIgnorantModelType +from etna.models.base import ContextRequiredModelType from etna.models.base import DeepBaseModel -from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel +from etna.models.base import ModelType from etna.models.base import PredictionIntervalContextIgnorantAbstractModel from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.pipeline.base import BasePipeline @@ -13,7 +17,7 @@ class Pipeline(BasePipeline): """Pipeline of transforms with a final estimator.""" - def __init__(self, model: BaseModel, transforms: Sequence[Transform] = (), horizon: int = 1): + def __init__(self, model: ModelType, transforms: Sequence[Transform] = (), horizon: int = 1): """ Create instance of Pipeline with given parameters. @@ -56,9 +60,8 @@ def _forecast(self) -> TSDataset: if self.ts is None: raise ValueError("Something went wrong, ts is None!") - if isinstance(self.model, NonPredictionIntervalContextRequiredAbstractModel) or isinstance( - self.model, PredictionIntervalContextRequiredAbstractModel - ): + if isinstance(self.model, get_args(ContextRequiredModelType)): + self.model = cast(ContextRequiredModelType, self.model) if isinstance(self.model, DeepBaseModel): future = self.ts.make_future( future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length @@ -67,6 +70,7 @@ def _forecast(self) -> TSDataset: future = self.ts.make_future(future_steps=self.horizon, tail_steps=self.model.context_size) predictions = self.model.forecast(ts=future, prediction_size=self.horizon) else: + self.model = cast(ContextIgnorantModelType, self.model) future = self.ts.make_future(self.horizon) predictions = self.model.forecast(ts=future) return predictions From 2985c6b3294b0bc41b8a1a7773ca78af137a3496 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 1 Sep 2022 16:28:28 +0300 Subject: [PATCH 11/15] Add tests on Pipeline, AutoregressivePipeline --- etna/pipeline/autoregressive_pipeline.py | 16 ++- etna/pipeline/pipeline.py | 4 +- .../test_autoregressive_pipeline.py | 103 ++++++++++++++++++ tests/test_pipeline/test_pipeline.py | 103 ++++++++++++++++++ 4 files changed, 219 insertions(+), 7 deletions(-) diff --git a/etna/pipeline/autoregressive_pipeline.py b/etna/pipeline/autoregressive_pipeline.py index 8debe689f..64c96d97c 100644 --- a/etna/pipeline/autoregressive_pipeline.py +++ b/etna/pipeline/autoregressive_pipeline.py @@ -8,6 +8,7 @@ from etna.datasets import TSDataset from etna.models.base import ContextIgnorantModelType from etna.models.base import ContextRequiredModelType +from etna.models.base import DeepBaseModel from etna.models.base import ModelType from etna.pipeline.base import BasePipeline from etna.transforms import Transform @@ -138,14 +139,19 @@ def _forecast(self) -> TSDataset: if isinstance(self.model, get_args(ContextRequiredModelType)): self.model = cast(ContextRequiredModelType, self.model) - current_ts_forecast = current_ts.make_future( - future_steps=current_step, tail_steps=self.model.context_size - ) - current_ts_future = self.model.forecast(current_ts_forecast, prediction_size=current_step) + if isinstance(self.model, DeepBaseModel): + current_ts_forecast = current_ts.make_future( + future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length + ) + else: + current_ts_forecast = current_ts.make_future( + future_steps=current_step, tail_steps=self.model.context_size + ) + current_ts_future = self.model.forecast(ts=current_ts_forecast, prediction_size=current_step) else: self.model = cast(ContextIgnorantModelType, self.model) current_ts_forecast = current_ts.make_future(future_steps=current_step) - current_ts_future = self.model.forecast(current_ts_forecast) + current_ts_future = self.model.forecast(ts=current_ts_forecast) prediction_df = prediction_df.combine_first(current_ts_future.to_pandas()[prediction_df.columns]) diff --git a/etna/pipeline/pipeline.py b/etna/pipeline/pipeline.py index adb09f972..ec1e70ac4 100644 --- a/etna/pipeline/pipeline.py +++ b/etna/pipeline/pipeline.py @@ -71,7 +71,7 @@ def _forecast(self) -> TSDataset: predictions = self.model.forecast(ts=future, prediction_size=self.horizon) else: self.model = cast(ContextIgnorantModelType, self.model) - future = self.ts.make_future(self.horizon) + future = self.ts.make_future(future_steps=self.horizon) predictions = self.model.forecast(ts=future) return predictions @@ -103,7 +103,7 @@ def forecast( self._validate_backtest_n_folds(n_folds=n_folds) if prediction_interval and isinstance(self.model, PredictionIntervalContextIgnorantAbstractModel): - future = self.ts.make_future(self.horizon) + future = self.ts.make_future(future_steps=self.horizon) predictions = self.model.forecast(ts=future, prediction_interval=prediction_interval, quantiles=quantiles) elif prediction_interval and isinstance(self.model, PredictionIntervalContextRequiredAbstractModel): future = self.ts.make_future(future_steps=self.horizon, tail_steps=self.model.context_size) diff --git a/tests/test_pipeline/test_autoregressive_pipeline.py b/tests/test_pipeline/test_autoregressive_pipeline.py index 980e996ec..4e21d0ac3 100644 --- a/tests/test_pipeline/test_autoregressive_pipeline.py +++ b/tests/test_pipeline/test_autoregressive_pipeline.py @@ -1,4 +1,8 @@ from copy import deepcopy +from typing import Optional +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import patch import numpy as np import pandas as pd @@ -10,6 +14,11 @@ from etna.models import CatBoostPerSegmentModel from etna.models import LinearPerSegmentModel from etna.models import NaiveModel +from etna.models.base import DeepBaseModel +from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel +from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel +from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.pipeline import AutoRegressivePipeline from etna.transforms import DateFlagsTransform from etna.transforms import LagTransform @@ -26,6 +35,100 @@ def test_fit(example_tsds): pipeline.fit(example_tsds) +""" +TODO: + +Что надо потестить: +1. При вызове forecast логика различается для модели с контекстом и модели без контекста + * Проверяем модель без контекста + * Проверяем модель с контекстом + * Проверяем сетки +""" + + +def fake_forecast(ts: TSDataset, prediction_size: Optional[int] = None): + df = ts.to_pandas() + + df.loc[:, pd.IndexSlice[:, "target"]] = 0 + if prediction_size is not None: + df = df.iloc[-prediction_size:] + + ts.df = df + + return TSDataset(df=df, freq=ts.freq) + + +def spy_decorator(method_to_decorate): + mock = MagicMock() + + def wrapper(self, *args, **kwargs): + mock(*args, **kwargs) + return method_to_decorate(self, *args, **kwargs) + + wrapper.mock = mock + return wrapper + + +@pytest.mark.parametrize( + "model_class", [NonPredictionIntervalContextIgnorantAbstractModel, PredictionIntervalContextIgnorantAbstractModel] +) +def test_private_forecast_context_ignorant_model(model_class, example_tsds): + # we should do it this way because we want not to change behavior but have ability to introspect calls + make_future = spy_decorator(TSDataset.make_future) + model = MagicMock(spec=model_class) + model.forecast.side_effect = fake_forecast + + with patch.object(TSDataset, "make_future", make_future): + pipeline = AutoRegressivePipeline(model=model, horizon=5, step=1) + pipeline.fit(example_tsds) + _ = pipeline._forecast() + + assert make_future.mock.call_count == 5 + make_future.mock.assert_called_with(future_steps=pipeline.step) + assert model.forecast.call_count == 5 + model.forecast.assert_called_with(ts=ANY) + + +@pytest.mark.parametrize( + "model_class", [NonPredictionIntervalContextRequiredAbstractModel, PredictionIntervalContextRequiredAbstractModel] +) +def test_private_forecast_context_required_model(model_class, example_tsds): + # we should do it this way because we want not to change behavior but have ability to introspect calls + make_future = spy_decorator(TSDataset.make_future) + model = MagicMock(spec=model_class) + model.context_size = 1 + model.forecast.side_effect = fake_forecast + + with patch.object(TSDataset, "make_future", make_future): + pipeline = AutoRegressivePipeline(model=model, horizon=5, step=1) + pipeline.fit(example_tsds) + _ = pipeline._forecast() + + assert make_future.mock.call_count == 5 + make_future.mock.assert_called_with(future_steps=pipeline.step, tail_steps=model.context_size) + assert model.forecast.call_count == 5 + model.forecast.assert_called_with(ts=ANY, prediction_size=pipeline.step) + + +def test_private_forecast_deep_base_model(example_tsds): + # we should do it this way because we want not to change behavior but have ability to introspect calls + make_future = spy_decorator(TSDataset.make_future) + model = MagicMock(spec=DeepBaseModel) + model.encoder_length = 1 + model.decoder_length = 1 + model.forecast.side_effect = fake_forecast + + with patch.object(TSDataset, "make_future", make_future): + pipeline = AutoRegressivePipeline(model=model, horizon=5) + pipeline.fit(example_tsds) + _ = pipeline._forecast() + + assert make_future.mock.call_count == 5 + make_future.mock.assert_called_with(future_steps=model.decoder_length, tail_steps=model.encoder_length) + assert model.forecast.call_count == 5 + model.forecast.assert_called_with(ts=ANY, prediction_size=pipeline.step) + + def test_forecast_columns(example_reg_tsds): """Test that AutoRegressivePipeline generates all the columns.""" original_ts = deepcopy(example_reg_tsds) diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index d81bab272..2415ec1a3 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -2,6 +2,8 @@ from datetime import datetime from typing import Dict from typing import List +from unittest.mock import MagicMock +from unittest.mock import patch import numpy as np import pandas as pd @@ -20,6 +22,11 @@ from etna.models import NaiveModel from etna.models import ProphetModel from etna.models import SARIMAXModel +from etna.models.base import DeepBaseModel +from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel +from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel +from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.pipeline import FoldMask from etna.pipeline import Pipeline from etna.transforms import AddConstTransform @@ -76,6 +83,102 @@ def test_fit(example_tsds): assert np.all(original_ts.df.values == pipeline.ts.df.values) +@patch("etna.pipeline.pipeline.Pipeline._forecast") +def test_forecast_without_intervals_calls_private_forecast(private_forecast, example_tsds): + model = LinearPerSegmentModel() + transforms = [AddConstTransform(in_column="target", value=10, inplace=True), DateFlagsTransform()] + pipeline = Pipeline(model=model, transforms=transforms, horizon=5) + pipeline.fit(example_tsds) + _ = pipeline.forecast() + + private_forecast.assert_called() + + +@pytest.mark.parametrize( + "model_class", [NonPredictionIntervalContextIgnorantAbstractModel, PredictionIntervalContextIgnorantAbstractModel] +) +def test_private_forecast_context_ignorant_model(model_class): + ts = MagicMock(spec=TSDataset) + model = MagicMock(spec=model_class) + + pipeline = Pipeline(model=model, horizon=5) + pipeline.fit(ts) + _ = pipeline._forecast() + + ts.make_future.assert_called_with(future_steps=pipeline.horizon) + model.forecast.assert_called_with(ts=ts.make_future()) + + +@pytest.mark.parametrize( + "model_class", [NonPredictionIntervalContextRequiredAbstractModel, PredictionIntervalContextRequiredAbstractModel] +) +def test_private_forecast_context_required_model(model_class): + ts = MagicMock(spec=TSDataset) + model = MagicMock(spec=model_class) + + pipeline = Pipeline(model=model, horizon=5) + pipeline.fit(ts) + _ = pipeline._forecast() + + ts.make_future.assert_called_with(future_steps=pipeline.horizon, tail_steps=model.context_size) + model.forecast.assert_called_with(ts=ts.make_future(), prediction_size=pipeline.horizon) + + +def test_private_forecast_deep_base_model(): + ts = MagicMock(spec=TSDataset) + model = MagicMock(spec=DeepBaseModel) + model.encoder_length = MagicMock() + model.decoder_length = MagicMock() + + pipeline = Pipeline(model=model, horizon=5) + pipeline.fit(ts) + _ = pipeline._forecast() + + ts.make_future.assert_called_with(future_steps=model.decoder_length, tail_steps=model.encoder_length) + model.forecast.assert_called_with(ts=ts.make_future(), prediction_size=pipeline.horizon) + + +def test_forecast_with_intervals_prediction_interval_context_ignorant_model(): + ts = MagicMock(spec=TSDataset) + model = MagicMock(spec=PredictionIntervalContextIgnorantAbstractModel) + + pipeline = Pipeline(model=model, horizon=5) + pipeline.fit(ts) + _ = pipeline.forecast(prediction_interval=True, quantiles=(0.025, 0.975)) + + ts.make_future.assert_called_with(future_steps=pipeline.horizon) + model.forecast.assert_called_with(ts=ts.make_future(), prediction_interval=True, quantiles=(0.025, 0.975)) + + +def test_forecast_with_intervals_prediction_interval_context_required_model(): + ts = MagicMock(spec=TSDataset) + model = MagicMock(spec=PredictionIntervalContextRequiredAbstractModel) + + pipeline = Pipeline(model=model, horizon=5) + pipeline.fit(ts) + _ = pipeline.forecast(prediction_interval=True, quantiles=(0.025, 0.975)) + + ts.make_future.assert_called_with(future_steps=pipeline.horizon, tail_steps=model.context_size) + model.forecast.assert_called_with( + ts=ts.make_future(), prediction_size=pipeline.horizon, prediction_interval=True, quantiles=(0.025, 0.975) + ) + + +@patch("etna.pipeline.base.BasePipeline.forecast") +@pytest.mark.parametrize( + "model_class", + [NonPredictionIntervalContextIgnorantAbstractModel, NonPredictionIntervalContextRequiredAbstractModel], +) +def test_forecast_with_intervals_other_model(base_forecast, model_class): + ts = MagicMock(spec=TSDataset) + model = MagicMock(spec=model_class) + + pipeline = Pipeline(model=model, horizon=5) + pipeline.fit(ts) + _ = pipeline.forecast(prediction_interval=True, quantiles=(0.025, 0.975)) + base_forecast.assert_called_with(prediction_interval=True, quantiles=(0.025, 0.975), n_folds=3) + + def test_forecast(example_tsds): """Test that the forecast from the Pipeline is correct.""" original_ts = deepcopy(example_tsds) From 3a67784af1b23e92c52e970b3cef284c1749e8f5 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 1 Sep 2022 16:33:02 +0300 Subject: [PATCH 12/15] Remove todos, add link to solution with patching --- .../test_autoregressive_pipeline.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/test_pipeline/test_autoregressive_pipeline.py b/tests/test_pipeline/test_autoregressive_pipeline.py index 4e21d0ac3..456b9ba04 100644 --- a/tests/test_pipeline/test_autoregressive_pipeline.py +++ b/tests/test_pipeline/test_autoregressive_pipeline.py @@ -35,17 +35,6 @@ def test_fit(example_tsds): pipeline.fit(example_tsds) -""" -TODO: - -Что надо потестить: -1. При вызове forecast логика различается для модели с контекстом и модели без контекста - * Проверяем модель без контекста - * Проверяем модель с контекстом - * Проверяем сетки -""" - - def fake_forecast(ts: TSDataset, prediction_size: Optional[int] = None): df = ts.to_pandas() @@ -73,7 +62,8 @@ def wrapper(self, *args, **kwargs): "model_class", [NonPredictionIntervalContextIgnorantAbstractModel, PredictionIntervalContextIgnorantAbstractModel] ) def test_private_forecast_context_ignorant_model(model_class, example_tsds): - # we should do it this way because we want not to change behavior but have ability to introspect calls + # we should do it this way because we want not to change behavior but have ability to inspect calls + # source: https://stackoverflow.com/a/41599695 make_future = spy_decorator(TSDataset.make_future) model = MagicMock(spec=model_class) model.forecast.side_effect = fake_forecast @@ -93,7 +83,8 @@ def test_private_forecast_context_ignorant_model(model_class, example_tsds): "model_class", [NonPredictionIntervalContextRequiredAbstractModel, PredictionIntervalContextRequiredAbstractModel] ) def test_private_forecast_context_required_model(model_class, example_tsds): - # we should do it this way because we want not to change behavior but have ability to introspect calls + # we should do it this way because we want not to change behavior but have ability to inspect calls + # source: https://stackoverflow.com/a/41599695 make_future = spy_decorator(TSDataset.make_future) model = MagicMock(spec=model_class) model.context_size = 1 @@ -111,7 +102,8 @@ def test_private_forecast_context_required_model(model_class, example_tsds): def test_private_forecast_deep_base_model(example_tsds): - # we should do it this way because we want not to change behavior but have ability to introspect calls + # we should do it this way because we want not to change behavior but have ability to inspect calls + # source: https://stackoverflow.com/a/41599695 make_future = spy_decorator(TSDataset.make_future) model = MagicMock(spec=DeepBaseModel) model.encoder_length = 1 From acb2c561a8ca2ae93c5c44ced9313d83ca7b8d74 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Fri, 2 Sep 2022 12:37:05 +0300 Subject: [PATCH 13/15] Fix comments on PR --- etna/models/__init__.py | 7 ++++--- etna/models/base.py | 16 ++++++++++++++-- etna/models/deadline_ma.py | 2 +- etna/models/seasonal_ma.py | 2 +- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/etna/models/__init__.py b/etna/models/__init__.py index 0e3198cb5..fe51d6fe1 100644 --- a/etna/models/__init__.py +++ b/etna/models/__init__.py @@ -1,12 +1,13 @@ from etna import SETTINGS from etna.models.autoarima import AutoARIMAModel -from etna.models.base import AbstractModel from etna.models.base import BaseAdapter from etna.models.base import ContextIgnorantModelType from etna.models.base import ContextRequiredModelType -from etna.models.base import Model from etna.models.base import ModelType -from etna.models.base import PerSegmentModelMixin +from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel +from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel +from etna.models.base import PredictionIntervalContextIgnorantAbstractModel +from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.models.catboost import CatBoostModelMultiSegment from etna.models.catboost import CatBoostModelPerSegment from etna.models.catboost import CatBoostMultiSegmentModel diff --git a/etna/models/base.py b/etna/models/base.py index 722ae9b0e..f776c20d9 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -154,7 +154,13 @@ def get_model(self) -> Union[Any, Dict[str, Any]]: class NonPredictionIntervalContextIgnorantAbstractModel(AbstractModel): """Interface for models that don't support prediction intervals and don't need context for prediction.""" - context_size: int = 0 + @property + def context_size(self) -> int: + """Context size of the model. Determines how many history points do we ask to pass to the model. + + Zero for this model. + """ + return 0 @abstractmethod def forecast(self, ts: TSDataset) -> TSDataset: @@ -199,7 +205,13 @@ def forecast(self, ts: TSDataset, prediction_size: int) -> TSDataset: class PredictionIntervalContextIgnorantAbstractModel(AbstractModel): """Interface for models that support prediction intervals and don't need context for prediction.""" - context_size: int = 0 + @property + def context_size(self) -> int: + """Context size of the model. Determines how many history points do we ask to pass to the model. + + Zero for this model. + """ + return 0 @abstractmethod def forecast( diff --git a/etna/models/deadline_ma.py b/etna/models/deadline_ma.py index fe29f174d..c2610bbd7 100644 --- a/etna/models/deadline_ma.py +++ b/etna/models/deadline_ma.py @@ -182,7 +182,7 @@ def __init__(self, window: int = 3, seasonality: str = "month"): ) @property - def context_size(self) -> int: # type: ignore + def context_size(self) -> int: """Upper bound to context size of the model.""" models = self.get_model() model = next(iter(models.values())) diff --git a/etna/models/seasonal_ma.py b/etna/models/seasonal_ma.py index 4bea401d4..6a989062f 100644 --- a/etna/models/seasonal_ma.py +++ b/etna/models/seasonal_ma.py @@ -128,7 +128,7 @@ def __init__(self, window: int = 5, seasonality: int = 7): ) @property - def context_size(self) -> int: # type: ignore + def context_size(self) -> int: """Context size of the model.""" return self.window * self.seasonality From 0fe28c8320fd5123bd37712082ec3fab7e1ac834 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Fri, 2 Sep 2022 13:00:33 +0300 Subject: [PATCH 14/15] Update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bae302dec..e2226c08d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - - -- +- Changed hierarchy of base models, enable passing context into models ([#888](https://github.com/tinkoff-ai/etna/pull/888)) - - - Teach AutoARIMAModel to work with out-sample predictions ([#830](https://github.com/tinkoff-ai/etna/pull/830)) From bf3005783c1bbc977812462d9468d0006a675092 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Fri, 2 Sep 2022 14:46:01 +0300 Subject: [PATCH 15/15] Remove nested ifs in pipelines --- etna/pipeline/autoregressive_pipeline.py | 18 +++++++++--------- etna/pipeline/pipeline.py | 12 +++++------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/etna/pipeline/autoregressive_pipeline.py b/etna/pipeline/autoregressive_pipeline.py index 64c96d97c..63061c171 100644 --- a/etna/pipeline/autoregressive_pipeline.py +++ b/etna/pipeline/autoregressive_pipeline.py @@ -137,16 +137,16 @@ def _forecast(self) -> TSDataset: action="ignore", ) - if isinstance(self.model, get_args(ContextRequiredModelType)): + if isinstance(self.model, DeepBaseModel): + current_ts_forecast = current_ts.make_future( + future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length + ) + current_ts_future = self.model.forecast(ts=current_ts_forecast, prediction_size=current_step) + elif isinstance(self.model, get_args(ContextRequiredModelType)): self.model = cast(ContextRequiredModelType, self.model) - if isinstance(self.model, DeepBaseModel): - current_ts_forecast = current_ts.make_future( - future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length - ) - else: - current_ts_forecast = current_ts.make_future( - future_steps=current_step, tail_steps=self.model.context_size - ) + current_ts_forecast = current_ts.make_future( + future_steps=current_step, tail_steps=self.model.context_size + ) current_ts_future = self.model.forecast(ts=current_ts_forecast, prediction_size=current_step) else: self.model = cast(ContextIgnorantModelType, self.model) diff --git a/etna/pipeline/pipeline.py b/etna/pipeline/pipeline.py index ec1e70ac4..aa06b9944 100644 --- a/etna/pipeline/pipeline.py +++ b/etna/pipeline/pipeline.py @@ -60,14 +60,12 @@ def _forecast(self) -> TSDataset: if self.ts is None: raise ValueError("Something went wrong, ts is None!") - if isinstance(self.model, get_args(ContextRequiredModelType)): + if isinstance(self.model, DeepBaseModel): + future = self.ts.make_future(future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length) + predictions = self.model.forecast(ts=future, prediction_size=self.horizon) + elif isinstance(self.model, get_args(ContextRequiredModelType)): self.model = cast(ContextRequiredModelType, self.model) - if isinstance(self.model, DeepBaseModel): - future = self.ts.make_future( - future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length - ) - else: - future = self.ts.make_future(future_steps=self.horizon, tail_steps=self.model.context_size) + future = self.ts.make_future(future_steps=self.horizon, tail_steps=self.model.context_size) predictions = self.model.forecast(ts=future, prediction_size=self.horizon) else: self.model = cast(ContextIgnorantModelType, self.model)