Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added ability to Table.add_row and Table.add_rows to allow new rows with different schemas #342

Closed
38 changes: 36 additions & 2 deletions src/safeds/data/tabular/containers/_row.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Mapping
import functools
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any

import pandas as pd
Expand Down Expand Up @@ -263,7 +264,7 @@ def __repr__(self) -> str:
>>> repr(row)
"Row({'a': 1})"
"""
return f"Row({str(self)})"
return f"Row({self!s})"

def __str__(self) -> str:
"""
Expand Down Expand Up @@ -440,6 +441,39 @@ def get_column_type(self, column_name: str) -> ColumnType:
"""
return self._schema.get_column_type(column_name)

# ------------------------------------------------------------------------------------------------------------------
# Transformations
# ------------------------------------------------------------------------------------------------------------------

def sort_columns(
self,
comparator: Callable[[tuple, tuple], int] = lambda col1, col2: (col1[0] > col2[0]) - (col1[0] < col2[0]),
) -> Row:
"""
Sort the columns of a `Row` with the given comparator and return a new `Row`.

The original row is not modified. The comparator is a function that takes two Tuples of (ColumnName: Value) `col1` and `col2` and
returns an integer:

* If `col1` should be ordered before `col2`, the function should return a negative number.
* If `col1` should be ordered after `col2`, the function should return a positive number.
* If the original order of `col1` and `col2` should be kept, the function should return 0.

If no comparator is given, the columns will be sorted alphabetically by their name.

Parameters
----------
comparator : Callable[[Tuple, Tuple], int]
The function used to compare two Tuples of (ColumnName: Value).

Returns
-------
new_row : Row
A new row with sorted columns.
"""
sorted_row_dict = dict(sorted(self.to_dict().items(), key=functools.cmp_to_key(comparator)))
return Row.from_dict(sorted_row_dict)

# ------------------------------------------------------------------------------------------------------------------
# Conversion
# ------------------------------------------------------------------------------------------------------------------
Expand Down
76 changes: 53 additions & 23 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
DuplicateColumnNameError,
IndexOutOfBoundsError,
NonNumericColumnError,
SchemaMismatchError,
UnknownColumnNameError,
WrongFileExtensionError,
)
Expand Down Expand Up @@ -245,23 +244,28 @@ def from_rows(rows: list[Row]) -> Table:

Raises
------
SchemaMismatchError
If any of the row schemas does not match with the others.
UnknownColumnNameError
If any of the row column names does not match with the first row.
"""
if len(rows) == 0:
return Table._from_pandas_dataframe(pd.DataFrame())

schema_compare: Schema = rows[0]._schema
column_names_compare: list = list(rows[0].column_names)
unknown_column_names = set()
row_array: list[pd.DataFrame] = []

for row in rows:
if schema_compare != row._schema:
raise SchemaMismatchError
unknown_column_names.update(set(column_names_compare) - set(row.column_names))
row_array.append(row._data)
if len(unknown_column_names) > 0:
raise UnknownColumnNameError(list(unknown_column_names))

dataframe: DataFrame = pd.concat(row_array, ignore_index=True)
dataframe.columns = schema_compare.column_names
return Table._from_pandas_dataframe(dataframe)
dataframe.columns = column_names_compare

schema = Schema.merge_multiple_schemas([row.schema for row in rows])

return Table._from_pandas_dataframe(dataframe, schema)

@staticmethod
def _from_pandas_dataframe(data: pd.DataFrame, schema: Schema | None = None) -> Table:
Expand Down Expand Up @@ -636,7 +640,8 @@ def add_row(self, row: Row) -> Table:
"""
Add a row to the table.

This table is not modified.
The order of columns of the new row will be adjusted to the order of columns in the table.
This table will contain the merged schema.

Parameters
----------
Expand All @@ -650,21 +655,30 @@ def add_row(self, row: Row) -> Table:

Raises
------
SchemaMismatchError
If the schema of the row does not match the table schema.
UnknownColumnNameError
If the row has different column names than the table.
"""
if self._schema != row.schema:
raise SchemaMismatchError
if self.number_of_columns == 0:
return Table.from_rows([row])

if len(set(self.column_names) - set(row.column_names)) > 0:
raise UnknownColumnNameError(list(set(self.column_names) - set(row.column_names)))

row = row.sort_columns(lambda col1, col2: self.column_names.index(col2[0]) - self.column_names.index(col1[0]))

new_df = pd.concat([self._data, row._data]).infer_objects()
new_df.columns = self.column_names
return Table._from_pandas_dataframe(new_df)

schema = Schema.merge_multiple_schemas([self.schema, row.schema])

return Table._from_pandas_dataframe(new_df, schema)

def add_rows(self, rows: list[Row] | Table) -> Table:
"""
Add multiple rows to a table.

This table is not modified.
The order of columns of the new rows will be adjusted to the order of columns in the table.
This table will contain the merged schema.

Parameters
----------
Expand All @@ -678,21 +692,39 @@ def add_rows(self, rows: list[Row] | Table) -> Table:

Raises
------
SchemaMismatchError
If the schema of on of the row does not match the table schema.
UnknownColumnNameError
If at least one of the rows have different column names than the table.
"""
if isinstance(rows, Table):
rows = rows.to_rows()
result = self._data

if self.number_of_columns == 0:
return Table.from_rows(rows)

missing_col_names = set()
for row in rows:
if self._schema != row.schema:
raise SchemaMismatchError
missing_col_names.update(set(self.column_names) - set(row.column_names))
if len(missing_col_names) > 0:
raise UnknownColumnNameError(list(missing_col_names))

sorted_rows = []
for row in rows:
sorted_rows.append(
row.sort_columns(
lambda col1, col2: self.column_names.index(col2[0]) - self.column_names.index(col1[0]),
),
)
rows = sorted_rows

result = self._data
row_frames = (row._data for row in rows)

result = pd.concat([result, *row_frames]).infer_objects()
result.columns = self.column_names
return Table._from_pandas_dataframe(result)

schema = Schema.merge_multiple_schemas([self.schema, *[row.schema for row in rows]])

return Table._from_pandas_dataframe(result, schema)

def filter_rows(self, query: Callable[[Row], bool]) -> Table:
"""
Expand Down Expand Up @@ -1025,8 +1057,6 @@ def sort_columns(

If no comparator is given, the columns will be sorted alphabetically by their name.

This table is not modified.

Parameters
----------
comparator : Callable[[Column, Column], int]
Expand Down
4 changes: 4 additions & 0 deletions src/safeds/data/tabular/typing/_column_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
class ColumnType(ABC):
"""Abstract base class for column types."""

@abstractmethod
def __init__(self, is_nullable: bool = False):
pass

@staticmethod
def _from_numpy_data_type(data_type: np.dtype) -> ColumnType:
"""
Expand Down
60 changes: 59 additions & 1 deletion src/safeds/data/tabular/typing/_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from safeds.data.tabular.typing import Anything, Integer, RealNumber
from safeds.data.tabular.typing._column_type import ColumnType
from safeds.exceptions import UnknownColumnNameError

Expand Down Expand Up @@ -95,7 +96,7 @@ def __repr__(self) -> str:
>>> repr(schema)
"Schema({'A': Integer})"
"""
return f"Schema({str(self)})"
return f"Schema({self!s})"

def __str__(self) -> str:
"""
Expand Down Expand Up @@ -220,6 +221,63 @@ def to_dict(self) -> dict[str, ColumnType]:
"""
return dict(self._schema) # defensive copy

@staticmethod
def merge_multiple_schemas(schemas: list[Schema]) -> Schema:
"""
Merge multiple schemas into one.

For each type missmatch the new schema will have the least common supertype.

The type hierarchy is as follows:
* Anything
* RealNumber
* Integer
* Boolean
* String

Parameters
----------
schemas : list[Schema]
the list of schemas you want to merge

Returns
-------
schema : Schema
the new merged schema

Raises
------
UnknownColumnNameError
if not all schemas have the same column names
"""
schema_dict = schemas[0]._schema
missing_col_names = set()
for schema in schemas:
missing_col_names.update(set(schema.column_names) - set(schema_dict.keys()))
if len(missing_col_names) > 0:
raise UnknownColumnNameError(list(missing_col_names))
for schema in schemas:
if schema_dict != schema._schema:
for col_name in schema_dict:
nullable = False
if schema_dict[col_name].is_nullable() or schema.get_column_type(col_name).is_nullable():
nullable = True
if isinstance(schema_dict[col_name], type(schema.get_column_type(col_name))):
if schema.get_column_type(col_name).is_nullable() and not schema_dict[col_name].is_nullable():
schema_dict[col_name] = type(schema.get_column_type(col_name))(nullable)
continue
if (
isinstance(schema_dict[col_name], RealNumber)
and isinstance(schema.get_column_type(col_name), Integer)
) or (
isinstance(schema_dict[col_name], Integer)
and isinstance(schema.get_column_type(col_name), RealNumber)
):
schema_dict[col_name] = RealNumber(nullable)
continue
schema_dict[col_name] = Anything(nullable)
return Schema(schema_dict)

# ------------------------------------------------------------------------------------------------------------------
# IPython Integration
# ------------------------------------------------------------------------------------------------------------------
Expand Down
51 changes: 39 additions & 12 deletions tests/safeds/data/tabular/containers/_table/test_add_row.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,52 @@
import pytest
from _pytest.python_api import raises
from safeds.data.tabular.containers import Row, Table
from safeds.exceptions import SchemaMismatchError
from safeds.data.tabular.typing import Anything, Integer, Schema
from safeds.exceptions import UnknownColumnNameError


@pytest.mark.parametrize(
("table", "row"),
("table", "row", "expected", "expected_schema"),
[
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Row({"col1": 5, "col2": 6})),
(
Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}),
Row({"col1": 5, "col2": 6}),
Table({"col1": [1, 2, 1, 5], "col2": [1, 2, 4, 6]}),
Schema({"col1": Integer(), "col2": Integer()}),
),
(
Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}),
Row({"col1": "5", "col2": 6}),
Table({"col1": [1, 2, 1, "5"], "col2": [1, 2, 4, 6]}),
Schema({"col1": Anything(), "col2": Integer()}),
),
(
Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}),
Table.from_rows([Row({"col1": "5", "col2": None}), Row({"col1": "5", "col2": 2})]).get_row(0),
Table({"col1": [1, 2, 1, "5"], "col2": [1, 2, 4, None]}),
Schema({"col1": Anything(), "col2": Integer(is_nullable=True)}),
),
],
ids=["added row"],
ids=["added row", "different schemas", "different schemas and nullable"],
)
def test_should_add_row(table: Table, row: Row) -> None:
def test_should_add_row(table: Table, row: Row, expected: Table, expected_schema: Schema) -> None:
table = table.add_row(row)
assert table.number_of_rows == 4
assert table.get_row(3) == row
assert table.schema == row._schema
assert table.schema == expected_schema
assert table == expected


def test_should_raise_error_if_row_schema_invalid() -> None:
table1 = Table({"col1": [1, 2, 1], "col2": [1, 2, 4]})
row = Row({"col1": 5, "col2": "Hallo"})
with raises(SchemaMismatchError, match=r"Failed because at least two schemas didn't match."):
table1.add_row(row)
@pytest.mark.parametrize(
("table", "row", "expected_error_msg"),
[
(
Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}),
Row({"col1": 5, "col3": "Hallo"}),
r"Could not find column\(s\) 'col2'",
),
],
ids=["unknown column col2 in row"],
)
def test_should_raise_error_if_row_column_names_invalid(table: Table, row: Row, expected_error_msg: str) -> None:
with raises(UnknownColumnNameError, match=expected_error_msg):
table.add_row(row)
Loading