Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return fitted transformer and transformed table from fit_and_transform #724

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/safeds/data/tabular/transformation/_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""
return self._wrapped_transformer is not None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None) -> Imputer:
"""
Learn a transformation for a set of columns in a table.
Expand Down Expand Up @@ -199,7 +198,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> Imputer:

return result

# noinspection PyProtectedMember
def transform(self, table: Table) -> Table:
"""
Apply the learned transformation to a table.
Expand Down
1 change: 0 additions & 1 deletion src/safeds/data/tabular/transformation/_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sklearn.preprocessing import OrdinalEncoder as sk_OrdinalEncoder


# noinspection PyProtectedMember
class LabelEncoder(InvertibleTableTransformer):
"""The LabelEncoder encodes one or more given columns into labels."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class OneHotEncoder(InvertibleTableTransformer):
>>> from safeds.data.tabular.transformation import OneHotEncoder
>>> table = Table({"col1": ["a", "b", "c", "a"]})
>>> transformer = OneHotEncoder()
>>> transformer.fit_and_transform(table, ["col1"])
>>> transformer.fit_and_transform(table, ["col1"])[1]
col1__a col1__b col1__c
0 1.0 0.0 0.0
1 0.0 1.0 0.0
Expand All @@ -65,7 +65,6 @@ def __init__(self) -> None:
# Maps nan values (str of old column) to corresponding new column name
self._value_to_column_nans: dict[str, str] | None = None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
"""
Learn a transformation for a set of columns in a table.
Expand Down Expand Up @@ -150,7 +149,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:

return result

# noinspection PyProtectedMember
def transform(self, table: Table) -> Table:
"""
Apply the learned transformation to a table.
Expand Down Expand Up @@ -238,7 +236,6 @@ def transform(self, table: Table) -> Table:
# Apply sorting and return:
return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

# noinspection PyProtectedMember
def inverse_transform(self, transformed_table: Table) -> Table:
"""
Undo the learned transformation.
Expand Down
42 changes: 14 additions & 28 deletions src/safeds/data/tabular/transformation/_table_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from safeds._utils import _structural_hash

Expand All @@ -26,8 +26,13 @@ def __hash__(self) -> int:
removed = self.get_names_of_removed_columns() if self.is_fitted else []
return _structural_hash(self.__class__.__qualname__, self.is_fitted, added, changed, removed)

@property
@abstractmethod
def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""

@abstractmethod
def fit(self, table: Table, column_names: list[str] | None) -> TableTransformer:
def fit(self, table: Table, column_names: list[str] | None) -> Self:
"""
Learn a transformation for a set of columns in a table.

Expand Down Expand Up @@ -117,16 +122,11 @@ def get_names_of_removed_columns(self) -> list[str]:
If the transformer has not been fitted yet.
"""

@property
@abstractmethod
def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""

def fit_and_transform(self, table: Table, column_names: list[str] | None = None) -> Table:
def fit_and_transform(self, table: Table, column_names: list[str] | None = None) -> tuple[Self, Table]:
"""
Learn a transformation for a set of columns in a table and apply the learned transformation to the same table.

The table is not modified. If you also need the fitted transformer, use `fit` and `transform` separately.
Neither the transformer nor the table are modified.

Parameters
----------
Expand All @@ -137,33 +137,19 @@ def fit_and_transform(self, table: Table, column_names: list[str] | None = None)

Returns
-------
fitted_transformer:
The fitted transformer.
transformed_table:
The transformed table.
"""
return self.fit(table, column_names).transform(table)
fitted_transformer = self.fit(table, column_names)
transformed_table = fitted_transformer.transform(table)
return fitted_transformer, transformed_table


class InvertibleTableTransformer(TableTransformer):
"""A `TableTransformer` that can also undo the learned transformation after it has been applied."""

@abstractmethod
def fit(self, table: Table, column_names: list[str] | None) -> InvertibleTableTransformer:
"""
Learn a transformation for a set of columns in a table.

Parameters
----------
table:
The table used to fit the transformer.
column_names:
The list of columns from the table used to fit the transformer. If `None`, all columns are used.

Returns
-------
fitted_transformer:
The fitted transformer.
"""

@abstractmethod
def inverse_transform(self, transformed_table: Table) -> Table:
"""
Expand Down
10 changes: 7 additions & 3 deletions tests/safeds/data/tabular/transformation/test_discretizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,15 @@ class TestFitAndTransform:
],
ids=["None", "col1"],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert Discretizer().fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = Discretizer().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert transformed_table == expected

@pytest.mark.parametrize(
("table", "number_of_bins", "expected"),
Expand Down Expand Up @@ -243,7 +245,9 @@ def test_should_return_transformed_table_with_correct_number_of_bins(
number_of_bins: int,
expected: Table,
) -> None:
assert Discretizer(number_of_bins).fit_and_transform(table, ["col1"]) == expected
fitted_transformer, transformed_table = Discretizer(number_of_bins).fit_and_transform(table, ["col1"])
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
Expand Down
28 changes: 13 additions & 15 deletions tests/safeds/data/tabular/transformation/test_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,29 +413,27 @@ class TestFitAndTransform:
"other value to replace",
],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
strategy: Imputer.Strategy,
value_to_replace: float | str | None,
expected: Table,
) -> None:
if isinstance(strategy, _Mode):
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore",
message=r"There are multiple most frequent values in a column given to the Imputer\..*",
category=UserWarning,
)
assert (
Imputer(strategy, value_to_replace=value_to_replace).fit_and_transform(table, column_names)
== expected
)
else:
assert (
Imputer(strategy, value_to_replace=value_to_replace).fit_and_transform(table, column_names) == expected
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore",
message=r"There are multiple most frequent values in a column given to the Imputer\..*",
category=UserWarning,
)
fitted_transformer, transformed_table = Imputer(
strategy,
value_to_replace=value_to_replace,
).fit_and_transform(table, column_names)

assert fitted_transformer.is_fitted
assert transformed_table == expected

@pytest.mark.parametrize("strategy", strategies(), ids=lambda x: x.__class__.__name__)
def test_should_not_change_original_table(self, strategy: Imputer.Strategy) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ class TestFitAndTransform:
],
ids=["no_column_names", "with_column_names"],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert LabelEncoder().fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = LabelEncoder().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,15 @@ class TestFitAndTransform:
"column with nans",
],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert OneHotEncoder().fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = OneHotEncoder().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
Expand Down
15 changes: 11 additions & 4 deletions tests/safeds/data/tabular/transformation/test_range_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,15 @@ class TestFitAndTransform:
],
ids=["one_column", "two_columns"],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert RangeScaler().fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = RangeScaler().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert transformed_table == expected

@pytest.mark.parametrize(
("table", "column_names", "expected"),
Expand Down Expand Up @@ -186,13 +188,18 @@ def test_should_return_transformed_table(
],
ids=["one_column", "two_columns"],
)
def test_should_return_transformed_table_with_correct_range(
def test_should_return_fitted_transformer_and_transformed_table_with_correct_range(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert RangeScaler(minimum=-10.0, maximum=10.0).fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = RangeScaler(minimum=-10.0, maximum=10.0).fit_and_transform(
table,
column_names,
)
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def test_should_return_true_after_fitting(self) -> None:
assert fitted_transformer.is_fitted


class TestFitAndTransformOnMultipleTables:
class TestFitAndTransform:
@pytest.mark.parametrize(
("fit_and_transform_table", "only_transform_table", "column_names", "expected_1", "expected_2"),
("table", "column_names", "expected"),
[
(
Table(
Expand All @@ -122,43 +122,27 @@ class TestFitAndTransformOnMultipleTables:
"col2": [0.0, 0.0, 1.0, 1.0],
},
),
Table(
{
"col1": [2],
"col2": [2],
},
),
None,
Table(
{
"col1": [-1.0, -1.0, 1.0, 1.0],
"col2": [-1.0, -1.0, 1.0, 1.0],
},
),
Table(
{
"col1": [3.0],
"col2": [3.0],
},
),
),
],
ids=["two_columns"],
)
def test_should_return_transformed_tables(
def test_should_return_fitted_transformer_and_transformed_table(
self,
fit_and_transform_table: Table,
only_transform_table: Table,
table: Table,
column_names: list[str] | None,
expected_1: Table,
expected_2: Table,
expected: Table,
) -> None:
s = StandardScaler().fit(fit_and_transform_table, column_names)
assert s.fit_and_transform(fit_and_transform_table, column_names) == expected_1
assert s.transform(only_transform_table) == expected_2

fitted_transformer, transformed_table = StandardScaler().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert_that_tables_are_close(transformed_table, expected)

class TestFitAndTransform:
def test_should_not_change_original_table(self) -> None:
table = Table(
{
Expand Down
Loading