Skip to content

Commit

Permalink
Add interactive plot for anomalies (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository authored Sep 27, 2021
1 parent ef5cd19 commit ebed9c9
Show file tree
Hide file tree
Showing 5 changed files with 443 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- BinsegTrendTransform, ChangePointsTrendTransform ([#87](https://github.com/tinkoff-ai/etna-ts/pull/87))
- Interactive plot for anomalies (#[95](https://github.com/tinkoff-ai/etna-ts/pull/95))

### Changed
- SklearnTransform out column names ([#99](https://github.com/tinkoff-ai/etna-ts/pull/99))
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from etna.analysis.outliers.median_outliers import get_anomalies_median
from etna.analysis.plotters import get_correlation_matrix
from etna.analysis.plotters import plot_anomalies
from etna.analysis.plotters import plot_anomalies_interactive
from etna.analysis.plotters import plot_backtest
from etna.analysis.plotters import plot_correlation_matrix
from etna.analysis.plotters import plot_forecast
74 changes: 74 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import math
from typing import TYPE_CHECKING
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -244,3 +247,74 @@ def plot_correlation_matrix(
ax.set_xticklabels(labels, rotation=45, horizontalalignment="right")
ax.set_yticklabels(labels, rotation=0, horizontalalignment="right")
ax.set_title("Correlation Heatmap")


def plot_anomalies_interactive(
ts: "TSDataset",
segment: str,
method: Callable[..., Dict[str, List[pd.Timestamp]]],
params_bounds: Dict[str, Tuple[Union[int, float], Union[int, float], Union[int, float]]],
):
"""Plot a time series with indicated anomalies.
Anomalies are obtained using the specified method. The method parameters values
can be changed using the corresponding sliders.
Parameters
----------
ts:
TSDataset with timeseries data
segment:
Segment to plot
method:
Method for outliers detection
params_bounds:
Parameters ranges of the outliers detection method. Bounds for the parameter are (min,max,step)
Examples
--------
>>> from etna.datasets import TSDataset
>>> from etna.datasets import generate_ar_df
>>> from etna.analysis import plot_anomalies_interactive, get_anomalies_density
>>> classic_df = generate_ar_df(periods=1000, start_time="2021-08-01", n_segments=2)
>>> df = TSDataset.to_dataset(classic_df)
>>> ts = TSDataset(df, "D")
>>> params_bounds = {"window_size": (5, 20, 1), "distance_coef": (0.1, 3, 0.25)}
>>> method = get_anomalies_density
>>> plot_anomalies_interactive(ts=ts, segment="segment_1", method=method, params_bounds=params_bounds)
"""
from ipywidgets import FloatSlider
from ipywidgets import IntSlider
from ipywidgets import interact

from etna.datasets import TSDataset

df = ts[:, segment, "target"]
ts = TSDataset(ts[:, segment, :], ts.freq)
x, y = df.index.values, df.values
cache = {}

sliders = dict()
style = {"description_width": "initial"}
for param, bounds in params_bounds.items():
min_, max_, step = bounds
if isinstance(min_, float) or isinstance(max_, float) or isinstance(step, float):
sliders[param] = FloatSlider(min=min_, max=max_, step=step, continuous_update=False, style=style)
else:
sliders[param] = IntSlider(min=min_, max=max_, step=step, continuous_update=False, style=style)

def update(**kwargs):
key = "_".join([str(val) for val in kwargs.values()])
if key not in cache:
anomalies = method(ts, **kwargs)[segment]
anomalies = sorted(anomalies)
cache[key] = anomalies
else:
anomalies = cache[key]
plt.figure(figsize=(20, 10))
plt.cla()
plt.plot(x, y)
plt.scatter(anomalies, y[pd.to_datetime(x).isin(anomalies)], c="r")
plt.xticks(rotation=45)
plt.show()

interact(update, **sliders)
Loading

0 comments on commit ebed9c9

Please sign in to comment.