From 45dc760f70d2e81ed25a1d49901eb0b5ed44be13 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Sat, 5 Mar 2022 11:39:27 +0300 Subject: [PATCH 1/8] Update plot_forecast for multi-forecast mode --- etna/analysis/plotters.py | 207 +++++++++++++++++++++++++++----------- 1 file changed, 149 insertions(+), 58 deletions(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 9b067fdce..9b6015ed8 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -7,6 +7,7 @@ from typing import List from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Union @@ -37,8 +38,74 @@ def prepare_axes(segments: List[str], columns_num: int, figsize: Tuple[int, int] return ax +def _get_existing_quantiles(ts: "TSDataset") -> Set[float]: + """Get quantiles that are present inside the TSDataset.""" + cols = [col for col in ts.columns.get_level_values("feature").unique().tolist() if col.startswith("target_0.")] + existing_quantiles = {float(col[len("target_") :]) for col in cols} + return existing_quantiles + + +def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]) -> List[float]: + """Select quantiles from the forecast results.""" + # restrict to use only two quantiles, in other case the plot will be messy + if len(forecast_results) > 1: + if quantiles is None: + unique_quantiles = set() + for forecast in forecast_results.values(): + unique_quantiles.update(_get_existing_quantiles(forecast)) + sorted_quantiles = sorted(list(unique_quantiles)) + + else: + sorted_quantiles = sorted(list(set(quantiles))) + + if len(sorted_quantiles) == 1: + selected_quantiles = [] + warnings.warn("In multi-forecast mode the minimum number of quantiles is two.") + elif len(sorted_quantiles) > 2: + selected_quantiles = [sorted_quantiles[0], sorted_quantiles[-1]] + warnings.warn(f"In multi-forecast mode only two quantiles will be used: {quantiles}") + else: + selected_quantiles = sorted_quantiles + + # check if selected quantiles present in every forecast + for forecast_name, forecast in forecast_results.items(): + existing_quantiles = _get_existing_quantiles(forecast) + if not set(selected_quantiles).issubset(existing_quantiles): + raise ValueError(f"Chosen quantiles {quantiles} isn't present in forecast {forecast_name}") + + return selected_quantiles + + # in this case we have only one forecast + else: + forecast = list(forecast_results.values())[0] + existing_quantiles = _get_existing_quantiles(forecast) + if quantiles is None: + return sorted(list(existing_quantiles)) + else: + non_existent = set(quantiles) - existing_quantiles + if len(non_existent): + warnings.warn(f"Quantiles {non_existent} do not exist in forecast dataset. They will be dropped.") + return sorted(list(set(quantiles).intersection(existing_quantiles))) + + +def _prepare_forecast_results( + forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]] +) -> Dict[str, "TSDataset"]: + """Prepare dictionary with forecasts results.""" + from etna.datasets import TSDataset + + if isinstance(forecast_ts, TSDataset): + return {"1": forecast_ts} + elif isinstance(forecast_ts, list) and len(forecast_ts) > 0: + return {str(i + 1): forecast for i, forecast in enumerate(forecast_ts)} + elif isinstance(forecast_ts, dict) and len(forecast_ts) > 0: + return forecast_ts + else: + raise ValueError("Unknown type of `forecast_ts`") + + def plot_forecast( - forecast_ts: "TSDataset", + forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]], test_ts: Optional["TSDataset"] = None, train_ts: Optional["TSDataset"] = None, segments: Optional[List[str]] = None, @@ -46,7 +113,7 @@ def plot_forecast( columns_num: int = 2, figsize: Tuple[int, int] = (10, 5), prediction_intervals: bool = False, - quantiles: Optional[Sequence[float]] = None, + quantiles: Optional[List[float]] = None, ): """ Plot of prediction for forecast pipeline. @@ -54,7 +121,10 @@ def plot_forecast( Parameters ---------- forecast_ts: - forecasted TSDataset with timeseries data + there are several options: + 1. Forecasted TSDataset with timeseries data, single-forecast mode + 2. List of forecasted TSDatasets, multi-forecast mode + 3. Dictionary with forecasted TSDatasets, multi-forecast mode test_ts: TSDataset with timeseries data train_ts: @@ -70,33 +140,36 @@ def plot_forecast( prediction_intervals: if True prediction intervals will be drawn quantiles: - list of quantiles to draw + List of quantiles to draw, if isn't set then quantiles from a given dataset will be used. + * In a single-forecast mode intervals will be drawn in layers. + The remaining quantile (if present) will be drawn as a line. + * In a multi-forecast mode only the highest and the lowest pair of quantiles will be used. + + Raises + ------ + ValueError: + if the format of `forecast_ts` is unknown + ValueError: + if in multi-forecast mode given quantiles isn't present in each dataset """ + forecast_results = _prepare_forecast_results(forecast_ts) + num_forecasts = len(forecast_results.keys()) + if not segments: - segments = list(set(forecast_ts.columns.get_level_values("segment"))) + unique_segments = set() + for forecast in forecast_results.values(): + unique_segments.update(forecast.segments) + segments = list(unique_segments) ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize) if prediction_intervals: - cols = [ - col - for col in forecast_ts.columns.get_level_values("feature").unique().tolist() - if col.startswith("target_0.") - ] - existing_quantiles = [float(col[7:]) for col in cols] - if quantiles is None: - quantiles = sorted(existing_quantiles) - else: - non_existent = set(quantiles) - set(existing_quantiles) - if len(non_existent): - warnings.warn(f"Quantiles {non_existent} do not exist in forecast dataset. They will be dropped.") - quantiles = sorted(list(set(quantiles).intersection(set(existing_quantiles)))) + quantiles = _select_quantiles(forecast_results, quantiles) if train_ts is not None: train_ts.df.sort_values(by="timestamp", inplace=True) if test_ts is not None: test_ts.df.sort_values(by="timestamp", inplace=True) - forecast_ts.df.sort_values(by="timestamp", inplace=True) for i, segment in enumerate(segments): if train_ts is not None: @@ -116,51 +189,69 @@ def plot_forecast( else: plot_df = pd.DataFrame(columns=["timestamp", "target", "segment"]) - segment_forecast_df = forecast_ts[:, segment, :][segment] - if (train_ts is not None) and (n_train_samples != 0): ax[i].plot(plot_df.index.values, plot_df.target.values, label="train") if test_ts is not None: ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test") - ax[i].plot(segment_forecast_df.index.values, segment_forecast_df.target.values, color="r", label="forecast") - - if prediction_intervals and quantiles is not None: - alpha = np.linspace(0, 1, len(quantiles) // 2 + 2)[1:-1] - for quantile in range(len(quantiles) // 2): - values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values - values_high = segment_forecast_df["target_" + str(quantiles[-quantile - 1])].values - if quantile == len(quantiles) // 2 - 1: - ax[i].fill_between( - segment_forecast_df.index.values, - values_low, - values_high, - facecolor="g", - alpha=alpha[quantile], - label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval", - ) - else: - values_next = segment_forecast_df["target_" + str(quantiles[quantile + 1])].values - ax[i].fill_between( + + # plot forecast plot for each of given forecasts + colors = plt.cm.Set2.colors + for j, (forecast_name, forecast) in enumerate(forecast_results.items()): + legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else "" + forecast_color = colors[j % len(colors)] + + segment_forecast_df = forecast[:, segment, :][segment].sort_values(by="timestamp") + ax[i].plot( + segment_forecast_df.index.values, + segment_forecast_df.target.values, + color=forecast_color, + linewidth=1, + label=f"{legend_prefix}forecast", + ) + + # draw prediction intervals from outer layers to inner ones + if prediction_intervals and quantiles is not None: + alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1] + for quantile in range(len(quantiles) // 2): + values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values + values_high = segment_forecast_df["target_" + str(quantiles[-quantile - 1])].values + if quantile == len(quantiles) // 2 - 1: + ax[i].fill_between( + segment_forecast_df.index.values, + values_low, + values_high, + facecolor=forecast_color, + alpha=alpha[quantile], + label=f"{legend_prefix}{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval", + ) + else: + values_next = segment_forecast_df["target_" + str(quantiles[quantile + 1])].values + ax[i].fill_between( + segment_forecast_df.index.values, + values_low, + values_next, + facecolor=forecast_color, + alpha=alpha[quantile], + label=f"{legend_prefix}{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval", + ) + values_prev = segment_forecast_df["target_" + str(quantiles[-quantile - 2])].values + ax[i].fill_between( + segment_forecast_df.index.values, + values_high, + values_prev, + facecolor=forecast_color, + alpha=alpha[quantile], + ) + # this isn't a possible option for `num_forecasts > 1` + if len(quantiles) % 2 != 0: + values = segment_forecast_df["target_" + str(quantiles[len(quantiles) // 2])].values + ax[i].plot( segment_forecast_df.index.values, - values_low, - values_next, - facecolor="g", - alpha=alpha[quantile], - label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval", + values, + "--", + c="orange", + label=f"{legend_prefix}{quantiles[len(quantiles)//2]} quantile", ) - values_prev = segment_forecast_df["target_" + str(quantiles[-quantile - 2])].values - ax[i].fill_between( - segment_forecast_df.index.values, values_high, values_prev, facecolor="g", alpha=alpha[quantile] - ) - if len(quantiles) % 2 != 0: - values = segment_forecast_df["target_" + str(quantiles[len(quantiles) // 2])].values - ax[i].plot( - segment_forecast_df.index.values, - values, - "--", - c="orange", - label=f"{quantiles[len(quantiles)//2]} quantile", - ) ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) ax[i].legend() From 3329be9c05624c0128bff3da4574b3a959588e9d Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Sat, 5 Mar 2022 11:41:41 +0300 Subject: [PATCH 2/8] Update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f26e6f93..44f327679 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,7 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - Add table option to ConsoleLogger ([#544](https://github.com/tinkoff-ai/etna/pull/544)) - Installation instruction ([#526](https://github.com/tinkoff-ai/etna/pull/526)) -- +- Update plot_forecast for multi-forecast mode ([#584](https://github.com/tinkoff-ai/etna/pull/584)) - Trainer kwargs for deep models ([#540](https://github.com/tinkoff-ai/etna/pull/540)) - Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536)) - From fe147af42ee57bff222516a74f6f42d42e960415 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Thu, 10 Mar 2022 13:40:20 +0300 Subject: [PATCH 3/8] Fix issues from PR --- etna/analysis/plotters.py | 103 ++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 61 deletions(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index ae6c64078..1ec20b2a2 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -54,46 +54,24 @@ def _get_existing_quantiles(ts: "TSDataset") -> Set[float]: def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]) -> List[float]: - """Select quantiles from the forecast results.""" - # restrict to use only two quantiles, in other case the plot will be messy - if len(forecast_results) > 1: - if quantiles is None: - unique_quantiles = set() - for forecast in forecast_results.values(): - unique_quantiles.update(_get_existing_quantiles(forecast)) - sorted_quantiles = sorted(list(unique_quantiles)) + """Select quantiles from the forecast results. - else: - sorted_quantiles = sorted(list(set(quantiles))) - - if len(sorted_quantiles) == 1: - selected_quantiles = [] - warnings.warn("In multi-forecast mode the minimum number of quantiles is two.") - elif len(sorted_quantiles) > 2: - selected_quantiles = [sorted_quantiles[0], sorted_quantiles[-1]] - warnings.warn(f"In multi-forecast mode only two quantiles will be used: {quantiles}") - else: - selected_quantiles = sorted_quantiles - - # check if selected quantiles present in every forecast - for forecast_name, forecast in forecast_results.items(): - existing_quantiles = _get_existing_quantiles(forecast) - if not set(selected_quantiles).issubset(existing_quantiles): - raise ValueError(f"Chosen quantiles {quantiles} isn't present in forecast {forecast_name}") - - return selected_quantiles + Selected quantiles exist in each forecast. + """ + intersection_quantiles_set = set.intersection( + *[_get_existing_quantiles(forecast) for forecast in forecast_results.values()] + ) + intersection_quantiles = sorted(list(intersection_quantiles_set)) - # in this case we have only one forecast + if quantiles is None: + selected_quantiles = intersection_quantiles else: - forecast = list(forecast_results.values())[0] - existing_quantiles = _get_existing_quantiles(forecast) - if quantiles is None: - return sorted(list(existing_quantiles)) - else: - non_existent = set(quantiles) - existing_quantiles - if len(non_existent): - warnings.warn(f"Quantiles {non_existent} do not exist in forecast dataset. They will be dropped.") - return sorted(list(set(quantiles).intersection(existing_quantiles))) + selected_quantiles = sorted(list(set(quantiles) & intersection_quantiles_set)) + non_existent = set(quantiles) - intersection_quantiles_set + if non_existent: + warnings.warn(f"Quantiles {non_existent} do not exist in each forecast dataset. They will be dropped.") + + return selected_quantiles def _prepare_forecast_results( @@ -149,16 +127,12 @@ def plot_forecast( if True prediction intervals will be drawn quantiles: List of quantiles to draw, if isn't set then quantiles from a given dataset will be used. - * In a single-forecast mode intervals will be drawn in layers. - The remaining quantile (if present) will be drawn as a line. - * In a multi-forecast mode only the highest and the lowest pair of quantiles will be used. + In multi-forecast mode, only quantiles present in each forecast will be used. Raises ------ ValueError: if the format of `forecast_ts` is unknown - ValueError: - if in multi-forecast mode given quantiles isn't present in each dataset """ forecast_results = _prepare_forecast_results(forecast_ts) num_forecasts = len(forecast_results.keys()) @@ -203,62 +177,69 @@ def plot_forecast( ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test") # plot forecast plot for each of given forecasts - colors = plt.cm.Set2.colors + quantile_prefix = "target_" for j, (forecast_name, forecast) in enumerate(forecast_results.items()): legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else "" - forecast_color = colors[j % len(colors)] segment_forecast_df = forecast[:, segment, :][segment].sort_values(by="timestamp") - ax[i].plot( + line = ax[i].plot( segment_forecast_df.index.values, segment_forecast_df.target.values, - color=forecast_color, linewidth=1, label=f"{legend_prefix}forecast", ) + forecast_color = line[0].get_color() # draw prediction intervals from outer layers to inner ones if prediction_intervals and quantiles is not None: alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1] - for quantile in range(len(quantiles) // 2): - values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values - values_high = segment_forecast_df["target_" + str(quantiles[-quantile - 1])].values - if quantile == len(quantiles) // 2 - 1: + for quantile_idx in range(len(quantiles) // 2): + # define upper and lower border for this iteration + low_quantile = quantiles[quantile_idx] + high_quantile = quantiles[-quantile_idx - 1] + values_low = segment_forecast_df[f"{quantile_prefix}{low_quantile}"].values + values_high = segment_forecast_df[f"{quantile_prefix}{high_quantile}"].values + # if (low_quantile, high_quantile) is the smallest interval + if quantile_idx == len(quantiles) // 2 - 1: ax[i].fill_between( segment_forecast_df.index.values, values_low, values_high, facecolor=forecast_color, - alpha=alpha[quantile], - label=f"{legend_prefix}{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval", + alpha=alpha[quantile_idx], + label=f"{legend_prefix}{low_quantile}-{high_quantile} prediction interval", ) + # if there is some interval inside (low_quantile, high_quantile) we should plot around it else: - values_next = segment_forecast_df["target_" + str(quantiles[quantile + 1])].values + low_next_quantile = quantiles[quantile_idx + 1] + high_prev_quantile = quantiles[-quantile_idx - 2] + values_next = segment_forecast_df[f"{quantile_prefix}{low_next_quantile}"].values ax[i].fill_between( segment_forecast_df.index.values, values_low, values_next, facecolor=forecast_color, - alpha=alpha[quantile], - label=f"{legend_prefix}{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval", + alpha=alpha[quantile_idx], + label=f"{legend_prefix}{low_quantile}-{high_quantile} prediction interval", ) - values_prev = segment_forecast_df["target_" + str(quantiles[-quantile - 2])].values + values_prev = segment_forecast_df[f"{quantile_prefix}{high_prev_quantile}"].values ax[i].fill_between( segment_forecast_df.index.values, values_high, values_prev, facecolor=forecast_color, - alpha=alpha[quantile], + alpha=alpha[quantile_idx], ) - # this isn't a possible option for `num_forecasts > 1` + # when we can't find pair quantile, we plot it separately if len(quantiles) % 2 != 0: - values = segment_forecast_df["target_" + str(quantiles[len(quantiles) // 2])].values + remaining_quantile = quantiles[len(quantiles) // 2] + values = segment_forecast_df[f"{quantile_prefix}{remaining_quantile}"].values ax[i].plot( segment_forecast_df.index.values, values, "--", - c="orange", - label=f"{legend_prefix}{quantiles[len(quantiles)//2]} quantile", + color=forecast_color, + label=f"{legend_prefix}{remaining_quantile} quantile", ) ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) From ac4e0d610cd512ea1cde2b2033fda0cac07e787e Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 14 Mar 2022 16:30:57 +0300 Subject: [PATCH 4/8] Move legend to the right of the plot --- etna/analysis/plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 1ec20b2a2..115e4dbbc 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -243,7 +243,7 @@ def plot_forecast( ) ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) - ax[i].legend() + ax[i].legend(bbox_to_anchor=(1.02, 1.0)) def plot_backtest( From 775c4da2ea127a0e7125cf3bf301faed5223face Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 14 Mar 2022 16:39:45 +0300 Subject: [PATCH 5/8] Fix location --- etna/analysis/plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 115e4dbbc..9bb7bbf16 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -243,7 +243,7 @@ def plot_forecast( ) ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) - ax[i].legend(bbox_to_anchor=(1.02, 1.0)) + ax[i].legend(bbox_to_anchor=(1.02, 1.0), loc="upper left") def plot_backtest( From 25fcee07095ae88ab5ab2412e1c43fed9f5ef90b Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 14 Mar 2022 17:00:06 +0300 Subject: [PATCH 6/8] Get legend back --- etna/analysis/plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 9bb7bbf16..1ec20b2a2 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -243,7 +243,7 @@ def plot_forecast( ) ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) - ax[i].legend(bbox_to_anchor=(1.02, 1.0), loc="upper left") + ax[i].legend() def plot_backtest( From 8617e355304432eb935383b3340a4f61e003db6c Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 14 Mar 2022 17:11:17 +0300 Subject: [PATCH 7/8] Remove suffix about quantile and prediction interval from the legend --- etna/analysis/plotters.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 1ec20b2a2..d2db6b5f7 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -207,7 +207,7 @@ def plot_forecast( values_high, facecolor=forecast_color, alpha=alpha[quantile_idx], - label=f"{legend_prefix}{low_quantile}-{high_quantile} prediction interval", + label=f"{legend_prefix}{low_quantile}-{high_quantile}", ) # if there is some interval inside (low_quantile, high_quantile) we should plot around it else: @@ -220,7 +220,7 @@ def plot_forecast( values_next, facecolor=forecast_color, alpha=alpha[quantile_idx], - label=f"{legend_prefix}{low_quantile}-{high_quantile} prediction interval", + label=f"{legend_prefix}{low_quantile}-{high_quantile}", ) values_prev = segment_forecast_df[f"{quantile_prefix}{high_prev_quantile}"].values ax[i].fill_between( @@ -239,7 +239,7 @@ def plot_forecast( values, "--", color=forecast_color, - label=f"{legend_prefix}{remaining_quantile} quantile", + label=f"{legend_prefix}{remaining_quantile}", ) ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) From 34dbcf7008fb3fe60d9e0ab730dc702f7598305e Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 14 Mar 2022 17:22:20 +0300 Subject: [PATCH 8/8] Put legend to upper left corner --- etna/analysis/plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index d2db6b5f7..054ed508d 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -243,7 +243,7 @@ def plot_forecast( ) ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) - ax[i].legend() + ax[i].legend(loc="upper left") def plot_backtest(