Skip to content

Update plot_forecast for multi-forecast mode #584

Merged
merged 10 commits into from
Mar 14, 2022
Merged

Update plot_forecast for multi-forecast mode #584

merged 10 commits into from
Mar 14, 2022

Conversation

Mr-Geekman
Copy link
Contributor

@Mr-Geekman Mr-Geekman commented Mar 5, 2022

IMPORTANT: Please do not create a Pull Request without creating an issue first.

Before submitting (must do checklist)

  • Did you read the contribution guide?
  • Did you update the docs? We use Numpy format for all the methods and classes.
  • Did you write any new necessary tests?
  • Did you update the CHANGELOG?

Type of Change

  • Examples / docs / tutorials / contributors update
  • Bug fix (non-breaking change which fixes an issue)
  • Improvement (non-breaking change which improves an existing feature)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Proposed Changes

Look #533.

Related Issue

#533.

Closing issues

Closes #533.

@Mr-Geekman Mr-Geekman added the enhancement New feature or request label Mar 5, 2022
@Mr-Geekman Mr-Geekman self-assigned this Mar 5, 2022
@Mr-Geekman
Copy link
Contributor Author

Script for generation plots:

import matplotlib.pyplot as plt
import pandas as pd

from etna.analysis import plot_forecast
from etna.datasets import TSDataset
from etna.models import ProphetModel
from etna.models import SARIMAXModel
from etna.pipeline import Pipeline

HORIZON = 7


def main():
    df = pd.read_csv("examples/data/example_dataset.csv", parse_dates=["timestamp"])
    ts = TSDataset(df=TSDataset.to_dataset(df), freq="D")
    train_ts, test_ts = ts.train_test_split(test_size=HORIZON)

    # sarimax
    sarimax_pipeline = Pipeline(model=SARIMAXModel(), quantiles=[0.025, 0.1, 0.4, 0.9, 0.975], horizon=HORIZON)
    sarimax_pipeline.fit(ts=train_ts)
    forecast_sarimax = sarimax_pipeline.forecast(prediction_interval=True)

    # prophet
    prophet_pipeline = Pipeline(model=ProphetModel(), quantiles=[0.025, 0.1, 0.6, 0.9, 0.975], horizon=HORIZON)
    prophet_pipeline.fit(ts=train_ts)
    forecast_prophet = prophet_pipeline.forecast(prediction_interval=True)

    # plot
    plot_forecast(
        forecast_ts=forecast_sarimax, train_ts=train_ts, test_ts=test_ts, n_train_samples=30, prediction_intervals=True
    )
    plt.savefig("plot_sarimax.png")

    plot_forecast(
        forecast_ts=forecast_prophet, train_ts=train_ts, test_ts=test_ts, n_train_samples=30, prediction_intervals=True
    )
    plt.savefig("plot_prophet.png")

    plot_forecast(
        forecast_ts={"sarimax": forecast_sarimax, "prophet": forecast_prophet},
        train_ts=train_ts,
        test_ts=test_ts,
        n_train_samples=30,
        prediction_intervals=True,
    )
    plt.savefig("plot_both.png")


if __name__ == "__main__":
    main()

Plot SARIMAX:
plot_sarimax

Plot Prophet:
plot_prophet

Plot both:
plot_both

@codecov-commenter
Copy link

codecov-commenter commented Mar 5, 2022

Codecov Report

Merging #584 (89a3182) into master (4e50849) will decrease coverage by 32.19%.
The diff coverage is 7.01%.

@@             Coverage Diff             @@
##           master     #584       +/-   ##
===========================================
- Coverage   85.72%   53.53%   -32.20%     
===========================================
  Files         117      117               
  Lines        5759     5789       +30     
===========================================
- Hits         4937     3099     -1838     
- Misses        822     2690     +1868     
Impacted Files Coverage Δ
etna/analysis/plotters.py 12.12% <7.01%> (-5.00%) ⬇️
etna/commands/__init__.py 0.00% <0.00%> (-100.00%) ⬇️
etna/commands/backtest_command.py 0.00% <0.00%> (-96.43%) ⬇️
etna/commands/forecast_command.py 0.00% <0.00%> (-92.00%) ⬇️
etna/commands/__main__.py 0.00% <0.00%> (-87.50%) ⬇️
etna/commands/resolvers.py 0.00% <0.00%> (-80.00%) ⬇️
etna/analysis/outliers/density_outliers.py 22.44% <0.00%> (-75.52%) ⬇️
etna/datasets/datasets_generation.py 26.47% <0.00%> (-73.53%) ⬇️
etna/transforms/timestamp/time_flags.py 27.02% <0.00%> (-72.98%) ⬇️
etna/transforms/timestamp/fourier.py 28.57% <0.00%> (-71.43%) ⬇️
... and 66 more

📣 Codecov can now indicate which changes are the most critical in Pull Requests. Learn more

@julia-shenshina
Copy link
Contributor

could you pls add the result of

plot_forecast(
        forecast_ts=[forecast_sarimax, forecast_prophet],
        train_ts=train_ts,
        test_ts=test_ts,
        n_train_samples=30,
        prediction_intervals=True,
    )

?

if prediction_intervals and quantiles is not None:
alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1]
for quantile in range(len(quantiles) // 2):
values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target_prefix = 'target_'?

# draw prediction intervals from outer layers to inner ones
if prediction_intervals and quantiles is not None:
alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1]
for quantile in range(len(quantiles) // 2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like it is not quantile but quantile_idx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of code wasn't written by me.
Ok, I can fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

ax[i].fill_between(

# plot forecast plot for each of given forecasts
colors = plt.cm.Set2.colors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure it's a good idea to set color scheme in the method itself

we were going to set common scheme for all the plots, so this method will reset all the settings 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line don't override anything, here I just take one particular scheme and use colors in it as a list. I also can take the default color scheme. I think that it can be reset from the outside if we will choose one particular version.

Should I do it?


def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]) -> List[float]:
"""Select quantiles from the forecast results."""
# restrict to use only two quantiles, in other case the plot will be messy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need it?

as a user I would wait plot_... to plot all the data from ts without any filters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the last plots here: #538 (not the first ones)
Do you think it will be good to plot something like this for all the forecasts?

Should we in this case select quantiles, that are common for all the forecasts or some other strategy should be applied?

@Mr-Geekman
Copy link
Contributor Author

Script for plots generation:

import matplotlib.pyplot as plt
import pandas as pd

from etna.analysis import plot_forecast
from etna.datasets import TSDataset
from etna.models import ProphetModel
from etna.models import SARIMAXModel
from etna.pipeline import Pipeline

HORIZON = 7


def main():
    df = pd.read_csv("examples/data/example_dataset.csv", parse_dates=["timestamp"])
    ts = TSDataset(df=TSDataset.to_dataset(df), freq="D")
    train_ts, test_ts = ts.train_test_split(test_size=HORIZON)

    # sarimax
    sarimax_pipeline = Pipeline(model=SARIMAXModel(), horizon=HORIZON)
    sarimax_pipeline.fit(ts=train_ts)
    forecast_sarimax = sarimax_pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.1, 0.6, 0.9, 0.975])
    forecast_sarimax_extra = sarimax_pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.8, 0.975])

    # prophet
    prophet_pipeline = Pipeline(model=ProphetModel(), horizon=HORIZON)
    prophet_pipeline.fit(ts=train_ts)
    forecast_prophet = prophet_pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.1, 0.6, 0.9, 0.975])
    forecast_prophet_extra = prophet_pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.2, 0.975])

    plot_forecast(
        forecast_ts=forecast_sarimax, train_ts=train_ts, test_ts=test_ts, n_train_samples=30, prediction_intervals=True
    )
    plt.savefig("plot_sarimax.png")

    plot_forecast(
        forecast_ts=forecast_prophet, train_ts=train_ts, test_ts=test_ts, n_train_samples=30, prediction_intervals=True
    )
    plt.savefig("plot_prophet.png")

    plot_forecast(
        forecast_ts={"sarimax": forecast_sarimax, "prophet": forecast_prophet},
        train_ts=train_ts,
        test_ts=test_ts,
        n_train_samples=30,
        prediction_intervals=True,
    )
    plt.savefig("plot_both_dict.png")

    plot_forecast(
        forecast_ts=[forecast_sarimax, forecast_prophet],
        train_ts=train_ts,
        test_ts=test_ts,
        n_train_samples=30,
        prediction_intervals=True,
    )
    plt.savefig("plot_both_list.png")

    plot_forecast(
        forecast_ts={"sarimax": forecast_sarimax_extra, "prophet": forecast_prophet_extra},
        train_ts=train_ts,
        test_ts=test_ts,
        n_train_samples=30,
        prediction_intervals=True,
    )
    plt.savefig("plot_both_dict_intersection.png")

    plot_forecast(
        forecast_ts={"sarimax": forecast_sarimax_extra, "prophet": forecast_prophet_extra},
        quantiles=[0.025, 0.2, 0.8, 0.975],
        train_ts=train_ts,
        test_ts=test_ts,
        n_train_samples=30,
        prediction_intervals=True,
    )
    plt.savefig("plot_both_dict_extra.png")


if __name__ == "__main__":
    main()

plot_sarimax:
plot_sarimax

plot_prophet:
plot_prophet

plot_both_dict:
plot_both_dict

plot_both_list:
plot_both_list

plot_both_dict_intersection:
plot_both_dict_intersection

plot_both_dict_extra:
plot_both_dict_extra

@Mr-Geekman
Copy link
Contributor Author

New plots after moving legend (script the same).
plot_sarimax:
plot_sarimax

plot_propthet:
plot_prophet

plot_both_dict:
plot_both_dict

plot_both_list:
plot_both_list

plot_both_dict_intersection:
plot_both_dict_intersection

plot_both_dict_extra:
plot_both_dict_extra

@Mr-Geekman Mr-Geekman merged commit a321f5a into master Mar 14, 2022
@Mr-Geekman Mr-Geekman deleted the issue-533 branch March 14, 2022 15:03
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multimodel Forecast Plot
3 participants