Skip to content

Commit

Permalink
Update ruptures version (#141)
Browse files Browse the repository at this point in the history
* Update ruptures version

* Fix doccstring

* Upd CHANGELOG
  • Loading branch information
julia-shenshina authored and alex-hse-repository committed Oct 7, 2021
1 parent 6a73abc commit ed53e5f
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 90 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update EDA notebook ([#96](https://github.com/tinkoff-ai/etna-ts/pull/96), [#114](https://github.com/tinkoff-ai/etna-ts/pull/114))
- Delete offset from WindowStatisticsTransform ([#111](https://github.com/tinkoff-ai/etna-ts/pull/111))
- Add Pipeline example in Get started notebook ([#115](https://github.com/tinkoff-ai/etna-ts/pull/115))
- Internal implementation of BinsegTrendTransform ([#141](https://github.com/tinkoff-ai/etna-ts/pull/141))

### Fixed
- Add more obvious Exception Error for forecasting with unfitted model ([#102](https://github.com/tinkoff-ai/etna-ts/pull/102))
Expand Down
44 changes: 2 additions & 42 deletions etna/transforms/binseg.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,15 @@
from functools import lru_cache
from typing import Any
from typing import Optional

from ruptures.base import BaseCost
from ruptures.costs import cost_factory
from ruptures.detection import Binseg
from sklearn.linear_model import LinearRegression

from etna.transforms.change_points_trend import ChangePointsTrendTransform
from etna.transforms.change_points_trend import TDetrendModel


class _Binseg(Binseg):
"""Binary segmentation with lru_cache."""

def __init__(
self,
model: str = "l2",
custom_cost: Optional[BaseCost] = None,
min_size: int = 2,
jump: int = 5,
params: Any = None,
):
"""Initialize a Binseg instance.
Args:
model (str, optional): segment model, ["l1", "l2", "rbf",...]. Not used if ``'custom_cost'`` is not None.
custom_cost (BaseCost, optional): custom cost function. Defaults to None.
min_size (int, optional): minimum segment length. Defaults to 2 samples.
jump (int, optional): subsample (one every *jump* points). Defaults to 5 samples.
params (dict, optional): a dictionary of parameters for the cost instance.
"""
if custom_cost is not None and isinstance(custom_cost, BaseCost):
self.cost = custom_cost
elif params is None:
self.cost = cost_factory(model=model)
else:
self.cost = cost_factory(model=model, **params)
self.min_size = max(min_size, self.cost.min_size)
self.jump = jump
self.n_samples = None
self.signal = None

@lru_cache(maxsize=None)
def single_bkp(self, start: int, end: int) -> Any:
"""Run _single_bkp with lru_cache decorator."""
return self._single_bkp(start=start, end=end)


class BinsegTrendTransform(ChangePointsTrendTransform):
"""BinsegTrendTransform uses _Binseg model as a change point detection model in ChangePointsTrendTransform transform."""
"""BinsegTrendTransform uses Binseg model as a change point detection model in ChangePointsTrendTransform transform."""

def __init__(
self,
Expand Down Expand Up @@ -95,7 +55,7 @@ def __init__(
self.epsilon = epsilon
super().__init__(
in_column=in_column,
change_point_model=_Binseg(
change_point_model=Binseg(
model=self.model, custom_cost=self.custom_cost, min_size=self.min_size, jump=self.jump
),
detrend_model=detrend_model,
Expand Down
4 changes: 1 addition & 3 deletions etna/transforms/change_points_trend.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def fit(self, df: pd.DataFrame) -> "OneSegmentChangePointsTransform":
-------
self
"""
# we need copy here because Binseg with CostAR (model="ar") changes given signal inplace; if it is fixed
# @TODO: delete copy
series = df.loc[df[self.in_column].first_valid_index() :, self.in_column].copy(deep=True)
series = df.loc[df[self.in_column].first_valid_index() :, self.in_column]
change_points = self._get_change_points(series=series)
self.intervals = self._build_trend_intervals(change_points=change_points)
self.per_interval_models = self._init_detrend_models(intervals=self.intervals)
Expand Down
64 changes: 31 additions & 33 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ scikit-learn = "^0.24.1"
prophet = "^1.0"
pandas = "^1"
catboost = "^0.25"
ruptures = "1.1.3"
ruptures = "1.1.5"
torch = "1.8.*"
pytorch-forecasting = "0.8.5"
numba = "^0.53.1"
Expand Down
Loading

0 comments on commit ed53e5f

Please sign in to comment.