-
Notifications
You must be signed in to change notification settings - Fork 80
Update plot_forecast
for multi-forecast mode
#584
Conversation
Script for generation plots: import matplotlib.pyplot as plt
import pandas as pd
from etna.analysis import plot_forecast
from etna.datasets import TSDataset
from etna.models import ProphetModel
from etna.models import SARIMAXModel
from etna.pipeline import Pipeline
HORIZON = 7
def main():
df = pd.read_csv("examples/data/example_dataset.csv", parse_dates=["timestamp"])
ts = TSDataset(df=TSDataset.to_dataset(df), freq="D")
train_ts, test_ts = ts.train_test_split(test_size=HORIZON)
# sarimax
sarimax_pipeline = Pipeline(model=SARIMAXModel(), quantiles=[0.025, 0.1, 0.4, 0.9, 0.975], horizon=HORIZON)
sarimax_pipeline.fit(ts=train_ts)
forecast_sarimax = sarimax_pipeline.forecast(prediction_interval=True)
# prophet
prophet_pipeline = Pipeline(model=ProphetModel(), quantiles=[0.025, 0.1, 0.6, 0.9, 0.975], horizon=HORIZON)
prophet_pipeline.fit(ts=train_ts)
forecast_prophet = prophet_pipeline.forecast(prediction_interval=True)
# plot
plot_forecast(
forecast_ts=forecast_sarimax, train_ts=train_ts, test_ts=test_ts, n_train_samples=30, prediction_intervals=True
)
plt.savefig("plot_sarimax.png")
plot_forecast(
forecast_ts=forecast_prophet, train_ts=train_ts, test_ts=test_ts, n_train_samples=30, prediction_intervals=True
)
plt.savefig("plot_prophet.png")
plot_forecast(
forecast_ts={"sarimax": forecast_sarimax, "prophet": forecast_prophet},
train_ts=train_ts,
test_ts=test_ts,
n_train_samples=30,
prediction_intervals=True,
)
plt.savefig("plot_both.png")
if __name__ == "__main__":
main() |
Codecov Report
@@ Coverage Diff @@
## master #584 +/- ##
===========================================
- Coverage 85.72% 53.53% -32.20%
===========================================
Files 117 117
Lines 5759 5789 +30
===========================================
- Hits 4937 3099 -1838
- Misses 822 2690 +1868
📣 Codecov can now indicate which changes are the most critical in Pull Requests. Learn more |
could you pls add the result of
? |
etna/analysis/plotters.py
Outdated
if prediction_intervals and quantiles is not None: | ||
alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1] | ||
for quantile in range(len(quantiles) // 2): | ||
values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target_prefix = 'target_'
?
etna/analysis/plotters.py
Outdated
# draw prediction intervals from outer layers to inner ones | ||
if prediction_intervals and quantiles is not None: | ||
alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1] | ||
for quantile in range(len(quantiles) // 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like it is not quantile
but quantile_idx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part of code wasn't written by me.
Ok, I can fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
etna/analysis/plotters.py
Outdated
ax[i].fill_between( | ||
|
||
# plot forecast plot for each of given forecasts | ||
colors = plt.cm.Set2.colors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure it's a good idea to set color scheme in the method itself
we were going to set common scheme for all the plots, so this method will reset all the settings 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line don't override anything, here I just take one particular scheme and use colors in it as a list. I also can take the default color scheme. I think that it can be reset from the outside if we will choose one particular version.
Should I do it?
etna/analysis/plotters.py
Outdated
|
||
def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]) -> List[float]: | ||
"""Select quantiles from the forecast results.""" | ||
# restrict to use only two quantiles, in other case the plot will be messy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need it?
as a user I would wait plot_...
to plot all the data from ts without any filters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look at the last plots here: #538 (not the first ones)
Do you think it will be good to plot something like this for all the forecasts?
Should we in this case select quantiles, that are common for all the forecasts or some other strategy should be applied?
Script for plots generation: import matplotlib.pyplot as plt
import pandas as pd
from etna.analysis import plot_forecast
from etna.datasets import TSDataset
from etna.models import ProphetModel
from etna.models import SARIMAXModel
from etna.pipeline import Pipeline
HORIZON = 7
def main():
df = pd.read_csv("examples/data/example_dataset.csv", parse_dates=["timestamp"])
ts = TSDataset(df=TSDataset.to_dataset(df), freq="D")
train_ts, test_ts = ts.train_test_split(test_size=HORIZON)
# sarimax
sarimax_pipeline = Pipeline(model=SARIMAXModel(), horizon=HORIZON)
sarimax_pipeline.fit(ts=train_ts)
forecast_sarimax = sarimax_pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.1, 0.6, 0.9, 0.975])
forecast_sarimax_extra = sarimax_pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.8, 0.975])
# prophet
prophet_pipeline = Pipeline(model=ProphetModel(), horizon=HORIZON)
prophet_pipeline.fit(ts=train_ts)
forecast_prophet = prophet_pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.1, 0.6, 0.9, 0.975])
forecast_prophet_extra = prophet_pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.2, 0.975])
plot_forecast(
forecast_ts=forecast_sarimax, train_ts=train_ts, test_ts=test_ts, n_train_samples=30, prediction_intervals=True
)
plt.savefig("plot_sarimax.png")
plot_forecast(
forecast_ts=forecast_prophet, train_ts=train_ts, test_ts=test_ts, n_train_samples=30, prediction_intervals=True
)
plt.savefig("plot_prophet.png")
plot_forecast(
forecast_ts={"sarimax": forecast_sarimax, "prophet": forecast_prophet},
train_ts=train_ts,
test_ts=test_ts,
n_train_samples=30,
prediction_intervals=True,
)
plt.savefig("plot_both_dict.png")
plot_forecast(
forecast_ts=[forecast_sarimax, forecast_prophet],
train_ts=train_ts,
test_ts=test_ts,
n_train_samples=30,
prediction_intervals=True,
)
plt.savefig("plot_both_list.png")
plot_forecast(
forecast_ts={"sarimax": forecast_sarimax_extra, "prophet": forecast_prophet_extra},
train_ts=train_ts,
test_ts=test_ts,
n_train_samples=30,
prediction_intervals=True,
)
plt.savefig("plot_both_dict_intersection.png")
plot_forecast(
forecast_ts={"sarimax": forecast_sarimax_extra, "prophet": forecast_prophet_extra},
quantiles=[0.025, 0.2, 0.8, 0.975],
train_ts=train_ts,
test_ts=test_ts,
n_train_samples=30,
prediction_intervals=True,
)
plt.savefig("plot_both_dict_extra.png")
if __name__ == "__main__":
main() |
IMPORTANT: Please do not create a Pull Request without creating an issue first.
Before submitting (must do checklist)
Type of Change
Proposed Changes
Look #533.
Related Issue
#533.
Closing issues
Closes #533.