Skip to content

Update plot_forecast for multi-forecast mode #584

Merged
merged 10 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Add table option to ConsoleLogger ([#544](https://github.com/tinkoff-ai/etna/pull/544))
- Installation instruction ([#526](https://github.com/tinkoff-ai/etna/pull/526))
-
- Update plot_forecast for multi-forecast mode ([#584](https://github.com/tinkoff-ai/etna/pull/584))
- Trainer kwargs for deep models ([#540](https://github.com/tinkoff-ai/etna/pull/540))
- Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536))
-
Expand Down
190 changes: 131 additions & 59 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -45,24 +46,71 @@ def prepare_axes(segments: List[str], columns_num: int, figsize: Tuple[int, int]
return ax


def _get_existing_quantiles(ts: "TSDataset") -> Set[float]:
"""Get quantiles that are present inside the TSDataset."""
cols = [col for col in ts.columns.get_level_values("feature").unique().tolist() if col.startswith("target_0.")]
existing_quantiles = {float(col[len("target_") :]) for col in cols}
return existing_quantiles


def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]) -> List[float]:
"""Select quantiles from the forecast results.

Selected quantiles exist in each forecast.
"""
intersection_quantiles_set = set.intersection(
*[_get_existing_quantiles(forecast) for forecast in forecast_results.values()]
)
intersection_quantiles = sorted(list(intersection_quantiles_set))

if quantiles is None:
selected_quantiles = intersection_quantiles
else:
selected_quantiles = sorted(list(set(quantiles) & intersection_quantiles_set))
non_existent = set(quantiles) - intersection_quantiles_set
if non_existent:
warnings.warn(f"Quantiles {non_existent} do not exist in each forecast dataset. They will be dropped.")

return selected_quantiles


def _prepare_forecast_results(
forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]]
) -> Dict[str, "TSDataset"]:
"""Prepare dictionary with forecasts results."""
from etna.datasets import TSDataset

if isinstance(forecast_ts, TSDataset):
return {"1": forecast_ts}
elif isinstance(forecast_ts, list) and len(forecast_ts) > 0:
return {str(i + 1): forecast for i, forecast in enumerate(forecast_ts)}
elif isinstance(forecast_ts, dict) and len(forecast_ts) > 0:
return forecast_ts
else:
raise ValueError("Unknown type of `forecast_ts`")


def plot_forecast(
forecast_ts: "TSDataset",
forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]],
test_ts: Optional["TSDataset"] = None,
train_ts: Optional["TSDataset"] = None,
segments: Optional[List[str]] = None,
n_train_samples: Optional[int] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
prediction_intervals: bool = False,
quantiles: Optional[Sequence[float]] = None,
quantiles: Optional[List[float]] = None,
):
"""
Plot of prediction for forecast pipeline.

Parameters
----------
forecast_ts:
forecasted TSDataset with timeseries data
there are several options:
1. Forecasted TSDataset with timeseries data, single-forecast mode
2. List of forecasted TSDatasets, multi-forecast mode
3. Dictionary with forecasted TSDatasets, multi-forecast mode
test_ts:
TSDataset with timeseries data
train_ts:
Expand All @@ -78,33 +126,32 @@ def plot_forecast(
prediction_intervals:
if True prediction intervals will be drawn
quantiles:
list of quantiles to draw
List of quantiles to draw, if isn't set then quantiles from a given dataset will be used.
In multi-forecast mode, only quantiles present in each forecast will be used.

Raises
------
ValueError:
if the format of `forecast_ts` is unknown
"""
forecast_results = _prepare_forecast_results(forecast_ts)
num_forecasts = len(forecast_results.keys())

if not segments:
segments = list(set(forecast_ts.columns.get_level_values("segment")))
unique_segments = set()
for forecast in forecast_results.values():
unique_segments.update(forecast.segments)
segments = list(unique_segments)

ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize)

if prediction_intervals:
cols = [
col
for col in forecast_ts.columns.get_level_values("feature").unique().tolist()
if col.startswith("target_0.")
]
existing_quantiles = [float(col[7:]) for col in cols]
if quantiles is None:
quantiles = sorted(existing_quantiles)
else:
non_existent = set(quantiles) - set(existing_quantiles)
if len(non_existent):
warnings.warn(f"Quantiles {non_existent} do not exist in forecast dataset. They will be dropped.")
quantiles = sorted(list(set(quantiles).intersection(set(existing_quantiles))))
quantiles = _select_quantiles(forecast_results, quantiles)

if train_ts is not None:
train_ts.df.sort_values(by="timestamp", inplace=True)
if test_ts is not None:
test_ts.df.sort_values(by="timestamp", inplace=True)
forecast_ts.df.sort_values(by="timestamp", inplace=True)

for i, segment in enumerate(segments):
if train_ts is not None:
Expand All @@ -124,54 +171,79 @@ def plot_forecast(
else:
plot_df = pd.DataFrame(columns=["timestamp", "target", "segment"])

segment_forecast_df = forecast_ts[:, segment, :][segment]

if (train_ts is not None) and (n_train_samples != 0):
ax[i].plot(plot_df.index.values, plot_df.target.values, label="train")
if test_ts is not None:
ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test")
ax[i].plot(segment_forecast_df.index.values, segment_forecast_df.target.values, color="r", label="forecast")

if prediction_intervals and quantiles is not None:
alpha = np.linspace(0, 1, len(quantiles) // 2 + 2)[1:-1]
for quantile in range(len(quantiles) // 2):
values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values
values_high = segment_forecast_df["target_" + str(quantiles[-quantile - 1])].values
if quantile == len(quantiles) // 2 - 1:
ax[i].fill_between(
segment_forecast_df.index.values,
values_low,
values_high,
facecolor="g",
alpha=alpha[quantile],
label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval",
)
else:
values_next = segment_forecast_df["target_" + str(quantiles[quantile + 1])].values
ax[i].fill_between(

# plot forecast plot for each of given forecasts
quantile_prefix = "target_"
for j, (forecast_name, forecast) in enumerate(forecast_results.items()):
legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else ""

segment_forecast_df = forecast[:, segment, :][segment].sort_values(by="timestamp")
line = ax[i].plot(
segment_forecast_df.index.values,
segment_forecast_df.target.values,
linewidth=1,
label=f"{legend_prefix}forecast",
)
forecast_color = line[0].get_color()

# draw prediction intervals from outer layers to inner ones
if prediction_intervals and quantiles is not None:
alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1]
for quantile_idx in range(len(quantiles) // 2):
# define upper and lower border for this iteration
low_quantile = quantiles[quantile_idx]
high_quantile = quantiles[-quantile_idx - 1]
values_low = segment_forecast_df[f"{quantile_prefix}{low_quantile}"].values
values_high = segment_forecast_df[f"{quantile_prefix}{high_quantile}"].values
# if (low_quantile, high_quantile) is the smallest interval
if quantile_idx == len(quantiles) // 2 - 1:
ax[i].fill_between(
segment_forecast_df.index.values,
values_low,
values_high,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
label=f"{legend_prefix}{low_quantile}-{high_quantile}",
)
# if there is some interval inside (low_quantile, high_quantile) we should plot around it
else:
low_next_quantile = quantiles[quantile_idx + 1]
high_prev_quantile = quantiles[-quantile_idx - 2]
values_next = segment_forecast_df[f"{quantile_prefix}{low_next_quantile}"].values
ax[i].fill_between(
segment_forecast_df.index.values,
values_low,
values_next,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
label=f"{legend_prefix}{low_quantile}-{high_quantile}",
)
values_prev = segment_forecast_df[f"{quantile_prefix}{high_prev_quantile}"].values
ax[i].fill_between(
segment_forecast_df.index.values,
values_high,
values_prev,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
)
# when we can't find pair quantile, we plot it separately
if len(quantiles) % 2 != 0:
remaining_quantile = quantiles[len(quantiles) // 2]
values = segment_forecast_df[f"{quantile_prefix}{remaining_quantile}"].values
ax[i].plot(
segment_forecast_df.index.values,
values_low,
values_next,
facecolor="g",
alpha=alpha[quantile],
label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval",
values,
"--",
color=forecast_color,
label=f"{legend_prefix}{remaining_quantile}",
)
values_prev = segment_forecast_df["target_" + str(quantiles[-quantile - 2])].values
ax[i].fill_between(
segment_forecast_df.index.values, values_high, values_prev, facecolor="g", alpha=alpha[quantile]
)
if len(quantiles) % 2 != 0:
values = segment_forecast_df["target_" + str(quantiles[len(quantiles) // 2])].values
ax[i].plot(
segment_forecast_df.index.values,
values,
"--",
c="orange",
label=f"{quantiles[len(quantiles)//2]} quantile",
)
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
ax[i].legend()
ax[i].legend(loc="upper left")


def plot_backtest(
Expand Down