diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index d7cd139bc..5d0c58cc2 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import functools import io import warnings @@ -440,8 +441,8 @@ def __eq__(self, other: Any) -> bool: return NotImplemented if self is other: return True - if self.number_of_rows == 0 and other.number_of_rows == 0: - return self.column_names == other.column_names + if self.number_of_columns == 0 and other.number_of_columns == 0: + return True table1 = self.sort_columns() table2 = other.sort_columns() if table1.number_of_rows == 0 and table2.number_of_rows == 0: @@ -1093,9 +1094,9 @@ def keep_only_columns(self, column_names: list[str]) -> Table: if len(invalid_columns) != 0: raise UnknownColumnNameError(invalid_columns) - transformed_data = self._data[column_names] - transformed_data.columns = column_names - return Table._from_pandas_dataframe(transformed_data) + clone = copy.deepcopy(self) + clone = clone.remove_columns(list(set(self.column_names) - set(column_names))) + return clone def remove_columns(self, column_names: list[str]) -> Table: """ @@ -1138,6 +1139,10 @@ def remove_columns(self, column_names: list[str]) -> Table: transformed_data = self._data.drop(labels=column_names, axis="columns") transformed_data.columns = [name for name in self._schema.column_names if name not in column_names] + + if len(transformed_data.columns) == 0: + return Table() + return Table._from_pandas_dataframe(transformed_data) def remove_columns_with_missing_values(self) -> Table: diff --git a/tests/safeds/data/tabular/containers/_table/test_eq.py b/tests/safeds/data/tabular/containers/_table/test_eq.py index d0c5425a8..21f185340 100644 --- a/tests/safeds/data/tabular/containers/_table/test_eq.py +++ b/tests/safeds/data/tabular/containers/_table/test_eq.py @@ -8,14 +8,16 @@ ("table1", "table2", "expected"), [ (Table(), Table(), True), + (Table({"a": [], "b": []}), Table({"a": [], "b": []}), True), (Table({"col1": [1]}), Table({"col1": [1]}), True), (Table({"col1": [1]}), Table({"col2": [1]}), False), (Table({"col1": [1, 2, 3]}), Table({"col1": [1, 1, 3]}), False), (Table({"col1": [1, 2, 3]}), Table({"col1": ["1", "2", "3"]}), False), ], ids=[ - "empty Table", - "equal Tables", + "empty table", + "rowless table", + "equal tables", "different column names", "different values", "different types", diff --git a/tests/safeds/data/tabular/containers/_table/test_keep_only_columns.py b/tests/safeds/data/tabular/containers/_table/test_keep_only_columns.py index 84a5b27ac..522e0e068 100644 --- a/tests/safeds/data/tabular/containers/_table/test_keep_only_columns.py +++ b/tests/safeds/data/tabular/containers/_table/test_keep_only_columns.py @@ -9,7 +9,7 @@ ( Table({"A": [1], "B": [2]}), [], - Table({}), + Table(), ), ( Table({"A": [1], "B": [2]}), @@ -44,6 +44,8 @@ def test_should_keep_only_listed_columns(table: Table, column_names: list[str], transformed_table = table.keep_only_columns(column_names) assert transformed_table.schema == expected.schema assert transformed_table == expected + if len(column_names) == 0: + assert expected.number_of_rows == 0 @pytest.mark.parametrize("table", [Table({"A": [1], "B": [2]}), Table()], ids=["table", "empty"]) diff --git a/tests/safeds/data/tabular/containers/_table/test_remove_columns.py b/tests/safeds/data/tabular/containers/_table/test_remove_columns.py index 01fdf7af8..ec29837a9 100644 --- a/tests/safeds/data/tabular/containers/_table/test_remove_columns.py +++ b/tests/safeds/data/tabular/containers/_table/test_remove_columns.py @@ -4,19 +4,20 @@ @pytest.mark.parametrize( - ("table1", "expected", "columns"), + ("table", "expected", "columns"), [ (Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"]), - (Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({}), ["col1", "col2"]), + (Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"]), (Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), []), (Table(), Table(), []), ], ids=["one column", "multiple columns", "no columns", "empty"], ) -def test_should_remove_table_columns(table1: Table, expected: Table, columns: list[str]) -> None: - table1 = table1.remove_columns(columns) - assert table1.schema == expected.schema - assert table1 == expected +def test_should_remove_table_columns(table: Table, expected: Table, columns: list[str]) -> None: + table = table.remove_columns(columns) + assert table.schema == expected.schema + assert table == expected + assert table.number_of_rows == expected.number_of_rows @pytest.mark.parametrize("table", [Table({"A": [1], "B": [2]}), Table()], ids=["normal", "empty"])