From 2acf8d6fc020241feddc461b6b5c2ca8331b1092 Mon Sep 17 00:00:00 2001 From: Julia Shenshina Date: Tue, 12 Oct 2021 13:24:29 +0300 Subject: [PATCH 1/3] Add example for ProphetModel --- etna/models/prophet.py | 42 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/etna/models/prophet.py b/etna/models/prophet.py index 6b291a89e..de0c155c1 100644 --- a/etna/models/prophet.py +++ b/etna/models/prophet.py @@ -129,7 +129,44 @@ def predict(self, df: pd.DataFrame, confidence_interval: bool = False): class ProphetModel(PerSegmentModel): - """Class for holding Prophet model.""" + """Class for holding Prophet model. + + Examples + -------- + >>> from etna.datasets import generate_periodic_df + >>> from etna.datasets import TSDataset + >>> from etna.models import ProphetModel + >>> classic_df = generate_periodic_df( + ... periods=100, + ... start_time="2020-01-01", + ... n_segments=4, + ... period=7, + ... sigma=3 + ... ) + >>> df = TSDataset.to_dataset(df=classic_df) + >>> ts = TSDataset(df, freq="D") + >>> future = ts.make_future(7) + >>> model = ProphetModel(growth="flat") + >>> model.fit(ts=ts) + ProphetModel(growth = 'flat', changepoints = None, n_changepoints = 25, + changepoint_range = 0.8, yearly_seasonality = 'auto', weekly_seasonality = 'auto', + daily_seasonality = 'auto', holidays = None, seasonality_mode = 'additive', + seasonality_prior_scale = 10.0, holidays_prior_scale = 10.0, mcmc_samples = 0, + interval_width = 0.8, uncertainty_samples = 1000, stan_backend = None, + additional_seasonality_params = (), ) + >>> forecast = model.forecast(future) + >>> forecast + segment segment_0 segment_1 segment_2 segment_3 + feature target target target target + timestamp + 2020-04-10 9.00 9.00 4.00 6.00 + 2020-04-11 5.00 2.00 7.00 9.00 + 2020-04-12 0.00 4.00 7.00 9.00 + 2020-04-13 0.00 5.00 9.00 7.00 + 2020-04-14 1.00 2.00 1.00 6.00 + 2020-04-15 5.00 7.00 4.00 7.00 + 2020-04-16 8.00 6.00 2.00 0.00 + """ def __init__( self, @@ -282,6 +319,7 @@ def _forecast_segment( @log_decorator def forecast(self, ts: TSDataset, confidence_interval: bool = False) -> TSDataset: """Make predictions. + Parameters ---------- ts: @@ -294,7 +332,7 @@ def forecast(self, ts: TSDataset, confidence_interval: bool = False) -> TSDatase Models result Notes ----- - The width of the confidence interval is specified in the constructor of ProphetModel setting the interval_width + The width of the confidence interval is specified in the constructor of ProphetModel setting the interval_width. """ if self._segments is None: raise ValueError("The model is not fitted yet, use fit() to train it") From a7025c877d63a23685e334d7e886feb4c812488a Mon Sep 17 00:00:00 2001 From: Julia Shenshina Date: Tue, 12 Oct 2021 13:31:03 +0300 Subject: [PATCH 2/3] Fix docs --- etna/models/prophet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/etna/models/prophet.py b/etna/models/prophet.py index de0c155c1..ffcaea1a7 100644 --- a/etna/models/prophet.py +++ b/etna/models/prophet.py @@ -189,6 +189,7 @@ def __init__( ): """ Create instance of Prophet model. + Parameters ---------- growth: From 4278e6c5c57aad82a522a6b80ede19a1d8b8086b Mon Sep 17 00:00:00 2001 From: Julia Shenshina Date: Tue, 12 Oct 2021 14:34:23 +0300 Subject: [PATCH 3/3] Upd CHANGELOG --- CHANGELOG.md | 1 + etna/models/prophet.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8a893c74..f78f48d9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Cluster plotter to EDA ([#169](https://github.com/tinkoff-ai/etna-ts/pull/169)) - STLTransform class ([#158](https://github.com/tinkoff-ai/etna-ts/pull/158)) - NN_examples notebook ([#159](https://github.com/tinkoff-ai/etna-ts/pull/159)) +- Example for ProphetModel ([#178](https://github.com/tinkoff-ai/etna-ts/pull/178)) ### Changed - Delete offset from WindowStatisticsTransform ([#111](https://github.com/tinkoff-ai/etna-ts/pull/111)) diff --git a/etna/models/prophet.py b/etna/models/prophet.py index ffcaea1a7..1164c5624 100644 --- a/etna/models/prophet.py +++ b/etna/models/prophet.py @@ -156,16 +156,16 @@ class ProphetModel(PerSegmentModel): additional_seasonality_params = (), ) >>> forecast = model.forecast(future) >>> forecast - segment segment_0 segment_1 segment_2 segment_3 - feature target target target target + segment segment_0 segment_1 segment_2 segment_3 + feature target target target target timestamp - 2020-04-10 9.00 9.00 4.00 6.00 - 2020-04-11 5.00 2.00 7.00 9.00 - 2020-04-12 0.00 4.00 7.00 9.00 - 2020-04-13 0.00 5.00 9.00 7.00 - 2020-04-14 1.00 2.00 1.00 6.00 - 2020-04-15 5.00 7.00 4.00 7.00 - 2020-04-16 8.00 6.00 2.00 0.00 + 2020-04-10 9.00 9.00 4.00 6.00 + 2020-04-11 5.00 2.00 7.00 9.00 + 2020-04-12 0.00 4.00 7.00 9.00 + 2020-04-13 0.00 5.00 9.00 7.00 + 2020-04-14 1.00 2.00 1.00 6.00 + 2020-04-15 5.00 7.00 4.00 7.00 + 2020-04-16 8.00 6.00 2.00 0.00 """ def __init__(