Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added method Table.inverse_transform_table which returns the original table #227

Merged
merged 9 commits into from
Apr 21, 2023
42 changes: 42 additions & 0 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
if TYPE_CHECKING:
from collections.abc import Callable, Iterable

from safeds.data.tabular.transformation import InvertibleTableTransformer

from ._tagged_table import TaggedTable


Expand Down Expand Up @@ -991,6 +993,46 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl
return self.replace_column(name, result)
raise UnknownColumnNameError([name])

def inverse_transform_table(self, transformer: InvertibleTableTransformer) -> Table:
"""
Invert the transformation applied by the given transformer.

Parameters
----------
transformer : InvertibleTableTransformer
A transformer that was fitted with columns, which are all present in the table.

Returns
-------
table : Table
The original table

Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.

Examples
--------
>>> from safeds.data.tabular.transformation import OneHotEncoder
>>> from safeds.data.tabular.containers import Table
>>> transformer = OneHotEncoder()
>>> table = Table.from_dict({"col1": [1, 2, 1], "col2": [1, 2, 4]})
>>> transformer = transformer.fit(table, None)
>>> transformed_table = transformer.transform(table)
>>> transformed_table.inverse_transform_table(transformer)
col1 col2
0 1 1
1 2 2
2 1 4
>>> transformer.inverse_transform(transformed_table)
col1 col2
0 1 1
1 2 2
2 1 4
"""
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
return transformer.inverse_transform(self)

# ------------------------------------------------------------------------------------------------------------------
# Plotting
# ------------------------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.data.tabular.exceptions import TransformerNotFittedError
from safeds.data.tabular.transformation import OneHotEncoder


class TestInverseTransformTableOnOneHotEncoder:
@pytest.mark.parametrize(
("table_to_fit", "column_names", "table_to_transform"),
[
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
["b"],
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
),
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
["b"],
Table.from_dict(
{
"c": [0.0, 0.0, 0.0, 1.0],
"b": ["a", "b", "b", "c"],
"a": [1.0, 0.0, 0.0, 0.0],
},
),
),
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"bb": ["a", "b", "b", "c"],
},
),
["b", "bb"],
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"bb": ["a", "b", "b", "c"],
},
),
),
],
ids=[
"same table to fit and transform",
"different tables to fit and transform",
"one column name is a prefix of another column name",
],
)
def test_should_return_original_table(
self,
table_to_fit: Table,
column_names: list[str],
table_to_transform: Table,
) -> None:
transformer = OneHotEncoder().fit(table_to_fit, column_names)
transformed_table = transformer.transform(table_to_transform)

result = transformed_table.inverse_transform_table(transformer)

# This checks whether the columns are in the same order
assert result.column_names == table_to_transform.column_names
# This is subsumed by the next assertion, but we get a better error message
assert result.schema == table_to_transform.schema
assert result == table_to_transform

def test_should_not_change_transformed_table(self) -> None:
table = Table.from_dict(
{
"col1": ["a", "b", "b", "c"],
},
)

transformer = OneHotEncoder().fit(table, None)
transformed_table = transformer.transform(table)
transformed_table.inverse_transform_table(transformer)

expected = Table.from_dict(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
},
)

assert transformed_table == expected

def test_should_raise_if_not_fitted(self) -> None:
table = Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": [0.0, 1.0, 1.0, 0.0],
"c": [0.0, 0.0, 0.0, 1.0],
},
)

transformer = OneHotEncoder()

with pytest.raises(TransformerNotFittedError):
table.inverse_transform_table(transformer)