Skip to content

Add example for ProphetModel #178

Merged
merged 4 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Pipeline.backtest method ([#161](https://github.com/tinkoff-ai/etna-ts/pull/161))
- STLTransform class ([#158](https://github.com/tinkoff-ai/etna-ts/pull/158))
- NN_examples notebook ([#159](https://github.com/tinkoff-ai/etna-ts/pull/159))
- Example for ProphetModel ([#178](https://github.com/tinkoff-ai/etna-ts/pull/178))
- Instruction notebook for custom model and transform creation ([#180](https://github.com/tinkoff-ai/etna-ts/pull/180))

### Changed
Expand Down
43 changes: 41 additions & 2 deletions etna/models/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,44 @@ def predict(self, df: pd.DataFrame, confidence_interval: bool = False):


class ProphetModel(PerSegmentModel):
"""Class for holding Prophet model."""
"""Class for holding Prophet model.

Examples
--------
>>> from etna.datasets import generate_periodic_df
>>> from etna.datasets import TSDataset
>>> from etna.models import ProphetModel
>>> classic_df = generate_periodic_df(
... periods=100,
... start_time="2020-01-01",
... n_segments=4,
... period=7,
... sigma=3
... )
>>> df = TSDataset.to_dataset(df=classic_df)
>>> ts = TSDataset(df, freq="D")
>>> future = ts.make_future(7)
>>> model = ProphetModel(growth="flat")
>>> model.fit(ts=ts)
ProphetModel(growth = 'flat', changepoints = None, n_changepoints = 25,
changepoint_range = 0.8, yearly_seasonality = 'auto', weekly_seasonality = 'auto',
daily_seasonality = 'auto', holidays = None, seasonality_mode = 'additive',
seasonality_prior_scale = 10.0, holidays_prior_scale = 10.0, mcmc_samples = 0,
interval_width = 0.8, uncertainty_samples = 1000, stan_backend = None,
additional_seasonality_params = (), )
>>> forecast = model.forecast(future)
>>> forecast
segment segment_0 segment_1 segment_2 segment_3
feature target target target target
timestamp
2020-04-10 9.00 9.00 4.00 6.00
2020-04-11 5.00 2.00 7.00 9.00
2020-04-12 0.00 4.00 7.00 9.00
2020-04-13 0.00 5.00 9.00 7.00
2020-04-14 1.00 2.00 1.00 6.00
2020-04-15 5.00 7.00 4.00 7.00
2020-04-16 8.00 6.00 2.00 0.00
"""

def __init__(
self,
Expand All @@ -152,6 +189,7 @@ def __init__(
):
"""
Create instance of Prophet model.

Parameters
----------
growth:
Expand Down Expand Up @@ -282,6 +320,7 @@ def _forecast_segment(
@log_decorator
def forecast(self, ts: TSDataset, confidence_interval: bool = False) -> TSDataset:
"""Make predictions.

Parameters
----------
ts:
Expand All @@ -294,7 +333,7 @@ def forecast(self, ts: TSDataset, confidence_interval: bool = False) -> TSDatase
Models result
Notes
-----
The width of the confidence interval is specified in the constructor of ProphetModel setting the interval_width
The width of the confidence interval is specified in the constructor of ProphetModel setting the interval_width.
"""
if self._segments is None:
raise ValueError("The model is not fitted yet, use fit() to train it")
Expand Down