Skip to content

Inherit all PerSegment models from PerSegmentModel #543

Merged
merged 7 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add find_change_points function ([#521](https://github.com/tinkoff-ai/etna/pull/521))
-
- Add plot_residuals ([#539](https://github.com/tinkoff-ai/etna/pull/539))

-
- Create `PerSegmentBaseModel`, `PerSegmentPredictionIntervalModel` ([#537](https://github.com/tinkoff-ai/etna/pull/537))
-
### Changed
- Change the way `ProphetModel` works with regressors ([#383](https://github.com/tinkoff-ai/etna/pull/383))
- Change the way `SARIMAXModel` works with regressors ([#380](https://github.com/tinkoff-ai/etna/pull/380))
- Change the way `Sklearn` models works with regressors ([#440](https://github.com/tinkoff-ai/etna/pull/440))
- Change the way `FeatureSelectionTransform` works with regressors, rename variables replacing the "regressor" to "feature" ([#522](https://github.com/tinkoff-ai/etna/pull/522))
-
-
-
-
- Installation instruction ([#526](https://github.com/tinkoff-ai/etna/pull/526))
-
-
- Trainer kwargs for deep models ([#540](https://github.com/tinkoff-ai/etna/pull/540))
- Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536))

-
- Rename `_CatBoostModel`, `_HoltWintersModel`, `_SklearnModel` ([#543](https://github.com/tinkoff-ai/etna/pull/543))
-
### Fixed
- Fix `TSDataset._update_regressors` logic removing the regressors ([#489](https://github.com/tinkoff-ai/etna/pull/489))
- Fix `TSDataset.info`, `TSDataset.describe` methods ([#519](https://github.com/tinkoff-ai/etna/pull/519))
Expand Down
2 changes: 1 addition & 1 deletion etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def fit(self, ts: TSDataset) -> "PerSegmentBaseModel":
segment_features = segment_features.dropna()
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
model.fit(df=segment_features)
model.fit(df=segment_features, regressors=ts.regressors)
return self

def get_model(self) -> Dict[str, Any]:
Expand Down
11 changes: 6 additions & 5 deletions etna/models/catboost.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from typing import Optional

import numpy as np
Expand All @@ -11,7 +12,7 @@
from etna.models.base import log_decorator


class _CatBoostModel:
class _CatBoostAdapter:
def __init__(
self,
iterations: Optional[int] = None,
Expand All @@ -34,7 +35,7 @@ def __init__(
)
self._categorical = None

def fit(self, df: pd.DataFrame) -> "_CatBoostModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_CatBoostAdapter":
features = df.drop(columns=["timestamp", "target"])
target = df["target"]
self._categorical = features.select_dtypes(include=["category"]).columns.to_list()
Expand Down Expand Up @@ -150,7 +151,7 @@ def __init__(
self.thread_count = thread_count
self.kwargs = kwargs
super(CatBoostModelPerSegment, self).__init__(
base_model=_CatBoostModel(
base_model=_CatBoostAdapter(
iterations=iterations,
depth=depth,
learning_rate=learning_rate,
Expand Down Expand Up @@ -263,7 +264,7 @@ def __init__(
self.thread_count = thread_count
self.kwargs = kwargs
super(CatBoostModelMultiSegment, self).__init__()
self._base_model = _CatBoostModel(
self._base_model = _CatBoostAdapter(
iterations=iterations,
depth=depth,
learning_rate=learning_rate,
Expand All @@ -279,7 +280,7 @@ def fit(self, ts: TSDataset) -> "CatBoostModelMultiSegment":
df = ts.to_pandas(flatten=True)
df = df.dropna()
df = df.drop(columns="segment")
self._base_model.fit(df=df)
self._base_model.fit(df=df, regressors=ts.regressors)
return self

@log_decorator
Expand Down
9 changes: 5 additions & 4 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from datetime import datetime
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
Expand All @@ -13,7 +14,7 @@
from etna.models.base import PerSegmentModel


class _HoltWintersModel:
class _HoltWintersAdapter:
"""
Class for holding Holt-Winters' exponential smoothing model.

Expand Down Expand Up @@ -168,7 +169,7 @@ def __init__(
self._model: Optional[ExponentialSmoothing] = None
self._result: Optional[HoltWintersResults] = None

def fit(self, df: pd.DataFrame) -> "_HoltWintersModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_HoltWintersAdapter":
"""
Fits a Holt-Winters' model.

Expand All @@ -179,7 +180,7 @@ def fit(self, df: pd.DataFrame) -> "_HoltWintersModel":

Returns
-------
self: _HoltWintersModel
self: _HoltWintersAdapter
fitted model
"""
self._check_df(df)
Expand Down Expand Up @@ -396,7 +397,7 @@ def __init__(
self.damping_trend = damping_trend
self.fit_kwargs = fit_kwargs
super().__init__(
base_model=_HoltWintersModel(
base_model=_HoltWintersAdapter(
trend=self.trend,
damped_trend=self.damped_trend,
seasonal=self.seasonal,
Expand Down
3 changes: 2 additions & 1 deletion etna/models/seasonal_ma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(self, window: int = 5, seasonality: int = 7):
self.seasonality = seasonality
self.shift = self.window * self.seasonality

def fit(self, df: pd.DataFrame) -> "_SeasonalMovingAverageModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SeasonalMovingAverageModel":
"""
Fitting simple model on given series.

Expand Down
23 changes: 4 additions & 19 deletions etna/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from etna.models.base import log_decorator


class _SklearnModel:
class _SklearnAdapter:
def __init__(self, regressor: RegressorMixin):
self.model = regressor
self.regressor_columns: Optional[List[str]] = None

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SklearnModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SklearnAdapter":
self.regressor_columns = regressors
try:
features = df[self.regressor_columns].apply(pd.to_numeric)
Expand Down Expand Up @@ -47,22 +47,7 @@ def __init__(self, regressor: RegressorMixin):
regressor:
sklearn model for regression
"""
super().__init__(base_model=_SklearnModel(regressor=regressor))

@log_decorator
def fit(self, ts: TSDataset) -> "SklearnPerSegmentModel":
"""Fit model."""
self._segments = ts.segments
self._build_models()

for segment in self._segments:
model = self._models[segment] # type: ignore
segment_features = ts[:, segment, :]
segment_features = segment_features.dropna()
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
model.fit(df=segment_features, regressors=ts.regressors)
return self
super().__init__(base_model=_SklearnAdapter(regressor=regressor))


class SklearnMultiSegmentModel(Model):
Expand All @@ -78,7 +63,7 @@ def __init__(self, regressor: RegressorMixin):
sklearn model for regression
"""
super().__init__()
self._base_model = _SklearnModel(regressor=regressor)
self._base_model = _SklearnAdapter(regressor=regressor)

@log_decorator
def fit(self, ts: TSDataset) -> "SklearnMultiSegmentModel":
Expand Down
3 changes: 1 addition & 2 deletions tests/test_models/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def linear_segments_ts_common(random_seed):
return linear_segments_by_parameters(alpha_values, intercept_values)


@pytest.mark.xfail
@pytest.mark.parametrize("model", (LinearPerSegmentModel(), ElasticPerSegmentModel()))
def test_not_fitted(model, linear_segments_ts_unique):
"""Check exception when trying to forecast with unfitted model."""
Expand All @@ -87,7 +86,7 @@ def test_not_fitted(model, linear_segments_ts_unique):
train.fit_transform([lags])

to_forecast = train.make_future(3)
with pytest.raises(ValueError, match="model is not fitted"):
with pytest.raises(ValueError, match="not fitted model!"):
model.forecast(to_forecast)


Expand Down