diff --git a/CHANGELOG.md b/CHANGELOG.md index db776cfd8..921fdc1e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Fixed - -- +- Fix `MeanSegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1104](https://github.com/tinkoff-ai/etna/pull/1104)) - - Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103)) - diff --git a/etna/transforms/encoders/mean_segment_encoder.py b/etna/transforms/encoders/mean_segment_encoder.py index 8f518441c..6a05bec3e 100644 --- a/etna/transforms/encoders/mean_segment_encoder.py +++ b/etna/transforms/encoders/mean_segment_encoder.py @@ -1,4 +1,7 @@ -import numpy as np +import reprlib +from typing import Dict +from typing import Optional + import pandas as pd from etna.transforms import Transform @@ -13,7 +16,7 @@ class MeanSegmentEncoderTransform(Transform, FutureMixin): def __init__(self): self.mean_encoder = MeanTransform(in_column="target", window=-1, out_column="segment_mean") - self.global_means: np.ndarray[float] = None + self.global_means: Optional[Dict[str, float]] = None def fit(self, df: pd.DataFrame) -> "MeanSegmentEncoderTransform": """ @@ -30,7 +33,9 @@ def fit(self, df: pd.DataFrame) -> "MeanSegmentEncoderTransform": Fitted transform """ self.mean_encoder.fit(df) - self.global_means = df.loc[:, self.idx[:, "target"]].mean().values + mean_values = df.loc[:, self.idx[:, "target"]].mean().to_dict() + mean_values = {key[0]: value for key, value in mean_values.items()} + self.global_means = mean_values return self def transform(self, df: pd.DataFrame) -> pd.DataFrame: @@ -46,9 +51,26 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: ------- : result dataframe + + Raises + ------ + ValueError: + If transform isn't fitted. + ValueError: + If there are segments that weren't present during training. """ + if self.global_means is None: + raise ValueError("The transform isn't fitted!") + + segments = df.columns.get_level_values("segment").unique().tolist() + new_segments = set(segments) - self.global_means.keys() + if len(new_segments) > 0: + raise ValueError( + f"This transform can't process segments that weren't present on train data: {reprlib.repr(new_segments)}" + ) + df = self.mean_encoder.transform(df) - segment = df.columns.get_level_values("segment").unique()[0] + segment = segments[0] nan_timestamps = df[df.loc[:, self.idx[segment, "target"]].isna()].index - df.loc[nan_timestamps, self.idx[:, "segment_mean"]] = self.global_means + df.loc[nan_timestamps, self.idx[:, "segment_mean"]] = [self.global_means[x] for x in segments] return df diff --git a/tests/test_transforms/test_encoders/test_mean_segment_encoder_transform.py b/tests/test_transforms/test_encoders/test_mean_segment_encoder_transform.py index 0ac9bced2..49da49e4b 100644 --- a/tests/test_transforms/test_encoders/test_mean_segment_encoder_transform.py +++ b/tests/test_transforms/test_encoders/test_mean_segment_encoder_transform.py @@ -9,11 +9,11 @@ from tests.test_transforms.utils import assert_transformation_equals_loaded_original -@pytest.mark.parametrize("expected_global_means", ([[3, 30]])) +@pytest.mark.parametrize("expected_global_means", [{"Moscow": 3, "Omsk": 30}]) def test_mean_segment_encoder_fit(simple_df, expected_global_means): encoder = MeanSegmentEncoderTransform() encoder.fit(simple_df) - assert (encoder.global_means == expected_global_means).all() + assert encoder.global_means == expected_global_means def test_mean_segment_encoder_transform(simple_df, transformed_simple_df): @@ -22,6 +22,36 @@ def test_mean_segment_encoder_transform(simple_df, transformed_simple_df): pd.testing.assert_frame_equal(transformed_df, transformed_simple_df) +def test_subset_segments(simple_df): + train_df = simple_df + test_df = simple_df.loc[:, pd.IndexSlice["Omsk", :]] + transform = MeanSegmentEncoderTransform() + + transform.fit(train_df) + transformed_test_df = transform.transform(test_df) + + segments = sorted(transformed_test_df.columns.get_level_values("segment").unique()) + features = sorted(transformed_test_df.columns.get_level_values("feature").unique()) + assert segments == ["Omsk"] + assert features == ["exog", "segment_mean", "target"] + + +def test_not_fitted_error(simple_df): + encoder = MeanSegmentEncoderTransform() + with pytest.raises(ValueError, match="The transform isn't fitted"): + encoder.transform(simple_df) + + +def test_new_segments_error(simple_df): + train_df = simple_df.loc[:, pd.IndexSlice["Moscow", :]] + test_df = simple_df.loc[:, pd.IndexSlice["Omsk", :]] + transform = MeanSegmentEncoderTransform() + + transform.fit(train_df) + with pytest.raises(ValueError, match="This transform can't process segments that weren't present on train data"): + _ = transform.transform(test_df) + + @pytest.fixture def almost_constant_ts(random_seed) -> TSDataset: df_1 = pd.DataFrame.from_dict({"timestamp": pd.date_range("2021-06-01", "2021-07-01", freq="D")})