Skip to content

Commit

Permalink
feat: raise if remove_colums is called with unknown column by defau…
Browse files Browse the repository at this point in the history
…lt (#852)

Closes #807

### Summary of Changes

- Added an optional, keyword-only parameter ignore_unknown_names: bool =
False to the remove_columns method.
- This parameter controls whether an error is raised when attempting to
remove non-existent columns.
- If ignore_unknown_names is set to False, the method checks for the
existence of specified columns using the - ----
- If ignore_unknown_names is set to True, non-existent columns are
ignored, and no error is raised.

---------

Co-authored-by: Sardar <[email protected]>
Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Tim Locke <[email protected]>
Co-authored-by: Lars Reimann <[email protected]>
  • Loading branch information
5 people committed Jun 28, 2024
1 parent 9880fe0 commit 8f78163
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 18 deletions.
14 changes: 10 additions & 4 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,20 +642,23 @@ def remove_columns(
self,
names: str | list[str],
/,
*,
ignore_unknown_names: bool = False,
) -> Table:
"""
Return a new table without the specified columns.
**Notes:**
- The original table is not modified.
- This method does not raise if a column does not exist. You can use it to ensure that the resulting table does
not contain certain columns.
Parameters
----------
names:
The names of the columns to remove.
ignore_unknown_names:
If set to True, columns that are not present in the table will be ignored.
If set to False, an error will be raised if any of the specified columns do not exist.
Returns
-------
Expand All @@ -677,7 +680,7 @@ def remove_columns(
| 6 |
+-----+
>>> table.remove_columns(["c"])
>>> table.remove_columns(["c"], ignore_unknown_names=True)
+-----+-----+
| a | b |
| --- | --- |
Expand All @@ -691,6 +694,9 @@ def remove_columns(
if isinstance(names, str):
names = [names]

if not ignore_unknown_names:
_check_columns_exist(self, names)

return Table._from_polars_lazy_frame(
self._lazy_frame.drop(names),
)
Expand Down Expand Up @@ -931,7 +937,7 @@ def replace_column(
_check_columns_dont_exist(self, [column.name for column in new_columns], old_name=old_name)

if len(new_columns) == 0:
return self.remove_columns(old_name)
return self.remove_columns(old_name, ignore_unknown_names=True)

if len(new_columns) == 1:
new_column = new_columns[0]
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/ml/classical/_supervised_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def _predict_with_sklearn_model(
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="X does not have valid feature names")
predicted_target_vector = model.predict(features._data_frame)
output = dataset.remove_columns(target_name).add_columns(
output = dataset.remove_columns(target_name, ignore_unknown_names=True).add_columns(
Column(target_name, predicted_target_vector),
)

Expand Down
2 changes: 1 addition & 1 deletion src/safeds/ml/classical/regression/_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def predict(self, time_series: TimeSeriesDataset) -> Table:
# make a table without
forecast_horizon = len(time_series.target._series.to_numpy())
result_table = time_series.to_table()
result_table = result_table.remove_columns([time_series.target.name])
result_table = result_table.remove_columns([time_series.target.name], ignore_unknown_names=True)
# Validation
if not self.is_fitted or self._arima is None:
raise ModelNotFittedError
Expand Down
59 changes: 50 additions & 9 deletions tests/safeds/data/tabular/containers/_table/test_remove_columns.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,67 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.exceptions import ColumnNotFoundError


# Test cases where no exception is expected
@pytest.mark.parametrize(
("table", "expected", "columns"),
("table", "expected", "columns", "ignore_unknown_names"),
[
(Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), []),
(Table(), Table(), []),
(Table(), Table(), ["col1"]),
(Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"], True),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"], True),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), [], True),
(Table(), Table(), [], True),
(Table(), Table(), ["col1"], True),
(Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"], False),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"], False),
(
Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}),
Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}),
[],
False,
),
(Table(), Table(), [], False),
],
ids=[
"one column, ignore unknown names",
"multiple columns, ignore unknown names",
"no columns, ignore unknown names",
"empty, ignore unknown names",
"missing columns, ignore unknown names",
"one column",
"multiple columns",
"no columns",
"empty",
"missing columns",
],
)
def test_should_remove_table_columns(table: Table, expected: Table, columns: list[str]) -> None:
table = table.remove_columns(columns)
def test_should_remove_table_columns_no_exception(
table: Table,
expected: Table,
columns: list[str],
ignore_unknown_names: bool,
) -> None:
table = table.remove_columns(columns, ignore_unknown_names=ignore_unknown_names)
assert table.schema == expected.schema
assert table == expected
assert table.row_count == expected.row_count


# Test cases where an exception is expected
@pytest.mark.parametrize(
("table", "columns", "ignore_unknown_names"),
[
(Table(), ["col1"], False),
(Table(), ["col12"], False),
],
ids=[
"missing columns",
"missing columns",
],
)
def test_should_raise_error_for_unknown_columns(
table: Table,
columns: list[str],
ignore_unknown_names: bool,
) -> None:
with pytest.raises(ColumnNotFoundError):
table.remove_columns(columns, ignore_unknown_names=ignore_unknown_names)
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_should_raise_if_not_fitted(self, classifier: Classifier, valid_data: Ta
def test_should_raise_if_dataset_misses_features(self, classifier: Classifier, valid_data: TabularDataset) -> None:
fitted_classifier = classifier.fit(valid_data)
with pytest.raises(DatasetMissesFeaturesError, match="[feat1, feat2]"):
fitted_classifier.predict(valid_data.features.remove_columns(["feat1", "feat2"]))
fitted_classifier.predict(valid_data.features.remove_columns(["feat1", "feat2"], ignore_unknown_names=True))

@pytest.mark.parametrize(
("invalid_data", "expected_error", "expected_error_msg"),
Expand Down
2 changes: 1 addition & 1 deletion tests/safeds/ml/classical/regression/test_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_should_raise_if_not_fitted(self, regressor: Regressor, valid_data: Tabu
def test_should_raise_if_dataset_misses_features(self, regressor: Regressor, valid_data: TabularDataset) -> None:
fitted_regressor = regressor.fit(valid_data)
with pytest.raises(DatasetMissesFeaturesError, match="[feat1, feat2]"):
fitted_regressor.predict(valid_data.features.remove_columns(["feat1", "feat2"]))
fitted_regressor.predict(valid_data.features.remove_columns(["feat1", "feat2"], ignore_unknown_names=True))

@pytest.mark.parametrize(
("invalid_data", "expected_error", "expected_error_msg"),
Expand Down
2 changes: 1 addition & 1 deletion tests/safeds/ml/nn/test_forward_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_forward_model(device: Device) -> None:
table_1 = Table.from_csv_file(
path=resolve_resource_path(_inflation_path),
)
table_1 = table_1.remove_columns(["date"])
table_1 = table_1.remove_columns(["date"], ignore_unknown_names=True)
table_2 = table_1.slice_rows(start=0, length=table_1.row_count - 14)
table_2 = table_2.add_columns([(table_1.slice_rows(start=14)).get_column("value").rename("target")])
train_table, test_table = table_2.split_rows(0.8)
Expand Down

0 comments on commit 8f78163

Please sign in to comment.