Skip to content

Commit

Permalink
Fix MeanSegmentEncoderTransform to pass inference tests (#1104)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Feb 8, 2023
1 parent 21cdb02 commit 27023dd
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
### Fixed
-
-
- Fix `MeanSegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1104](https://github.com/tinkoff-ai/etna/pull/1104))
-
- Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103))
-
Expand Down
32 changes: 27 additions & 5 deletions etna/transforms/encoders/mean_segment_encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import numpy as np
import reprlib
from typing import Dict
from typing import Optional

import pandas as pd

from etna.transforms import Transform
Expand All @@ -13,7 +16,7 @@ class MeanSegmentEncoderTransform(Transform, FutureMixin):

def __init__(self):
self.mean_encoder = MeanTransform(in_column="target", window=-1, out_column="segment_mean")
self.global_means: np.ndarray[float] = None
self.global_means: Optional[Dict[str, float]] = None

def fit(self, df: pd.DataFrame) -> "MeanSegmentEncoderTransform":
"""
Expand All @@ -30,7 +33,9 @@ def fit(self, df: pd.DataFrame) -> "MeanSegmentEncoderTransform":
Fitted transform
"""
self.mean_encoder.fit(df)
self.global_means = df.loc[:, self.idx[:, "target"]].mean().values
mean_values = df.loc[:, self.idx[:, "target"]].mean().to_dict()
mean_values = {key[0]: value for key, value in mean_values.items()}
self.global_means = mean_values
return self

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -46,9 +51,26 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
-------
:
result dataframe
Raises
------
ValueError:
If transform isn't fitted.
ValueError:
If there are segments that weren't present during training.
"""
if self.global_means is None:
raise ValueError("The transform isn't fitted!")

segments = df.columns.get_level_values("segment").unique().tolist()
new_segments = set(segments) - self.global_means.keys()
if len(new_segments) > 0:
raise ValueError(
f"This transform can't process segments that weren't present on train data: {reprlib.repr(new_segments)}"
)

df = self.mean_encoder.transform(df)
segment = df.columns.get_level_values("segment").unique()[0]
segment = segments[0]
nan_timestamps = df[df.loc[:, self.idx[segment, "target"]].isna()].index
df.loc[nan_timestamps, self.idx[:, "segment_mean"]] = self.global_means
df.loc[nan_timestamps, self.idx[:, "segment_mean"]] = [self.global_means[x] for x in segments]
return df
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from tests.test_transforms.utils import assert_transformation_equals_loaded_original


@pytest.mark.parametrize("expected_global_means", ([[3, 30]]))
@pytest.mark.parametrize("expected_global_means", [{"Moscow": 3, "Omsk": 30}])
def test_mean_segment_encoder_fit(simple_df, expected_global_means):
encoder = MeanSegmentEncoderTransform()
encoder.fit(simple_df)
assert (encoder.global_means == expected_global_means).all()
assert encoder.global_means == expected_global_means


def test_mean_segment_encoder_transform(simple_df, transformed_simple_df):
Expand All @@ -22,6 +22,36 @@ def test_mean_segment_encoder_transform(simple_df, transformed_simple_df):
pd.testing.assert_frame_equal(transformed_df, transformed_simple_df)


def test_subset_segments(simple_df):
train_df = simple_df
test_df = simple_df.loc[:, pd.IndexSlice["Omsk", :]]
transform = MeanSegmentEncoderTransform()

transform.fit(train_df)
transformed_test_df = transform.transform(test_df)

segments = sorted(transformed_test_df.columns.get_level_values("segment").unique())
features = sorted(transformed_test_df.columns.get_level_values("feature").unique())
assert segments == ["Omsk"]
assert features == ["exog", "segment_mean", "target"]


def test_not_fitted_error(simple_df):
encoder = MeanSegmentEncoderTransform()
with pytest.raises(ValueError, match="The transform isn't fitted"):
encoder.transform(simple_df)


def test_new_segments_error(simple_df):
train_df = simple_df.loc[:, pd.IndexSlice["Moscow", :]]
test_df = simple_df.loc[:, pd.IndexSlice["Omsk", :]]
transform = MeanSegmentEncoderTransform()

transform.fit(train_df)
with pytest.raises(ValueError, match="This transform can't process segments that weren't present on train data"):
_ = transform.transform(test_df)


@pytest.fixture
def almost_constant_ts(random_seed) -> TSDataset:
df_1 = pd.DataFrame.from_dict({"timestamp": pd.date_range("2021-06-01", "2021-07-01", freq="D")})
Expand Down

0 comments on commit 27023dd

Please sign in to comment.