diff --git a/CHANGELOG.md b/CHANGELOG.md index a207ff58b..c9e804f6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - ### Fixed -- +- Fix inference tests on new segments for `DeepARModel` and `TFTModel` ([#1109](https://github.com/tinkoff-ai/etna/pull/1109)) - Fix `MeanSegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1104](https://github.com/tinkoff-ai/etna/pull/1104)) - - Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103)) diff --git a/tests/test_models/test_inference/test_forecast.py b/tests/test_models/test_inference/test_forecast.py index 85cf65b3e..7108927c5 100644 --- a/tests/test_models/test_inference/test_forecast.py +++ b/tests/test_models/test_inference/test_forecast.py @@ -5,6 +5,7 @@ import pytest from pandas.util.testing import assert_frame_equal from pytorch_forecasting.data import GroupNormalizer +from pytorch_forecasting.data import NaNLabelEncoder from typing_extensions import get_args from etna.datasets import TSDataset @@ -879,15 +880,6 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], ), - ], - ) - def test_forecast_new_segments(self, model, transforms, example_tsds): - self._test_forecast_new_segments(example_tsds, model, transforms, train_segments=["segment_1"]) - - @to_be_fixed(raises=KeyError, match="Unknown category") - @pytest.mark.parametrize( - "model, transforms", - [ ( DeepARModel(max_epochs=1, learning_rate=[0.01]), [ @@ -896,6 +888,7 @@ def test_forecast_new_segments(self, model, transforms, example_tsds): max_prediction_length=5, time_varying_known_reals=["time_idx"], time_varying_unknown_reals=["target"], + categorical_encoders={"segment": NaNLabelEncoder(add_nan=True, warn=False)}, target_normalizer=GroupNormalizer(groups=["segment"]), ) ], @@ -909,6 +902,7 @@ def test_forecast_new_segments(self, model, transforms, example_tsds): max_prediction_length=5, time_varying_known_reals=["time_idx"], time_varying_unknown_reals=["target"], + categorical_encoders={"segment": NaNLabelEncoder(add_nan=True, warn=False)}, static_categoricals=["segment"], target_normalizer=None, ) @@ -916,7 +910,7 @@ def test_forecast_new_segments(self, model, transforms, example_tsds): ), ], ) - def test_forecast_new_segments_failed_encoding_error(self, model, transforms, example_tsds): + def test_forecast_new_segments(self, model, transforms, example_tsds): self._test_forecast_new_segments(example_tsds, model, transforms, train_segments=["segment_1"]) @to_be_fixed(raises=NotImplementedError, match="Per-segment models can't make predictions on new segments") diff --git a/tests/test_models/test_inference/test_predict.py b/tests/test_models/test_inference/test_predict.py index e05f0e95c..ca2b9a0ac 100644 --- a/tests/test_models/test_inference/test_predict.py +++ b/tests/test_models/test_inference/test_predict.py @@ -5,6 +5,7 @@ import pytest from pandas.util.testing import assert_frame_equal from pytorch_forecasting.data import GroupNormalizer +from pytorch_forecasting.data import NaNLabelEncoder from etna.datasets import TSDataset from etna.models import AutoARIMAModel @@ -782,7 +783,7 @@ def _test_predict_new_segments(self, ts, model, transforms, train_segments, num_ def test_predict_new_segments(self, model, transforms, example_tsds): self._test_predict_new_segments(example_tsds, model, transforms, train_segments=["segment_1"]) - @to_be_fixed(raises=KeyError, match="Unknown category") + @to_be_fixed(raises=NotImplementedError, match="Method predict isn't currently implemented") @pytest.mark.parametrize( "model, transforms", [ @@ -794,6 +795,7 @@ def test_predict_new_segments(self, model, transforms, example_tsds): max_prediction_length=5, time_varying_known_reals=["time_idx"], time_varying_unknown_reals=["target"], + categorical_encoders={"segment": NaNLabelEncoder(add_nan=True, warn=False)}, target_normalizer=GroupNormalizer(groups=["segment"]), ) ], @@ -807,20 +809,12 @@ def test_predict_new_segments(self, model, transforms, example_tsds): max_prediction_length=5, time_varying_known_reals=["time_idx"], time_varying_unknown_reals=["target"], + categorical_encoders={"segment": NaNLabelEncoder(add_nan=True, warn=False)}, static_categoricals=["segment"], target_normalizer=None, ) ], ), - ], - ) - def test_predict_new_segments_failed_encoding_error(self, model, transforms, example_tsds): - self._test_predict_new_segments(example_tsds, model, transforms, train_segments=["segment_1"]) - - @to_be_fixed(raises=NotImplementedError, match="Method predict isn't currently implemented") - @pytest.mark.parametrize( - "model, transforms", - [ (RNNModel(input_size=1, encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), []), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),