-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add VotingEnsemble * Upd CHANGELOG * Fix docs * Fix
- Loading branch information
1 parent
ffbc70e
commit aa463ed
Showing
8 changed files
with
312 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ API | |
datasets | ||
metrics | ||
transforms | ||
ensembles | ||
analysis | ||
model_selection | ||
loggers | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
Ensembles | ||
========= | ||
|
||
.. _ensembles: | ||
|
||
.. currentmodule:: etna | ||
|
||
Details of ETNA Ensembles | ||
------------------------- | ||
|
||
See the API documentation for further details on ensembles: | ||
|
||
.. currentmodule:: etna | ||
|
||
.. moduleautosummary:: | ||
:toctree: api/ | ||
:template: custom-module-template.rst | ||
:recursive: | ||
|
||
etna.ensembles |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from etna.ensembles.voting_ensemble import VotingEnsemble |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from copy import deepcopy | ||
from typing import Iterable | ||
from typing import List | ||
from typing import Optional | ||
|
||
from joblib import Parallel | ||
from joblib import delayed | ||
|
||
from etna.datasets import TSDataset | ||
from etna.loggers import tslogger | ||
from etna.pipeline import Pipeline | ||
|
||
|
||
class VotingEnsemble(Pipeline): | ||
"""VotingEnsemble is a pipeline that forecast future values with weighted averaging of it's pipelines forecasts. | ||
Examples | ||
-------- | ||
>>> from etna.datasets import generate_ar_df | ||
>>> from etna.datasets import TSDataset | ||
>>> from etna.ensembles import VotingEnsemble | ||
>>> from etna.models import NaiveModel | ||
>>> from etna.models import ProphetModel | ||
>>> from etna.pipeline import Pipeline | ||
>>> df = generate_ar_df(periods=30, start_time="2021-06-01", ar_coef=[1.2], n_segments=3) | ||
>>> df_ts_format = TSDataset.to_dataset(df) | ||
>>> ts = TSDataset(df_ts_format, "D") | ||
>>> prophet_pipeline = Pipeline(model=ProphetModel(), transforms=[], horizon=7) | ||
>>> naive_pipeline = Pipeline(model=NaiveModel(lag=10), transforms=[], horizon=7) | ||
>>> ensemble = VotingEnsemble( | ||
... pipelines=[prophet_pipeline, naive_pipeline], | ||
... weights=[0.7, 0.3] | ||
... ) | ||
>>> ensemble.fit(ts=ts) | ||
>>> forecast = ensemble.forecast() | ||
>>> forecast | ||
segment segment_0 segment_1 segment_2 | ||
feature target target target | ||
timestamp | ||
2021-07-01 -8.84 -186.67 130.99 | ||
2021-07-02 -8.96 -198.16 138.81 | ||
2021-07-03 -9.57 -212.48 148.48 | ||
2021-07-04 -10.48 -229.16 160.13 | ||
2021-07-05 -11.20 -248.93 174.39 | ||
2021-07-06 -12.47 -281.90 197.82 | ||
2021-07-07 -13.51 -307.02 215.73 | ||
""" | ||
|
||
def __init__(self, pipelines: List[Pipeline], weights: Optional[List[float]] = None, n_jobs: int = 1): | ||
"""Init VotingEnsemble. | ||
Parameters | ||
---------- | ||
pipelines: | ||
list of pipelines that should be used in ensemble | ||
weights: | ||
list of pipelines' weights; weights will be normalized automatically. | ||
n_jobs: | ||
number of jobs to run in parallel | ||
""" | ||
self._validate_pipeline_number(pipelines=pipelines) | ||
self.horizon = self._get_horizon(pipelines=pipelines) | ||
self.weights = self._process_weights(weights=weights, pipelines_number=len(pipelines)) | ||
self.pipelines = pipelines | ||
self.n_jobs = n_jobs | ||
|
||
@staticmethod | ||
def _validate_pipeline_number(pipelines: List[Pipeline]): | ||
"""Check that given valid number of pipelines.""" | ||
if len(pipelines) < 2: | ||
raise ValueError("At least two pipelines are expected.") | ||
|
||
@staticmethod | ||
def _get_horizon(pipelines: List[Pipeline]) -> int: | ||
"""Get ensemble's horizon.""" | ||
horizons = list(set([pipeline.horizon for pipeline in pipelines])) | ||
if len(horizons) > 1: | ||
raise ValueError("All the pipelines should have the same horizon.") | ||
return horizons[0] | ||
|
||
@staticmethod | ||
def _process_weights(weights: Optional[Iterable[float]], pipelines_number: int) -> List[float]: | ||
"""Process weights: if weights are not given, set them with default values, normalize weights.""" | ||
if weights is None: | ||
weights = [1 / pipelines_number for _ in range(pipelines_number)] | ||
elif len(weights) != pipelines_number: | ||
raise ValueError("Weights size should be equal to pipelines number.") | ||
common_weight = sum(weights) | ||
weights = [w / common_weight for w in weights] | ||
return weights | ||
|
||
@staticmethod | ||
def _fit_pipeline(pipeline: Pipeline, ts: TSDataset) -> Pipeline: | ||
"""Fit given pipeline with ts.""" | ||
tslogger.log(msg=f"Start fitting {pipeline.__repr__()}.") | ||
pipeline.fit(ts=ts) | ||
tslogger.log(msg=f"Pipeline {pipeline.__repr__()} is fitted.") | ||
return pipeline | ||
|
||
def fit(self, ts: TSDataset) -> "VotingEnsemble": | ||
"""Fit pipelines in ensemble. | ||
Parameters | ||
---------- | ||
ts: | ||
TSDataset to fit ensemble | ||
Returns | ||
------- | ||
VotingEnsemble: | ||
fitted ensemble | ||
""" | ||
self.pipelines = Parallel(n_jobs=self.n_jobs, backend="multiprocessing", verbose=11)( | ||
delayed(self._fit_pipeline)(pipeline=pipeline, ts=deepcopy(ts)) for pipeline in self.pipelines | ||
) | ||
|
||
@staticmethod | ||
def _forecast_pipeline(pipeline: Pipeline) -> TSDataset: | ||
"""Make forecast with given pipeline.""" | ||
tslogger.log(msg=f"Start forecasting with {pipeline.__repr__()}.") | ||
forecast = pipeline.forecast() | ||
tslogger.log(msg=f"Forecast is done with {pipeline.__repr__()}.") | ||
return forecast | ||
|
||
def _vote(self, forecasts: List[TSDataset]) -> TSDataset: | ||
"""Get average forecast.""" | ||
forecast_df = sum([forecast[:, :, "target"] * weight for forecast, weight in zip(forecasts, self.weights)]) | ||
forecast_dataset = TSDataset(df=forecast_df, freq=forecasts[0].freq) | ||
return forecast_dataset | ||
|
||
def forecast(self) -> TSDataset: | ||
"""Forecast with ensemble: compute weighted average of pipelines' forecasts. | ||
Returns | ||
------- | ||
TSDataset: | ||
dataset with forecasts | ||
""" | ||
forecasts = Parallel(n_jobs=self.n_jobs, backend="multiprocessing", verbose=11)( | ||
delayed(self._forecast_pipeline)(pipeline=pipeline) for pipeline in self.pipelines | ||
) | ||
forecast = self._vote(forecasts=forecasts) | ||
return forecast |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from copy import deepcopy | ||
from typing import List | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from etna.datasets import TSDataset | ||
from etna.ensembles.voting_ensemble import VotingEnsemble | ||
from etna.models import CatBoostModelPerSegment | ||
from etna.models import NaiveModel | ||
from etna.models import ProphetModel | ||
from etna.pipeline import Pipeline | ||
from etna.transforms import LagTransform | ||
|
||
HORIZON = 7 | ||
|
||
|
||
@pytest.fixture | ||
def catboost_pipeline() -> Pipeline: | ||
"""Generate pipeline with CatBoostModelMultiSegment.""" | ||
pipeline = Pipeline( | ||
model=CatBoostModelPerSegment(), | ||
transforms=[LagTransform(in_column="target", lags=[10, 11, 12])], | ||
horizon=HORIZON, | ||
) | ||
return pipeline | ||
|
||
|
||
@pytest.fixture | ||
def prophet_pipeline() -> Pipeline: | ||
"""Generate pipeline with ProphetModel.""" | ||
pipeline = Pipeline(model=ProphetModel(), transforms=[], horizon=HORIZON) | ||
return pipeline | ||
|
||
|
||
@pytest.fixture | ||
def naive_pipeline() -> Pipeline: | ||
"""Generate pipeline with NaiveModel.""" | ||
pipeline = Pipeline(model=NaiveModel(20), transforms=[], horizon=2 * HORIZON) | ||
return pipeline | ||
|
||
|
||
@pytest.fixture | ||
def naive_pipeline_1() -> Pipeline: | ||
"""Generate pipeline with NaiveModel(1).""" | ||
pipeline = Pipeline(model=NaiveModel(1), transforms=[], horizon=HORIZON) | ||
return pipeline | ||
|
||
|
||
@pytest.fixture | ||
def naive_pipeline_2() -> Pipeline: | ||
"""Generate pipeline with NaiveModel(2).""" | ||
pipeline = Pipeline(model=NaiveModel(2), transforms=[], horizon=HORIZON) | ||
return pipeline | ||
|
||
|
||
def test_invalid_pipelines_number(catboost_pipeline: Pipeline): | ||
"""Test VotingEnsemble behavior in case of invalid pipelines number.""" | ||
with pytest.raises(ValueError): | ||
_ = VotingEnsemble(pipelines=[catboost_pipeline]) | ||
|
||
|
||
def test_get_horizon_pass(catboost_pipeline: Pipeline, prophet_pipeline: Pipeline): | ||
"""Check that VotingEnsemble._get horizon works correctly in case of valid pipelines list.""" | ||
horizon = VotingEnsemble._get_horizon(pipelines=[catboost_pipeline, prophet_pipeline]) | ||
assert horizon == HORIZON | ||
|
||
|
||
def test_get_horizon_fail(catboost_pipeline: Pipeline, naive_pipeline: Pipeline): | ||
"""Check that VotingEnsemble._get horizon works correctly in case of invalid pipelines list.""" | ||
with pytest.raises(ValueError): | ||
_ = VotingEnsemble._get_horizon(pipelines=[catboost_pipeline, naive_pipeline]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"weights,pipelines_number,expected", | ||
((None, 5, [0.2, 0.2, 0.2, 0.2, 0.2]), ([0.2, 0.3, 0.5], 3, [0.2, 0.3, 0.5]), ([1, 1, 2], 3, [0.25, 0.25, 0.5])), | ||
) | ||
def test_process_weights_pass( | ||
weights: Optional[List[float]], | ||
pipelines_number: int, | ||
expected: List[float], | ||
catboost_pipeline: Pipeline, | ||
prophet_pipeline: Pipeline, | ||
): | ||
"""Check that VotingEnsemble._process_weights processes weights correctly in case of valid args sets.""" | ||
result = VotingEnsemble._process_weights(weights=weights, pipelines_number=pipelines_number) | ||
assert isinstance(result, list) | ||
assert all([x == y for x, y in zip(result, expected)]) | ||
|
||
|
||
def test_process_weights_fail(): | ||
"""Check that VotingEnsemble._process_weights processes weights correctly in case of invalid args sets.""" | ||
with pytest.raises(ValueError): | ||
_ = VotingEnsemble._process_weights(weights=[0.3, 0.4, 0.3], pipelines_number=2) | ||
|
||
|
||
def test_forecast_interface(example_tsds: TSDataset, catboost_pipeline: Pipeline, prophet_pipeline: Pipeline): | ||
"""Check that VotingEnsemble.forecast returns TSDataset of correct length.""" | ||
ensemble = VotingEnsemble(pipelines=[catboost_pipeline, prophet_pipeline]) | ||
ensemble.fit(ts=example_tsds) | ||
forecast = ensemble.forecast() | ||
assert isinstance(forecast, TSDataset) | ||
assert len(forecast.df) == HORIZON | ||
|
||
|
||
def test_forecast_values_default_weights(simple_df: TSDataset, naive_pipeline_1: Pipeline, naive_pipeline_2: Pipeline): | ||
"""Check that VotingEnsemble gets average.""" | ||
ensemble = VotingEnsemble(pipelines=[naive_pipeline_1, naive_pipeline_2]) | ||
ensemble.fit(ts=simple_df) | ||
forecast = ensemble.forecast() | ||
np.testing.assert_array_equal(forecast[:, "A", "target"].values, [47.5, 48, 47.5, 48, 47.5, 48, 47.5]) | ||
np.testing.assert_array_equal(forecast[:, "B", "target"].values, [11, 12, 11, 12, 11, 12, 11]) | ||
|
||
|
||
def test_forecast_values_custom_weights(simple_df: TSDataset, naive_pipeline_1: Pipeline, naive_pipeline_2: Pipeline): | ||
"""Check that VotingEnsemble gets average.""" | ||
ensemble = VotingEnsemble(pipelines=[naive_pipeline_1, naive_pipeline_2], weights=[1, 3]) | ||
ensemble.fit(ts=simple_df) | ||
forecast = ensemble.forecast() | ||
np.testing.assert_array_equal(forecast[:, "A", "target"].values, [47.25, 48, 47.25, 48, 47.25, 48, 47.25]) | ||
np.testing.assert_array_equal(forecast[:, "B", "target"].values, [10.5, 12, 10.5, 12, 10.5, 12, 10.5]) | ||
|
||
|
||
@pytest.mark.long | ||
def test_multiprocessing_ensemples( | ||
simple_df: TSDataset, | ||
catboost_pipeline: Pipeline, | ||
prophet_pipeline: Pipeline, | ||
naive_pipeline_1: Pipeline, | ||
naive_pipeline_2: Pipeline, | ||
): | ||
"""Check that VotingEnsemble works the same in case of multi and single jobs modes.""" | ||
pipelines = [catboost_pipeline, prophet_pipeline, naive_pipeline_1, naive_pipeline_2] | ||
single_jobs_ensemble = VotingEnsemble(pipelines=deepcopy(pipelines), n_jobs=1) | ||
multi_jobs_ensemble = VotingEnsemble(pipelines=deepcopy(pipelines), n_jobs=3) | ||
|
||
single_jobs_ensemble.fit(ts=deepcopy(simple_df)) | ||
multi_jobs_ensemble.fit(ts=deepcopy(simple_df)) | ||
|
||
single_jobs_forecast = single_jobs_ensemble.forecast() | ||
multi_jobs_forecast = multi_jobs_ensemble.forecast() | ||
|
||
assert (single_jobs_forecast.df == multi_jobs_forecast.df).all().all() |