diff --git a/CHANGELOG.md b/CHANGELOG.md index e07509c5d..8a6818e88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - Masked backtest ([#613](https://github.com/tinkoff-ai/etna/pull/613)) - -- +- Add seasonal_plot ([#628](https://github.com/tinkoff-ai/etna/pull/628)) - - Add plot_periodogram ([#606](https://github.com/tinkoff-ai/etna/pull/606)) - diff --git a/etna/analysis/__init__.py b/etna/analysis/__init__.py index c68e071cb..e89c4c297 100644 --- a/etna/analysis/__init__.py +++ b/etna/analysis/__init__.py @@ -1,10 +1,14 @@ from etna.analysis.change_points_trend import find_change_points +from etna.analysis.eda_utils import SeasonalPlotAggregation +from etna.analysis.eda_utils import SeasonalPlotAlignment +from etna.analysis.eda_utils import SeasonalPlotCycle from etna.analysis.eda_utils import cross_corr_plot from etna.analysis.eda_utils import distribution_plot from etna.analysis.eda_utils import prediction_actual_scatter_plot from etna.analysis.eda_utils import qq_plot from etna.analysis.eda_utils import sample_acf_plot from etna.analysis.eda_utils import sample_pacf_plot +from etna.analysis.eda_utils import seasonal_plot from etna.analysis.eda_utils import stl_plot from etna.analysis.feature_relevance.relevance import ModelRelevanceTable from etna.analysis.feature_relevance.relevance import RelevanceTable diff --git a/etna/analysis/eda_utils.py b/etna/analysis/eda_utils.py index 04ce6912d..0012a24e6 100644 --- a/etna/analysis/eda_utils.py +++ b/etna/analysis/eda_utils.py @@ -1,12 +1,15 @@ import math import warnings +from enum import Enum from itertools import combinations from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Optional from typing import Tuple +from typing import Union import matplotlib.pyplot as plt import numpy as np @@ -19,6 +22,7 @@ from statsmodels.graphics import utils from statsmodels.graphics.gofplots import qqplot from statsmodels.tsa.seasonal import STL +from typing_extensions import Literal from etna.analysis.utils import prepare_axes @@ -411,3 +415,337 @@ def prediction_actual_scatter_plot( ax[i].set_xlim(*xlim) ax[i].set_ylim(*ylim) ax[i].legend() + + +class SeasonalPlotAlignment(str, Enum): + """Enum for types of alignment in a seasonal plot. + + Class Attributes + ---------------- + first: + make first period full, allow last period to have NaNs in the ending + last: + make last period full, allow first period to have NaNs in the beginning + """ + + first = "first" + last = "last" + + @classmethod + def _missing_(cls, value): + raise NotImplementedError( + f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} alignments are allowed" + ) + + +class SeasonalPlotAggregation(str, Enum): + """Enum for types of aggregation in a seasonal plot.""" + + mean = "mean" + sum = "sum" + + @classmethod + def _missing_(cls, value): + raise NotImplementedError( + f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} aggregations are allowed" + ) + + @staticmethod + def _modified_nansum(series): + """Sum values with ignoring of NaNs. + + * If there some nan: we skip them. + + * If all values equal to nan we return nan. + """ + if np.all(np.isnan(series)): + return np.NaN + else: + return np.nansum(series) + + def get_function(self): + """Get aggregation function.""" + if self.value == "mean": + return np.nanmean + elif self.value == "sum": + return self._modified_nansum + + +class SeasonalPlotCycle(str, Enum): + """Enum for types of cycles in a seasonal plot.""" + + hour = "hour" + day = "day" + week = "week" + month = "month" + quarter = "quarter" + year = "year" + + @classmethod + def _missing_(cls, value): + raise NotImplementedError( + f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} cycles are allowed" + ) + + +def _get_seasonal_cycle_name( + timestamp: pd.Series, + cycle: Union[ + Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int + ], +) -> pd.Series: + """Get unique name for each cycle in a series with timestamps.""" + cycle_functions: Dict[SeasonalPlotCycle, Callable[[pd.Series], pd.Series]] = { + SeasonalPlotCycle.hour: lambda x: x.dt.strftime("%Y-%m-%d %H"), + SeasonalPlotCycle.day: lambda x: x.dt.strftime("%Y-%m-%d"), + SeasonalPlotCycle.week: lambda x: x.dt.strftime("%Y-%W"), + SeasonalPlotCycle.month: lambda x: x.dt.strftime("%Y-%b"), + SeasonalPlotCycle.quarter: lambda x: x.apply(lambda x: f"{x.year}-{x.quarter}"), + SeasonalPlotCycle.year: lambda x: x.dt.strftime("%Y"), + } + + if isinstance(cycle, int): + row_numbers = pd.Series(np.arange(len(timestamp))) + return (row_numbers // cycle + 1).astype(str) + else: + return cycle_functions[SeasonalPlotCycle(cycle)](timestamp) + + +def _get_seasonal_in_cycle_num( + timestamp: pd.Series, + cycle_name: pd.Series, + cycle: Union[ + Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int + ], + freq: str, +) -> pd.Series: + """Get number for each point within cycle in a series of timestamps.""" + cycle_functions: Dict[Tuple[SeasonalPlotCycle, str], Callable[[pd.Series], pd.Series]] = { + (SeasonalPlotCycle.hour, "T"): lambda x: x.dt.minute, + (SeasonalPlotCycle.day, "H"): lambda x: x.dt.hour, + (SeasonalPlotCycle.week, "D"): lambda x: x.dt.weekday, + (SeasonalPlotCycle.month, "D"): lambda x: x.dt.day, + (SeasonalPlotCycle.quarter, "D"): lambda x: (x - pd.PeriodIndex(x, freq="Q").start_time).dt.days, + (SeasonalPlotCycle.year, "D"): lambda x: x.dt.dayofyear, + (SeasonalPlotCycle.year, "Q"): lambda x: x.dt.quarter, + (SeasonalPlotCycle.year, "QS"): lambda x: x.dt.quarter, + (SeasonalPlotCycle.year, "M"): lambda x: x.dt.month, + (SeasonalPlotCycle.year, "MS"): lambda x: x.dt.month, + } + + if isinstance(cycle, int): + pass + else: + key = (SeasonalPlotCycle(cycle), freq) + if key in cycle_functions: + return cycle_functions[key](timestamp) + + # in all other cases we can use numbers within each group + cycle_df = pd.DataFrame({"timestamp": timestamp.tolist(), "cycle_name": cycle_name.tolist()}) + return cycle_df.sort_values("timestamp").groupby("cycle_name").cumcount() + + +def _get_seasonal_in_cycle_name( + timestamp: pd.Series, + in_cycle_num: pd.Series, + cycle: Union[ + Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int + ], + freq: str, +) -> pd.Series: + """Get unique name for each point within the cycle in a series of timestamps.""" + if isinstance(cycle, int): + pass + elif SeasonalPlotCycle(cycle) == SeasonalPlotCycle.week: + if freq == "D": + return timestamp.dt.strftime("%a") + elif SeasonalPlotCycle(cycle) == SeasonalPlotCycle.year: + if freq == "M" or freq == "MS": + return timestamp.dt.strftime("%b") + + # in all other cases we can use numbers from cycle_num + return in_cycle_num.astype(str) + + +def _seasonal_split( + timestamp: pd.Series, + freq: str, + cycle: Union[ + Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int + ], +) -> pd.DataFrame: + """Create a seasonal split into cycles of a given timestamp. + + Parameters + ---------- + timestamp: + series with timestamps + freq: + frequency of dataframe + cycle: + period of seasonality to capture (see :class:`~etna.analysis.SeasonalPlotCycle`) + + Returns + ------- + result: + dataframe with timestamps and corresponding cycle names and in cycle names + """ + cycles_df = pd.DataFrame({"timestamp": timestamp.tolist()}) + cycles_df["cycle_name"] = _get_seasonal_cycle_name(timestamp=cycles_df["timestamp"], cycle=cycle) + cycles_df["in_cycle_num"] = _get_seasonal_in_cycle_num( + timestamp=cycles_df["timestamp"], cycle_name=cycles_df["cycle_name"], cycle=cycle, freq=freq + ) + cycles_df["in_cycle_name"] = _get_seasonal_in_cycle_name( + timestamp=cycles_df["timestamp"], in_cycle_num=cycles_df["in_cycle_num"], cycle=cycle, freq=freq + ) + return cycles_df + + +def _resample(df: pd.DataFrame, freq: str, aggregation: Union[Literal["sum"], Literal["mean"]]) -> pd.DataFrame: + from etna.datasets import TSDataset + + agg_enum = SeasonalPlotAggregation(aggregation) + df_flat = TSDataset.to_flatten(df) + df_flat = ( + df_flat.set_index("timestamp") + .groupby(["segment", pd.Grouper(freq=freq)]) + .agg(agg_enum.get_function()) + .reset_index() + ) + df = TSDataset.to_dataset(df_flat) + return df + + +def _prepare_seasonal_plot_df( + ts: "TSDataset", + freq: str, + cycle: Union[ + Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int + ], + alignment: Union[Literal["first"], Literal["last"]], + aggregation: Union[Literal["sum"], Literal["mean"]], + in_column: str, + segments: List[str], +): + # for simplicity we will rename our column to target + df = ts.to_pandas().loc[:, pd.IndexSlice[segments, in_column]] + df.rename(columns={in_column: "target"}, inplace=True) + + # remove timestamps with only nans, it is possible if in_column != "target" + df = df[(~df.isna()).sum(axis=1) > 0] + + # make resampling if necessary + if ts.freq != freq: + df = _resample(df=df, freq=freq, aggregation=aggregation) + + # process alignment + if isinstance(cycle, int): + timestamp = df.index + num_to_add = -len(timestamp) % cycle + # if we want align by the first value, then we should append NaNs to timestamp + to_add_index = None + if SeasonalPlotAlignment(alignment) == SeasonalPlotAlignment.first: + to_add_index = pd.date_range(start=timestamp.max(), periods=num_to_add + 1, closed="right", freq=freq) + # if we want to align by the last value, then we should prepend NaNs to timestamp + elif SeasonalPlotAlignment(alignment) == SeasonalPlotAlignment.last: + to_add_index = pd.date_range(end=timestamp.min(), periods=num_to_add + 1, closed="left", freq=freq) + + df = df.append(pd.DataFrame(None, index=to_add_index)).sort_index() + + return df + + +def seasonal_plot( + ts: "TSDataset", + freq: Optional[str] = None, + cycle: Union[ + Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int + ] = "year", + alignment: Union[Literal["first"], Literal["last"]] = "last", + aggregation: Union[Literal["sum"], Literal["mean"]] = "sum", + in_column: str = "target", + plot_params: Optional[Dict[str, Any]] = None, + cmap: str = "plasma", + segments: Optional[List[str]] = None, + columns_num: int = 2, + figsize: Tuple[int, int] = (10, 5), +): + """Plot each season on one canvas for each segment. + + Parameters + ---------- + ts: + dataset with timeseries data + freq: + frequency to analyze seasons: + + * if isn't set, the frequency of ``ts`` will be used; + + * if set, resampling will be made using ``aggregation`` parameter. + If given frequency is too low, then the frequency of ``ts`` will be used. + + cycle: + period of seasonality to capture (see :class:`~etna.analysis.eda_utils.SeasonalPlotCycle`) + alignment: + how to align dataframe in case of integer cycle (see :class:`~etna.analysis.eda_utils.SeasonalPlotAlignment`) + aggregation: + how to aggregate values during resampling (see :class:`~etna.analysis.eda_utils.SeasonalPlotAggregation`) + in_column: + column to use + cmap: + name of colormap for plotting different cycles + (see `Choosing Colormaps in Matplotlib `_) + plot_params: + dictionary with parameters for plotting, :meth:`matplotlib.axes.Axes.plot` is used + segments: + segments to use + columns_num: + number of columns in subplots + figsize: + size of the figure per subplot with one segment in inches + """ + if plot_params is None: + plot_params = {} + if freq is None: + freq = ts.freq + if segments is None: + segments = sorted(ts.segments) + + df = _prepare_seasonal_plot_df( + ts=ts, + freq=freq, + cycle=cycle, + alignment=alignment, + aggregation=aggregation, + in_column=in_column, + segments=segments, + ) + seasonal_df = _seasonal_split(timestamp=df.index.to_series(), freq=freq, cycle=cycle) + + colors = plt.get_cmap(cmap) + ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize) + for i, segment in enumerate(segments): + segment_df = df.loc[:, pd.IndexSlice[segment, "target"]] + cycle_names = seasonal_df["cycle_name"].unique() + for j, cycle_name in enumerate(cycle_names): + color = colors(j / len(cycle_names)) + cycle_df = seasonal_df[seasonal_df["cycle_name"] == cycle_name] + segment_cycle_df = segment_df.loc[cycle_df["timestamp"]] + ax[i].plot( + cycle_df["in_cycle_num"], + segment_cycle_df[cycle_df["timestamp"]], + color=color, + label=cycle_name, + **plot_params, + ) + + # draw ticks if they are not digits + if not np.all(seasonal_df["in_cycle_name"].str.isnumeric()): + ticks_dict = {key: value for key, value in zip(seasonal_df["in_cycle_num"], seasonal_df["in_cycle_name"])} + ticks = np.array(list(ticks_dict.keys())) + ticks_labels = np.array(list(ticks_dict.values())) + idx_sort = np.argsort(ticks) + ax[i].set_xticks(ticks=ticks[idx_sort], labels=ticks_labels[idx_sort]) + ax[i].set_xlabel(freq) + ax[i].set_title(segment) + ax[i].legend(loc="upper center", bbox_to_anchor=(0.5, -0.12), ncol=6) diff --git a/tests/test_analysis/test_eda_utils.py b/tests/test_analysis/test_eda_utils.py new file mode 100644 index 000000000..1c08d518c --- /dev/null +++ b/tests/test_analysis/test_eda_utils.py @@ -0,0 +1,118 @@ +import numpy as np +import pandas as pd +import pytest + +from etna.analysis.eda_utils import _resample +from etna.analysis.eda_utils import _seasonal_split +from etna.analysis.eda_utils import seasonal_plot +from etna.datasets import TSDataset + + +@pytest.mark.parametrize( + "timestamp, cycle, expected_cycle_names, expected_in_cycle_nums, expected_in_cycle_names", + [ + ( + pd.date_range(start="2020-01-01", periods=5, freq="D"), + 3, + ["1", "1", "1", "2", "2"], + [0, 1, 2, 0, 1], + ["0", "1", "2", "0", "1"], + ), + ( + pd.date_range(start="2020-01-01", periods=6, freq="15T"), + "hour", + ["2020-01-01 00"] * 4 + ["2020-01-01 01"] * 2, + [0, 1, 2, 3, 0, 1], + ["0", "1", "2", "3", "0", "1"], + ), + ( + pd.date_range(start="2020-01-01", periods=26, freq="H"), + "day", + ["2020-01-01"] * 24 + ["2020-01-02"] * 2, + [i % 24 for i in range(26)], + [str(i % 24) for i in range(26)], + ), + ( + pd.date_range(start="2020-01-01", periods=10, freq="D"), + "week", + ["2020-00"] * 5 + ["2020-01"] * 5, + [2, 3, 4, 5, 6, 0, 1, 2, 3, 4], + ["Wed", "Thu", "Fri", "Sat", "Sun", "Mon", "Tue", "Wed", "Thu", "Fri"], + ), + ( + pd.date_range(start="2020-01-03", periods=40, freq="D"), + "month", + ["2020-Jan"] * 29 + ["2020-Feb"] * 11, + [i for i in range(3, 32)] + [i for i in range(1, 12)], + [str(i) for i in range(3, 32)] + [str(i) for i in range(1, 12)], + ), + ( + pd.date_range(start="2020-01-01", periods=14, freq="M"), + "quarter", + ["2020-1"] * 3 + ["2020-2"] * 3 + ["2020-3"] * 3 + ["2020-4"] * 3 + ["2021-1"] * 2, + [i % 3 for i in range(14)], + [str(i % 3) for i in range(14)], + ), + ( + pd.date_range(start="2020-01-01", periods=14, freq="M"), + "year", + ["2020"] * 12 + ["2021"] * 2, + [i % 12 + 1 for i in range(14)], + ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", "Jan", "Feb"], + ), + ], +) +def test_seasonal_split(timestamp, cycle, expected_cycle_names, expected_in_cycle_nums, expected_in_cycle_names): + cycle_df = _seasonal_split(timestamp=timestamp.to_series(), freq=timestamp.freq.freqstr, cycle=cycle) + assert cycle_df["cycle_name"].tolist() == expected_cycle_names + assert cycle_df["in_cycle_num"].tolist() == expected_in_cycle_nums + assert cycle_df["in_cycle_name"].tolist() == expected_in_cycle_names + + +@pytest.mark.parametrize( + "timestamp, values, resample_freq, aggregation, expected_timestamp, expected_values", + [ + ( + pd.date_range(start="2020-01-01", periods=14, freq="Q"), + [np.NaN, np.NaN, np.NaN, np.NaN, np.NaN, 10, 16, 10, 5, 7, 5, 7, 3, 3], + "Y", + "sum", + pd.date_range(start="2020-01-01", periods=4, freq="Y"), + [np.NaN, 36.0, 24.0, 6.0], + ), + ( + pd.date_range(start="2020-01-01", periods=14, freq="Q"), + [np.NaN, np.NaN, np.NaN, np.NaN, np.NaN, 10, 16, 10, 5, 7, 5, 7, 3, 3], + "Y", + "mean", + pd.date_range(start="2020-01-01", periods=4, freq="Y"), + [np.NaN, 12.0, 6.0, 3.0], + ), + ], +) +def test_resample(timestamp, values, resample_freq, aggregation, expected_timestamp, expected_values): + df = pd.DataFrame({"timestamp": timestamp.tolist(), "target": values, "segment": len(timestamp) * ["segment_0"]}) + df_wide = TSDataset.to_dataset(df) + df_resampled = _resample(df=df_wide, freq=resample_freq, aggregation=aggregation) + assert df_resampled.index.tolist() == expected_timestamp.tolist() + assert ( + df_resampled.loc[:, pd.IndexSlice["segment_0", "target"]] + .reset_index(drop=True) + .equals(pd.Series(expected_values)) + ) + + +@pytest.mark.parametrize( + "freq, cycle, additional_params", + [ + ("D", 5, dict(alignment="first")), + ("D", 5, dict(alignment="last")), + ("D", "week", {}), + ("D", "month", {}), + ("D", "year", {}), + ("M", "year", dict(aggregation="sum")), + ("M", "year", dict(aggregation="mean")), + ], +) +def test_dummy_seasonal_plot(freq, cycle, additional_params, ts_with_different_series_length): + seasonal_plot(ts=ts_with_different_series_length, freq=freq, cycle=cycle, **additional_params)