Skip to content

Enabling passing context into Model.forecast v2 #888

Merged
merged 17 commits into from
Sep 2, 2022
2 changes: 1 addition & 1 deletion etna/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from etna.models.base import BaseAdapter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to add also AbstractModels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's do it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about MultiSegmentModelMixin and the other abstract models?

from etna.models.base import BaseModel
from etna.models.base import Model
from etna.models.base import PerSegmentModel
from etna.models.base import PerSegmentModelMixin
from etna.models.catboost import CatBoostModelMultiSegment
from etna.models.catboost import CatBoostModelPerSegment
from etna.models.catboost import CatBoostMultiSegmentModel
Expand Down
10 changes: 6 additions & 4 deletions etna/models/autoarima.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from statsmodels.tools.sm_exceptions import ValueWarning
from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper

from etna.models.base import PerSegmentPredictionIntervalModel
from etna.models.base import PerSegmentModelMixin
from etna.models.base import PredictionIntervalContextIgnorantAbstractModel
from etna.models.base import PredictionIntervalContextIgnorantModelMixin
from etna.models.sarimax import _SARIMAXBaseAdapter

warnings.filterwarnings(
Expand Down Expand Up @@ -48,7 +50,9 @@ def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResul
return model.arima_res_


class AutoARIMAModel(PerSegmentPredictionIntervalModel):
class AutoARIMAModel(
PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel
):
"""
Class for holding auto arima model.

Expand All @@ -57,8 +61,6 @@ class AutoARIMAModel(PerSegmentPredictionIntervalModel):
We use :py:class:`pmdarima.arima.arima.ARIMA`.
"""

context_size = 0

def __init__(
self,
**kwargs,
Expand Down
Loading