Skip to content

Update mypy and make AbstractSaveable.save return Self type #1261

Merged
merged 8 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
### Changed
- Set the default value of `final_model` to `LinearRegression(positive=True)` in the constructor of `StackingEnsemble` ([#1238](https://github.com/tinkoff-ai/etna/pull/1238))
- Add microseconds to `FileLogger`'s directory name ([#947](https://github.com/tinkoff-ai/etna/pull/947))
- Add microseconds to `FileLogger`'s directory name ([#1264](https://github.com/tinkoff-ai/etna/pull/1264))
- Inherit `SaveMixin` from `AbstractSaveable` for mypy checker ([#1261](https://github.com/tinkoff-ai/etna/pull/1261))
-
-
### Fixed
Expand Down
9 changes: 6 additions & 3 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import hydra_slayer
from sklearn.base import BaseEstimator
from typing_extensions import Self

from etna.core.saving import AbstractSaveable


class BaseMixin:
Expand Down Expand Up @@ -163,7 +166,7 @@ def get_etna_version() -> Tuple[int, int, int]:
return result


class SaveMixin:
class SaveMixin(AbstractSaveable):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't work without this somehow.

"""Basic implementation of ``AbstractSaveable`` abstract class.

It saves object to the zip archive with 2 files:
Expand Down Expand Up @@ -223,12 +226,12 @@ def _validate_metadata(cls, metadata: Dict[str, Any]):
)

@classmethod
def _load_state(cls, archive: zipfile.ZipFile) -> Any:
def _load_state(cls, archive: zipfile.ZipFile) -> Self:
with archive.open("object.pkl", "r") as input_file:
return pickle.load(input_file)

@classmethod
def load(cls, path: pathlib.Path) -> Any:
def load(cls, path: pathlib.Path) -> Self:
"""Load an object.

Parameters
Expand Down
5 changes: 3 additions & 2 deletions etna/core/saving.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pathlib
from abc import ABC
from abc import abstractmethod
from typing import Any

from typing_extensions import Self


class AbstractSaveable(ABC):
Expand All @@ -20,7 +21,7 @@ def save(self, path: pathlib.Path):

@classmethod
@abstractmethod
def load(cls, path: pathlib.Path) -> Any:
def load(cls, path: pathlib.Path) -> Self:
"""Load an object.

Parameters
Expand Down
4 changes: 2 additions & 2 deletions etna/ensembles/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import tempfile
import zipfile
from copy import deepcopy
from typing import Any
from typing import List
from typing import Optional

import pandas as pd
from typing_extensions import Self

from etna.core import SaveMixin
from etna.core import load
Expand Down Expand Up @@ -113,7 +113,7 @@ def save(self, path: pathlib.Path):
archive.write(pipeline_save_path, f"pipelines/{save_name}")

@classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any:
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.

Parameters
Expand Down
3 changes: 3 additions & 0 deletions etna/experimental/classification/feature_extraction/weasel.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def transform(self, x: List[np.ndarray]) -> np.ndarray:
:
Transformed input data.
"""
if self._padding_expected_len is None:
raise ValueError("Transform is not fitted!")

n_samples = len(x)
window_sizes, window_steps = self._check_params(self._min_series_len)
for i in range(len(x)):
Expand Down
3 changes: 1 addition & 2 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import pandas as pd

from etna import SETTINGS
from etna.core import AbstractSaveable
from etna.core import SaveMixin
from etna.core.mixins import BaseMixin
from etna.datasets.tsdataset import TSDataset
Expand All @@ -36,7 +35,7 @@
SaveNNMixin = Mock # type: ignore


class AbstractModel(SaveMixin, AbstractSaveable, ABC, BaseMixin):
class AbstractModel(SaveMixin, ABC, BaseMixin):
"""Interface for model with fit method."""

@property
Expand Down
3 changes: 2 additions & 1 deletion etna/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dill
import numpy as np
import pandas as pd
from typing_extensions import Self

from etna.core.mixins import SaveMixin
from etna.datasets.tsdataset import TSDataset
Expand Down Expand Up @@ -645,7 +646,7 @@ def _save_state(self, archive: zipfile.ZipFile):
torch.save(self, output_file, pickle_module=dill)

@classmethod
def _load_state(cls, archive: zipfile.ZipFile) -> Any:
def _load_state(cls, archive: zipfile.ZipFile) -> Self:
import torch

with archive.open("object.pt", "r") as input_file:
Expand Down
4 changes: 2 additions & 2 deletions etna/pipeline/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import tempfile
import zipfile
from copy import deepcopy
from typing import Any
from typing import Optional
from typing import Sequence

import numpy as np
import pandas as pd
from typing_extensions import Self
from typing_extensions import get_args

from etna.core import SaveMixin
Expand Down Expand Up @@ -169,7 +169,7 @@ def save(self, path: pathlib.Path):
archive.write(transform_save_path, f"transforms/{save_name}")

@classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any:
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.

Parameters
Expand Down
3 changes: 1 addition & 2 deletions etna/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pandas as pd
from typing_extensions import Literal

from etna.core import AbstractSaveable
from etna.core import BaseMixin
from etna.core import SaveMixin
from etna.datasets import TSDataset
Expand All @@ -21,7 +20,7 @@ class FutureMixin:
"""Mixin for transforms that can convert non-regressor column to a regressor one."""


class Transform(SaveMixin, AbstractSaveable, BaseMixin):
class Transform(SaveMixin, BaseMixin):
"""Base class to create any transforms to apply to data."""

def __init__(self, required_features: Union[Literal["all"], List[str]]):
Expand Down
58 changes: 31 additions & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pep8-naming = {version = "^0.12.1", optional = true}
flake8-bugbear = {version = "^22.4.25", optional = true}
flake8-comprehensions = {version = "^3.9.0", optional = true}
flake8-docstrings = {version = "^1.6.0", optional = true}
mypy = {version = "^0.950", optional = true}
mypy = {version = ">=0.950,<2", optional = true}
types-PyYAML = {version = "^6.0.0", optional = true}
codespell = {version = "^2.0.0", optional = true}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def test_windowed_view(many_time_series, window_size, window_step, expected, req
np.testing.assert_array_equal(n_windows_per_sample_cum, n_windows_per_sample_cum_expected)


def test_not_fitted(many_time_series):
x, y = many_time_series
feature_extractor = WEASELFeatureExtractor(padding_value=0, window_sizes=[10, 15])
with pytest.raises(ValueError, match="Transform is not fitted"):
_ = feature_extractor.transform(x)


def test_preprocessor_and_classifier(many_time_series_big):
x, y = many_time_series_big
model = LogisticRegression()
Expand Down