Skip to content

Add find_change_points and tests for it #521

Merged
merged 8 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Add find_change_points function ([#521](https://github.com/tinkoff-ai/etna/pull/521))
-
-

Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from etna.analysis.change_points_trend import find_change_points
from etna.analysis.eda_utils import cross_corr_plot
from etna.analysis.eda_utils import distribution_plot
from etna.analysis.eda_utils import sample_acf_plot
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/change_points_trend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from etna.analysis.change_points_trend.find_change_points import find_change_points
63 changes: 63 additions & 0 deletions etna/analysis/change_points_trend/find_change_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Dict
from typing import List

import numpy as np
import pandas as pd
from ruptures.base import BaseEstimator
from ruptures.costs import CostLinear

from etna.datasets import TSDataset


def _prepare_signal(series: pd.Series, model: BaseEstimator) -> np.ndarray:
"""Prepare series for change point model."""
signal = series.to_numpy()
if isinstance(model.cost, CostLinear):
signal = signal.reshape((-1, 1))
return signal


def _find_change_points_segment(
series: pd.Series, change_point_model: BaseEstimator, **model_predict_params
) -> List[pd.Timestamp]:
"""Find trend change points within one segment."""
signal = _prepare_signal(series=series, model=change_point_model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make the order of parameters consistent with the order of parameters in signature(or vice versa)

timestamp = series.index
change_point_model.fit(signal=signal)
# last point in change points is the first index after the series
change_points_indices = change_point_model.predict(**model_predict_params)[:-1]
change_points = [timestamp[idx] for idx in change_points_indices]
return change_points


def find_change_points(
ts: TSDataset, in_column: str, change_point_model: BaseEstimator, **model_predict_params
) -> Dict[str, List[pd.Timestamp]]:
"""Find trend change points using ruptures models.

Parameters
----------
ts:
dataset to work with
in_column:
name of column to work with
change_point_model:
ruptures model to get trend change points
model_predict_params:
params for change_point_model predict method

Returns
-------
result:
dictionary with list of trend change points for each segment
"""
result: Dict[str, List[pd.Timestamp]] = {}
df = ts.to_pandas()
for segment in ts.segments:
df_segment = df[segment]
raw_series = df_segment[in_column]
series = raw_series.loc[raw_series.first_valid_index() : raw_series.last_valid_index()]
result[segment] = _find_change_points_segment(
series=series, change_point_model=change_point_model, **model_predict_params
)
return result
23 changes: 4 additions & 19 deletions etna/transforms/decomposition/change_points_trend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import numpy as np
import pandas as pd
from ruptures.base import BaseEstimator
from ruptures.costs import CostLinear
from sklearn.base import RegressorMixin

from etna.analysis.change_points_trend.find_change_points import _find_change_points_segment
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import Transform
from etna.transforms.utils import match_target_quantiles
Expand Down Expand Up @@ -50,23 +50,6 @@ def __init__(
self.intervals: Optional[List[TTimestampInterval]] = None
self.change_point_model_predict_params = change_point_model_predict_params

def _prepare_signal(self, series: pd.Series) -> np.ndarray:
"""Prepare series for change point model."""
signal = series.to_numpy()
if isinstance(self.change_point_model.cost, CostLinear):
signal = signal.reshape((-1, 1))
return signal

def _get_change_points(self, series: pd.Series) -> List[pd.Timestamp]:
"""Fit change point model with series data and predict trends change points."""
signal = self._prepare_signal(series=series)
timestamp = series.index
self.change_point_model.fit(signal=signal)
# last point in change points is the first index after the series
change_points_indices = self.change_point_model.predict(**self.change_point_model_predict_params)[:-1]
change_points = [timestamp[idx] for idx in change_points_indices]
return change_points

@staticmethod
def _build_trend_intervals(change_points: List[pd.Timestamp]) -> List[TTimestampInterval]:
"""Create list of stable trend intervals from list of change points."""
Expand Down Expand Up @@ -132,7 +115,9 @@ def fit(self, df: pd.DataFrame) -> "_OneSegmentChangePointsTrendTransform":
series = df.loc[df[self.in_column].first_valid_index() : df[self.in_column].last_valid_index(), self.in_column]
if series.isnull().values.any():
raise ValueError("The input column contains NaNs in the middle of the series! Try to use the imputer.")
change_points = self._get_change_points(series=series)
change_points = _find_change_points_segment(
series=series, change_point_model=self.change_point_model, **self.change_point_model_predict_params
)
self.intervals = self._build_trend_intervals(change_points=change_points)
self.per_interval_models = self._init_detrend_models(intervals=self.intervals)
self._fit_per_interval_model(series=series)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Dict
from typing import List

import numpy as np
import pandas as pd
import pytest
from ruptures import Binseg

from etna.analysis import find_change_points
from etna.datasets import TSDataset


def check_change_points(change_points: Dict[str, List[pd.Timestamp]], segments: List[str], num_points: int):
"""Check change points on validity."""
assert isinstance(change_points, dict)
assert set(change_points.keys()) == set(segments)
for segment in segments:
change_points_segment = change_points[segment]
assert len(change_points_segment) == num_points
for point in change_points_segment:
assert isinstance(point, pd.Timestamp)


@pytest.mark.parametrize("n_bkps", [5, 10, 12, 27])
def test_find_change_points_simple(multitrend_df: pd.DataFrame, n_bkps: int):
"""Test that find_change_points works fine with multitrend example."""
ts = TSDataset(df=multitrend_df, freq="D")
change_points = find_change_points(ts=ts, in_column="target", change_point_model=Binseg(), n_bkps=n_bkps)
check_change_points(change_points, segments=ts.segments, num_points=n_bkps)


@pytest.mark.parametrize("n_bkps", [5, 10, 12, 27])
def test_find_change_points_nans_head(multitrend_df: pd.DataFrame, n_bkps: int):
"""Test that find_change_points works fine with nans at the beginning of the series."""
multitrend_df.iloc[:5, :] = np.NaN
ts = TSDataset(df=multitrend_df, freq="D")
change_points = find_change_points(ts=ts, in_column="target", change_point_model=Binseg(), n_bkps=n_bkps)
check_change_points(change_points, segments=ts.segments, num_points=n_bkps)


@pytest.mark.parametrize("n_bkps", [5, 10, 12, 27])
def test_find_change_points_nans_tail(multitrend_df: pd.DataFrame, n_bkps: int):
"""Test that find_change_points works fine with nans at the end of the series."""
multitrend_df.iloc[-5:, :] = np.NaN
ts = TSDataset(df=multitrend_df, freq="D")
change_points = find_change_points(ts=ts, in_column="target", change_point_model=Binseg(), n_bkps=n_bkps)
check_change_points(change_points, segments=ts.segments, num_points=n_bkps)
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,6 @@ def multitrend_df_with_nans_in_tails(multitrend_df):
return multitrend_df


@pytest.mark.parametrize("n_bkps", (5, 10, 12, 27))
def test_get_change_points(multitrend_df: pd.DataFrame, n_bkps: int):
"""Check that _get_change_points method return correct number of points in correct format."""
bs = _OneSegmentChangePointsTrendTransform(
in_column="target", change_point_model=Binseg(), detrend_model=LinearRegression(), n_bkps=n_bkps
)
change_points = bs._get_change_points(multitrend_df["segment_1"]["target"])
assert isinstance(change_points, list)
assert len(change_points) == n_bkps
for point in change_points:
assert isinstance(point, pd.Timestamp)


def test_build_trend_intervals():
"""Check correctness of intervals generation with list of change points."""
change_points = [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-18"), pd.Timestamp("2020-02-24")]
Expand Down