diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ecf93714..48f62f687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,7 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - ### Fixed -- +- Type hints for `Pipeline.model` match `models.nn`([#768](https://github.com/tinkoff-ai/etna/pull/840)) - - - Fix behavior of SARIMAXModel if simple_differencing=True is set ([#837](https://github.com/tinkoff-ai/etna/pull/837)) diff --git a/etna/models/base.py b/etna/models/base.py index 6a87ca624..33e631181 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -820,4 +820,31 @@ def get_model(self) -> "DeepBaseNet": return self.net -BaseModel = Union[PerSegmentModel, PerSegmentPredictionIntervalModel, MultiSegmentModel, DeepBaseModel] +class MultiSegmentPredictionIntervalModel(FitAbstractModel, PredictIntervalAbstractModel, BaseMixin): + """Class for holding specific models for multi-segment prediction which are able to build prediction intervals.""" + + def __init__(self): + """Init MultiSegmentPredictionIntervalModel.""" + self.model = None + + def get_model(self) -> Any: + """Get internal model that is used inside etna class. + + Internal model is a model that is used inside etna to forecast segments, + e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`. + + Returns + ------- + : + Internal model + """ + return self.model + + +BaseModel = Union[ + PerSegmentModel, + PerSegmentPredictionIntervalModel, + MultiSegmentModel, + DeepBaseModel, + MultiSegmentPredictionIntervalModel, +] diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 0f04380d7..74cb6698e 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -10,8 +10,7 @@ from etna import SETTINGS from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger -from etna.models.base import Model -from etna.models.base import PredictIntervalAbstractModel +from etna.models.base import MultiSegmentPredictionIntervalModel from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -25,7 +24,7 @@ from pytorch_lightning import LightningModule -class DeepARModel(Model, PredictIntervalAbstractModel, _DeepCopyMixin): +class DeepARModel(MultiSegmentPredictionIntervalModel, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`. Notes @@ -84,6 +83,7 @@ def __init__( quantiles_kwargs: Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss. """ + super().__init__() if loss is None: loss = NormalDistributionLoss() self.max_epochs = max_epochs diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index b3ff52b09..427ad85f3 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -11,8 +11,7 @@ from etna import SETTINGS from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger -from etna.models.base import Model -from etna.models.base import PredictIntervalAbstractModel +from etna.models.base import MultiSegmentPredictionIntervalModel from etna.models.base import log_decorator from etna.models.nn.utils import _DeepCopyMixin from etna.transforms import PytorchForecastingTransform @@ -26,7 +25,7 @@ from pytorch_lightning import LightningModule -class TFTModel(Model, PredictIntervalAbstractModel, _DeepCopyMixin): +class TFTModel(MultiSegmentPredictionIntervalModel, _DeepCopyMixin): """Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. Notes @@ -89,6 +88,7 @@ def __init__( quantiles_kwargs: Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss. """ + super().__init__() if loss is None: loss = QuantileLoss() self.max_epochs = max_epochs