Skip to content

Update plot_forecast for multi-forecast mode #584

Merged
merged 10 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Add table option to ConsoleLogger ([#544](https://github.com/tinkoff-ai/etna/pull/544))
- Installation instruction ([#526](https://github.com/tinkoff-ai/etna/pull/526))
-
- Update plot_forecast for multi-forecast mode ([#584](https://github.com/tinkoff-ai/etna/pull/584))
- Trainer kwargs for deep models ([#540](https://github.com/tinkoff-ai/etna/pull/540))
- Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536))
-
Expand Down
188 changes: 130 additions & 58 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -45,24 +46,71 @@ def prepare_axes(segments: List[str], columns_num: int, figsize: Tuple[int, int]
return ax


def _get_existing_quantiles(ts: "TSDataset") -> Set[float]:
"""Get quantiles that are present inside the TSDataset."""
cols = [col for col in ts.columns.get_level_values("feature").unique().tolist() if col.startswith("target_0.")]
existing_quantiles = {float(col[len("target_") :]) for col in cols}
return existing_quantiles


def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]) -> List[float]:
"""Select quantiles from the forecast results.

Selected quantiles exist in each forecast.
"""
intersection_quantiles_set = set.intersection(
*[_get_existing_quantiles(forecast) for forecast in forecast_results.values()]
)
intersection_quantiles = sorted(list(intersection_quantiles_set))

if quantiles is None:
selected_quantiles = intersection_quantiles
else:
selected_quantiles = sorted(list(set(quantiles) & intersection_quantiles_set))
non_existent = set(quantiles) - intersection_quantiles_set
if non_existent:
warnings.warn(f"Quantiles {non_existent} do not exist in each forecast dataset. They will be dropped.")

return selected_quantiles


def _prepare_forecast_results(
forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]]
) -> Dict[str, "TSDataset"]:
"""Prepare dictionary with forecasts results."""
from etna.datasets import TSDataset

if isinstance(forecast_ts, TSDataset):
return {"1": forecast_ts}
elif isinstance(forecast_ts, list) and len(forecast_ts) > 0:
return {str(i + 1): forecast for i, forecast in enumerate(forecast_ts)}
elif isinstance(forecast_ts, dict) and len(forecast_ts) > 0:
return forecast_ts
else:
raise ValueError("Unknown type of `forecast_ts`")


def plot_forecast(
forecast_ts: "TSDataset",
forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]],
test_ts: Optional["TSDataset"] = None,
train_ts: Optional["TSDataset"] = None,
segments: Optional[List[str]] = None,
n_train_samples: Optional[int] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
prediction_intervals: bool = False,
quantiles: Optional[Sequence[float]] = None,
quantiles: Optional[List[float]] = None,
):
"""
Plot of prediction for forecast pipeline.

Parameters
----------
forecast_ts:
forecasted TSDataset with timeseries data
there are several options:
1. Forecasted TSDataset with timeseries data, single-forecast mode
2. List of forecasted TSDatasets, multi-forecast mode
3. Dictionary with forecasted TSDatasets, multi-forecast mode
test_ts:
TSDataset with timeseries data
train_ts:
Expand All @@ -78,33 +126,32 @@ def plot_forecast(
prediction_intervals:
if True prediction intervals will be drawn
quantiles:
list of quantiles to draw
List of quantiles to draw, if isn't set then quantiles from a given dataset will be used.
In multi-forecast mode, only quantiles present in each forecast will be used.

Raises
------
ValueError:
if the format of `forecast_ts` is unknown
"""
forecast_results = _prepare_forecast_results(forecast_ts)
num_forecasts = len(forecast_results.keys())

if not segments:
segments = list(set(forecast_ts.columns.get_level_values("segment")))
unique_segments = set()
for forecast in forecast_results.values():
unique_segments.update(forecast.segments)
segments = list(unique_segments)

ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize)

if prediction_intervals:
cols = [
col
for col in forecast_ts.columns.get_level_values("feature").unique().tolist()
if col.startswith("target_0.")
]
existing_quantiles = [float(col[7:]) for col in cols]
if quantiles is None:
quantiles = sorted(existing_quantiles)
else:
non_existent = set(quantiles) - set(existing_quantiles)
if len(non_existent):
warnings.warn(f"Quantiles {non_existent} do not exist in forecast dataset. They will be dropped.")
quantiles = sorted(list(set(quantiles).intersection(set(existing_quantiles))))
quantiles = _select_quantiles(forecast_results, quantiles)

if train_ts is not None:
train_ts.df.sort_values(by="timestamp", inplace=True)
if test_ts is not None:
test_ts.df.sort_values(by="timestamp", inplace=True)
forecast_ts.df.sort_values(by="timestamp", inplace=True)

for i, segment in enumerate(segments):
if train_ts is not None:
Expand All @@ -124,51 +171,76 @@ def plot_forecast(
else:
plot_df = pd.DataFrame(columns=["timestamp", "target", "segment"])

segment_forecast_df = forecast_ts[:, segment, :][segment]

if (train_ts is not None) and (n_train_samples != 0):
ax[i].plot(plot_df.index.values, plot_df.target.values, label="train")
if test_ts is not None:
ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test")
ax[i].plot(segment_forecast_df.index.values, segment_forecast_df.target.values, color="r", label="forecast")

if prediction_intervals and quantiles is not None:
alpha = np.linspace(0, 1, len(quantiles) // 2 + 2)[1:-1]
for quantile in range(len(quantiles) // 2):
values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values
values_high = segment_forecast_df["target_" + str(quantiles[-quantile - 1])].values
if quantile == len(quantiles) // 2 - 1:
ax[i].fill_between(
segment_forecast_df.index.values,
values_low,
values_high,
facecolor="g",
alpha=alpha[quantile],
label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval",
)
else:
values_next = segment_forecast_df["target_" + str(quantiles[quantile + 1])].values
ax[i].fill_between(

# plot forecast plot for each of given forecasts
quantile_prefix = "target_"
for j, (forecast_name, forecast) in enumerate(forecast_results.items()):
legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else ""

segment_forecast_df = forecast[:, segment, :][segment].sort_values(by="timestamp")
line = ax[i].plot(
segment_forecast_df.index.values,
segment_forecast_df.target.values,
linewidth=1,
label=f"{legend_prefix}forecast",
)
forecast_color = line[0].get_color()

# draw prediction intervals from outer layers to inner ones
if prediction_intervals and quantiles is not None:
alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1]
for quantile_idx in range(len(quantiles) // 2):
# define upper and lower border for this iteration
low_quantile = quantiles[quantile_idx]
high_quantile = quantiles[-quantile_idx - 1]
values_low = segment_forecast_df[f"{quantile_prefix}{low_quantile}"].values
values_high = segment_forecast_df[f"{quantile_prefix}{high_quantile}"].values
# if (low_quantile, high_quantile) is the smallest interval
if quantile_idx == len(quantiles) // 2 - 1:
ax[i].fill_between(
segment_forecast_df.index.values,
values_low,
values_high,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
label=f"{legend_prefix}{low_quantile}-{high_quantile}",
)
# if there is some interval inside (low_quantile, high_quantile) we should plot around it
else:
low_next_quantile = quantiles[quantile_idx + 1]
high_prev_quantile = quantiles[-quantile_idx - 2]
values_next = segment_forecast_df[f"{quantile_prefix}{low_next_quantile}"].values
ax[i].fill_between(
segment_forecast_df.index.values,
values_low,
values_next,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
label=f"{legend_prefix}{low_quantile}-{high_quantile}",
)
values_prev = segment_forecast_df[f"{quantile_prefix}{high_prev_quantile}"].values
ax[i].fill_between(
segment_forecast_df.index.values,
values_high,
values_prev,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
)
# when we can't find pair quantile, we plot it separately
if len(quantiles) % 2 != 0:
remaining_quantile = quantiles[len(quantiles) // 2]
values = segment_forecast_df[f"{quantile_prefix}{remaining_quantile}"].values
ax[i].plot(
segment_forecast_df.index.values,
values_low,
values_next,
facecolor="g",
alpha=alpha[quantile],
label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval",
)
values_prev = segment_forecast_df["target_" + str(quantiles[-quantile - 2])].values
ax[i].fill_between(
segment_forecast_df.index.values, values_high, values_prev, facecolor="g", alpha=alpha[quantile]
values,
"--",
color=forecast_color,
label=f"{legend_prefix}{remaining_quantile}",
)
if len(quantiles) % 2 != 0:
values = segment_forecast_df["target_" + str(quantiles[len(quantiles) // 2])].values
ax[i].plot(
segment_forecast_df.index.values,
values,
"--",
c="orange",
label=f"{quantiles[len(quantiles)//2]} quantile",
)
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
ax[i].legend()
Expand Down