Skip to content

Commit

Permalink
Implement predict method in SARIMAXModel, AutoARIMAModel, `Seas…
Browse files Browse the repository at this point in the history
…onalMovingAverageModel`, `DeadlineMovingAverageModel` (#948)
  • Loading branch information
Mr-Geekman authored Sep 21, 2022
1 parent 583c0ca commit d5eb926
Show file tree
Hide file tree
Showing 11 changed files with 485 additions and 195 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DirectEnsemble` ([#824](https://github.com/tinkoff-ai/etna/pull/824))
-
-
-
- Implement predict method in `SARIMAXModel`, `AutoARIMAModel`, `SeasonalMovingAverageModel`, `DeadlineMovingAverageModel` ([#948](https://github.com/tinkoff-ai/etna/pull/948))
- Make `SeasonalMovingAverageModel` and `DeadlineMovingAverageModel` to work with context ([#917](https://github.com/tinkoff-ai/etna/pull/917))
-
- Add `predict` method to models ([#935](https://github.com/tinkoff-ai/etna/pull/935))
Expand Down
2 changes: 1 addition & 1 deletion etna/analysis/outliers/prediction_interval_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_anomalies_prediction_interval(
model_instance = model(**model_params)
model_instance.fit(ts_inner)
lower_p, upper_p = [(1 - interval_width) / 2, (1 + interval_width) / 2]
prediction_interval = model_instance.forecast(
prediction_interval = model_instance.predict(
deepcopy(ts_inner), prediction_interval=True, quantiles=[lower_p, upper_p]
)
for segment in ts_inner.segments:
Expand Down
100 changes: 72 additions & 28 deletions etna/models/deadline_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,32 @@ def _get_context_beginning(

return first_index

def _make_predictions(self, result_template: pd.Series, context: pd.Series, prediction_size: int) -> np.ndarray:
"""Make predictions using ``result_template`` as a base and ``context`` as a context."""
index = result_template.index
start_idx = len(result_template) - prediction_size
end_idx = len(result_template)
for i in range(start_idx, end_idx):
for w in range(1, self.window + 1):
if self.seasonality == SeasonalityMode.month:
prev_date = result_template.index[i] - pd.DateOffset(months=w)
elif self.seasonality == SeasonalityMode.year:
prev_date = result_template.index[i] - pd.DateOffset(years=w)

result_template.loc[index[i]] += context.loc[prev_date]

result_template.loc[index[i]] = result_template.loc[index[i]] / self.window

result_values = result_template.values[-prediction_size:]
return result_values

def forecast(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
"""
Compute predictions from a DeadlineMovingAverageModel.
"""Compute autoregressive forecasts.
Parameters
----------
df: pd.DataFrame
Used only for getting the horizon of forecast and timestamps.
df:
Features dataframe.
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context for models that require it.
Expand All @@ -156,41 +174,67 @@ def forecast(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
------
ValueError:
if context isn't big enough
ValueError:
if forecast context contains NaNs
"""
context_beginning = self._get_context_beginning(
df=df, prediction_size=prediction_size, seasonality=self.seasonality, window=self.window
)

df = df.set_index("timestamp")
df_history = df.iloc[:-prediction_size]
history_targets = df_history["target"]
history_timestamps = df_history["timestamp"]
history_targets = history_targets.loc[history_timestamps >= context_beginning]
history_timestamps = history_timestamps.loc[history_timestamps >= context_beginning]
future_timestamps = df["timestamp"].iloc[-prediction_size:]
history = df_history["target"]
history = history[history.index >= context_beginning]
if np.any(history.isnull()):
raise ValueError("There are NaNs in a forecast context, forecast method required context to filled!")

index = pd.date_range(start=context_beginning, end=df.index[-1], freq=self._freq)
result_template = np.append(history.values, np.zeros(prediction_size))
result_template = pd.Series(result_template, index=index)
result_values = self._make_predictions(
result_template=result_template, context=result_template, prediction_size=prediction_size
)
return result_values

index = pd.date_range(start=context_beginning, end=future_timestamps.iloc[-1])
res = np.append(history_targets.values, np.zeros(prediction_size))
res = pd.DataFrame(res)
res.index = index
for i in range(len(history_targets), len(res)):
for w in range(1, self.window + 1):
if self.seasonality == SeasonalityMode.month:
prev_date = res.index[i] - pd.DateOffset(months=w)
elif self.seasonality == SeasonalityMode.year:
prev_date = res.index[i] - pd.DateOffset(years=w)
def predict(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
"""Compute predictions using true target data as context.
Parameters
----------
df:
Features dataframe.
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context for models that require it.
if prev_date <= history_timestamps.iloc[-1]:
res.loc[index[i]] += history_targets.loc[history_timestamps == prev_date].values
else:
res.loc[index[i]] += res.loc[prev_date].values
Returns
-------
:
Array with predictions.
res.loc[index[i]] = res.loc[index[i]] / self.window
Raises
------
ValueError:
if context isn't big enough
ValueError:
if there are NaNs in a target column on timestamps that are required to make predictions
"""
context_beginning = self._get_context_beginning(
df=df, prediction_size=prediction_size, seasonality=self.seasonality, window=self.window
)

res = res.values.ravel()[-prediction_size:]
return res
df = df.set_index("timestamp")
context = df["target"]
context = context[context.index >= context_beginning]
if np.any(np.isnan(context)):
raise ValueError("There are NaNs in a target column, predict method requires target to be filled!")

def predict(self, df: pd.DataFrame, prediction_size: int) -> pd.DataFrame:
raise NotImplementedError("Method predict isn't currently implemented!")
index = pd.date_range(start=df.index[-prediction_size], end=df.index[-1], freq=self._freq)
result_template = pd.Series(np.zeros(prediction_size), index=index)
result_values = self._make_predictions(
result_template=result_template, context=context, prediction_size=prediction_size
)
return result_values

@property
def context_size(self) -> int:
Expand Down
55 changes: 33 additions & 22 deletions etna/models/sarimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,10 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SARIMAXBaseAdapter":

return self

def forecast(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame:
"""
Compute autoregressive predictions from a SARIMAX model.
Parameters
----------
df:
Features dataframe
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution
Returns
-------
:
DataFrame with predictions
"""
def _make_prediction(
self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float], dynamic: bool
) -> pd.DataFrame:
"""Make predictions taking into account ``dynamic`` parameter."""
if self._fit_results is None:
raise ValueError("Model is not fitted! Fit the model before calling predict method!")

Expand All @@ -106,14 +92,19 @@ def forecast(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Seque

if prediction_interval:
forecast, _ = seasonal_prediction_with_confidence(
arima_res=self._fit_results, start=start_idx, end=end_idx, X=exog_future, alpha=0.05
arima_res=self._fit_results, start=start_idx, end=end_idx, X=exog_future, alpha=0.05, dynamic=dynamic
)
y_pred = pd.DataFrame({"mean": forecast})
for quantile in quantiles:
# set alpha in the way to get a desirable quantile
alpha = min(quantile * 2, (1 - quantile) * 2)
_, borders = seasonal_prediction_with_confidence(
arima_res=self._fit_results, start=start_idx, end=end_idx, X=exog_future, alpha=alpha
arima_res=self._fit_results,
start=start_idx,
end=end_idx,
X=exog_future,
alpha=alpha,
dynamic=dynamic,
)
if quantile < 1 / 2:
series = borders[:, 0]
Expand All @@ -122,7 +113,7 @@ def forecast(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Seque
y_pred[f"mean_{quantile:.4g}"] = series
else:
forecast, _ = seasonal_prediction_with_confidence(
arima_res=self._fit_results, start=start_idx, end=end_idx, X=exog_future, alpha=0.05
arima_res=self._fit_results, start=start_idx, end=end_idx, X=exog_future, alpha=0.05, dynamic=dynamic
)
y_pred = pd.DataFrame({"mean": forecast})

Expand All @@ -132,6 +123,26 @@ def forecast(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Seque
y_pred = y_pred.rename(rename_dict, axis=1)
return y_pred

def forecast(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame:
"""
Compute autoregressive predictions from a SARIMAX model.
Parameters
----------
df:
Features dataframe
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution
Returns
-------
:
DataFrame with predictions
"""
return self._make_prediction(df=df, prediction_interval=prediction_interval, quantiles=quantiles, dynamic=True)

def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame:
"""
Compute predictions from a SARIMAX model and use true in-sample data as lags if possible.
Expand All @@ -150,7 +161,7 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
:
DataFrame with predictions
"""
return self.forecast(df=df, prediction_interval=prediction_interval, quantiles=quantiles)
return self._make_prediction(df=df, prediction_interval=prediction_interval, quantiles=quantiles, dynamic=False)

@abstractmethod
def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResultsWrapper:
Expand Down
49 changes: 44 additions & 5 deletions etna/models/seasonal_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,12 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SeasonalMovingAverag
return self

def forecast(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
"""
Compute predictions from a SeasonalMovingAverage model.
"""Compute autoregressive forecasts.
Parameters
----------
df:
Used only for getting the horizon of forecast
Features dataframe.
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context for models that require it.
Expand All @@ -83,6 +82,8 @@ def forecast(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
------
ValueError:
if context isn't big enough
ValueError:
if forecast context contains NaNs
"""
expected_length = prediction_size + self.shift
if len(df) < expected_length:
Expand All @@ -91,14 +92,52 @@ def forecast(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
)

history = df["target"][-expected_length:-prediction_size]
if np.any(history.isnull()):
raise ValueError("There are NaNs in a forecast context, forecast method required context to filled!")

res = np.append(history, np.zeros(prediction_size))
for i in range(self.shift, len(res)):
res[i] = res[i - self.shift : i : self.seasonality].mean()
y_pred = res[-prediction_size:]
return y_pred

def predict(self, df: pd.DataFrame, prediction_size: int) -> pd.DataFrame:
raise NotImplementedError("Method predict isn't currently implemented!")
def predict(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
"""Compute predictions using true target data as context.
Parameters
----------
df:
Features dataframe.
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context for models that require it.
Returns
-------
:
Array with predictions.
Raises
------
ValueError:
if context isn't big enough
ValueError:
if there are NaNs in a target column on timestamps that are required to make predictions
"""
expected_length = prediction_size + self.shift
if len(df) < expected_length:
raise ValueError(
"Given context isn't big enough, try to decrease context_size, prediction_size of increase length of given dataframe!"
)

context = df["target"][-expected_length:].values
if np.any(np.isnan(context)):
raise ValueError("There are NaNs in a target column, predict method requires target to be filled!")

res = np.zeros(prediction_size)
for res_idx, context_idx in enumerate(range(self.shift, len(context))):
res[res_idx] = context[context_idx - self.shift : context_idx : self.seasonality].mean()
return res


class SeasonalMovingAverageModel(
Expand Down
Loading

0 comments on commit d5eb926

Please sign in to comment.