Skip to content

Enabling passing context into Model.forecast v2 #888

Merged
merged 17 commits into from
Sep 2, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
-
-
-
- Changed hierarchy of base models, enable passing context into models ([#888](https://github.com/tinkoff-ai/etna/pull/888))
-
-
- Teach AutoARIMAModel to work with out-sample predictions ([#830](https://github.com/tinkoff-ai/etna/pull/830))
Expand Down
10 changes: 7 additions & 3 deletions etna/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from etna import SETTINGS
from etna.models.autoarima import AutoARIMAModel
from etna.models.base import BaseAdapter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to add also AbstractModels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's do it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about MultiSegmentModelMixin and the other abstract models?

from etna.models.base import BaseModel
from etna.models.base import Model
from etna.models.base import PerSegmentModel
from etna.models.base import ContextIgnorantModelType
from etna.models.base import ContextRequiredModelType
from etna.models.base import ModelType
from etna.models.base import NonPredictionIntervalContextIgnorantAbstractModel
from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel
from etna.models.base import PredictionIntervalContextIgnorantAbstractModel
from etna.models.base import PredictionIntervalContextRequiredAbstractModel
from etna.models.catboost import CatBoostModelMultiSegment
from etna.models.catboost import CatBoostModelPerSegment
from etna.models.catboost import CatBoostMultiSegmentModel
Expand Down
10 changes: 6 additions & 4 deletions etna/models/autoarima.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from statsmodels.tools.sm_exceptions import ValueWarning
from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper

from etna.models.base import PerSegmentPredictionIntervalModel
from etna.models.base import PerSegmentModelMixin
from etna.models.base import PredictionIntervalContextIgnorantAbstractModel
from etna.models.base import PredictionIntervalContextIgnorantModelMixin
from etna.models.sarimax import _SARIMAXBaseAdapter

warnings.filterwarnings(
Expand Down Expand Up @@ -48,7 +50,9 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResul
return model.arima_res_


class AutoARIMAModel(PerSegmentPredictionIntervalModel):
class AutoARIMAModel(
PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel
):
"""
Class for holding auto arima model.

Expand All @@ -57,8 +61,6 @@ class AutoARIMAModel(PerSegmentPredictionIntervalModel):
We use :py:class:`pmdarima.arima.arima.ARIMA`.
"""

context_size = 0

def __init__(
self,
**kwargs,
Expand Down
Loading