From 60b68b887e0683c94127c42ca4678409fa5d9af1 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 23 Mar 2023 10:42:21 +0300 Subject: [PATCH] refactor: replace selt.ts with ts, remove redundant checks of self.ts --- etna/pipeline/base.py | 5 +---- etna/pipeline/hierarchical_pipeline.py | 2 +- tests/test_pipeline/test_hierarchical_pipeline.py | 4 ++-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/etna/pipeline/base.py b/etna/pipeline/base.py index c099c102d..b8fc003a1 100644 --- a/etna/pipeline/base.py +++ b/etna/pipeline/base.py @@ -325,10 +325,7 @@ def _add_forecast_borders( self, ts: TSDataset, backtest_forecasts: pd.DataFrame, quantiles: Sequence[float], predictions: TSDataset ) -> None: """Estimate prediction intervals and add to the forecasts.""" - if self.ts is None: - raise ValueError("Pipeline is not fitted!") - - backtest_forecasts = TSDataset(df=backtest_forecasts, freq=self.ts.freq) + backtest_forecasts = TSDataset(df=backtest_forecasts, freq=ts.freq) residuals = ( backtest_forecasts.loc[:, pd.IndexSlice[:, "target"]] - ts[backtest_forecasts.index.min() : backtest_forecasts.index.max(), :, "target"] diff --git a/etna/pipeline/hierarchical_pipeline.py b/etna/pipeline/hierarchical_pipeline.py index 04ed84d79..c2752c9b0 100644 --- a/etna/pipeline/hierarchical_pipeline.py +++ b/etna/pipeline/hierarchical_pipeline.py @@ -116,7 +116,7 @@ def raw_forecast( freq=forecast.freq, df_exog=forecast.df_exog, known_future=forecast.known_future, - hierarchical_structure=self.ts.hierarchical_structure, # type: ignore + hierarchical_structure=ts.hierarchical_structure, # type: ignore ) return hierarchical_forecast diff --git a/tests/test_pipeline/test_hierarchical_pipeline.py b/tests/test_pipeline/test_hierarchical_pipeline.py index 834f6e42c..e9b902cd1 100644 --- a/tests/test_pipeline/test_hierarchical_pipeline.py +++ b/tests/test_pipeline/test_hierarchical_pipeline.py @@ -1,5 +1,4 @@ import pathlib -from copy import deepcopy from unittest.mock import Mock from unittest.mock import patch @@ -23,7 +22,8 @@ from etna.transforms import LinearTrendTransform from etna.transforms import MeanTransform from tests.test_pipeline.utils import assert_pipeline_equals_loaded_original -from tests.test_pipeline.utils import assert_pipeline_forecasts_given_ts_with_prediction_intervals, assert_pipeline_forecasts_given_ts +from tests.test_pipeline.utils import assert_pipeline_forecasts_given_ts +from tests.test_pipeline.utils import assert_pipeline_forecasts_given_ts_with_prediction_intervals from tests.utils import to_be_fixed