diff --git a/CHANGELOG.md b/CHANGELOG.md index 28bf87c83..285f73d15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Changed - Set the default value of `final_model` to `LinearRegression(positive=True)` in the constructor of `StackingEnsemble` ([#1238](https://github.com/tinkoff-ai/etna/pull/1238)) -- Add microseconds to `FileLogger`'s directory name ([#947](https://github.com/tinkoff-ai/etna/pull/947)) +- Add microseconds to `FileLogger`'s directory name ([#1264](https://github.com/tinkoff-ai/etna/pull/1264)) +- Inherit `SaveMixin` from `AbstractSaveable` for mypy checker ([#1261](https://github.com/tinkoff-ai/etna/pull/1261)) - - ### Fixed diff --git a/etna/core/mixins.py b/etna/core/mixins.py index 538fa2d4d..02d1d0b73 100644 --- a/etna/core/mixins.py +++ b/etna/core/mixins.py @@ -15,6 +15,9 @@ import hydra_slayer from sklearn.base import BaseEstimator +from typing_extensions import Self + +from etna.core.saving import AbstractSaveable class BaseMixin: @@ -163,7 +166,7 @@ def get_etna_version() -> Tuple[int, int, int]: return result -class SaveMixin: +class SaveMixin(AbstractSaveable): """Basic implementation of ``AbstractSaveable`` abstract class. It saves object to the zip archive with 2 files: @@ -223,12 +226,12 @@ def _validate_metadata(cls, metadata: Dict[str, Any]): ) @classmethod - def _load_state(cls, archive: zipfile.ZipFile) -> Any: + def _load_state(cls, archive: zipfile.ZipFile) -> Self: with archive.open("object.pkl", "r") as input_file: return pickle.load(input_file) @classmethod - def load(cls, path: pathlib.Path) -> Any: + def load(cls, path: pathlib.Path) -> Self: """Load an object. Parameters diff --git a/etna/core/saving.py b/etna/core/saving.py index 25c84ae53..f17d0f57c 100644 --- a/etna/core/saving.py +++ b/etna/core/saving.py @@ -1,7 +1,8 @@ import pathlib from abc import ABC from abc import abstractmethod -from typing import Any + +from typing_extensions import Self class AbstractSaveable(ABC): @@ -20,7 +21,7 @@ def save(self, path: pathlib.Path): @classmethod @abstractmethod - def load(cls, path: pathlib.Path) -> Any: + def load(cls, path: pathlib.Path) -> Self: """Load an object. Parameters diff --git a/etna/ensembles/mixins.py b/etna/ensembles/mixins.py index 81e8f02c9..00fb4c9e7 100644 --- a/etna/ensembles/mixins.py +++ b/etna/ensembles/mixins.py @@ -2,11 +2,11 @@ import tempfile import zipfile from copy import deepcopy -from typing import Any from typing import List from typing import Optional import pandas as pd +from typing_extensions import Self from etna.core import SaveMixin from etna.core import load @@ -113,7 +113,7 @@ def save(self, path: pathlib.Path): archive.write(pipeline_save_path, f"pipelines/{save_name}") @classmethod - def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any: + def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self: """Load an object. Parameters diff --git a/etna/experimental/classification/feature_extraction/weasel.py b/etna/experimental/classification/feature_extraction/weasel.py index 8a4df84fe..5a19bf3e8 100644 --- a/etna/experimental/classification/feature_extraction/weasel.py +++ b/etna/experimental/classification/feature_extraction/weasel.py @@ -213,6 +213,9 @@ def transform(self, x: List[np.ndarray]) -> np.ndarray: : Transformed input data. """ + if self._padding_expected_len is None: + raise ValueError("Transform is not fitted!") + n_samples = len(x) window_sizes, window_steps = self._check_params(self._min_series_len) for i in range(len(x)): diff --git a/etna/models/base.py b/etna/models/base.py index 601ad41fe..ffa9902a9 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -14,7 +14,6 @@ import pandas as pd from etna import SETTINGS -from etna.core import AbstractSaveable from etna.core import SaveMixin from etna.core.mixins import BaseMixin from etna.datasets.tsdataset import TSDataset @@ -36,7 +35,7 @@ SaveNNMixin = Mock # type: ignore -class AbstractModel(SaveMixin, AbstractSaveable, ABC, BaseMixin): +class AbstractModel(SaveMixin, ABC, BaseMixin): """Interface for model with fit method.""" @property diff --git a/etna/models/mixins.py b/etna/models/mixins.py index 09757d465..baf999479 100644 --- a/etna/models/mixins.py +++ b/etna/models/mixins.py @@ -11,6 +11,7 @@ import dill import numpy as np import pandas as pd +from typing_extensions import Self from etna.core.mixins import SaveMixin from etna.datasets.tsdataset import TSDataset @@ -645,7 +646,7 @@ def _save_state(self, archive: zipfile.ZipFile): torch.save(self, output_file, pickle_module=dill) @classmethod - def _load_state(cls, archive: zipfile.ZipFile) -> Any: + def _load_state(cls, archive: zipfile.ZipFile) -> Self: import torch with archive.open("object.pt", "r") as input_file: diff --git a/etna/pipeline/mixins.py b/etna/pipeline/mixins.py index a91d5497a..0a49fa88f 100644 --- a/etna/pipeline/mixins.py +++ b/etna/pipeline/mixins.py @@ -2,12 +2,12 @@ import tempfile import zipfile from copy import deepcopy -from typing import Any from typing import Optional from typing import Sequence import numpy as np import pandas as pd +from typing_extensions import Self from typing_extensions import get_args from etna.core import SaveMixin @@ -169,7 +169,7 @@ def save(self, path: pathlib.Path): archive.write(transform_save_path, f"transforms/{save_name}") @classmethod - def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any: + def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self: """Load an object. Parameters diff --git a/etna/transforms/base.py b/etna/transforms/base.py index bd2e03ee9..d56de52da 100644 --- a/etna/transforms/base.py +++ b/etna/transforms/base.py @@ -10,7 +10,6 @@ import pandas as pd from typing_extensions import Literal -from etna.core import AbstractSaveable from etna.core import BaseMixin from etna.core import SaveMixin from etna.datasets import TSDataset @@ -21,7 +20,7 @@ class FutureMixin: """Mixin for transforms that can convert non-regressor column to a regressor one.""" -class Transform(SaveMixin, AbstractSaveable, BaseMixin): +class Transform(SaveMixin, BaseMixin): """Base class to create any transforms to apply to data.""" def __init__(self, required_features: Union[Literal["all"], List[str]]): diff --git a/poetry.lock b/poetry.lock index f04e6480a..fd8685be6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2942,44 +2942,48 @@ files = [ [[package]] name = "mypy" -version = "0.950" +version = "1.2.0" description = "Optional static typing for Python" category = "main" optional = true -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, - {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, - {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, - {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, - {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, - {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, - {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, - {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, - {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, - {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, - {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, - {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, - {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, - {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, - {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, - {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, - {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, - {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, - {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, + {file = "mypy-1.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:701189408b460a2ff42b984e6bd45c3f41f0ac9f5f58b8873bbedc511900086d"}, + {file = "mypy-1.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fe91be1c51c90e2afe6827601ca14353bbf3953f343c2129fa1e247d55fd95ba"}, + {file = "mypy-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d26b513225ffd3eacece727f4387bdce6469192ef029ca9dd469940158bc89e"}, + {file = "mypy-1.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3a2d219775a120581a0ae8ca392b31f238d452729adbcb6892fa89688cb8306a"}, + {file = "mypy-1.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:2e93a8a553e0394b26c4ca683923b85a69f7ccdc0139e6acd1354cc884fe0128"}, + {file = "mypy-1.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3efde4af6f2d3ccf58ae825495dbb8d74abd6d176ee686ce2ab19bd025273f41"}, + {file = "mypy-1.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:695c45cea7e8abb6f088a34a6034b1d273122e5530aeebb9c09626cea6dca4cb"}, + {file = "mypy-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0e9464a0af6715852267bf29c9553e4555b61f5904a4fc538547a4d67617937"}, + {file = "mypy-1.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8293a216e902ac12779eb7a08f2bc39ec6c878d7c6025aa59464e0c4c16f7eb9"}, + {file = "mypy-1.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:f46af8d162f3d470d8ffc997aaf7a269996d205f9d746124a179d3abe05ac602"}, + {file = "mypy-1.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:031fc69c9a7e12bcc5660b74122ed84b3f1c505e762cc4296884096c6d8ee140"}, + {file = "mypy-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:390bc685ec209ada4e9d35068ac6988c60160b2b703072d2850457b62499e336"}, + {file = "mypy-1.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4b41412df69ec06ab141808d12e0bf2823717b1c363bd77b4c0820feaa37249e"}, + {file = "mypy-1.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:4e4a682b3f2489d218751981639cffc4e281d548f9d517addfd5a2917ac78119"}, + {file = "mypy-1.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a197ad3a774f8e74f21e428f0de7f60ad26a8d23437b69638aac2764d1e06a6a"}, + {file = "mypy-1.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c9a084bce1061e55cdc0493a2ad890375af359c766b8ac311ac8120d3a472950"}, + {file = "mypy-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaeaa0888b7f3ccb7bcd40b50497ca30923dba14f385bde4af78fac713d6d6f6"}, + {file = "mypy-1.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bea55fc25b96c53affab852ad94bf111a3083bc1d8b0c76a61dd101d8a388cf5"}, + {file = "mypy-1.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:4c8d8c6b80aa4a1689f2a179d31d86ae1367ea4a12855cc13aa3ba24bb36b2d8"}, + {file = "mypy-1.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:70894c5345bea98321a2fe84df35f43ee7bb0feec117a71420c60459fc3e1eed"}, + {file = "mypy-1.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4a99fe1768925e4a139aace8f3fb66db3576ee1c30b9c0f70f744ead7e329c9f"}, + {file = "mypy-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023fe9e618182ca6317ae89833ba422c411469156b690fde6a315ad10695a521"}, + {file = "mypy-1.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4d19f1a239d59f10fdc31263d48b7937c585810288376671eaf75380b074f238"}, + {file = "mypy-1.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:2de7babe398cb7a85ac7f1fd5c42f396c215ab3eff731b4d761d68d0f6a80f48"}, + {file = "mypy-1.2.0-py3-none-any.whl", hash = "sha256:d8e9187bfcd5ffedbe87403195e1fc340189a68463903c39e2b63307c9fa0394"}, + {file = "mypy-1.2.0.tar.gz", hash = "sha256:f70a40410d774ae23fcb4afbbeca652905a04de7948eaf0b1789c8d1426b72d1"}, ] [package.dependencies] -mypy-extensions = ">=0.4.3" +mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typing-extensions = ">=3.10" [package.extras] dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] python2 = ["typed-ast (>=1.4.0,<2)"] reports = ["lxml"] @@ -6126,4 +6130,4 @@ wandb = ["wandb"] [metadata] lock-version = "2.0" python-versions = ">=3.8.0, <3.11.0" -content-hash = "8a416d6f5d0fa31ed68dd4b84d2a6d4f543f2ab8c67d33b31e91f288f3380063" +content-hash = "732a798394e2e7b6bdc5ad5524c6f171c240c69ea30fecc0244a16b7f33eaa78" diff --git a/pyproject.toml b/pyproject.toml index 0c6536072..a0893a8ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ pep8-naming = {version = "^0.12.1", optional = true} flake8-bugbear = {version = "^22.4.25", optional = true} flake8-comprehensions = {version = "^3.9.0", optional = true} flake8-docstrings = {version = "^1.6.0", optional = true} -mypy = {version = "^0.950", optional = true} +mypy = {version = ">=0.950,<2", optional = true} types-PyYAML = {version = "^6.0.0", optional = true} codespell = {version = "^2.0.0", optional = true} diff --git a/tests/test_experimental/test_classification/test_feature_extraction/test_weasel.py b/tests/test_experimental/test_classification/test_feature_extraction/test_weasel.py index 1bf432ede..c04e28f12 100644 --- a/tests/test_experimental/test_classification/test_feature_extraction/test_weasel.py +++ b/tests/test_experimental/test_classification/test_feature_extraction/test_weasel.py @@ -55,6 +55,13 @@ def test_windowed_view(many_time_series, window_size, window_step, expected, req np.testing.assert_array_equal(n_windows_per_sample_cum, n_windows_per_sample_cum_expected) +def test_not_fitted(many_time_series): + x, y = many_time_series + feature_extractor = WEASELFeatureExtractor(padding_value=0, window_sizes=[10, 15]) + with pytest.raises(ValueError, match="Transform is not fitted"): + _ = feature_extractor.transform(x) + + def test_preprocessor_and_classifier(many_time_series_big): x, y = many_time_series_big model = LogisticRegression()