Skip to content

Add seasonal_plot #628

Merged
merged 8 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Masked backtest ([#613](https://github.com/tinkoff-ai/etna/pull/613))
-
-
- Add seasonal_plot ([#628](https://github.com/tinkoff-ai/etna/pull/628))
-
- Add plot_periodogram ([#606](https://github.com/tinkoff-ai/etna/pull/606))
-
Expand Down
4 changes: 4 additions & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from etna.analysis.change_points_trend import find_change_points
from etna.analysis.eda_utils import SeasonalPlotAggregation
from etna.analysis.eda_utils import SeasonalPlotAlignment
from etna.analysis.eda_utils import SeasonalPlotCycle
from etna.analysis.eda_utils import cross_corr_plot
from etna.analysis.eda_utils import distribution_plot
from etna.analysis.eda_utils import prediction_actual_scatter_plot
from etna.analysis.eda_utils import qq_plot
from etna.analysis.eda_utils import sample_acf_plot
from etna.analysis.eda_utils import sample_pacf_plot
from etna.analysis.eda_utils import seasonal_plot
from etna.analysis.eda_utils import stl_plot
from etna.analysis.feature_relevance.relevance import ModelRelevanceTable
from etna.analysis.feature_relevance.relevance import RelevanceTable
Expand Down
338 changes: 338 additions & 0 deletions etna/analysis/eda_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import math
import warnings
from enum import Enum
from itertools import combinations
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -19,6 +22,7 @@
from statsmodels.graphics import utils
from statsmodels.graphics.gofplots import qqplot
from statsmodels.tsa.seasonal import STL
from typing_extensions import Literal

from etna.analysis.utils import prepare_axes

Expand Down Expand Up @@ -411,3 +415,337 @@ def prediction_actual_scatter_plot(
ax[i].set_xlim(*xlim)
ax[i].set_ylim(*ylim)
ax[i].legend()


class SeasonalPlotAlignment(str, Enum):
"""Enum for types of alignment in a seasonal plot.

Class Attributes
----------------
first:
make first period full, allow last period to have NaNs in the ending
last:
make last period full, allow first period to have NaNs in the beginning
"""

first = "first"
last = "last"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} alignments are allowed"
)


class SeasonalPlotAggregation(str, Enum):
"""Enum for types of aggregation in a seasonal plot."""

mean = "mean"
sum = "sum"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} aggregations are allowed"
)

@staticmethod
def _modified_nansum(series):
"""Sum values with ignoring of NaNs.

* If there some nan: we skip them.

* If all values equal to nan we return nan.
"""
if np.all(np.isnan(series)):
return np.NaN
else:
return np.nansum(series)

def get_function(self):
"""Get aggregation function."""
if self.value == "mean":
return np.nanmean
elif self.value == "sum":
return self._modified_nansum


class SeasonalPlotCycle(str, Enum):
"""Enum for types of cycles in a seasonal plot."""

hour = "hour"
day = "day"
week = "week"
month = "month"
quarter = "quarter"
year = "year"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} cycles are allowed"
)


def _get_seasonal_cycle_name(
timestamp: pd.Series,
cycle: Union[
Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int
],
) -> pd.Series:
"""Get unique name for each cycle in a series with timestamps."""
cycle_functions: Dict[SeasonalPlotCycle, Callable[[pd.Series], pd.Series]] = {
SeasonalPlotCycle.hour: lambda x: x.dt.strftime("%Y-%m-%d %H"),
SeasonalPlotCycle.day: lambda x: x.dt.strftime("%Y-%m-%d"),
SeasonalPlotCycle.week: lambda x: x.dt.strftime("%Y-%W"),
SeasonalPlotCycle.month: lambda x: x.dt.strftime("%Y-%b"),
SeasonalPlotCycle.quarter: lambda x: x.apply(lambda x: f"{x.year}-{x.quarter}"),
SeasonalPlotCycle.year: lambda x: x.dt.strftime("%Y"),
}

if isinstance(cycle, int):
row_numbers = pd.Series(np.arange(len(timestamp)))
return (row_numbers // cycle + 1).astype(str)
else:
return cycle_functions[SeasonalPlotCycle(cycle)](timestamp)


def _get_seasonal_in_cycle_num(
timestamp: pd.Series,
cycle_name: pd.Series,
cycle: Union[
Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int
],
freq: str,
) -> pd.Series:
"""Get number for each point within cycle in a series of timestamps."""
cycle_functions: Dict[Tuple[SeasonalPlotCycle, str], Callable[[pd.Series], pd.Series]] = {
(SeasonalPlotCycle.hour, "T"): lambda x: x.dt.minute,
(SeasonalPlotCycle.day, "H"): lambda x: x.dt.hour,
(SeasonalPlotCycle.week, "D"): lambda x: x.dt.weekday,
(SeasonalPlotCycle.month, "D"): lambda x: x.dt.day,
(SeasonalPlotCycle.quarter, "D"): lambda x: (x - pd.PeriodIndex(x, freq="Q").start_time).dt.days,
(SeasonalPlotCycle.year, "D"): lambda x: x.dt.dayofyear,
(SeasonalPlotCycle.year, "Q"): lambda x: x.dt.quarter,
(SeasonalPlotCycle.year, "QS"): lambda x: x.dt.quarter,
(SeasonalPlotCycle.year, "M"): lambda x: x.dt.month,
(SeasonalPlotCycle.year, "MS"): lambda x: x.dt.month,
}

if isinstance(cycle, int):
pass
else:
key = (SeasonalPlotCycle(cycle), freq)
if key in cycle_functions:
return cycle_functions[key](timestamp)

# in all other cases we can use numbers within each group
cycle_df = pd.DataFrame({"timestamp": timestamp.tolist(), "cycle_name": cycle_name.tolist()})
return cycle_df.sort_values("timestamp").groupby("cycle_name").cumcount()


def _get_seasonal_in_cycle_name(
timestamp: pd.Series,
in_cycle_num: pd.Series,
cycle: Union[
Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int
],
freq: str,
) -> pd.Series:
"""Get unique name for each point within the cycle in a series of timestamps."""
if isinstance(cycle, int):
pass
elif SeasonalPlotCycle(cycle) == SeasonalPlotCycle.week:
if freq == "D":
return timestamp.dt.strftime("%a")
elif SeasonalPlotCycle(cycle) == SeasonalPlotCycle.year:
if freq == "M" or freq == "MS":
return timestamp.dt.strftime("%b")

# in all other cases we can use numbers from cycle_num
return in_cycle_num.astype(str)


def _seasonal_split(
timestamp: pd.Series,
freq: str,
cycle: Union[
Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int
],
) -> pd.DataFrame:
"""Create a seasonal split into cycles of a given timestamp.

Parameters
----------
timestamp:
series with timestamps
freq:
frequency of dataframe
cycle:
period of seasonality to capture (see :class:`~etna.analysis.SeasonalPlotCycle`)

Returns
-------
result:
dataframe with timestamps and corresponding cycle names and in cycle names
"""
cycles_df = pd.DataFrame({"timestamp": timestamp.tolist()})
cycles_df["cycle_name"] = _get_seasonal_cycle_name(timestamp=cycles_df["timestamp"], cycle=cycle)
cycles_df["in_cycle_num"] = _get_seasonal_in_cycle_num(
timestamp=cycles_df["timestamp"], cycle_name=cycles_df["cycle_name"], cycle=cycle, freq=freq
)
cycles_df["in_cycle_name"] = _get_seasonal_in_cycle_name(
timestamp=cycles_df["timestamp"], in_cycle_num=cycles_df["in_cycle_num"], cycle=cycle, freq=freq
)
return cycles_df


def _resample(df: pd.DataFrame, freq: str, aggregation: Union[Literal["sum"], Literal["mean"]]) -> pd.DataFrame:
martins0n marked this conversation as resolved.
Show resolved Hide resolved
from etna.datasets import TSDataset

agg_enum = SeasonalPlotAggregation(aggregation)
df_flat = TSDataset.to_flatten(df)
df_flat = (
df_flat.set_index("timestamp")
.groupby(["segment", pd.Grouper(freq=freq)])
.agg(agg_enum.get_function())
.reset_index()
)
df = TSDataset.to_dataset(df_flat)
return df


def _prepare_seasonal_plot_df(
ts: "TSDataset",
freq: str,
cycle: Union[
Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int
],
alignment: Union[Literal["first"], Literal["last"]],
aggregation: Union[Literal["sum"], Literal["mean"]],
in_column: str,
segments: List[str],
):
# for simplicity we will rename our column to target
df = ts.to_pandas().loc[:, pd.IndexSlice[segments, in_column]]
df.rename(columns={in_column: "target"}, inplace=True)

# remove timestamps with only nans, it is possible if in_column != "target"
df = df[(~df.isna()).sum(axis=1) > 0]

# make resampling if necessary
if ts.freq != freq:
df = _resample(df=df, freq=freq, aggregation=aggregation)

# process alignment
if isinstance(cycle, int):
timestamp = df.index
num_to_add = -len(timestamp) % cycle
# if we want align by the first value, then we should append NaNs to timestamp
to_add_index = None
if SeasonalPlotAlignment(alignment) == SeasonalPlotAlignment.first:
to_add_index = pd.date_range(start=timestamp.max(), periods=num_to_add + 1, closed="right", freq=freq)
# if we want to align by the last value, then we should prepend NaNs to timestamp
elif SeasonalPlotAlignment(alignment) == SeasonalPlotAlignment.last:
to_add_index = pd.date_range(end=timestamp.min(), periods=num_to_add + 1, closed="left", freq=freq)

df = df.append(pd.DataFrame(None, index=to_add_index)).sort_index()

return df


def seasonal_plot(
ts: "TSDataset",
freq: Optional[str] = None,
cycle: Union[
Literal["hour"], Literal["day"], Literal["week"], Literal["month"], Literal["quarter"], Literal["year"], int
] = "year",
alignment: Union[Literal["first"], Literal["last"]] = "last",
aggregation: Union[Literal["sum"], Literal["mean"]] = "sum",
in_column: str = "target",
plot_params: Optional[Dict[str, Any]] = None,
cmap: str = "plasma",
segments: Optional[List[str]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
):
"""Plot each season on one canvas for each segment.

Parameters
----------
ts:
dataset with timeseries data
freq:
frequency to analyze seasons:

* if isn't set, the frequency of ``ts`` will be used;

* if set, resampling will be made using ``aggregation`` parameter.
If given frequency is too low, then the frequency of ``ts`` will be used.

cycle:
period of seasonality to capture (see :class:`~etna.analysis.eda_utils.SeasonalPlotCycle`)
alignment:
how to align dataframe in case of integer cycle (see :class:`~etna.analysis.eda_utils.SeasonalPlotAlignment`)
aggregation:
how to aggregate values during resampling (see :class:`~etna.analysis.eda_utils.SeasonalPlotAggregation`)
in_column:
column to use
cmap:
name of colormap for plotting different cycles
(see `Choosing Colormaps in Matplotlib <https://matplotlib.org/3.5.1/tutorials/colors/colormaps.html>`_)
plot_params:
dictionary with parameters for plotting, :meth:`matplotlib.axes.Axes.plot` is used
segments:
segments to use
columns_num:
number of columns in subplots
figsize:
size of the figure per subplot with one segment in inches
"""
if plot_params is None:
plot_params = {}
if freq is None:
freq = ts.freq
if segments is None:
segments = sorted(ts.segments)

df = _prepare_seasonal_plot_df(
ts=ts,
freq=freq,
cycle=cycle,
alignment=alignment,
aggregation=aggregation,
in_column=in_column,
segments=segments,
)
seasonal_df = _seasonal_split(timestamp=df.index.to_series(), freq=freq, cycle=cycle)

colors = plt.get_cmap(cmap)
ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize)
for i, segment in enumerate(segments):
segment_df = df.loc[:, pd.IndexSlice[segment, "target"]]
cycle_names = seasonal_df["cycle_name"].unique()
for j, cycle_name in enumerate(cycle_names):
color = colors(j / len(cycle_names))
cycle_df = seasonal_df[seasonal_df["cycle_name"] == cycle_name]
segment_cycle_df = segment_df.loc[cycle_df["timestamp"]]
ax[i].plot(
cycle_df["in_cycle_num"],
segment_cycle_df[cycle_df["timestamp"]],
color=color,
label=cycle_name,
**plot_params,
)

# draw ticks if they are not digits
if not np.all(seasonal_df["in_cycle_name"].str.isnumeric()):
ticks_dict = {key: value for key, value in zip(seasonal_df["in_cycle_num"], seasonal_df["in_cycle_name"])}
ticks = np.array(list(ticks_dict.keys()))
ticks_labels = np.array(list(ticks_dict.values()))
idx_sort = np.argsort(ticks)
ax[i].set_xticks(ticks=ticks[idx_sort], labels=ticks_labels[idx_sort])
ax[i].set_xlabel(freq)
ax[i].set_title(segment)
ax[i].legend(loc="upper center", bbox_to_anchor=(0.5, -0.12), ncol=6)
Loading