From 686c37c79020ec303c736a7f662139dc367135ba Mon Sep 17 00:00:00 2001 From: Julia Shenshina Date: Wed, 12 Jan 2022 16:27:43 +0300 Subject: [PATCH] Add test for MADTransform x NaN --- .../test_math/test_statistics_transform.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/tests/test_transforms/test_math/test_statistics_transform.py b/tests/test_transforms/test_math/test_statistics_transform.py index d3cf13009..31997d294 100644 --- a/tests/test_transforms/test_math/test_statistics_transform.py +++ b/tests/test_transforms/test_math/test_statistics_transform.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from etna.datasets import TSDataset from etna.transforms.math import MADTransform from etna.transforms.math import MaxTransform from etna.transforms.math import MeanTransform @@ -19,11 +20,7 @@ def simple_df_for_agg() -> pd.DataFrame: df = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", periods=n)}) df["target"] = list(range(n)) df["segment"] = "segment_1" - - df = df.pivot(index="timestamp", columns="segment") - df = df.reorder_levels([1, 0], axis=1) - df = df.sort_index(axis=1) - df.columns.names = ["segment", "feature"] + df = TSDataset.to_dataset(df) return df @@ -33,11 +30,17 @@ def df_for_agg() -> pd.DataFrame: df = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", periods=n)}) df["target"] = [-1, 1, 3, 2, 4, 9, 8, 5, 6, 0] df["segment"] = "segment_1" + df = TSDataset.to_dataset(df) + return df + - df = df.pivot(index="timestamp", columns="segment") - df = df.reorder_levels([1, 0], axis=1) - df = df.sort_index(axis=1) - df.columns.names = ["segment", "feature"] +@pytest.fixture +def df_for_agg_with_nan() -> pd.DataFrame: + n = 10 + df = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", periods=n)}) + df["target"] = [-1, 1, 3, None, 4, 9, 8, 5, 6, 0] + df["segment"] = "segment_1" + df = TSDataset.to_dataset(df) return df @@ -197,6 +200,20 @@ def test_mad_transform(df_for_agg: pd.DataFrame, window: int, periods: int, fill np.testing.assert_array_almost_equal(expected, res["segment_1"]["result"]) +@pytest.mark.parametrize( + "window,periods,fill_na,expected", + ((3, 3, -17, [-17, -17, -17, 4 / 3, -17, -17, -17, 2, 14 / 9, 10 / 9]),), +) +def test_mad_transform_with_nans( + df_for_agg_with_nan: pd.DataFrame, window: int, periods: int, fill_na: float, expected: np.ndarray +): + transform = MADTransform( + window=window, min_periods=periods, fillna=fill_na, in_column="target", out_column="result" + ) + res = transform.fit_transform(df_for_agg_with_nan) + np.testing.assert_array_almost_equal(expected, res["segment_1"]["result"]) + + @pytest.mark.parametrize( "transform", (