Skip to content

Commit

Permalink
feat: enhance replace_column to accept a list of new columns (#312)
Browse files Browse the repository at this point in the history
Closes #301.

### Summary of Changes

`replace_column` should accept a list of `new_columns` (instead of the
original `new_column` parameter) and replace the original column with
the new columns now.

All new columns should be inserted at the position of the old column.
New columns should be ordered as defined in the list.

---------

Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Severin Paul Höfer <[email protected]>
  • Loading branch information
3 people authored Jun 9, 2023
1 parent 498999f commit d50c5b5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 41 deletions.
46 changes: 22 additions & 24 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def add_row(self, row: Row) -> Table:
result = Table._from_pandas_dataframe(new_df)

for column in int_columns:
result = result.replace_column(column, result.get_column(column).transform(lambda it: int(it)))
result = result.replace_column(column, [result.get_column(column).transform(lambda it: int(it))])

return result

Expand Down Expand Up @@ -768,7 +768,7 @@ def add_rows(self, rows: list[Row] | Table) -> Table:
result = Table._from_pandas_dataframe(new_df)

for column in int_columns:
result = result.replace_column(column, result.get_column(column).transform(lambda it: int(it)))
result = result.replace_column(column, [result.get_column(column).transform(lambda it: int(it))])

return result

Expand Down Expand Up @@ -1001,9 +1001,9 @@ def rename_column(self, old_name: str, new_name: str) -> Table:
new_df.columns = self._schema.column_names
return Table._from_pandas_dataframe(new_df.rename(columns={old_name: new_name}))

def replace_column(self, old_column_name: str, new_column: Column) -> Table:
def replace_column(self, old_column_name: str, new_columns: list[Column]) -> Table:
"""
Return a copy of the table with the specified old column replaced by a new column. Keeps the order of columns.
Return a copy of the table with the specified old column replaced by a list of new columns. Keeps the order of columns.
This table is not modified.
Expand All @@ -1012,44 +1012,42 @@ def replace_column(self, old_column_name: str, new_column: Column) -> Table:
old_column_name : str
The name of the column to be replaced.
new_column : Column
The new column replacing the old column.
new_columns : list[Column]
The list of new columns replacing the old column.
Returns
-------
result : Table
A table with the old column replaced by the new column.
A table with the old column replaced by the new columns.
Raises
------
UnknownColumnNameError
If the old column does not exist.
DuplicateColumnNameError
If the new column already exists and the existing column is not affected by the replacement.
If at least one of the new columns already exists and the existing column is not affected by the replacement.
ColumnSizeError
If the size of the column does not match the amount of rows.
If the size of at least one of the new columns does not match the amount of rows.
"""
if old_column_name not in self._schema.column_names:
raise UnknownColumnNameError([old_column_name])

if new_column.name in self._schema.column_names and new_column.name != old_column_name:
raise DuplicateColumnNameError(new_column.name)

if self.number_of_rows != new_column._data.size:
raise ColumnSizeError(str(self.number_of_rows), str(new_column._data.size))
columns = list[Column]()
for old_column in self.column_names:
if old_column == old_column_name:
for new_column in new_columns:
if new_column.name in self.column_names and new_column.name != old_column_name:
raise DuplicateColumnNameError(new_column.name)

if old_column_name != new_column.name:
renamed_table = self.rename_column(old_column_name, new_column.name)
result = renamed_table._data
result.columns = renamed_table._schema.column_names
else:
result = self._data.copy()
result.columns = self._schema.column_names
if self.number_of_rows != new_column.number_of_rows:
raise ColumnSizeError(str(self.number_of_rows), str(new_column.number_of_rows))
columns.append(new_column)
else:
columns.append(self.get_column(old_column))

result[new_column.name] = new_column._data
return Table._from_pandas_dataframe(result)
return Table.from_columns(columns)

def shuffle_rows(self) -> Table:
"""
Expand Down Expand Up @@ -1251,7 +1249,7 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl
"""
if self.has_column(name):
items: list = [transformer(item) for item in self.to_rows()]
result: Column = Column(name, items)
result: list[Column] = [Column(name, items)]
return self.replace_column(name, result)
raise UnknownColumnNameError([name])

Expand Down
44 changes: 27 additions & 17 deletions tests/safeds/data/tabular/containers/_table/test_replace_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.mark.parametrize(
("table", "column_name", "column", "expected"),
("table", "column_name", "columns", "expected"),
[
(
Table(
Expand All @@ -18,13 +18,14 @@
"C": ["a", "b", "c"],
},
),
"C",
Column("C", ["d", "e", "f"]),
"B",
[Column("B", ["d", "e", "f"]), Column("D", [3, 4, 5])],
Table(
{
"A": [1, 2, 3],
"B": [4, 5, 6],
"C": ["d", "e", "f"],
"B": ["d", "e", "f"],
"D": [3, 4, 5],
"C": ["a", "b", "c"],
},
),
),
Expand All @@ -37,7 +38,7 @@
},
),
"C",
Column("D", ["d", "e", "f"]),
[Column("D", ["d", "e", "f"])],
Table(
{
"A": [1, 2, 3],
Expand All @@ -47,26 +48,36 @@
),
),
],
ids=["multiple Columns", "one Column"],
)
def test_should_replace_column(table: Table, column_name: str, column: Column, expected: Table) -> None:
result = table.replace_column(column_name, column)
assert result.schema == expected.schema
def test_should_replace_column(table: Table, column_name: str, columns: list[Column], expected: Table) -> None:
result = table.replace_column(column_name, columns)
assert result._schema == expected._schema
assert result == expected


@pytest.mark.parametrize(
("old_column_name", "column_values", "column_name", "error", "error_message"),
("old_column_name", "column", "error", "error_message"),
[
("D", ["d", "e", "f"], "C", UnknownColumnNameError, r"Could not find column\(s\) 'D'"),
("C", ["d", "e", "f"], "B", DuplicateColumnNameError, r"Column 'B' already exists."),
("C", ["d", "e"], "D", ColumnSizeError, r"Expected a column of size 3 but got column of size 2."),
("D", [Column("C", ["d", "e", "f"])], UnknownColumnNameError, r"Could not find column\(s\) 'D'"),
(
"C",
[Column("B", ["d", "e", "f"]), Column("D", [3, 2, 1])],
DuplicateColumnNameError,
r"Column 'B' already exists.",
),
(
"C",
[Column("D", [7, 8]), Column("E", ["c", "b"])],
ColumnSizeError,
r"Expected a column of size 3 but got column of size 2.",
),
],
ids=["UnknownColumnNameError", "DuplicateColumnNameError", "ColumnSizeError"],
)
def test_should_raise_error(
old_column_name: str,
column_values: list[str],
column_name: str,
column: list[Column],
error: type[Exception],
error_message: str,
) -> None:
Expand All @@ -77,12 +88,11 @@ def test_should_raise_error(
"C": ["a", "b", "c"],
},
)
column = Column(column_name, column_values)

with pytest.raises(error, match=error_message):
input_table.replace_column(old_column_name, column)


def test_should_fail_on_empty_table() -> None:
with pytest.raises(UnknownColumnNameError):
Table().replace_column("col", Column("a", [1, 2]))
Table().replace_column("col", [Column("a", [1, 2])])

0 comments on commit d50c5b5

Please sign in to comment.