From 02bf892501f3ec6a89b5ea95fa5be16819c500f7 Mon Sep 17 00:00:00 2001 From: Mr-Geekman <36005824+Mr-Geekman@users.noreply.github.com> Date: Tue, 14 Mar 2023 11:38:29 +0300 Subject: [PATCH] Add `refit` parameter into `backtest` (#1159) --- CHANGELOG.md | 2 +- etna/pipeline/base.py | 189 ++++++++++++++++++++++--- tests/test_loggers/test_file_logger.py | 20 +-- tests/test_pipeline/conftest.py | 13 +- tests/test_pipeline/test_pipeline.py | 164 ++++++++++++++++++++- 5 files changed, 352 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c0645712..6026ef750 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add saving/loading for transforms, models, pipelines, ensembles; tutorial for saving/loading ([#1068](https://github.com/tinkoff-ai/etna/pull/1068)) - - -- +- Add `refit` parameter into `backtest` ([#1159](https://github.com/tinkoff-ai/etna/pull/1159)) - Add optional parameter `ts` into `forecast` method of pipelines ([#1071](https://github.com/tinkoff-ai/etna/pull/1071)) - Add tests on `transform` method of transforms on subset of segments, on new segments, on future with gap ([#1094](https://github.com/tinkoff-ai/etna/pull/1094)) - Add tests on `inverse_transform` method of transforms on subset of segments, on new segments, on future with gap ([#1127](https://github.com/tinkoff-ai/etna/pull/1127)) diff --git a/etna/pipeline/base.py b/etna/pipeline/base.py index ceb706254..5879fc4e8 100644 --- a/etna/pipeline/base.py +++ b/etna/pipeline/base.py @@ -1,3 +1,4 @@ +import math from abc import abstractmethod from copy import deepcopy from enum import Enum @@ -15,6 +16,7 @@ from joblib import Parallel from joblib import delayed from scipy.stats import norm +from typing_extensions import TypedDict from etna.core import AbstractSaveable from etna.core import BaseMixin @@ -210,11 +212,14 @@ def backtest( mode: str = "expand", aggregate_metrics: bool = False, n_jobs: int = 1, + refit: Union[bool, int] = True, joblib_params: Optional[Dict[str, Any]] = None, forecast_params: Optional[Dict[str, Any]] = None, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Run backtest with the pipeline. + If ``refit != True`` and some component of the pipeline doesn't support forecasting with gap, this component will raise an exception. + Parameters ---------- ts: @@ -229,6 +234,15 @@ def backtest( If True aggregate metrics above folds, return raw metrics otherwise n_jobs: Number of jobs to run in parallel + refit: + Determines how often pipeline should be retrained during iteration over folds. + + * If ``True``: pipeline is retrained on each fold. + + * If ``False``: pipeline is trained only on the first fold. + + * If ``value: int``: pipeline is trained every ``value`` folds starting from the first. + joblib_params: Additional parameters for :py:class:`joblib.Parallel` forecast_params: @@ -241,6 +255,15 @@ def backtest( """ +class FoldParallelGroup(TypedDict): + """Group for parallel fold processing.""" + + train_fold_number: int + train_mask: FoldMask + forecast_fold_numbers: List[int] + forecast_masks: List[FoldMask] + + class BasePipeline(AbstractPipeline, BaseMixin): """Base class for pipeline.""" @@ -522,21 +545,39 @@ def _compute_metrics(metrics: List[Metric], y_true: TSDataset, y_pred: TSDataset metrics_values[metric.name] = metric(y_true=y_true, y_pred=y_pred) # type: ignore return metrics_values - def _run_fold( + def _fit_backtest_pipeline( self, + ts: TSDataset, + fold_number: int, + ) -> "BasePipeline": + """Fit pipeline for a given data in backtest.""" + tslogger.start_experiment(job_type="training", group=str(fold_number)) + pipeline = deepcopy(self) + pipeline.fit(ts=ts) + tslogger.finish_experiment() + return pipeline + + def _forecast_backtest_pipeline( + self, pipeline: "BasePipeline", ts: TSDataset, fold_number: int, forecast_params: Dict[str, Any] + ) -> TSDataset: + """Make a forecast with a given pipeline in backtest.""" + tslogger.start_experiment(job_type="forecasting", group=str(fold_number)) + forecast = pipeline.forecast(ts=ts, **forecast_params) + tslogger.finish_experiment() + return forecast + + def _process_fold_forecast( + self, + forecast: TSDataset, train: TSDataset, test: TSDataset, fold_number: int, mask: FoldMask, metrics: List[Metric], - forecast_params: Dict[str, Any], ) -> Dict[str, Any]: - """Run fit-forecast pipeline of model for one fold.""" + """Process forecast made for a fold.""" tslogger.start_experiment(job_type="crossval", group=str(fold_number)) - pipeline = deepcopy(self) - pipeline.fit(ts=train) - forecast = pipeline.forecast(**forecast_params) fold: Dict[str, Any] = {} for stage_name, stage_df in zip(("train", "test"), (train, test)): fold[f"{stage_name}_timerange"] = {} @@ -620,6 +661,108 @@ def _prepare_fold_masks(self, ts: TSDataset, masks: Union[int, List[FoldMask]], mask.validate_on_dataset(ts=ts, horizon=self.horizon) return masks + @staticmethod + def _make_backtest_fold_groups(masks: List[FoldMask], refit: Union[bool, int]) -> List[FoldParallelGroup]: + """Make groups of folds for backtest.""" + if not refit: + refit = len(masks) + + grouped_folds = [] + num_groups = math.ceil(len(masks) / refit) + for group_id in range(num_groups): + train_fold_number = group_id * refit + forecast_fold_numbers = [train_fold_number + i for i in range(refit) if train_fold_number + i < len(masks)] + cur_group: FoldParallelGroup = { + "train_fold_number": train_fold_number, + "train_mask": masks[train_fold_number], + "forecast_fold_numbers": forecast_fold_numbers, + "forecast_masks": [masks[i] for i in forecast_fold_numbers], + } + grouped_folds.append(cur_group) + + return grouped_folds + + def _run_all_folds( + self, + masks: List[FoldMask], + ts: TSDataset, + metrics: List[Metric], + n_jobs: int, + refit: Union[bool, int], + joblib_params: Dict[str, Any], + forecast_params: Dict[str, Any], + ) -> Dict[int, Any]: + """Run pipeline on all folds.""" + fold_groups = self._make_backtest_fold_groups(masks=masks, refit=refit) + + with Parallel(n_jobs=n_jobs, **joblib_params) as parallel: + # fitting + fit_masks = [group["train_mask"] for group in fold_groups] + fit_datasets = ( + train for train, _ in self._generate_folds_datasets(ts=ts, masks=fit_masks, horizon=self.horizon) + ) + pipelines = parallel( + delayed(self._fit_backtest_pipeline)(ts=fit_ts, fold_number=fold_groups[group_idx]["train_fold_number"]) + for group_idx, fit_ts in enumerate(fit_datasets) + ) + + # forecasting + forecast_masks = [group["forecast_masks"] for group in fold_groups] + forecast_datasets = ( + ( + train + for train, _ in self._generate_folds_datasets( + ts=ts, masks=group_forecast_masks, horizon=self.horizon + ) + ) + for group_forecast_masks in forecast_masks + ) + forecasts_flat = parallel( + delayed(self._forecast_backtest_pipeline)( + ts=forecast_ts, + pipeline=pipelines[group_idx], + fold_number=fold_groups[group_idx]["forecast_fold_numbers"][idx], + forecast_params=forecast_params, + ) + for group_idx, group_forecast_datasets in enumerate(forecast_datasets) + for idx, forecast_ts in enumerate(group_forecast_datasets) + ) + + # processing forecasts + fold_process_train_datasets = ( + train for train, _ in self._generate_folds_datasets(ts=ts, masks=fit_masks, horizon=self.horizon) + ) + fold_process_test_datasets = ( + ( + test + for _, test in self._generate_folds_datasets( + ts=ts, masks=group_forecast_masks, horizon=self.horizon + ) + ) + for group_forecast_masks in forecast_masks + ) + fold_results_flat = parallel( + delayed(self._process_fold_forecast)( + forecast=forecasts_flat[group_idx * refit + idx], + train=train, + test=test, + fold_number=fold_groups[group_idx]["forecast_fold_numbers"][idx], + mask=fold_groups[group_idx]["forecast_masks"][idx], + metrics=metrics, + ) + for group_idx, (train, group_fold_process_test_datasets) in enumerate( + zip(fold_process_train_datasets, fold_process_test_datasets) + ) + for idx, test in enumerate(group_fold_process_test_datasets) + ) + + results = { + fold_number: fold_results_flat[group_idx * refit + idx] + for group_idx in range(len(fold_groups)) + for idx, fold_number in enumerate(fold_groups[group_idx]["forecast_fold_numbers"]) + } + return results + def backtest( self, ts: TSDataset, @@ -628,11 +771,14 @@ def backtest( mode: str = "expand", aggregate_metrics: bool = False, n_jobs: int = 1, + refit: Union[bool, int] = True, joblib_params: Optional[Dict[str, Any]] = None, forecast_params: Optional[Dict[str, Any]] = None, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Run backtest with the pipeline. + If ``refit != True`` and some component of the pipeline doesn't support forecasting with gap, this component will raise an exception. + Parameters ---------- ts: @@ -647,6 +793,15 @@ def backtest( If True aggregate metrics above folds, return raw metrics otherwise n_jobs: Number of jobs to run in parallel + refit: + Determines how often pipeline should be retrained during iteration over folds. + + * If ``True``: pipeline is retrained on each fold. + + * If ``False``: pipeline is trained only on the first fold. + + * If ``value: int``: pipeline is trained every ``value`` folds starting from the first. + joblib_params: Additional parameters for :py:class:`joblib.Parallel` forecast_params: @@ -666,21 +821,15 @@ def backtest( self._init_backtest() self._validate_backtest_metrics(metrics=metrics) masks = self._prepare_fold_masks(ts=ts, masks=n_folds, mode=mode) - - folds = Parallel(n_jobs=n_jobs, **joblib_params)( - delayed(self._run_fold)( - train=train, - test=test, - fold_number=fold_number, - mask=masks[fold_number], - metrics=metrics, - forecast_params=forecast_params, - ) - for fold_number, (train, test) in enumerate( - self._generate_folds_datasets(ts=ts, masks=masks, horizon=self.horizon) - ) + self._folds = self._run_all_folds( + masks=masks, + ts=ts, + metrics=metrics, + n_jobs=n_jobs, + refit=refit, + joblib_params=joblib_params, + forecast_params=forecast_params, ) - self._folds = {i: fold for i, fold in enumerate(folds)} metrics_df = self._get_backtest_metrics(aggregate_metrics=aggregate_metrics) forecast_df = self._get_backtest_forecasts() diff --git a/tests/test_loggers/test_file_logger.py b/tests/test_loggers/test_file_logger.py index b1603611e..ac4d7ea83 100644 --- a/tests/test_loggers/test_file_logger.py +++ b/tests/test_loggers/test_file_logger.py @@ -249,11 +249,12 @@ def test_local_file_logger_with_stacking_ensemble(example_df): assert len(list(cur_dir.iterdir())) == 1, "we've run one experiment" current_experiment_dir = list(cur_dir.iterdir())[0] - assert len(list(current_experiment_dir.iterdir())) == 2, "crossval and crossval_results folders" + assert len(list(current_experiment_dir.iterdir())) == 4, "training, forecasting, crossval, crossval_results" - assert ( - len(list((current_experiment_dir / "crossval").iterdir())) == n_folds - ), "crossval should have `n_folds` runs" + for folder in ["training", "forecasting", "crossval"]: + assert ( + len(list((current_experiment_dir / folder).iterdir())) == n_folds + ), f"{folder} should have `n_folds` runs" tslogger.remove(idx) @@ -281,11 +282,14 @@ def test_local_file_logger_with_empirical_prediction_interval(example_df): assert len(list(cur_dir.iterdir())) == 1, "we've run one experiment" current_experiment_dir = list(cur_dir.iterdir())[0] - assert len(list(current_experiment_dir.iterdir())) == 2, "crossval and crossval_results folders" - assert ( - len(list((current_experiment_dir / "crossval").iterdir())) == n_folds - ), "crossval should have `n_folds` runs" + len(list(current_experiment_dir.iterdir())) == 4 + ), "training, forecasting, crossval, crossval_results folders" + + for folder in ["training", "forecasting", "crossval"]: + assert ( + len(list((current_experiment_dir / folder).iterdir())) == n_folds + ), f"{folder} should have `n_folds` runs" tslogger.remove(idx) diff --git a/tests/test_pipeline/conftest.py b/tests/test_pipeline/conftest.py index e3dbb3464..eb88f93a4 100644 --- a/tests/test_pipeline/conftest.py +++ b/tests/test_pipeline/conftest.py @@ -8,6 +8,7 @@ from etna.datasets import TSDataset from etna.models import CatBoostPerSegmentModel +from etna.models import NaiveModel from etna.pipeline import Pipeline from etna.transforms import LagTransform @@ -25,6 +26,16 @@ def catboost_pipeline() -> Pipeline: return pipeline +@pytest.fixture +def naive_pipeline() -> Pipeline: + """Generate pipeline with NaiveModel.""" + pipeline = Pipeline( + model=NaiveModel(lag=7), + horizon=7, + ) + return pipeline + + @pytest.fixture def catboost_pipeline_big() -> Pipeline: """Generate pipeline with CatBoostPerSegmentModel.""" @@ -218,7 +229,7 @@ def masked_ts() -> TSDataset: @pytest.fixture -def ts_run_fold() -> TSDataset: +def ts_process_fold_forecast() -> TSDataset: timerange = pd.date_range(start="2020-01-01", periods=11).to_list() df = pd.DataFrame({"timestamp": timerange + timerange}) df["segment"] = ["segment_0"] * 11 + ["segment_1"] * 11 diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index fb7382dd4..9f593192f 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -32,6 +32,7 @@ from etna.pipeline import Pipeline from etna.transforms import AddConstTransform from etna.transforms import DateFlagsTransform +from etna.transforms import DifferencingTransform from etna.transforms import FilterFeaturesTransform from etna.transforms import LagTransform from etna.transforms import LogTransform @@ -439,15 +440,80 @@ def test_get_fold_info_interface_hours(catboost_pipeline: Pipeline, example_tsdf assert expected_columns == sorted(info_df.columns) +def test_get_fold_info_refit_true(example_tsdf: TSDataset): + """Check that Pipeline.backtest returns info dataframe with correct train with regular refit.""" + n_folds = 5 + pipeline = Pipeline(model=NaiveModel(lag=7), horizon=7) + _, _, info_df = pipeline.backtest(ts=example_tsdf, n_jobs=1, metrics=DEFAULT_METRICS, n_folds=n_folds, refit=True) + assert info_df["train_start_time"].nunique() == 1 + assert info_df["train_end_time"].nunique() == n_folds + assert info_df["test_start_time"].nunique() == n_folds + assert info_df["test_end_time"].nunique() == n_folds + + +def test_get_fold_info_refit_false(example_tsdf: TSDataset): + """Check that Pipeline.backtest returns info dataframe with correct train with no refit.""" + n_folds = 5 + pipeline = Pipeline(model=NaiveModel(lag=7), horizon=7) + _, _, info_df = pipeline.backtest(ts=example_tsdf, n_jobs=1, metrics=DEFAULT_METRICS, n_folds=n_folds, refit=False) + assert info_df["train_start_time"].nunique() == 1 + assert info_df["train_end_time"].nunique() == 1 + assert info_df["test_start_time"].nunique() == n_folds + assert info_df["test_end_time"].nunique() == n_folds + + +@pytest.mark.parametrize( + "n_folds, refit, expected_refits", + [ + (1, 1, 1), + (1, 2, 1), + (3, 1, 3), + (3, 2, 2), + (3, 3, 1), + (3, 4, 1), + (4, 1, 4), + (4, 2, 2), + (4, 3, 2), + (4, 4, 1), + (4, 5, 1), + ], +) +def test_get_fold_info_refit_int(n_folds, refit, expected_refits, example_tsdf: TSDataset): + """Check that Pipeline.backtest returns info dataframe with correct train with rare refit.""" + pipeline = Pipeline(model=NaiveModel(lag=7), horizon=7) + _, _, info_df = pipeline.backtest(ts=example_tsdf, n_jobs=1, metrics=DEFAULT_METRICS, n_folds=n_folds, refit=refit) + assert info_df["train_start_time"].nunique() == 1 + assert info_df["train_end_time"].nunique() == expected_refits + assert info_df["test_start_time"].nunique() == n_folds + assert info_df["test_end_time"].nunique() == n_folds + + +def test_backtest_refit_success(catboost_pipeline: Pipeline, big_example_tsdf: TSDataset): + """Check that backtest without refit works on pipeline that supports it.""" + _ = catboost_pipeline.backtest(ts=big_example_tsdf, n_jobs=1, metrics=DEFAULT_METRICS, n_folds=3, refit=False) + + +def test_backtest_refit_fail(big_example_tsdf: TSDataset): + """Check that backtest without refit doesn't work on pipeline that doesn't support it.""" + pipeline = Pipeline( + model=NaiveModel(lag=7), + transforms=[DifferencingTransform(in_column="target", inplace=True)], + horizon=7, + ) + with pytest.raises(ValueError, match="Test should go after the train without gaps"): + _ = pipeline.backtest(ts=big_example_tsdf, n_jobs=1, metrics=DEFAULT_METRICS, n_folds=3, refit=False) + + @pytest.mark.long_1 -def test_backtest_with_n_jobs(catboost_pipeline: Pipeline, big_example_tsdf: TSDataset): +@pytest.mark.parametrize("refit", [True, False, 2]) +def test_backtest_with_n_jobs(refit, catboost_pipeline: Pipeline, big_example_tsdf: TSDataset): """Check that Pipeline.backtest gives the same results in case of single and multiple jobs modes.""" ts1 = deepcopy(big_example_tsdf) ts2 = deepcopy(big_example_tsdf) pipeline_1 = deepcopy(catboost_pipeline) pipeline_2 = deepcopy(catboost_pipeline) - _, forecast_1, _ = pipeline_1.backtest(ts=ts1, n_jobs=1, metrics=DEFAULT_METRICS) - _, forecast_2, _ = pipeline_2.backtest(ts=ts2, n_jobs=3, metrics=DEFAULT_METRICS) + _, forecast_1, _ = pipeline_1.backtest(ts=ts1, n_jobs=1, n_folds=4, metrics=DEFAULT_METRICS, refit=refit) + _, forecast_2, _ = pipeline_2.backtest(ts=ts2, n_jobs=3, n_folds=4, metrics=DEFAULT_METRICS, refit=refit) assert (forecast_1 == forecast_2).all().all() @@ -564,17 +630,103 @@ def test_generate_folds_datasets_without_first_date(ts_name, mask, request): (FoldMask("2020-01-01", "2020-01-07", ["2020-01-08", "2020-01-11"]), {"segment_0": 95.5, "segment_1": 5}), ), ) -def test_run_fold(ts_run_fold: TSDataset, mask: FoldMask, expected: Dict[str, List[float]]): - train, test = ts_run_fold.train_test_split( +def test_process_fold_forecast(ts_process_fold_forecast, mask: FoldMask, expected: Dict[str, List[float]]): + train, test = ts_process_fold_forecast.train_test_split( train_start=mask.first_train_timestamp, train_end=mask.last_train_timestamp ) pipeline = Pipeline(model=NaiveModel(lag=5), transforms=[], horizon=4) - fold = pipeline._run_fold(train, test, 1, mask, [MAE()], forecast_params=dict()) + pipeline = pipeline.fit(ts=train) + forecast = pipeline.forecast() + fold = pipeline._process_fold_forecast( + forecast=forecast, train=train, test=test, fold_number=1, mask=mask, metrics=[MAE()] + ) for seg in fold["metrics"]["MAE"].keys(): assert fold["metrics"]["MAE"][seg] == expected[seg] +def test_make_backtest_fold_groups_refit_true(): + masks = [MagicMock() for _ in range(2)] + obtained_results = Pipeline._make_backtest_fold_groups(masks=masks, refit=True) + expected_results = [ + { + "train_fold_number": 0, + "train_mask": masks[0], + "forecast_fold_numbers": [0], + "forecast_masks": [masks[0]], + }, + { + "train_fold_number": 1, + "train_mask": masks[1], + "forecast_fold_numbers": [1], + "forecast_masks": [masks[1]], + }, + ] + assert obtained_results == expected_results + + +def test_make_backtest_fold_groups_refit_false(): + masks = [MagicMock() for _ in range(2)] + obtained_results = Pipeline._make_backtest_fold_groups(masks=masks, refit=False) + expected_results = [ + { + "train_fold_number": 0, + "train_mask": masks[0], + "forecast_fold_numbers": [0, 1], + "forecast_masks": [masks[0], masks[1]], + } + ] + assert obtained_results == expected_results + + +def test_make_backtest_fold_groups_refit_int(): + masks = [MagicMock() for _ in range(5)] + obtained_results = Pipeline._make_backtest_fold_groups(masks=masks, refit=2) + expected_results = [ + { + "train_fold_number": 0, + "train_mask": masks[0], + "forecast_fold_numbers": [0, 1], + "forecast_masks": [masks[0], masks[1]], + }, + { + "train_fold_number": 2, + "train_mask": masks[2], + "forecast_fold_numbers": [2, 3], + "forecast_masks": [masks[2], masks[3]], + }, + { + "train_fold_number": 4, + "train_mask": masks[4], + "forecast_fold_numbers": [4], + "forecast_masks": [masks[4]], + }, + ] + assert obtained_results == expected_results + + +@pytest.mark.parametrize( + "n_folds, refit, expected_refits", + [ + (1, 1, 1), + (1, 2, 1), + (3, 1, 3), + (3, 2, 2), + (3, 3, 1), + (3, 4, 1), + (4, 1, 4), + (4, 2, 2), + (4, 3, 2), + (4, 4, 1), + (4, 5, 1), + ], +) +def test_make_backtest_fold_groups_length_refit_int(n_folds, refit, expected_refits): + masks = [MagicMock() for _ in range(n_folds)] + obtained_results = Pipeline._make_backtest_fold_groups(masks=masks, refit=refit) + assert len(obtained_results) == expected_refits + + @pytest.mark.parametrize( "lag,expected", ((5, {"segment_0": 76.923077, "segment_1": 90.909091}), (6, {"segment_0": 100, "segment_1": 120})) )