diff --git a/CHANGELOG.md b/CHANGELOG.md index d82ce1423..621970f67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,7 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - Add table option to ConsoleLogger ([#544](https://github.com/tinkoff-ai/etna/pull/544)) - Installation instruction ([#526](https://github.com/tinkoff-ai/etna/pull/526)) -- +- Update plot_forecast for multi-forecast mode ([#584](https://github.com/tinkoff-ai/etna/pull/584)) - Trainer kwargs for deep models ([#540](https://github.com/tinkoff-ai/etna/pull/540)) - Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536)) - diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index daf5ac2c8..054ed508d 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -8,6 +8,7 @@ from typing import List from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Union @@ -45,8 +46,52 @@ def prepare_axes(segments: List[str], columns_num: int, figsize: Tuple[int, int] return ax +def _get_existing_quantiles(ts: "TSDataset") -> Set[float]: + """Get quantiles that are present inside the TSDataset.""" + cols = [col for col in ts.columns.get_level_values("feature").unique().tolist() if col.startswith("target_0.")] + existing_quantiles = {float(col[len("target_") :]) for col in cols} + return existing_quantiles + + +def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]) -> List[float]: + """Select quantiles from the forecast results. + + Selected quantiles exist in each forecast. + """ + intersection_quantiles_set = set.intersection( + *[_get_existing_quantiles(forecast) for forecast in forecast_results.values()] + ) + intersection_quantiles = sorted(list(intersection_quantiles_set)) + + if quantiles is None: + selected_quantiles = intersection_quantiles + else: + selected_quantiles = sorted(list(set(quantiles) & intersection_quantiles_set)) + non_existent = set(quantiles) - intersection_quantiles_set + if non_existent: + warnings.warn(f"Quantiles {non_existent} do not exist in each forecast dataset. They will be dropped.") + + return selected_quantiles + + +def _prepare_forecast_results( + forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]] +) -> Dict[str, "TSDataset"]: + """Prepare dictionary with forecasts results.""" + from etna.datasets import TSDataset + + if isinstance(forecast_ts, TSDataset): + return {"1": forecast_ts} + elif isinstance(forecast_ts, list) and len(forecast_ts) > 0: + return {str(i + 1): forecast for i, forecast in enumerate(forecast_ts)} + elif isinstance(forecast_ts, dict) and len(forecast_ts) > 0: + return forecast_ts + else: + raise ValueError("Unknown type of `forecast_ts`") + + def plot_forecast( - forecast_ts: "TSDataset", + forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]], test_ts: Optional["TSDataset"] = None, train_ts: Optional["TSDataset"] = None, segments: Optional[List[str]] = None, @@ -54,7 +99,7 @@ def plot_forecast( columns_num: int = 2, figsize: Tuple[int, int] = (10, 5), prediction_intervals: bool = False, - quantiles: Optional[Sequence[float]] = None, + quantiles: Optional[List[float]] = None, ): """ Plot of prediction for forecast pipeline. @@ -62,7 +107,10 @@ def plot_forecast( Parameters ---------- forecast_ts: - forecasted TSDataset with timeseries data + there are several options: + 1. Forecasted TSDataset with timeseries data, single-forecast mode + 2. List of forecasted TSDatasets, multi-forecast mode + 3. Dictionary with forecasted TSDatasets, multi-forecast mode test_ts: TSDataset with timeseries data train_ts: @@ -78,33 +126,32 @@ def plot_forecast( prediction_intervals: if True prediction intervals will be drawn quantiles: - list of quantiles to draw + List of quantiles to draw, if isn't set then quantiles from a given dataset will be used. + In multi-forecast mode, only quantiles present in each forecast will be used. + + Raises + ------ + ValueError: + if the format of `forecast_ts` is unknown """ + forecast_results = _prepare_forecast_results(forecast_ts) + num_forecasts = len(forecast_results.keys()) + if not segments: - segments = list(set(forecast_ts.columns.get_level_values("segment"))) + unique_segments = set() + for forecast in forecast_results.values(): + unique_segments.update(forecast.segments) + segments = list(unique_segments) ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize) if prediction_intervals: - cols = [ - col - for col in forecast_ts.columns.get_level_values("feature").unique().tolist() - if col.startswith("target_0.") - ] - existing_quantiles = [float(col[7:]) for col in cols] - if quantiles is None: - quantiles = sorted(existing_quantiles) - else: - non_existent = set(quantiles) - set(existing_quantiles) - if len(non_existent): - warnings.warn(f"Quantiles {non_existent} do not exist in forecast dataset. They will be dropped.") - quantiles = sorted(list(set(quantiles).intersection(set(existing_quantiles)))) + quantiles = _select_quantiles(forecast_results, quantiles) if train_ts is not None: train_ts.df.sort_values(by="timestamp", inplace=True) if test_ts is not None: test_ts.df.sort_values(by="timestamp", inplace=True) - forecast_ts.df.sort_values(by="timestamp", inplace=True) for i, segment in enumerate(segments): if train_ts is not None: @@ -124,54 +171,79 @@ def plot_forecast( else: plot_df = pd.DataFrame(columns=["timestamp", "target", "segment"]) - segment_forecast_df = forecast_ts[:, segment, :][segment] - if (train_ts is not None) and (n_train_samples != 0): ax[i].plot(plot_df.index.values, plot_df.target.values, label="train") if test_ts is not None: ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test") - ax[i].plot(segment_forecast_df.index.values, segment_forecast_df.target.values, color="r", label="forecast") - - if prediction_intervals and quantiles is not None: - alpha = np.linspace(0, 1, len(quantiles) // 2 + 2)[1:-1] - for quantile in range(len(quantiles) // 2): - values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values - values_high = segment_forecast_df["target_" + str(quantiles[-quantile - 1])].values - if quantile == len(quantiles) // 2 - 1: - ax[i].fill_between( - segment_forecast_df.index.values, - values_low, - values_high, - facecolor="g", - alpha=alpha[quantile], - label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval", - ) - else: - values_next = segment_forecast_df["target_" + str(quantiles[quantile + 1])].values - ax[i].fill_between( + + # plot forecast plot for each of given forecasts + quantile_prefix = "target_" + for j, (forecast_name, forecast) in enumerate(forecast_results.items()): + legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else "" + + segment_forecast_df = forecast[:, segment, :][segment].sort_values(by="timestamp") + line = ax[i].plot( + segment_forecast_df.index.values, + segment_forecast_df.target.values, + linewidth=1, + label=f"{legend_prefix}forecast", + ) + forecast_color = line[0].get_color() + + # draw prediction intervals from outer layers to inner ones + if prediction_intervals and quantiles is not None: + alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1] + for quantile_idx in range(len(quantiles) // 2): + # define upper and lower border for this iteration + low_quantile = quantiles[quantile_idx] + high_quantile = quantiles[-quantile_idx - 1] + values_low = segment_forecast_df[f"{quantile_prefix}{low_quantile}"].values + values_high = segment_forecast_df[f"{quantile_prefix}{high_quantile}"].values + # if (low_quantile, high_quantile) is the smallest interval + if quantile_idx == len(quantiles) // 2 - 1: + ax[i].fill_between( + segment_forecast_df.index.values, + values_low, + values_high, + facecolor=forecast_color, + alpha=alpha[quantile_idx], + label=f"{legend_prefix}{low_quantile}-{high_quantile}", + ) + # if there is some interval inside (low_quantile, high_quantile) we should plot around it + else: + low_next_quantile = quantiles[quantile_idx + 1] + high_prev_quantile = quantiles[-quantile_idx - 2] + values_next = segment_forecast_df[f"{quantile_prefix}{low_next_quantile}"].values + ax[i].fill_between( + segment_forecast_df.index.values, + values_low, + values_next, + facecolor=forecast_color, + alpha=alpha[quantile_idx], + label=f"{legend_prefix}{low_quantile}-{high_quantile}", + ) + values_prev = segment_forecast_df[f"{quantile_prefix}{high_prev_quantile}"].values + ax[i].fill_between( + segment_forecast_df.index.values, + values_high, + values_prev, + facecolor=forecast_color, + alpha=alpha[quantile_idx], + ) + # when we can't find pair quantile, we plot it separately + if len(quantiles) % 2 != 0: + remaining_quantile = quantiles[len(quantiles) // 2] + values = segment_forecast_df[f"{quantile_prefix}{remaining_quantile}"].values + ax[i].plot( segment_forecast_df.index.values, - values_low, - values_next, - facecolor="g", - alpha=alpha[quantile], - label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval", + values, + "--", + color=forecast_color, + label=f"{legend_prefix}{remaining_quantile}", ) - values_prev = segment_forecast_df["target_" + str(quantiles[-quantile - 2])].values - ax[i].fill_between( - segment_forecast_df.index.values, values_high, values_prev, facecolor="g", alpha=alpha[quantile] - ) - if len(quantiles) % 2 != 0: - values = segment_forecast_df["target_" + str(quantiles[len(quantiles) // 2])].values - ax[i].plot( - segment_forecast_df.index.values, - values, - "--", - c="orange", - label=f"{quantiles[len(quantiles)//2]} quantile", - ) ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) - ax[i].legend() + ax[i].legend(loc="upper left") def plot_backtest(