-
Notifications
You must be signed in to change notification settings - Fork 80
Add find_change_points
and tests for it
#521
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
d9ce401
Add find_change_points and tests for it
c66add8
Update changelog
2f2ba3a
Update changelog
22bdceb
Merge remote-tracking branch 'origin/master' into issue-485
29b29a9
Rename function for one segment
a2c4bce
Use finding change points from analysis
fda9fdc
Make consistent call of _prepare_signal, rename change trend points -…
f7f2a6f
Merge remote-tracking branch 'origin/master' into issue-485
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from etna.analysis.change_points_trend.find_change_points import find_change_points |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import Dict | ||
from typing import List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from ruptures.base import BaseEstimator | ||
from ruptures.costs import CostLinear | ||
|
||
from etna.datasets import TSDataset | ||
|
||
|
||
def _prepare_signal(series: pd.Series, model: BaseEstimator) -> np.ndarray: | ||
"""Prepare series for change point model.""" | ||
signal = series.to_numpy() | ||
if isinstance(model.cost, CostLinear): | ||
signal = signal.reshape((-1, 1)) | ||
return signal | ||
|
||
|
||
def _find_change_points_segment( | ||
series: pd.Series, change_point_model: BaseEstimator, **model_predict_params | ||
) -> List[pd.Timestamp]: | ||
"""Find trend change points within one segment.""" | ||
signal = _prepare_signal(series=series, model=change_point_model) | ||
timestamp = series.index | ||
change_point_model.fit(signal=signal) | ||
# last point in change points is the first index after the series | ||
change_points_indices = change_point_model.predict(**model_predict_params)[:-1] | ||
change_points = [timestamp[idx] for idx in change_points_indices] | ||
return change_points | ||
|
||
|
||
def find_change_points( | ||
ts: TSDataset, in_column: str, change_point_model: BaseEstimator, **model_predict_params | ||
) -> Dict[str, List[pd.Timestamp]]: | ||
"""Find trend change points using ruptures models. | ||
|
||
Parameters | ||
---------- | ||
ts: | ||
dataset to work with | ||
in_column: | ||
name of column to work with | ||
change_point_model: | ||
ruptures model to get trend change points | ||
model_predict_params: | ||
params for change_point_model predict method | ||
|
||
Returns | ||
------- | ||
result: | ||
dictionary with list of trend change points for each segment | ||
""" | ||
result: Dict[str, List[pd.Timestamp]] = {} | ||
df = ts.to_pandas() | ||
for segment in ts.segments: | ||
df_segment = df[segment] | ||
raw_series = df_segment[in_column] | ||
series = raw_series.loc[raw_series.first_valid_index() : raw_series.last_valid_index()] | ||
result[segment] = _find_change_points_segment( | ||
series=series, change_point_model=change_point_model, **model_predict_params | ||
) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
tests/test_analysis/test_change_points_trend/test_find_change_points.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Dict | ||
from typing import List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from ruptures import Binseg | ||
|
||
from etna.analysis import find_change_points | ||
from etna.datasets import TSDataset | ||
|
||
|
||
def check_change_points(change_points: Dict[str, List[pd.Timestamp]], segments: List[str], num_points: int): | ||
"""Check change points on validity.""" | ||
assert isinstance(change_points, dict) | ||
assert set(change_points.keys()) == set(segments) | ||
for segment in segments: | ||
change_points_segment = change_points[segment] | ||
assert len(change_points_segment) == num_points | ||
for point in change_points_segment: | ||
assert isinstance(point, pd.Timestamp) | ||
|
||
|
||
@pytest.mark.parametrize("n_bkps", [5, 10, 12, 27]) | ||
def test_find_change_points_simple(multitrend_df: pd.DataFrame, n_bkps: int): | ||
"""Test that find_change_points works fine with multitrend example.""" | ||
ts = TSDataset(df=multitrend_df, freq="D") | ||
change_points = find_change_points(ts=ts, in_column="target", change_point_model=Binseg(), n_bkps=n_bkps) | ||
check_change_points(change_points, segments=ts.segments, num_points=n_bkps) | ||
|
||
|
||
@pytest.mark.parametrize("n_bkps", [5, 10, 12, 27]) | ||
def test_find_change_points_nans_head(multitrend_df: pd.DataFrame, n_bkps: int): | ||
"""Test that find_change_points works fine with nans at the beginning of the series.""" | ||
multitrend_df.iloc[:5, :] = np.NaN | ||
ts = TSDataset(df=multitrend_df, freq="D") | ||
change_points = find_change_points(ts=ts, in_column="target", change_point_model=Binseg(), n_bkps=n_bkps) | ||
check_change_points(change_points, segments=ts.segments, num_points=n_bkps) | ||
|
||
|
||
@pytest.mark.parametrize("n_bkps", [5, 10, 12, 27]) | ||
def test_find_change_points_nans_tail(multitrend_df: pd.DataFrame, n_bkps: int): | ||
"""Test that find_change_points works fine with nans at the end of the series.""" | ||
multitrend_df.iloc[-5:, :] = np.NaN | ||
ts = TSDataset(df=multitrend_df, freq="D") | ||
change_points = find_change_points(ts=ts, in_column="target", change_point_model=Binseg(), n_bkps=n_bkps) | ||
check_change_points(change_points, segments=ts.segments, num_points=n_bkps) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make the order of parameters consistent with the order of parameters in signature(or vice versa)