Skip to content

Create DirectEnsemble #824

Merged
merged 6 commits into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- `DirectEnsemble` ([#824](https://github.com/tinkoff-ai/etna/pull/824))
-
-
-
Expand Down
1 change: 1 addition & 0 deletions etna/ensembles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from etna.ensembles.base import EnsembleMixin
from etna.ensembles.direct_ensemble import DirectEnsemble
from etna.ensembles.stacking_ensemble import StackingEnsemble
from etna.ensembles.voting_ensemble import VotingEnsemble
135 changes: 135 additions & 0 deletions etna/ensembles/direct_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from copy import deepcopy
from typing import Any
from typing import Dict
from typing import List
from typing import Optional

import numpy as np
from joblib import Parallel
from joblib import delayed

from etna.datasets import TSDataset
from etna.ensembles import EnsembleMixin
from etna.pipeline.base import BasePipeline


class DirectEnsemble(BasePipeline, EnsembleMixin):
"""DirectEnsemble is a pipeline that forecasts future values merging the forecasts of base pipelines.

Ensemble expects several pipelines during init. These pipelines are expected to have different forecasting horizons.
For each point in the future, forecast of the ensemble is forecast of base pipeline with the shortest horizon,
which covers this point.

Examples
--------
>>> from etna.datasets import generate_ar_df
>>> from etna.datasets import TSDataset
>>> from etna.ensembles import DirectEnsemble
>>> from etna.models import NaiveModel
>>> from etna.models import ProphetModel
>>> from etna.pipeline import Pipeline
>>> df = generate_ar_df(periods=30, start_time="2021-06-01", ar_coef=[1.2], n_segments=3)
>>> df_ts_format = TSDataset.to_dataset(df)
>>> ts = TSDataset(df_ts_format, "D")
>>> prophet_pipeline = Pipeline(model=ProphetModel(), transforms=[], horizon=3)
>>> naive_pipeline = Pipeline(model=NaiveModel(lag=10), transforms=[], horizon=5)
>>> ensemble = DirectEnsemble(pipelines=[prophet_pipeline, naive_pipeline])
>>> _ = ensemble.fit(ts=ts)
>>> forecast = ensemble.forecast()
>>> forecast
segment segment_0 segment_1 segment_2
feature target target target
timestamp
2021-07-01 -10.37 -232.60 163.16
2021-07-02 -10.59 -242.05 169.62
2021-07-03 -11.41 -253.82 177.62
2021-07-04 -5.85 -139.57 96.99
2021-07-05 -6.11 -167.69 116.59
"""

def __init__(
self,
pipelines: List[BasePipeline],
n_jobs: int = 1,
joblib_params: Optional[Dict[str, Any]] = None,
):
"""Init DirectEnsemble.

Parameters
----------
pipelines:
List of pipelines that should be used in ensemble
n_jobs:
Number of jobs to run in parallel
joblib_params:
Additional parameters for :py:class:`joblib.Parallel`

Raises
------
ValueError:
If two or more pipelines have the same horizons.
"""
self._validate_pipeline_number(pipelines=pipelines)
self.pipelines = pipelines
self.n_jobs = n_jobs
if joblib_params is None:
self.joblib_params = dict(verbose=11, backend="multiprocessing", mmap_mode="c")
else:
self.joblib_params = joblib_params
super().__init__(horizon=self._get_horizon(pipelines=pipelines))

@staticmethod
def _get_horizon(pipelines: List[BasePipeline]) -> int:
"""Get ensemble's horizon."""
horizons = {pipeline.horizon for pipeline in pipelines}
if len(horizons) != len(pipelines):
raise ValueError("All the pipelines should have pairwise different horizons.")
return max(horizons)

def fit(self, ts: TSDataset) -> "DirectEnsemble":
"""Fit pipelines in ensemble.

Parameters
----------
ts:
TSDataset to fit ensemble

Returns
-------
self:
Fitted ensemble
"""
self.ts = ts
self.pipelines = Parallel(n_jobs=self.n_jobs, **self.joblib_params)(
delayed(self._fit_pipeline)(pipeline=pipeline, ts=deepcopy(ts)) for pipeline in self.pipelines
)
return self

def _merge(self, forecasts: List[TSDataset]) -> TSDataset:
"""Merge the forecasts of base pipelines according to the direct strategy."""
segments = sorted(forecasts[0].segments)
horizons = [pipeline.horizon for pipeline in self.pipelines]
pipelines_order = np.argsort(horizons)[::-1]
# TODO: Fix slicing with explicit passing the segments in issue #775
forecast_df = forecasts[pipelines_order[0]][:, segments, "target"]
for idx in pipelines_order:
# TODO: Fix slicing with explicit passing the segments in issue #775
horizon, forecast = horizons[idx], forecasts[idx][:, segments, "target"]
forecast_df.iloc[:horizon] = forecast
forecast_dataset = TSDataset(df=forecast_df, freq=forecasts[0].freq)
return forecast_dataset

def _forecast(self) -> TSDataset:
"""Make predictions.

In each point in the future, forecast of the ensemble is forecast of base pipeline with the shortest horizon,
which covers this point.
"""
if self.ts is None:
raise ValueError("Something went wrong, ts is None!")

forecasts = Parallel(n_jobs=self.n_jobs, backend="multiprocessing", verbose=11)(
delayed(self._forecast_pipeline)(pipeline=pipeline) for pipeline in self.pipelines
)
forecast = self._merge(forecasts=forecasts)
return forecast
48 changes: 48 additions & 0 deletions tests/test_ensembles/test_direct_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest.mock import Mock

import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.datasets import generate_from_patterns_df
from etna.ensembles import DirectEnsemble
from etna.models import NaiveModel
from etna.pipeline import Pipeline


@pytest.fixture
def simple_ts_train():
df = generate_from_patterns_df(patterns=[[1, 3, 5], [2, 4, 6], [7, 9, 11]], periods=3, start_time="2000-01-01")
df = TSDataset.to_dataset(df)
ts = TSDataset(df=df, freq="D")
return ts


@pytest.fixture
def simple_ts_forecast():
df = generate_from_patterns_df(patterns=[[5, 3], [6, 4], [11, 9]], periods=2, start_time="2000-01-04")
df = TSDataset.to_dataset(df)
ts = TSDataset(df=df, freq="D")
return ts


def test_get_horizon():
ensemble = DirectEnsemble(pipelines=[Mock(horizon=1), Mock(horizon=2)])
assert ensemble.horizon == 2


def test_get_horizon_raise_error_on_same_horizons():
with pytest.raises(ValueError, match="All the pipelines should have pairwise different horizons."):
_ = DirectEnsemble(pipelines=[Mock(horizon=1), Mock(horizon=1)])


def test_forecast(simple_ts_train, simple_ts_forecast):
ensemble = DirectEnsemble(
pipelines=[
Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1),
Pipeline(model=NaiveModel(lag=3), transforms=[], horizon=2),
]
)
ensemble.fit(simple_ts_train)
forecast = ensemble.forecast()
pd.testing.assert_frame_equal(forecast.to_pandas(), simple_ts_forecast.to_pandas())