diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bd1b0b90..38d07f07f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - Make native prediction intervals for DeepAR ([#761](https://github.com/tinkoff-ai/etna/pull/761)) - Make native prediction intervals for TFTModel ([#770](https://github.com/tinkoff-ai/etna/pull/770)) -- +- Test cases for testing inference of models ([#794](https://github.com/tinkoff-ai/etna/pull/794)) - ### Fixed - diff --git a/tests/test_models/test_inference.py b/tests/test_models/test_inference.py index 60e2e6240..9d9285ea3 100644 --- a/tests/test_models/test_inference.py +++ b/tests/test_models/test_inference.py @@ -116,10 +116,10 @@ def _test_forecast_out_sample_suffix(ts, model, transforms): (HoltWintersModel(), []), (SimpleExpSmoothingModel(), []), (MovingAverageModel(window=3), []), - (NaiveModel(), []), + (NaiveModel(lag=3), []), (SeasonalMovingAverageModel(), []), - (BATSModel(), []), - (TBATSModel(), []), + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ], ) def test_forecast_in_sample_full(model, transforms, example_tsds): @@ -183,10 +183,10 @@ def test_forecast_in_sample_full_failed(model, transforms, example_tsds): (HoltWintersModel(), []), (SimpleExpSmoothingModel(), []), (MovingAverageModel(window=3), []), - (NaiveModel(), []), + (NaiveModel(lag=3), []), (SeasonalMovingAverageModel(), []), - (BATSModel(), []), - (TBATSModel(), []), + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ], ) def test_forecast_in_sample_suffix(model, transforms, example_tsds): @@ -246,9 +246,9 @@ def test_forecast_in_sample_suffix_failed(model, transforms, example_tsds): (SimpleExpSmoothingModel(), []), (MovingAverageModel(window=3), []), (SeasonalMovingAverageModel(), []), - (NaiveModel(), []), - (BATSModel(), []), - (TBATSModel(), []), + (NaiveModel(lag=3), []), + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ], ) def test_forecast_out_sample_prefix(model, transforms, example_tsds): @@ -305,9 +305,6 @@ def test_forecast_out_sample_prefix_failed(model, transforms, example_tsds): (HoltModel(), []), (HoltWintersModel(), []), (SimpleExpSmoothingModel(), []), - (NaiveModel(), []), - (BATSModel(), []), - (TBATSModel(), []), ( TFTModel(max_epochs=1, learning_rate=[0.01]), [ @@ -335,6 +332,9 @@ def test_forecast_out_sample_suffix(model, transforms, example_tsds): (AutoARIMAModel(), []), (MovingAverageModel(window=3), []), (SeasonalMovingAverageModel(), []), + (NaiveModel(lag=3), []), + (BATSModel(use_trend=True), []), + (TBATSModel(use_trend=True), []), ( DeepARModel(max_epochs=5, learning_rate=[0.01]), [