Skip to content

Commit

Permalink
Function to transform etna objects to_dict (#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
scanhex12 authored Aug 8, 2022
1 parent 3c036d2 commit c0a21a7
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
-
- Function to transform etna objects to dict([#818](https://github.com/tinkoff-ai/etna/issues/818))
-
- `DeadlineMovingAverageModel` ([#827](https://github.com/tinkoff-ai/etna/pull/827))
- `DirectEnsemble` ([#824](https://github.com/tinkoff-ai/etna/pull/824))
Expand Down
47 changes: 47 additions & 0 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import inspect
import warnings
from enum import Enum
from typing import Any
from typing import Dict
from typing import List

from sklearn.base import BaseEstimator


class BaseMixin:
Expand All @@ -26,6 +31,48 @@ def __repr__(self):
args_str_representation += f"{arg} = {repr(value)}, "
return f"{self.__class__.__name__}({args_str_representation})"

@staticmethod
def _get_target_from_class(value: Any):
if value is None:
return None
return str(value.__module__) + "." + str(value.__class__.__name__)

@staticmethod
def _parse_value(value: Any) -> Any:
if isinstance(value, BaseMixin):
return value.to_dict()
elif isinstance(value, BaseEstimator):
answer = {}
answer["_target_"] = BaseMixin._get_target_from_class(value)
model_parameters = value.get_params()
answer.update(model_parameters)
return answer
elif isinstance(value, (str, float, int)):
return value
elif isinstance(value, List):
return [BaseMixin._parse_value(elem) for elem in value]
elif isinstance(value, tuple):
return tuple([BaseMixin._parse_value(elem) for elem in value])
elif isinstance(value, Dict):
return {key: BaseMixin._parse_value(item) for key, item in value.items()}
else:
answer = {}
answer["_target_"] = BaseMixin._get_target_from_class(value)
warnings.warn("Some of external objects in input parameters could be not written in dict")
return answer

def to_dict(self):
"""Collect all information about etna object in dict."""
init_args = inspect.signature(self.__init__).parameters
params = {}
for arg in init_args.keys():
value = self.__dict__[arg]
if value is None:
continue
params[arg] = BaseMixin._parse_value(value=value)
params["_target_"] = BaseMixin._get_target_from_class(self)
return params


class StringEnumWithRepr(str, Enum):
"""Base class for str enums, that has alternative __repr__ method."""
Expand Down
Empty file added tests/test_core/__init__.py
Empty file.
137 changes: 137 additions & 0 deletions tests/test_core/test_to_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import json
import pickle

import hydra_slayer
import pytest
from ruptures import Binseg
from sklearn.linear_model import LinearRegression

from etna.core import BaseMixin
from etna.ensembles import StackingEnsemble
from etna.ensembles import VotingEnsemble
from etna.metrics import MAE
from etna.metrics import SMAPE
from etna.models import AutoARIMAModel
from etna.models import CatBoostModelPerSegment
from etna.models import LinearPerSegmentModel
from etna.models.nn import DeepARModel
from etna.pipeline import Pipeline
from etna.transforms import AddConstTransform
from etna.transforms import ChangePointsTrendTransform
from etna.transforms import LambdaTransform
from etna.transforms import LogTransform


def ensemble_samples():
pipeline1 = Pipeline(
model=CatBoostModelPerSegment(),
transforms=[
AddConstTransform(in_column="target", value=10),
ChangePointsTrendTransform(
in_column="target", change_point_model=Binseg(), detrend_model=LinearRegression(), n_bkps=50
),
],
horizon=5,
)
pipeline2 = Pipeline(
model=LinearPerSegmentModel(),
transforms=[
ChangePointsTrendTransform(
in_column="target", change_point_model=Binseg(), detrend_model=LinearRegression(), n_bkps=50
),
LogTransform(in_column="target"),
],
horizon=5,
)
return [pipeline1, pipeline2]


@pytest.mark.parametrize(
"target_object",
[
AddConstTransform(in_column="target", value=10),
ChangePointsTrendTransform(
in_column="target", change_point_model=Binseg(), detrend_model=LinearRegression(), n_bkps=50
),
pytest.param(
LambdaTransform(in_column="target", transform_func=lambda x: x - 2, inverse_transform_func=lambda x: x + 2),
marks=pytest.mark.xfail(reason="some bug"),
),
],
)
def test_to_dict_transforms(target_object):
dict_object = target_object.to_dict()
transformed_object = hydra_slayer.get_from_params(**dict_object)
assert json.loads(json.dumps(dict_object)) == dict_object
assert pickle.dumps(transformed_object) == pickle.dumps(target_object)


@pytest.mark.parametrize(
"target_model",
[
pytest.param(DeepARModel(), marks=pytest.mark.xfail(reason="some bug")),
LinearPerSegmentModel(),
CatBoostModelPerSegment(),
AutoARIMAModel(),
],
)
def test_to_dict_models(target_model):
dict_object = target_model.to_dict()
transformed_object = hydra_slayer.get_from_params(**dict_object)
assert json.loads(json.dumps(dict_object)) == dict_object
assert pickle.dumps(transformed_object) == pickle.dumps(target_model)


@pytest.mark.parametrize(
"target_object",
[
Pipeline(
model=CatBoostModelPerSegment(),
transforms=[
AddConstTransform(in_column="target", value=10),
ChangePointsTrendTransform(
in_column="target", change_point_model=Binseg(), detrend_model=LinearRegression(), n_bkps=50
),
],
horizon=5,
)
],
)
def test_to_dict_pipeline(target_object):
dict_object = target_object.to_dict()
transformed_object = hydra_slayer.get_from_params(**dict_object)
assert json.loads(json.dumps(dict_object)) == dict_object
assert pickle.dumps(transformed_object) == pickle.dumps(target_object)


@pytest.mark.parametrize("target_object", [MAE(mode="macro"), SMAPE()])
def test_to_dict_metrics(target_object):
dict_object = target_object.to_dict()
transformed_object = hydra_slayer.get_from_params(**dict_object)
assert json.loads(json.dumps(dict_object)) == dict_object
assert pickle.dumps(transformed_object) == pickle.dumps(target_object)


@pytest.mark.parametrize(
"target_ensemble",
[VotingEnsemble(pipelines=ensemble_samples(), weights=[0.4, 0.6]), StackingEnsemble(pipelines=ensemble_samples())],
)
def test_ensembles(target_ensemble):
dict_object = target_ensemble.to_dict()
transformed_object = hydra_slayer.get_from_params(**dict_object)
assert json.loads(json.dumps(dict_object)) == dict_object
assert pickle.dumps(transformed_object) == pickle.dumps(target_ensemble)


class _Dummy:
pass


class _InvalidParsing(BaseMixin):
def __init__(self, a: _Dummy):
self.a = a


def test_warnings():
with pytest.warns(Warning, match="Some of external objects in input parameters could be not written in dict"):
_ = _InvalidParsing(_Dummy()).to_dict()

1 comment on commit c0a21a7

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.