Skip to content

Fix behavior of SARIMAXModel if simple_differencing=True is set #837

Merged
merged 23 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Fix behavior of SARIMAXModel if simple_differencing=True is set ([#837](https://github.com/tinkoff-ai/etna/pull/837))
-
-
-
Expand Down
1 change: 1 addition & 0 deletions etna/libs/pmdarima_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from etna.libs.pmdarima_utils.arima import seasonal_prediction_with_confidence
153 changes: 153 additions & 0 deletions etna/libs/pmdarima_utils/arima.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
MIT License

Copyright (c) 2017 Taylor G Smith

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
# Note: Copied from pmdarima package (https://github.com/blue-yonder/tsfresh/blob/https://github.com/alkaline-ml/pmdarima/blob/v1.8.5/pmdarima/arima/arima.py)

import numpy as np
import numpy.polynomial.polynomial as np_polynomial
from sklearn.utils.validation import check_array
from pmdarima.utils import diff
from pmdarima.utils import diff_inv
from pmdarima.utils import check_endog


def ARMAtoMA(ar, ma, max_deg):
r"""
Convert ARMA coefficients to infinite MA coefficients.
Compute coefficients of MA model equivalent to given ARMA model.
MA coefficients are cut off at max_deg.
The same function as ARMAtoMA() in stats library of R
Parameters
----------
ar : array-like, shape=(n_orders,)
The array of AR coefficients.
ma : array-like, shape=(n_orders,)
The array of MA coefficients.
max_deg : int
Coefficients are computed up to the order of max_deg.
Returns
-------
np.ndarray, shape=(max_deg,)
Equivalent MA coefficients.
Notes
-----
Here is the derivation. Suppose ARMA model is defined as
.. math::
x_t - ar_1*x_{t-1} - ar_2*x_{t-2} - ... - ar_p*x_{t-p}\\
= e_t + ma_1*e_{t-1} + ma_2*e_{t-2} + ... + ma_q*e_{t-q}
namely
.. math::
(1 - \sum_{i=1}^p[ar_i*B^i]) x_t = (1 + \sum_{i=1}^q[ma_i*B^i]) e_t
where :math:`B` is a backward operator.
Equivalent MA model is
.. math::
x_t = (1 - \sum_{i=1}^p[ar_i*B^i])^{-1}\\
* (1 + \sum_{i=1}^q[ma_i*B^i]) e_t\\
= (1 + \sum_{i=1}[ema_i*B^i]) e_t
where :math:``ema_i`` is a coefficient of equivalent MA model.
The :math:``ema_i`` satisfies
.. math::
(1 - \sum_{i=1}^p[ar_i*B^i]) * (1 + \sum_{i=1}[ema_i*B^i]) \\
= 1 + \sum_{i=1}^q[ma_i*B^i]
thus
.. math::
\sum_{i=1}[ema_i*B^i] = \sum_{i=1}^p[ar_i*B^i] \\
+ \sum_{i=1}^p[ar_i*B^i] * \sum_{j=1}[ema_j*B^j] \\
+ \Sum_{i=1}^q[ma_i*B^i]
therefore
.. math::
ema_i = ar_i (but 0 if i>p) \\
+ \Sum_{j=1}^{min(i-1,p)}[ar_j*ema_{i-j}] + ma_i(but 0 if i>q) \\
= \sum_{j=1}{min(i,p)}[ar_j*ema_{i-j}(but 1 if j=i)] \\
+ ma_i(but 0 if i>q)
"""
p = len(ar)
q = len(ma)
ema = np.empty(max_deg)
for i in range(0, max_deg):
temp = ma[i] if i < q else 0.0
for j in range(0, min(i + 1, p)):
temp += ar[j] * (ema[i - j - 1] if i - j - 1 >= 0 else 1.0)
ema[i] = temp
return ema


# Note: Originally copied from pmdarima package (https://github.com/blue-yonder/tsfresh/blob/https://github.com/alkaline-ml/pmdarima/blob/v1.8.5/pmdarima/arima/arima.py)
def seasonal_prediction_with_confidence(arima_res,
start,
end,
X,
alpha,
**kwargs):
"""Compute the prediction for a SARIMAX and get a conf interval

Unfortunately, SARIMAX does not really provide a nice way to get the
confidence intervals out of the box, so we have to perform the
``get_prediction`` code here and unpack the confidence intervals manually.
"""
results = arima_res.get_prediction(
start=start,
end=end,
exog=X,
**kwargs)

f = results.predicted_mean
conf_int = results.conf_int(alpha=alpha)
if arima_res.specification['simple_differencing']:
# If simple_differencing == True, statsmodels.get_prediction returns
# mid and confidence intervals on differenced time series.
# We have to invert differencing the mid and confidence intervals
y_org = arima_res.model.orig_endog
d = arima_res.model.orig_k_diff
D = arima_res.model.orig_k_seasonal_diff
period = arima_res.model.seasonal_periods
# Forecast mid: undifferencing non-seasonal part
if d > 0:
y_sdiff = y_org if D == 0 else diff(y_org, period, D)
f_temp = np.append(y_sdiff[-d:], f)
f_temp = diff_inv(f_temp, 1, d)
f = f_temp[(2 * d):]
# Forecast mid: undifferencing seasonal part
if D > 0 and period > 1:
f_temp = np.append(y_org[-(D * period):], f)
f_temp = diff_inv(f_temp, period, D)
f = f_temp[(2 * D * period):]
# confidence interval
ar_poly = arima_res.polynomial_reduced_ar
poly_diff = np_polynomial.polypow(np.array([1., -1.]), d)
sdiff = np.zeros(period + 1)
sdiff[0] = 1.
sdiff[-1] = 1.
poly_sdiff = np_polynomial.polypow(sdiff, D)
ar = -np.polymul(ar_poly, np.polymul(poly_diff, poly_sdiff))[1:]
ma = arima_res.polynomial_reduced_ma[1:]
n_predMinus1 = end - start
ema = ARMAtoMA(ar, ma, n_predMinus1)
sigma2 = arima_res._params_variance[0]
var = np.cumsum(np.append(1., ema * ema)) * sigma2
q = results.dist.ppf(1. - alpha / 2, *results.dist_args)
conf_int[:, 0] = f - q * np.sqrt(var)
conf_int[:, 1] = f + q * np.sqrt(var)

return check_endog(f, dtype=None, copy=False), \
check_array(conf_int, copy=False, dtype=None)
50 changes: 34 additions & 16 deletions etna/models/sarimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from statsmodels.tools.sm_exceptions import ValueWarning
from statsmodels.tsa.statespace.sarimax import SARIMAX

from etna.libs.pmdarima_utils import seasonal_prediction_with_confidence
from etna.models.base import BaseAdapter
from etna.models.base import PerSegmentPredictionIntervalModel
from etna.models.utils import determine_num_steps

warnings.filterwarnings(
message="No frequency information was provided, so inferred frequency .* will be used",
Expand Down Expand Up @@ -164,6 +166,8 @@ def __init__(
self._model: Optional[SARIMAX] = None
self._result: Optional[SARIMAX] = None
self.regressor_columns: Optional[List[str]] = None
self._freq = None
self._first_train_timestamp = None

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SARIMAXAdapter":
"""
Expand Down Expand Up @@ -193,8 +197,8 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SARIMAXAdapter":

self._check_df(df)

targets = df["target"]
targets.index = df["timestamp"]
# make it a numpy array for forgetting about indices, it is necessary for _seasonal_prediction_with_confidence
targets = df["target"].values

exog_train = self._select_regressors(df)

Expand All @@ -221,6 +225,14 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SARIMAXAdapter":
**self.kwargs,
)
self._result = self._model.fit()

freq = pd.infer_freq(df["timestamp"], warn=False)
if freq is None:
raise ValueError("Can't determine frequency of a given dataframe")
self._freq = freq

self._first_train_timestamp = df["timestamp"].min()

return self

def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame:
Expand Down Expand Up @@ -256,30 +268,36 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
)

exog_future = self._select_regressors(df)
start_timestamp = df["timestamp"].min()
end_timestamp = df["timestamp"].max()
start_idx = determine_num_steps(
start_timestamp=self._first_train_timestamp, end_timestamp=start_timestamp, freq=self._freq # type: ignore
)
end_idx = determine_num_steps(
start_timestamp=self._first_train_timestamp, end_timestamp=end_timestamp, freq=self._freq # type: ignore
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May we should add assert there, something like
end_idx-start_idx == len(df), "Check that total number of steps to forecast is equal to total ts lenght and so so on

It's not obvious how number of steps has become index

if prediction_interval:
forecast = self._result.get_prediction(
start=df["timestamp"].min(), end=df["timestamp"].max(), dynamic=False, exog=exog_future
forecast, _ = seasonal_prediction_with_confidence(
arima_res=self._result, start=start_idx, end=end_idx, X=exog_future, alpha=0.05
)
y_pred = forecast.predicted_mean
y_pred.name = "mean"
y_pred = pd.DataFrame(y_pred)
y_pred = pd.DataFrame({"mean": forecast})
for quantile in quantiles:
# set alpha in the way to get a desirable quantile
alpha = min(quantile * 2, (1 - quantile) * 2)
borders = forecast.conf_int(alpha=alpha)
_, borders = seasonal_prediction_with_confidence(
arima_res=self._result, start=start_idx, end=end_idx, X=exog_future, alpha=alpha
)
if quantile < 1 / 2:
series = borders["lower target"]
series = borders[:, 0]
else:
series = borders["upper target"]
series = borders[:, 1]
y_pred[f"mean_{quantile:.4g}"] = series
else:
forecast = self._result.get_prediction(
start=df["timestamp"].min(), end=df["timestamp"].max(), dynamic=False, exog=exog_future
forecast, _ = seasonal_prediction_with_confidence(
arima_res=self._result, start=start_idx, end=end_idx, X=exog_future, alpha=0.05
)
y_pred = forecast.predicted_mean
y_pred.name = "mean"
y_pred = pd.DataFrame(y_pred)
y_pred = y_pred.reset_index(drop=True, inplace=False)
y_pred = pd.DataFrame({"mean": forecast})

rename_dict = {
column: column.replace("mean", "target") for column in y_pred.columns if column.startswith("mean")
}
Expand Down
6 changes: 3 additions & 3 deletions etna/models/tbats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from etna.models.base import BaseAdapter
from etna.models.base import PerSegmentPredictionIntervalModel
from etna.models.utils import determine_num_steps_to_forecast
from etna.models.utils import determine_num_steps


class _TBATSAdapter(BaseAdapter):
Expand Down Expand Up @@ -43,8 +43,8 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Iterab
"In-sample predictions aren't supported by current implementation."
)

steps_to_forecast = determine_num_steps_to_forecast(
last_train_timestamp=self._last_train_timestamp, last_test_timestamp=df["timestamp"].max(), freq=self._freq
steps_to_forecast = determine_num_steps(
start_timestamp=self._last_train_timestamp, end_timestamp=df["timestamp"].max(), freq=self._freq
)
steps_to_skip = steps_to_forecast - df.shape[0]

Expand Down
47 changes: 23 additions & 24 deletions etna/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import pandas as pd


def determine_num_steps_to_forecast(
last_train_timestamp: pd.Timestamp, last_test_timestamp: pd.Timestamp, freq: str
) -> int:
"""Determine number of steps to make a forecast in future.

It is useful for out-sample forecast with gap if model predicts only on a certain number of steps
in autoregressive manner.
def determine_num_steps(start_timestamp: pd.Timestamp, end_timestamp: pd.Timestamp, freq: str) -> int:
"""Determine how many steps of ``freq`` should we make from ``start_timestamp`` to reach ``end_timestamp``.

Parameters
----------
last_train_timestamp:
last timestamp in train data
last_test_timestamp:
last timestamp in test data, should be after ``last_train_timestamp``
start_timestamp:
timestamp to start counting from
end_timestamp:
timestamp to end counting, should be not earlier than ``start_timestamp``
freq:
pandas frequency string: `Offset aliases <https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases>`_

Expand All @@ -26,26 +21,30 @@ def determine_num_steps_to_forecast(
Raises
------
ValueError:
Value of last test timestamp is less or equal than last train timestamp
Value of end timestamp is less than start timestamp
ValueError:
Last train timestamp isn't correct according to a given frequency
Start timestamp isn't correct according to a given frequency
ValueError:
Last test timestamps isn't reachable with a given frequency
End timestamp isn't reachable with a given frequency
"""
if last_test_timestamp <= last_train_timestamp:
raise ValueError("Last train timestamp should be less than last test timestamp!")
if start_timestamp > end_timestamp:
raise ValueError("Start train timestamp should be less or equal than end timestamp!")

# check if start_timestamp is normalized
normalized_start_timestamp = pd.date_range(start=start_timestamp, periods=1, freq=freq)
if normalized_start_timestamp != start_timestamp:
raise ValueError(f"Start timestamp isn't correct according to given frequency: {freq}")

# check if last_train_timestamp is normalized
normalized_last_train_timestamp = pd.date_range(start=last_train_timestamp, periods=1, freq=freq)
if normalized_last_train_timestamp != last_train_timestamp:
raise ValueError(f"Last train timestamp isn't correct according to given frequency: {freq}")
# check a simple case
if start_timestamp == end_timestamp:
return 0

# make linear probing, because for complex offsets there is a cycle in `pd.date_range`
cur_value = 1
while True:
timestamps = pd.date_range(start=last_train_timestamp, periods=cur_value + 1, freq=freq)
if timestamps[-1] == last_test_timestamp:
timestamps = pd.date_range(start=start_timestamp, periods=cur_value + 1, freq=freq)
if timestamps[-1] == end_timestamp:
return cur_value
elif timestamps[-1] > last_test_timestamp:
raise ValueError(f"Last test timestamps isn't reachable with freq: {freq}")
elif timestamps[-1] > end_timestamp:
raise ValueError(f"End timestamp isn't reachable with freq: {freq}")
cur_value += 1
30 changes: 30 additions & 0 deletions tests/test_models/test_sarimax_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pytest
from statsmodels.tsa.statespace.sarimax import SARIMAX

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
from etna.models import SARIMAXModel
from etna.pipeline import Pipeline

Expand Down Expand Up @@ -134,3 +137,30 @@ def test_sarimax_forecast_1_point(example_tsds):
assert len(pred.df) == horizon
pred_quantiles = model.forecast(future_ts, prediction_interval=True, quantiles=[0.025, 0.8])
assert len(pred_quantiles.df) == horizon


def test_prediction_simple_differencing():
"""Check that SARIMAX gives similar results with different values of ``simple_differencing``.

We generate dataset from ``generate_ar_df`` with ``ar_coef=[1]`` and it gives us (0, 1, 1) process.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1,0,1), no?

"""
horizon = 7
df = generate_ar_df(periods=100, n_segments=3, start_time="2020-01-01", ar_coef=[1])
ts = TSDataset(df=TSDataset.to_dataset(df), freq="D")

# prepare prediction from regular model
model_regular = SARIMAXModel(order=(0, 1, 1))
model_regular.fit(ts)
future_ts = ts.make_future(future_steps=horizon)
regular_prediction = model_regular.forecast(future_ts)
regular_prediction = regular_prediction.to_pandas(flatten=True)

# prepare prediction from model with simple differencing
model_simplified = SARIMAXModel(order=(0, 1, 1), simple_differencing=True)
model_simplified.fit(ts)
future_ts = ts.make_future(future_steps=horizon)
simplified_prediction = model_simplified.forecast(future_ts)
simplified_prediction = simplified_prediction.to_pandas(flatten=True)

correlation = np.corrcoef(regular_prediction["target"], simplified_prediction["target"])[0, 1]
assert correlation >= 0.95
Loading