Skip to content

Commit

Permalink
Fix SegmentEncoderTransform to pass inference tests (#1103)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Feb 8, 2023
1 parent 892945e commit 21cdb02
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103))
-
-
## [1.14.0] - 2022-12-16
Expand Down
28 changes: 25 additions & 3 deletions etna/transforms/encoders/segment_encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import reprlib

import numpy as np
import pandas as pd
from sklearn import preprocessing

Expand Down Expand Up @@ -44,12 +47,31 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
-------
:
result dataframe
Raises
------
ValueError:
If transform isn't fitted.
ValueError:
If there are segments that weren't present during training.
"""
encoded_matrix = self._le.transform(self._le.classes_)
encoded_matrix = encoded_matrix.reshape(len(self._le.classes_), -1).repeat(len(df), axis=1).T
segments = df.columns.get_level_values("segment").unique().tolist()

try:
new_segments = set(segments) - set(self._le.classes_)
except AttributeError:
raise ValueError("The transform isn't fitted!")

if len(new_segments) > 0:
raise ValueError(
f"This transform can't process segments that weren't present on train data: {reprlib.repr(new_segments)}"
)

encoded_matrix = self._le.transform(segments)
encoded_matrix = np.tile(encoded_matrix, (len(df), 1))
encoded_df = pd.DataFrame(
encoded_matrix,
columns=pd.MultiIndex.from_product([self._le.classes_, ["segment_code"]], names=("segment", "feature")),
columns=pd.MultiIndex.from_product([segments, ["segment_code"]], names=("segment", "feature")),
index=df.index,
)
encoded_df = encoded_df.astype("category")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import pytest

from etna.transforms import SegmentEncoderTransform
from tests.test_transforms.utils import assert_transformation_equals_loaded_original
Expand All @@ -21,6 +22,38 @@ def test_segment_encoder_transform(dummy_df):
assert codes == {0, 1}, "Codes are not 0 and 1"


def test_subset_segments(dummy_df):
train_df = dummy_df
test_df = dummy_df.loc[:, pd.IndexSlice["Omsk", :]]
transform = SegmentEncoderTransform()

transform.fit(train_df)
transformed_test_df = transform.transform(test_df)

segments = sorted(transformed_test_df.columns.get_level_values("segment").unique())
features = sorted(transformed_test_df.columns.get_level_values("feature").unique())
assert segments == ["Omsk"]
assert features == ["segment_code", "target"]
values = transformed_test_df.loc[:, pd.IndexSlice[:, "segment_code"]]
assert np.all(values == values.iloc[0])


def test_not_fitted_error(dummy_df):
encoder = SegmentEncoderTransform()
with pytest.raises(ValueError, match="The transform isn't fitted"):
encoder.transform(dummy_df)


def test_new_segments_error(dummy_df):
train_df = dummy_df.loc[:, pd.IndexSlice["Moscow", :]]
test_df = dummy_df.loc[:, pd.IndexSlice["Omsk", :]]
transform = SegmentEncoderTransform()

transform.fit(train_df)
with pytest.raises(ValueError, match="This transform can't process segments that weren't present on train data"):
_ = transform.transform(test_df)


def test_save_load(example_tsds):
transform = SegmentEncoderTransform()
assert_transformation_equals_loaded_original(transform=transform, ts=example_tsds)

0 comments on commit 21cdb02

Please sign in to comment.