Skip to content

Commit

Permalink
fix: Keeping no columns with Table.keep_only_columns results in an em…
Browse files Browse the repository at this point in the history
…pty Table with a row count above 0 (#386)

Closes #318 

Fixes Bug:

If you use the method Table.keep_only_columns with an empty list on a
table that contains at least one row, you get an empty table with a row
count above 0.

### Summary of Changes

<!-- Please provide a summary of changes in this pull request, ensuring
all changes are explained. -->

---------

Co-authored-by: patrikguempel <[email protected]>
Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Simon Breuer <[email protected]>
  • Loading branch information
4 people authored Jun 24, 2023
1 parent 2979f24 commit 15dab06
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
15 changes: 10 additions & 5 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import functools
import io
import warnings
Expand Down Expand Up @@ -440,8 +441,8 @@ def __eq__(self, other: Any) -> bool:
return NotImplemented
if self is other:
return True
if self.number_of_rows == 0 and other.number_of_rows == 0:
return self.column_names == other.column_names
if self.number_of_columns == 0 and other.number_of_columns == 0:
return True
table1 = self.sort_columns()
table2 = other.sort_columns()
if table1.number_of_rows == 0 and table2.number_of_rows == 0:
Expand Down Expand Up @@ -1093,9 +1094,9 @@ def keep_only_columns(self, column_names: list[str]) -> Table:
if len(invalid_columns) != 0:
raise UnknownColumnNameError(invalid_columns)

transformed_data = self._data[column_names]
transformed_data.columns = column_names
return Table._from_pandas_dataframe(transformed_data)
clone = copy.deepcopy(self)
clone = clone.remove_columns(list(set(self.column_names) - set(column_names)))
return clone

def remove_columns(self, column_names: list[str]) -> Table:
"""
Expand Down Expand Up @@ -1138,6 +1139,10 @@ def remove_columns(self, column_names: list[str]) -> Table:

transformed_data = self._data.drop(labels=column_names, axis="columns")
transformed_data.columns = [name for name in self._schema.column_names if name not in column_names]

if len(transformed_data.columns) == 0:
return Table()

return Table._from_pandas_dataframe(transformed_data)

def remove_columns_with_missing_values(self) -> Table:
Expand Down
6 changes: 4 additions & 2 deletions tests/safeds/data/tabular/containers/_table/test_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
("table1", "table2", "expected"),
[
(Table(), Table(), True),
(Table({"a": [], "b": []}), Table({"a": [], "b": []}), True),
(Table({"col1": [1]}), Table({"col1": [1]}), True),
(Table({"col1": [1]}), Table({"col2": [1]}), False),
(Table({"col1": [1, 2, 3]}), Table({"col1": [1, 1, 3]}), False),
(Table({"col1": [1, 2, 3]}), Table({"col1": ["1", "2", "3"]}), False),
],
ids=[
"empty Table",
"equal Tables",
"empty table",
"rowless table",
"equal tables",
"different column names",
"different values",
"different types",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
(
Table({"A": [1], "B": [2]}),
[],
Table({}),
Table(),
),
(
Table({"A": [1], "B": [2]}),
Expand Down Expand Up @@ -44,6 +44,8 @@ def test_should_keep_only_listed_columns(table: Table, column_names: list[str],
transformed_table = table.keep_only_columns(column_names)
assert transformed_table.schema == expected.schema
assert transformed_table == expected
if len(column_names) == 0:
assert expected.number_of_rows == 0


@pytest.mark.parametrize("table", [Table({"A": [1], "B": [2]}), Table()], ids=["table", "empty"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@


@pytest.mark.parametrize(
("table1", "expected", "columns"),
("table", "expected", "columns"),
[
(Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({}), ["col1", "col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), []),
(Table(), Table(), []),
],
ids=["one column", "multiple columns", "no columns", "empty"],
)
def test_should_remove_table_columns(table1: Table, expected: Table, columns: list[str]) -> None:
table1 = table1.remove_columns(columns)
assert table1.schema == expected.schema
assert table1 == expected
def test_should_remove_table_columns(table: Table, expected: Table, columns: list[str]) -> None:
table = table.remove_columns(columns)
assert table.schema == expected.schema
assert table == expected
assert table.number_of_rows == expected.number_of_rows


@pytest.mark.parametrize("table", [Table({"A": [1], "B": [2]}), Table()], ids=["normal", "empty"])
Expand Down

0 comments on commit 15dab06

Please sign in to comment.