Skip to content

Commit

Permalink
Fix problem with swapped forecast methods in HierarchicalPipeline (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored May 10, 2023
1 parent fecfef5 commit 634a5c6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `tsfresh` into optional dependencies, remove instruction about `pip install tsfresh` ([#1246](https://github.com/tinkoff-ai/etna/pull/1246))
- Fix `DeepARModel` and `TFTModel` to work with changed `prediction_size` ([#1251](https://github.com/tinkoff-ai/etna/pull/1251))
- Fix problems with flake8 B023 ([#1252](https://github.com/tinkoff-ai/etna/pull/1252))
- Fix problem with swapped forecast methods in HierarchicalPipeline ([#1259](https://github.com/tinkoff-ai/etna/pull/1259))

## [2.0.0] - 2023-04-11
### Added
Expand Down
25 changes: 12 additions & 13 deletions etna/pipeline/hierarchical_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,25 +312,24 @@ def _forecast_prediction_interval(
self, ts: TSDataset, predictions: TSDataset, quantiles: Sequence[float], n_folds: int
) -> TSDataset:
"""Add prediction intervals to the forecasts."""
# TODO: fix this: what if during backtest KeyboardInterrupt is raised
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

if self.ts is None:
raise ValueError("Pipeline is not fitted! Fit the Pipeline before calling forecast method.")

# TODO: rework intervals estimation for `BottomUpReconciliator`

with tslogger.disable():
_, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds)
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore
try:
# TODO: rework intervals estimation for `BottomUpReconciliator`

source_ts = self.reconciliator.aggregate(ts=ts)
self._add_forecast_borders(
ts=source_ts, backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions
)
with tslogger.disable():
_, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds)

self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore
source_ts = self.reconciliator.aggregate(ts=ts)
self._add_forecast_borders(
ts=source_ts, backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions
)
return predictions

return predictions
finally:
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

def save(self, path: pathlib.Path):
"""Save the object.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_pipeline/test_hierarchical_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,33 @@ def test_backtest_w_exog(product_level_constant_hierarchical_ts_with_exog, recon
np.testing.assert_allclose(metrics["MAE"], 0)


@pytest.mark.parametrize(
"reconciliator",
(
TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"),
TopDownReconciliator(target_level="product", source_level="market", period=1, method="PHA"),
BottomUpReconciliator(target_level="total", source_level="market"),
),
)
def test_private_forecast_prediction_interval_no_swap_after_error(
product_level_constant_hierarchical_ts_with_exog, reconciliator
):
ts = product_level_constant_hierarchical_ts_with_exog
model = LinearPerSegmentModel()
pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1)
pipeline.backtest = Mock(side_effect=ValueError("Some error"))
forecast_method = pipeline.forecast
raw_forecast_method = pipeline.raw_forecast

pipeline.fit(ts=ts)
with pytest.raises(ValueError, match="Some error"):
_ = pipeline.forecast(prediction_interval=True, n_folds=1, quantiles=[0.025, 0.5, 0.975])

# check that methods aren't swapped
assert pipeline.forecast == forecast_method
assert pipeline.raw_forecast == raw_forecast_method


@pytest.mark.parametrize(
"reconciliator",
(
Expand Down

1 comment on commit 634a5c6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.