Skip to content

Commit

Permalink
Add test for MADTransform x NaN
Browse files Browse the repository at this point in the history
  • Loading branch information
julia-shenshina committed Jan 12, 2022
1 parent 443cbe7 commit 686c37c
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions tests/test_transforms/test_math/test_statistics_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.transforms.math import MADTransform
from etna.transforms.math import MaxTransform
from etna.transforms.math import MeanTransform
Expand All @@ -19,11 +20,7 @@ def simple_df_for_agg() -> pd.DataFrame:
df = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", periods=n)})
df["target"] = list(range(n))
df["segment"] = "segment_1"

df = df.pivot(index="timestamp", columns="segment")
df = df.reorder_levels([1, 0], axis=1)
df = df.sort_index(axis=1)
df.columns.names = ["segment", "feature"]
df = TSDataset.to_dataset(df)
return df


Expand All @@ -33,11 +30,17 @@ def df_for_agg() -> pd.DataFrame:
df = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", periods=n)})
df["target"] = [-1, 1, 3, 2, 4, 9, 8, 5, 6, 0]
df["segment"] = "segment_1"
df = TSDataset.to_dataset(df)
return df


df = df.pivot(index="timestamp", columns="segment")
df = df.reorder_levels([1, 0], axis=1)
df = df.sort_index(axis=1)
df.columns.names = ["segment", "feature"]
@pytest.fixture
def df_for_agg_with_nan() -> pd.DataFrame:
n = 10
df = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", periods=n)})
df["target"] = [-1, 1, 3, None, 4, 9, 8, 5, 6, 0]
df["segment"] = "segment_1"
df = TSDataset.to_dataset(df)
return df


Expand Down Expand Up @@ -197,6 +200,20 @@ def test_mad_transform(df_for_agg: pd.DataFrame, window: int, periods: int, fill
np.testing.assert_array_almost_equal(expected, res["segment_1"]["result"])


@pytest.mark.parametrize(
"window,periods,fill_na,expected",
((3, 3, -17, [-17, -17, -17, 4 / 3, -17, -17, -17, 2, 14 / 9, 10 / 9]),),
)
def test_mad_transform_with_nans(
df_for_agg_with_nan: pd.DataFrame, window: int, periods: int, fill_na: float, expected: np.ndarray
):
transform = MADTransform(
window=window, min_periods=periods, fillna=fill_na, in_column="target", out_column="result"
)
res = transform.fit_transform(df_for_agg_with_nan)
np.testing.assert_array_almost_equal(expected, res["segment_1"]["result"])


@pytest.mark.parametrize(
"transform",
(
Expand Down

0 comments on commit 686c37c

Please sign in to comment.