-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BinsegTrendTransform, ChangePointsTrendTransform (#87)
- Loading branch information
1 parent
9f0ff82
commit 014e6b6
Showing
7 changed files
with
581 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from functools import lru_cache | ||
from typing import Any | ||
from typing import Optional | ||
|
||
from ruptures.base import BaseCost | ||
from ruptures.costs import cost_factory | ||
from ruptures.detection import Binseg | ||
from sklearn.linear_model import LinearRegression | ||
|
||
from etna.transforms.change_points_trend import ChangePointsTrendTransform | ||
from etna.transforms.change_points_trend import TDetrendModel | ||
|
||
|
||
class _Binseg(Binseg): | ||
"""Binary segmentation with lru_cache.""" | ||
|
||
def __init__( | ||
self, | ||
model: str = "l2", | ||
custom_cost: Optional[BaseCost] = None, | ||
min_size: int = 2, | ||
jump: int = 5, | ||
params: Any = None, | ||
): | ||
"""Initialize a Binseg instance. | ||
Args: | ||
model (str, optional): segment model, ["l1", "l2", "rbf",...]. Not used if ``'custom_cost'`` is not None. | ||
custom_cost (BaseCost, optional): custom cost function. Defaults to None. | ||
min_size (int, optional): minimum segment length. Defaults to 2 samples. | ||
jump (int, optional): subsample (one every *jump* points). Defaults to 5 samples. | ||
params (dict, optional): a dictionary of parameters for the cost instance. | ||
""" | ||
if custom_cost is not None and isinstance(custom_cost, BaseCost): | ||
self.cost = custom_cost | ||
elif params is None: | ||
self.cost = cost_factory(model=model) | ||
else: | ||
self.cost = cost_factory(model=model, **params) | ||
self.min_size = max(min_size, self.cost.min_size) | ||
self.jump = jump | ||
self.n_samples = None | ||
self.signal = None | ||
|
||
@lru_cache(maxsize=None) | ||
def single_bkp(self, start: int, end: int) -> Any: | ||
"""Run _single_bkp with lru_cache decorator.""" | ||
return self._single_bkp(start=start, end=end) | ||
|
||
|
||
class BinsegTrendTransform(ChangePointsTrendTransform): | ||
"""BinsegTrendTransform uses _Binseg model as a change point detection model in ChangePointsTrendTransform transform.""" | ||
|
||
def __init__( | ||
self, | ||
in_column: str, | ||
detrend_model: TDetrendModel = LinearRegression(), | ||
model: str = "ar", | ||
custom_cost: Optional[BaseCost] = None, | ||
min_size: int = 2, | ||
jump: int = 1, | ||
n_bkps: int = 5, | ||
pen: Optional[float] = None, | ||
epsilon: Optional[float] = None, | ||
): | ||
"""Init BinsegTrendTransform. | ||
Parameters | ||
---------- | ||
in_column: | ||
name of column to apply transform to | ||
detrend_model: | ||
model to get trend in data | ||
model: | ||
binseg segment model, ["l1", "l2", "rbf",...]. Not used if 'custom_cost' is not None. | ||
custom_cost: | ||
binseg custom cost function | ||
min_size: | ||
minimum segment length necessary to decide it is a stable trend segment | ||
jump: | ||
jump value can speed up computations: if jump==k, the algo will use every k-th value for change points search. | ||
n_bkps: | ||
number of change points to find | ||
pen: | ||
penalty value (>0) | ||
epsilon: | ||
reconstruction budget (>0) | ||
""" | ||
self.model = model | ||
self.custom_cost = custom_cost | ||
self.min_size = min_size | ||
self.jump = jump | ||
self.n_bkps = n_bkps | ||
self.pen = pen | ||
self.epsilon = epsilon | ||
super().__init__( | ||
in_column=in_column, | ||
change_point_model=_Binseg( | ||
model=self.model, custom_cost=self.custom_cost, min_size=self.min_size, jump=self.jump | ||
), | ||
detrend_model=detrend_model, | ||
n_bkps=self.n_bkps, | ||
pen=self.pen, | ||
epsilon=self.epsilon, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
from copy import deepcopy | ||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from ruptures.base import BaseEstimator | ||
from ruptures.costs import CostLinear | ||
from sklearn.base import RegressorMixin | ||
|
||
from etna.transforms.base import PerSegmentWrapper | ||
from etna.transforms.base import Transform | ||
|
||
TTimestampInterval = Tuple[pd.Timestamp, pd.Timestamp] | ||
TDetrendModel = RegressorMixin | ||
|
||
|
||
class _OneSegmentChangePointsTrendTransform(Transform): | ||
"""_OneSegmentChangePointsTransform subtracts multiple linear trend from series.""" | ||
|
||
def __init__( | ||
self, | ||
in_column: str, | ||
change_point_model: BaseEstimator, | ||
detrend_model: TDetrendModel, | ||
**change_point_model_predict_params, | ||
): | ||
"""Init _OneSegmentChangePointsTrendTransform. | ||
Parameters | ||
---------- | ||
in_column: | ||
name of column to apply transform to | ||
change_point_model: | ||
model to get trend change points | ||
detrend_model: | ||
model to get trend in data | ||
change_point_model_predict_params: | ||
params for change_point_model predict method | ||
""" | ||
self.in_column = in_column | ||
self.out_columns = in_column | ||
self.change_point_model = change_point_model | ||
self.detrend_model = detrend_model | ||
self.per_interval_models: Optional[Dict[TTimestampInterval, TDetrendModel]] = None | ||
self.intervals: Optional[List[TTimestampInterval]] = None | ||
self.change_point_model_predict_params = change_point_model_predict_params | ||
|
||
def _prepare_signal(self, series: pd.Series) -> np.array: | ||
"""Prepare series for change point model.""" | ||
signal = series.to_numpy() | ||
if isinstance(self.change_point_model.cost, CostLinear): | ||
signal = signal.reshape((-1, 1)) | ||
return signal | ||
|
||
def _get_change_points(self, series: pd.Series) -> List[pd.Timestamp]: | ||
"""Fit change point model with series data and predict trens change points.""" | ||
signal = self._prepare_signal(series=series) | ||
timestamp = series.index | ||
self.change_point_model.fit(signal=signal) | ||
# last point in change points is the first index after the series | ||
change_points_indices = self.change_point_model.predict(**self.change_point_model_predict_params)[:-1] | ||
change_points = [timestamp[idx] for idx in change_points_indices] | ||
return change_points | ||
|
||
@staticmethod | ||
def _build_trend_intervals(change_points: List[pd.Timestamp]) -> List[TTimestampInterval]: | ||
"""Create list of stable trend intervals from list of change points.""" | ||
change_points = sorted(change_points) | ||
left_border = pd.Timestamp.min | ||
intervals = [] | ||
for point in change_points: | ||
right_border = point | ||
intervals.append((left_border, right_border)) | ||
left_border = right_border | ||
intervals.append((left_border, pd.Timestamp.max)) | ||
return intervals | ||
|
||
def _init_detrend_models( | ||
self, intervals: List[TTimestampInterval] | ||
) -> Dict[Tuple[pd.Timestamp, pd.Timestamp], TDetrendModel]: | ||
"""Create copy of detrend model for each timestamp interval.""" | ||
per_interval_models = {interval: deepcopy(self.detrend_model) for interval in intervals} | ||
return per_interval_models | ||
|
||
def _get_timestamps(self, series: pd.Series) -> np.ndarray: | ||
"""Convert ETNA timestamp-index to a list of timestamps to fit regression models.""" | ||
timestamps = series.index | ||
timestamps = np.array([[ts.timestamp()] for ts in timestamps]) | ||
return timestamps | ||
|
||
def _fit_per_interval_model(self, series: pd.Series): | ||
"""Fit per-interval models with corresponding data from series.""" | ||
for interval in self.intervals: | ||
tmp_series = series[interval[0] : interval[1]] | ||
x = self._get_timestamps(series=tmp_series) | ||
y = tmp_series.values | ||
self.per_interval_models[interval].fit(x, y) | ||
|
||
def _predict_per_interval_model(self, series: pd.Series) -> pd.Series: | ||
"""Apply per-interval detrending to series.""" | ||
trend_series = pd.Series(index=series.index) | ||
for interval in self.intervals: | ||
tmp_series = series[interval[0] : interval[1]] | ||
if tmp_series.empty: | ||
continue | ||
x = self._get_timestamps(series=tmp_series) | ||
trend = self.per_interval_models[interval].predict(x) | ||
trend_series[tmp_series.index] = trend | ||
return trend_series | ||
|
||
def fit(self, df: pd.DataFrame) -> "OneSegmentChangePointsTransform": | ||
"""Fit OneSegmentChangePointsTransform: find thrend change points in df, fit detrend models with data from intervals of stable trend. | ||
Parameters | ||
---------- | ||
df: | ||
one segment dataframe indexed with timestamp | ||
Returns | ||
------- | ||
self | ||
""" | ||
# we need copy here because Binseg with CostAR (model="ar") changes given signal inplace; if it is fixed | ||
# @TODO: delete copy | ||
series = df.loc[df[self.in_column].first_valid_index() :, self.in_column].copy(deep=True) | ||
change_points = self._get_change_points(series=series) | ||
self.intervals = self._build_trend_intervals(change_points=change_points) | ||
self.per_interval_models = self._init_detrend_models(intervals=self.intervals) | ||
self._fit_per_interval_model(series=series) | ||
return self | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Split df to intervals of stable trend and subtract trend from each one. | ||
Parameters | ||
---------- | ||
df: | ||
one segment dataframe to subtract trend | ||
Returns | ||
------- | ||
detrended df: pd.DataFrame | ||
df with detrended in_column series | ||
""" | ||
df._is_copy = False | ||
series = df.loc[df[self.in_column].first_valid_index() :, self.in_column] | ||
trend_series = self._predict_per_interval_model(series=series) | ||
df.loc[:, self.in_column] -= trend_series | ||
return df | ||
|
||
def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Split df to intervals of stable trend according to previous change point detection and add trend to each one. | ||
Parameters | ||
---------- | ||
df: | ||
one segment dataframe to turn trend back | ||
Returns | ||
------- | ||
df: pd.DataFrame | ||
df with restored trend in in_column | ||
""" | ||
df._is_copy = False | ||
series = df.loc[df[self.in_column].first_valid_index() :, self.in_column] | ||
trend_series = self._predict_per_interval_model(series=series) | ||
df.loc[:, self.in_column] += trend_series | ||
return df | ||
|
||
|
||
class ChangePointsTrendTransform(PerSegmentWrapper): | ||
"""ChangePointsTrendTransform subtracts multiple linear trend from series.""" | ||
|
||
def __init__( | ||
self, | ||
in_column: str, | ||
change_point_model: BaseEstimator, | ||
detrend_model: TDetrendModel, | ||
**change_point_model_predict_params, | ||
): | ||
"""Init ChangePointsTrendTransform. | ||
Parameters | ||
---------- | ||
in_column: | ||
name of column to apply transform to | ||
change_point_model: | ||
model to get trend change points | ||
detrend_model: | ||
model to get trend in data | ||
change_point_model_predict_params: | ||
params for change_point_model predict method | ||
""" | ||
self.in_column = in_column | ||
self.change_point_model = change_point_model | ||
self.detrend_model = detrend_model | ||
self.change_point_model_predict_params = change_point_model_predict_params | ||
super().__init__( | ||
transform=_OneSegmentChangePointsTrendTransform( | ||
in_column=self.in_column, | ||
change_point_model=self.change_point_model, | ||
detrend_model=self.detrend_model, | ||
**self.change_point_model_predict_params, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.