Skip to content

Enabling passing context into Model.forecast v2 #888

Merged
merged 17 commits into from
Sep 2, 2022
7 changes: 3 additions & 4 deletions etna/models/autoarima.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from statsmodels.tools.sm_exceptions import ValueWarning
from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper

from etna.models.base import PerSegmentPredictionIntervalModel
from etna.models.base import PerSegmentModel
from etna.models.base import PredictionIntervalAbstractModel
from etna.models.sarimax import _SARIMAXBaseAdapter

warnings.filterwarnings(
Expand Down Expand Up @@ -48,7 +49,7 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResul
return model.arima_res_


class AutoARIMAModel(PerSegmentPredictionIntervalModel):
class AutoARIMAModel(PerSegmentModel, PredictionIntervalAbstractModel):
"""
Class for holding auto arima model.

Expand All @@ -57,8 +58,6 @@ class AutoARIMAModel(PerSegmentPredictionIntervalModel):
We use :py:class:`pmdarima.arima.arima.ARIMA`.
"""

context_size = 0

def __init__(
self,
**kwargs,
Expand Down
180 changes: 67 additions & 113 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,6 @@ def _forecast_segment(model, segment: Union[str, List[str]], ts: TSDataset) -> p
class FitAbstractModel(ABC):
"""Interface for model with fit method."""

@property
@abstractmethod
def context_size(self) -> int:
"""Upper bound to context size of the model."""
pass

@abstractmethod
def fit(self, ts: TSDataset) -> "FitAbstractModel":
"""Fit model.
Expand Down Expand Up @@ -151,43 +145,68 @@ def get_model(self) -> Union[Any, Dict[str, Any]]:
pass

Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved

alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
class PredictionIntervalAbstractModel(ABC):
"""Interface that is used to mark classes that support prediction intervals."""

pass


class ContextRequiredAbstractModel(ABC):
"""Interface that is used to mark classes that need context for prediction."""

@property
@abstractmethod
def context_size(self) -> int:
"""Context size of the model. Determines how many history points do we ask to pass to the model."""
pass


class ForecastAbstractModel(ABC):
"""Interface for model with forecast method."""

@abstractmethod
def forecast(self, ts: TSDataset) -> TSDataset:
"""Make predictions.
def _extract_prediction_interval_params(self, **kwargs) -> Dict[str, Any]:
extracted_params = {}
class_name = self.__class__.__name__

Parameters
----------
ts:
Dataset with features
if isinstance(self, PredictionIntervalAbstractModel):
prediction_interval = kwargs.get("prediction_interval", False)
extracted_params["prediction_interval"] = prediction_interval

Returns
-------
:
Dataset with predictions
"""
pass
quantiles = kwargs.get("quantiles", (0.025, 0.975))
extracted_params["quantiles"] = quantiles
else:
if "prediction_interval" in kwargs or "quantiles" in kwargs:
raise NotImplementedError(f"Model {class_name} doesn't support prediction intervals!")

return extracted_params

def _extract_prediction_size_params(self, **kwargs):
extracted_params = {}
class_name = self.__class__.__name__

class PredictIntervalAbstractModel(ABC):
"""Interface for model with forecast method that creates prediction interval."""
if isinstance(self, ContextRequiredAbstractModel):
prediction_size = kwargs.get("prediction_size")
if prediction_size is None:
raise ValueError(f"Parameter prediction_size is required for {class_name} model!")

if not isinstance(prediction_size, int) or prediction_size <= 0:
raise ValueError(f"Parameter prediction_size should be positive integer!")

extracted_params = prediction_size
else:
if "prediction_size" in kwargs:
raise NotImplementedError(f"Model {class_name} doesn't support prediction_size parameter!")

return extracted_params

@abstractmethod
def forecast(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
) -> TSDataset:
def forecast(self, ts: TSDataset, **kwargs) -> TSDataset:
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
"""Make predictions.

Parameters
----------
ts:
Dataset with features
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval

Returns
-------
Expand Down Expand Up @@ -301,74 +320,27 @@ def __init__(self, base_model: Any):
super().__init__(base_model=base_model)

@log_decorator
def forecast(self, ts: TSDataset) -> TSDataset:
def forecast(self, ts: TSDataset, **kwargs) -> TSDataset:
"""Make predictions.

Parameters
----------
ts:
Dataframe with features
Returns
-------
:
Dataset with predictions
"""
result_list = list()
for segment, model in self._get_model().items():
segment_predict = self._forecast_segment(model=model, segment=segment, ts=ts)
result_list.append(segment_predict)

result_df = pd.concat(result_list, ignore_index=True)
result_df = result_df.set_index(["timestamp", "segment"])
df = ts.to_pandas(flatten=True)
df = df.set_index(["timestamp", "segment"])
df = df.combine_first(result_df).reset_index()

df = TSDataset.to_dataset(df)
ts.df = df
ts.inverse_transform()
return ts


class PerSegmentPredictionIntervalModel(PerSegmentBaseModel, PredictIntervalAbstractModel):
"""Class for holding specific models for per-segment prediction which are able to build prediction intervals."""

def __init__(self, base_model: Any):
"""
Init PerSegmentPredictionIntervalModel.

Parameters
----------
base_model:
Internal model which will be used to forecast segments, expected to have fit/predict interface
"""
super().__init__(base_model=base_model)

@log_decorator
def forecast(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
) -> TSDataset:
"""Make predictions.

Parameters
----------
ts:
Dataset with features
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval

Returns
-------
:
Dataset with predictions
"""
forecast_params = {}
forecast_params.update(self._extract_prediction_interval_params(**kwargs))
forecast_params.update(self._extract_prediction_size_params(**kwargs))

result_list = list()
for segment, model in self._get_model().items():
segment_predict = self._forecast_segment(
model=model, segment=segment, ts=ts, prediction_interval=prediction_interval, quantiles=quantiles
)
segment_predict = self._forecast_segment(model=model, segment=segment, ts=ts, **forecast_params)

result_list.append(segment_predict)

result_df = pd.concat(result_list, ignore_index=True)
Expand Down Expand Up @@ -418,7 +390,7 @@ def fit(self, ts: TSDataset) -> "MultiSegmentModel":
return self

@log_decorator
def forecast(self, ts: TSDataset) -> TSDataset:
def forecast(self, ts: TSDataset, **kwargs) -> TSDataset:
"""Make predictions.

Parameters
Expand All @@ -431,9 +403,13 @@ def forecast(self, ts: TSDataset) -> TSDataset:
:
Dataset with predictions
"""
forecast_params = {}
forecast_params.update(self._extract_prediction_interval_params(**kwargs))
forecast_params.update(self._extract_prediction_size_params(**kwargs))

horizon = len(ts.df)
x = ts.to_pandas(flatten=True).drop(["segment"], axis=1)
y = self._base_model.predict(x).reshape(-1, horizon).T
y = self._base_model.predict(x, **forecast_params).reshape(-1, horizon).T
ts.loc[:, pd.IndexSlice[:, "target"]] = y
ts.inverse_transform()
return ts
Expand Down Expand Up @@ -619,7 +595,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore
return loss


class DeepBaseModel(FitAbstractModel, DeepBaseAbstractModel, BaseMixin):
class DeepBaseModel(FitAbstractModel, DeepBaseAbstractModel, ContextRequiredAbstractModel, BaseMixin):
"""Class for partially implemented interfaces for holding deep models."""

def __init__(
Expand Down Expand Up @@ -790,15 +766,16 @@ def raw_predict(self, torch_dataset: "Dataset") -> Dict[Tuple[str, str], np.ndar
return predictions_dict

@log_decorator
def forecast(self, ts: "TSDataset", horizon: int) -> "TSDataset":
def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset":
"""Make predictions.

Parameters
----------
ts:
Dataset with features and expected decoder length for context
horizon:
Horizon to predict for
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context.

Returns
-------
Expand All @@ -812,9 +789,9 @@ def forecast(self, ts: "TSDataset", horizon: int) -> "TSDataset":
dropna=False,
)
predictions = self.raw_predict(test_dataset)
future_ts = ts.tsdataset_idx_slice(start_idx=self.encoder_length, end_idx=self.encoder_length + horizon)
future_ts = ts.tsdataset_idx_slice(start_idx=self.encoder_length, end_idx=self.encoder_length + prediction_size)
for (segment, feature_nm), value in predictions.items():
future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]] = value[:horizon, :]
future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]] = value[:prediction_size, :]

future_ts.inverse_transform()

Expand All @@ -831,31 +808,8 @@ def get_model(self) -> "DeepBaseNet":
return self.net


class MultiSegmentPredictionIntervalModel(FitAbstractModel, PredictIntervalAbstractModel, BaseMixin):
"""Class for holding specific models for multi-segment prediction which are able to build prediction intervals."""

def __init__(self):
"""Init MultiSegmentPredictionIntervalModel."""
self.model = None

def get_model(self) -> Any:
"""Get internal model that is used inside etna class.

Internal model is a model that is used inside etna to forecast segments,
e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`.

Returns
-------
:
Internal model
"""
return self.model


BaseModel = Union[
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
PerSegmentModel,
PerSegmentPredictionIntervalModel,
MultiSegmentModel,
DeepBaseModel,
MultiSegmentPredictionIntervalModel,
]
8 changes: 0 additions & 8 deletions etna/models/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ class CatBoostPerSegmentModel(PerSegmentModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
iterations: Optional[int] = None,
Expand Down Expand Up @@ -255,8 +253,6 @@ class CatBoostMultiSegmentModel(MultiSegmentModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
iterations: Optional[int] = None,
Expand Down Expand Up @@ -387,8 +383,6 @@ class CatBoostModelPerSegment(CatBoostPerSegmentModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
iterations: Optional[int] = None,
Expand Down Expand Up @@ -518,8 +512,6 @@ class CatBoostModelMultiSegment(CatBoostMultiSegmentModel):
2020-04-16 8.00 6.00 2.00 0.00
"""

context_size = 0

def __init__(
self,
iterations: Optional[int] = None,
Expand Down
6 changes: 0 additions & 6 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,6 @@ class HoltWintersModel(PerSegmentModel):
We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package.
"""

context_size = 0

def __init__(
self,
trend: Optional[str] = None,
Expand Down Expand Up @@ -484,8 +482,6 @@ class HoltModel(HoltWintersModel):
as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model.
"""

context_size = 0

def __init__(
self,
exponential: bool = False,
Expand Down Expand Up @@ -583,8 +579,6 @@ class SimpleExpSmoothingModel(HoltWintersModel):
as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model.
"""

context_size = 0

def __init__(
self,
initialization_method: str = "estimated",
Expand Down
Loading