Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Keeping no columns with Table.keep_only_columns results in an empty Table with a row count above 0 #386

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import functools
import io
import warnings
Expand Down Expand Up @@ -440,8 +441,8 @@ def __eq__(self, other: Any) -> bool:
return NotImplemented
if self is other:
return True
if self.number_of_rows == 0 and other.number_of_rows == 0:
return self.column_names == other.column_names
if self.number_of_columns == 0 and other.number_of_columns == 0:
return True
table1 = self.sort_columns()
table2 = other.sort_columns()
if table1.number_of_rows == 0 and table2.number_of_rows == 0:
Expand Down Expand Up @@ -1093,9 +1094,9 @@ def keep_only_columns(self, column_names: list[str]) -> Table:
if len(invalid_columns) != 0:
raise UnknownColumnNameError(invalid_columns)

transformed_data = self._data[column_names]
transformed_data.columns = column_names
return Table._from_pandas_dataframe(transformed_data)
clone = copy.deepcopy(self)
clone = clone.remove_columns(list(set(self.column_names) - set(column_names)))
return clone

def remove_columns(self, column_names: list[str]) -> Table:
"""
Expand Down Expand Up @@ -1138,6 +1139,10 @@ def remove_columns(self, column_names: list[str]) -> Table:

transformed_data = self._data.drop(labels=column_names, axis="columns")
transformed_data.columns = [name for name in self._schema.column_names if name not in column_names]

if len(transformed_data.columns) == 0:
return Table()

return Table._from_pandas_dataframe(transformed_data)

def remove_columns_with_missing_values(self) -> Table:
Expand Down
6 changes: 4 additions & 2 deletions tests/safeds/data/tabular/containers/_table/test_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
("table1", "table2", "expected"),
[
(Table(), Table(), True),
(Table({"a": [], "b": []}), Table({"a": [], "b": []}), True),
(Table({"col1": [1]}), Table({"col1": [1]}), True),
(Table({"col1": [1]}), Table({"col2": [1]}), False),
(Table({"col1": [1, 2, 3]}), Table({"col1": [1, 1, 3]}), False),
(Table({"col1": [1, 2, 3]}), Table({"col1": ["1", "2", "3"]}), False),
],
ids=[
"empty Table",
"equal Tables",
"empty table",
"rowless table",
"equal tables",
"different column names",
"different values",
"different types",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
(
Table({"A": [1], "B": [2]}),
[],
Table({}),
Table(),
),
(
Table({"A": [1], "B": [2]}),
Expand Down Expand Up @@ -44,6 +44,8 @@ def test_should_keep_only_listed_columns(table: Table, column_names: list[str],
transformed_table = table.keep_only_columns(column_names)
assert transformed_table.schema == expected.schema
assert transformed_table == expected
if len(column_names) == 0:
assert expected.number_of_rows == 0


@pytest.mark.parametrize("table", [Table({"A": [1], "B": [2]}), Table()], ids=["table", "empty"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@


@pytest.mark.parametrize(
("table1", "expected", "columns"),
("table", "expected", "columns"),
[
(Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({}), ["col1", "col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), []),
(Table(), Table(), []),
],
ids=["one column", "multiple columns", "no columns", "empty"],
)
def test_should_remove_table_columns(table1: Table, expected: Table, columns: list[str]) -> None:
table1 = table1.remove_columns(columns)
assert table1.schema == expected.schema
assert table1 == expected
def test_should_remove_table_columns(table: Table, expected: Table, columns: list[str]) -> None:
table = table.remove_columns(columns)
assert table.schema == expected.schema
assert table == expected
assert table.number_of_rows == expected.number_of_rows


@pytest.mark.parametrize("table", [Table({"A": [1], "B": [2]}), Table()], ids=["normal", "empty"])
Expand Down