diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index 2bb4c6d2b..176585411 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -797,6 +797,22 @@ def summary(self) -> Table: # Transformations # ------------------------------------------------------------------------------------------------------------------ + # This method is meant as a way to "cast" instances of subclasses of `Table` to a proper `Table`, dropping any + # additional constraints that might have to hold in the subclass. + # Override accordingly in subclasses. + def _as_table(self: Table) -> Table: + """ + Transform the table to an instance of the Table class. + + The original table is not modified. + + Returns + ------- + table: Table + The table, as an instance of the Table class. + """ + return self + def add_column(self, column: Column) -> Table: """ Return the original table with the provided column attached at the end. @@ -888,8 +904,9 @@ def add_row(self, row: Row) -> Table: """ Add a row to the table. + If the table happens to be empty beforehand, respective columns will be added automatically. + This table is not modified. - If the table happens to be empty beforehand, respective features will be added automatically. Parameters ---------- @@ -1077,6 +1094,8 @@ def keep_only_columns(self, column_names: list[str]) -> Table: ------ UnknownColumnNameError If any of the given columns does not exist. + IllegalSchemaModificationError + If removing the columns would violate an invariant in the subclass. Examples -------- @@ -1120,6 +1139,8 @@ def remove_columns(self, column_names: list[str]) -> Table: ------ UnknownColumnNameError If any of the given columns does not exist. + IllegalSchemaModificationError + If removing the columns would violate an invariant in the subclass. Examples -------- @@ -1158,6 +1179,11 @@ def remove_columns_with_missing_values(self) -> Table: table : Table A table without the columns that contain missing values. + Raises + ------ + IllegalSchemaModificationError + If removing the columns would violate an invariant in the subclass. + Examples -------- >>> from safeds.data.tabular.containers import Table @@ -1182,6 +1208,11 @@ def remove_columns_with_non_numerical_values(self) -> Table: table : Table A table without the columns that contain non-numerical values. + Raises + ------ + IllegalSchemaModificationError + If removing the columns would violate an invariant in the subclass. + Examples -------- >>> from safeds.data.tabular.containers import Table @@ -1331,7 +1362,9 @@ def rename_column(self, old_name: str, new_name: str) -> Table: def replace_column(self, old_column_name: str, new_columns: list[Column]) -> Table: """ - Return a copy of the table with the specified old column replaced by a list of new columns. Keeps the order of columns. + Return a copy of the table with the specified old column replaced by a list of new columns. + + The order of columns is kept. This table is not modified. @@ -1352,12 +1385,12 @@ def replace_column(self, old_column_name: str, new_columns: list[Column]) -> Tab ------ UnknownColumnNameError If the old column does not exist. - DuplicateColumnNameError If at least one of the new columns already exists and the existing column is not affected by the replacement. - ColumnSizeError If the size of at least one of the new columns does not match the amount of rows. + IllegalSchemaModificationError + If replacing the column would violate an invariant in the subclass. Examples -------- @@ -1475,7 +1508,7 @@ def sort_columns( """ Sort the columns of a `Table` with the given comparator and return a new `Table`. - The original table is not modified. The comparator is a function that takes two columns `col1` and `col2` and + The comparator is a function that takes two columns `col1` and `col2` and returns an integer: * If `col1` should be ordered before `col2`, the function should return a negative number. @@ -1519,7 +1552,7 @@ def sort_rows(self, comparator: Callable[[Row, Row], int]) -> Table: """ Sort the rows of a `Table` with the given comparator and return a new `Table`. - The original table is not modified. The comparator is a function that takes two rows `row1` and `row2` and + The comparator is a function that takes two rows `row1` and `row2` and returns an integer: * If `row1` should be ordered before `row2`, the function should return a negative number. @@ -1695,6 +1728,8 @@ def transform_table(self, transformer: TableTransformer) -> Table: ------ TransformerNotFittedError If the transformer has not been fitted yet. + IllegalSchemaModificationError + If replacing the column would violate an invariant in the subclass. Examples -------- diff --git a/src/safeds/data/tabular/containers/_tagged_table.py b/src/safeds/data/tabular/containers/_tagged_table.py index 33ab14d67..0f3c6e674 100644 --- a/src/safeds/data/tabular/containers/_tagged_table.py +++ b/src/safeds/data/tabular/containers/_tagged_table.py @@ -3,11 +3,11 @@ import copy from typing import TYPE_CHECKING -from safeds.data.tabular.containers import Column, Table -from safeds.exceptions import UnknownColumnNameError +from safeds.data.tabular.containers import Column, Row, Table +from safeds.exceptions import ColumnIsTargetError, IllegalSchemaModificationError, UnknownColumnNameError if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Callable, Mapping, Sequence from typing import Any @@ -82,6 +82,7 @@ def _from_table( >>> table = Table({"col1": ["a", "b", "c", "a"], "col2": [1, 2, 3, 4]}) >>> tagged_table = TaggedTable._from_table(table, "col2", ["col1"]) """ + table = table._as_table() if target_name not in table.column_names: raise UnknownColumnNameError([target_name]) @@ -101,8 +102,8 @@ def _from_table( result._data = table._data result._schema = table.schema - result._features = result.keep_only_columns(feature_names) - result._target = result.get_column(target_name) + result._features = table.keep_only_columns(feature_names) + result._target = table.get_column(target_name) return result @@ -143,10 +144,11 @@ def __init__( >>> table = TaggedTable({"a": [1, 2, 3], "b": [4, 5, 6]}, "b", ["a"]) """ super().__init__(data) + _data = Table(data) # If no feature names are specified, use all columns except the target column if feature_names is None: - feature_names = self.column_names + feature_names = _data.column_names if target_name in feature_names: feature_names.remove(target_name) @@ -156,8 +158,8 @@ def __init__( if len(feature_names) == 0: raise ValueError("At least one feature column must be specified.") - self._features: Table = self.keep_only_columns(feature_names) - self._target: Column = self.get_column(target_name) + self._features: Table = _data.keep_only_columns(feature_names) + self._target: Column = _data.get_column(target_name) # ------------------------------------------------------------------------------------------------------------------ # Properties @@ -185,3 +187,630 @@ def _copy(self) -> TaggedTable: The copy of this tagged table. """ return copy.deepcopy(self) + + # ------------------------------------------------------------------------------------------------------------------ + # Specific methods from TaggedTable class: + # ------------------------------------------------------------------------------------------------------------------ + + def add_column_as_feature(self, column: Column) -> TaggedTable: + """ + Return the original table with the provided column attached at the end, as a feature column. + + This table is not modified. + + Returns + ------- + result : TaggedTable + The table with the attached feature column. + + Raises + ------ + DuplicateColumnNameError + If the new column already exists. + ColumnSizeError + If the size of the column does not match the amount of rows. + """ + return TaggedTable._from_table( + super().add_column(column), + target_name=self.target.name, + feature_names=[*self.features.column_names, column.name], + ) + + def add_columns_as_features(self, columns: list[Column] | Table) -> TaggedTable: + """ + Return the original table with the provided columns attached at the end, as feature columns. + + This table is not modified. + + Returns + ------- + result : TaggedTable + The table with the attached feature columns. + + Raises + ------ + DuplicateColumnNameError + If the new column already exists. + ColumnSizeError + If the size of the column does not match the amount of rows. + """ + return TaggedTable._from_table( + super().add_columns(columns), + target_name=self.target.name, + feature_names=self.features.column_names + + [col.name for col in (columns.to_columns() if isinstance(columns, Table) else columns)], + ) + + # ------------------------------------------------------------------------------------------------------------------ + # Overriden methods from Table class: + # ------------------------------------------------------------------------------------------------------------------ + + def _as_table(self: TaggedTable) -> Table: + """ + Remove the tagging from a TaggedTable. + + The original TaggedTable is not modified. + + Parameters + ---------- + self: TaggedTable + The TaggedTable. + + Returns + ------- + table: Table + The table as an untagged Table, i.e. without the information about which columns are features or target. + + """ + return Table.from_columns(super().to_columns()) + + def add_column(self, column: Column) -> TaggedTable: + """ + Return the original table with the provided column attached at the end, as neither target nor feature column. + + This table is not modified. + + Returns + ------- + result : TaggedTable + The table with the column attached as neither target nor feature column. + + Raises + ------ + DuplicateColumnNameError + If the new column already exists. + ColumnSizeError + If the size of the column does not match the amount of rows. + """ + return TaggedTable._from_table( + super().add_column(column), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def add_columns(self, columns: list[Column] | Table) -> TaggedTable: + """ + Add multiple columns to the table, as neither target nor feature columns. + + This table is not modified. + + Parameters + ---------- + columns : list[Column] or Table + The columns to be added. + + Returns + ------- + result: TaggedTable + A new table combining the original table and the given columns as neither target nor feature columns. + + Raises + ------ + ColumnSizeError + If at least one of the column sizes from the provided column list does not match the table. + DuplicateColumnNameError + If at least one column name from the provided column list already exists in the table. + """ + return TaggedTable._from_table( + super().add_columns(columns), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def add_row(self, row: Row) -> TaggedTable: + """ + Add a row to the table. + + This table is not modified. + + Parameters + ---------- + row : Row + The row to be added. + + Returns + ------- + table : TaggedTable + A new table with the added row at the end. + + Raises + ------ + SchemaMismatchError + If the schema of the row does not match the table schema. + """ + return TaggedTable._from_table(super().add_row(row), target_name=self.target.name) + + def add_rows(self, rows: list[Row] | Table) -> TaggedTable: + """ + Add multiple rows to the table. + + This table is not modified. + + Parameters + ---------- + rows : list[Row] or Table + The rows to be added. + + Returns + ------- + result : TaggedTable + A new table which combines the original table and the given rows. + + Raises + ------ + SchemaMismatchError + If the schema of on of the row does not match the table schema. + """ + return TaggedTable._from_table(super().add_rows(rows), target_name=self.target.name) + + def filter_rows(self, query: Callable[[Row], bool]) -> TaggedTable: + """ + Return a table containing only rows that match the given Callable (e.g. lambda function). + + This table is not modified. + + Parameters + ---------- + query : lambda function + A Callable that is applied to all rows. + + Returns + ------- + table : TaggedTable + A table containing only the rows to match the query. + """ + return TaggedTable._from_table( + super().filter_rows(query), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def keep_only_columns(self, column_names: list[str]) -> TaggedTable: + """ + Return a table with only the given column(s). + + This table is not modified. + + Parameters + ---------- + column_names : list[str] + A list containing only the columns to be kept. + + Returns + ------- + table : TaggedTable + A table containing only the given column(s). + + Raises + ------ + UnknownColumnNameError + If any of the given columns does not exist. + IllegalSchemaModificationError + If none of the given columns is the target column or any of the feature columns. + """ + if self.target.name not in column_names: + raise IllegalSchemaModificationError("Must keep the target column.") + if len(set(self.features.column_names).intersection(set(column_names))) == 0: + raise IllegalSchemaModificationError("Must keep at least one feature column.") + return TaggedTable._from_table( + super().keep_only_columns(column_names), + target_name=self.target.name, + feature_names=sorted( + set(self.features.column_names).intersection(set(column_names)), + key={val: ix for ix, val in enumerate(self.features.column_names)}.__getitem__, + ), + ) + + def remove_columns(self, column_names: list[str]) -> TaggedTable: + """ + Remove the given column(s) from the table. + + This table is not modified. + + Parameters + ---------- + column_names : list[str] + The names of all columns to be dropped. + + Returns + ------- + table : TaggedTable + A table without the given columns. + + Raises + ------ + UnknownColumnNameError + If any of the given columns does not exist. + ColumnIsTargetError + If any of the given columns is the target column. + IllegalSchemaModificationError + If the given columns contain all the feature columns. + """ + if self.target.name in column_names: + raise ColumnIsTargetError(self.target.name) + if len(set(self.features.column_names) - set(column_names)) == 0: + raise IllegalSchemaModificationError("You cannot remove every feature column.") + return TaggedTable._from_table( + super().remove_columns(column_names), + target_name=self.target.name, + feature_names=sorted( + set(self.features.column_names) - set(column_names), + key={val: ix for ix, val in enumerate(self.features.column_names)}.__getitem__, + ), + ) + + def remove_columns_with_missing_values(self) -> TaggedTable: + """ + Remove every column that misses values. + + This table is not modified. + + Returns + ------- + table : TaggedTable + A table without the columns that contain missing values. + + Raises + ------ + ColumnIsTargetError + If any of the columns to be removed is the target column. + IllegalSchemaModificationError + If the columns to remove contain all the feature columns. + """ + table = super().remove_columns_with_missing_values() + if self.target.name not in table.column_names: + raise ColumnIsTargetError(self.target.name) + if len(set(self.features.column_names).intersection(set(table.column_names))) == 0: + raise IllegalSchemaModificationError("You cannot remove every feature column.") + return TaggedTable._from_table( + table, + self.target.name, + feature_names=sorted( + set(self.features.column_names).intersection(set(table.column_names)), + key={val: ix for ix, val in enumerate(self.features.column_names)}.__getitem__, + ), + ) + + def remove_columns_with_non_numerical_values(self) -> TaggedTable: + """ + Remove every column that contains non-numerical values. + + This table is not modified. + + Returns + ------- + table : TaggedTable + A table without the columns that contain non-numerical values. + + Raises + ------ + ColumnIsTargetError + If any of the columns to be removed is the target column. + IllegalSchemaModificationError + If the columns to remove contain all the feature columns. + """ + table = super().remove_columns_with_non_numerical_values() + if self.target.name not in table.column_names: + raise ColumnIsTargetError(self.target.name) + if len(set(self.features.column_names).intersection(set(table.column_names))) == 0: + raise IllegalSchemaModificationError("You cannot remove every feature column.") + return TaggedTable._from_table( + table, + self.target.name, + feature_names=sorted( + set(self.features.column_names).intersection(set(table.column_names)), + key={val: ix for ix, val in enumerate(self.features.column_names)}.__getitem__, + ), + ) + + def remove_duplicate_rows(self) -> TaggedTable: + """ + Remove all row duplicates. + + This table is not modified. + + Returns + ------- + result : TaggedTable + The table with the duplicate rows removed. + """ + return TaggedTable._from_table( + super().remove_duplicate_rows(), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def remove_rows_with_missing_values(self) -> TaggedTable: + """ + Return a table without the rows that contain missing values. + + This table is not modified. + + Returns + ------- + table : TaggedTable + A table without the rows that contain missing values. + """ + return TaggedTable._from_table( + super().remove_rows_with_missing_values(), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def remove_rows_with_outliers(self) -> TaggedTable: + """ + Remove all rows from the table that contain at least one outlier. + + We define an outlier as a value that has a distance of more than 3 standard deviations from the column mean. + Missing values are not considered outliers. They are also ignored during the calculation of the standard + deviation. + + This table is not modified. + + Returns + ------- + new_table : TaggedTable + A new table without rows containing outliers. + """ + return TaggedTable._from_table( + super().remove_rows_with_outliers(), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def rename_column(self, old_name: str, new_name: str) -> TaggedTable: + """ + Rename a single column. + + This table is not modified. + + Parameters + ---------- + old_name : str + The old name of the target column + new_name : str + The new name of the target column + + Returns + ------- + table : TaggedTable + The Table with the renamed column. + + Raises + ------ + UnknownColumnNameError + If the specified old target column name does not exist. + DuplicateColumnNameError + If the specified new target column name already exists. + """ + return TaggedTable._from_table( + super().rename_column(old_name, new_name), + target_name=new_name if self.target.name == old_name else self.target.name, + feature_names=( + self.features.column_names + if old_name not in self.features.column_names + else [ + column_name if column_name != old_name else new_name for column_name in self.features.column_names + ] + ), + ) + + def replace_column(self, old_column_name: str, new_columns: list[Column]) -> TaggedTable: + """ + Replace the specified old column by a list of new columns. + + The order of columns is kept. + + If the column to be replaced is the target column, it must be replaced by exactly one column. That column becomes the new target column. + If the column to be replaced is a feature column, the new columns that replace it all become feature columns. + This table is not modified. + + Parameters + ---------- + old_column_name : str + The name of the column to be replaced. + new_columns : list[Column] + The new columns replacing the old column. + + Returns + ------- + result : TaggedTable + A table with the old column replaced by the new column. + + Raises + ------ + UnknownColumnNameError + If the old column does not exist. + DuplicateColumnNameError + If the new column already exists and the existing column is not affected by the replacement. + ColumnSizeError + If the size of the column does not match the amount of rows. + IllegalSchemaModificationError + If the target column would be removed or replaced by more than one column. + """ + if old_column_name == self.target.name: + if len(new_columns) != 1: + raise IllegalSchemaModificationError( + f'Target column "{self.target.name}" can only be replaced by exactly one new column.', + ) + else: + return TaggedTable._from_table( + super().replace_column(old_column_name, new_columns), + target_name=new_columns[0].name, + feature_names=self.features.column_names, + ) + else: + return TaggedTable._from_table( + super().replace_column(old_column_name, new_columns), + target_name=self.target.name, + feature_names=( + self.features.column_names + if old_column_name not in self.features.column_names + else self.features.column_names[: self.features.column_names.index(old_column_name)] + + [col.name for col in new_columns] + + self.features.column_names[self.features.column_names.index(old_column_name) + 1 :] + ), + ) + + def shuffle_rows(self) -> TaggedTable: + """ + Shuffle the table randomly. + + This table is not modified. + + Returns + ------- + result : TaggedTable + The shuffled Table. + """ + return TaggedTable._from_table( + super().shuffle_rows(), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def slice_rows( + self, + start: int | None = None, + end: int | None = None, + step: int = 1, + ) -> TaggedTable: + """ + Slice a part of the table into a new table. + + This table is not modified. + + Parameters + ---------- + start : int + The first index of the range to be copied into a new table, None by default. + end : int + The last index of the range to be copied into a new table, None by default. + step : int + The step size used to iterate through the table, 1 by default. + + Returns + ------- + result : TaggedTable + The resulting table. + + Raises + ------ + IndexOutOfBoundsError + If the index is out of bounds. + """ + return TaggedTable._from_table( + super().slice_rows(start, end, step), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def sort_columns( + self, + comparator: Callable[[Column, Column], int] = lambda col1, col2: (col1.name > col2.name) + - (col1.name < col2.name), + ) -> TaggedTable: + """ + Sort the columns of a `TaggedTable` with the given comparator and return a new `TaggedTable`. + + The comparator is a function that takes two columns `col1` and `col2` and + returns an integer: + + * If the function returns a negative number, `col1` will be ordered before `col2`. + * If the function returns a positive number, `col1` will be ordered after `col2`. + * If the function returns 0, the original order of `col1` and `col2` will be kept. + + If no comparator is given, the columns will be sorted alphabetically by their name. + + This table is not modified. + + Parameters + ---------- + comparator : Callable[[Column, Column], int] + The function used to compare two columns. + + Returns + ------- + new_table : TaggedTable + A new table with sorted columns. + """ + sorted_table = super().sort_columns(comparator) + return TaggedTable._from_table( + sorted_table, + target_name=self.target.name, + feature_names=sorted( + set(sorted_table.column_names).intersection(self.features.column_names), + key={val: ix for ix, val in enumerate(sorted_table.column_names)}.__getitem__, + ), + ) + + def sort_rows(self, comparator: Callable[[Row, Row], int]) -> TaggedTable: + """ + Sort the rows of a `TaggedTable` with the given comparator and return a new `TaggedTable`. + + The comparator is a function that takes two rows `row1` and `row2` and + returns an integer: + + * If the function returns a negative number, `row1` will be ordered before `row2`. + * If the function returns a positive number, `row1` will be ordered after `row2`. + * If the function returns 0, the original order of `row1` and `row2` will be kept. + + This table is not modified. + + Parameters + ---------- + comparator : Callable[[Row, Row], int] + The function used to compare two rows. + + Returns + ------- + new_table : TaggedTable + A new table with sorted rows. + """ + return TaggedTable._from_table( + super().sort_rows(comparator), + target_name=self.target.name, + feature_names=self.features.column_names, + ) + + def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> TaggedTable: + """ + Transform provided column by calling provided transformer. + + This table is not modified. + + Returns + ------- + result : TaggedTable + The table with the transformed column. + + Raises + ------ + UnknownColumnNameError + If the column does not exist. + """ + return TaggedTable._from_table( + super().transform_column(name, transformer), + target_name=self.target.name, + feature_names=self.features.column_names, + ) diff --git a/src/safeds/data/tabular/transformation/_one_hot_encoder.py b/src/safeds/data/tabular/transformation/_one_hot_encoder.py index 1070478fe..9a198ea0e 100644 --- a/src/safeds/data/tabular/transformation/_one_hot_encoder.py +++ b/src/safeds/data/tabular/transformation/_one_hot_encoder.py @@ -103,11 +103,17 @@ def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder: if table.number_of_rows == 0: raise ValueError("The OneHotEncoder cannot be fitted because the table contains 0 rows") - if table.keep_only_columns(column_names).remove_columns_with_non_numerical_values().number_of_columns > 0: + if ( + table._as_table() + .keep_only_columns(column_names) + .remove_columns_with_non_numerical_values() + .number_of_columns + > 0 + ): warnings.warn( ( "The columns" - f" {table.keep_only_columns(column_names).remove_columns_with_non_numerical_values().column_names} contain" + f" {table._as_table().keep_only_columns(column_names).remove_columns_with_non_numerical_values().column_names} contain" " numerical data. The OneHotEncoder is designed to encode non-numerical values into numerical" " values" ), @@ -271,7 +277,7 @@ def inverse_transform(self, transformed_table: Table) -> Table: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) - if transformed_table.keep_only_columns( + if transformed_table._as_table().keep_only_columns( _transformed_column_names, ).remove_columns_with_non_numerical_values().number_of_columns < len(_transformed_column_names): raise NonNumericColumnError( diff --git a/src/safeds/exceptions/__init__.py b/src/safeds/exceptions/__init__.py index 4b5fcddfd..cf9b10066 100644 --- a/src/safeds/exceptions/__init__.py +++ b/src/safeds/exceptions/__init__.py @@ -1,9 +1,11 @@ """Custom exceptions that can be raised by Safe-DS.""" from safeds.exceptions._data import ( + ColumnIsTargetError, ColumnLengthMismatchError, ColumnSizeError, DuplicateColumnNameError, + IllegalSchemaModificationError, IndexOutOfBoundsError, MissingValuesColumnError, NonNumericColumnError, @@ -25,9 +27,11 @@ __all__ = [ # Data exceptions + "ColumnIsTargetError", "ColumnLengthMismatchError", "ColumnSizeError", "DuplicateColumnNameError", + "IllegalSchemaModificationError", "IndexOutOfBoundsError", "MissingValuesColumnError", "NonNumericColumnError", diff --git a/src/safeds/exceptions/_data.py b/src/safeds/exceptions/_data.py index 845055daf..f11c7a334 100644 --- a/src/safeds/exceptions/_data.py +++ b/src/safeds/exceptions/_data.py @@ -138,3 +138,17 @@ def __init__(self, file: str | Path, file_extension: str | list[str]) -> None: f" {file_extension}" ), ) + + +class IllegalSchemaModificationError(Exception): + """Exception raised when modifying a schema in a way that is inconsistent with the subclass's requirements.""" + + def __init__(self, msg: str) -> None: + super().__init__(f"Illegal schema modification: {msg}") + + +class ColumnIsTargetError(IllegalSchemaModificationError): + """Exception raised in overriden methods of the Table class when removing tagged Columns from a TaggedTable.""" + + def __init__(self, column_name: str) -> None: + super().__init__(f'Column "{column_name}" is the target column and cannot be removed.') diff --git a/src/safeds/ml/classical/_util_sklearn.py b/src/safeds/ml/classical/_util_sklearn.py index 19c13e5b3..8cd2c0229 100644 --- a/src/safeds/ml/classical/_util_sklearn.py +++ b/src/safeds/ml/classical/_util_sklearn.py @@ -126,6 +126,8 @@ def predict(model: Any, dataset: Table, feature_names: list[str] | None, target_ missing_feature_names = [feature_name for feature_name in feature_names if not dataset.has_column(feature_name)] if missing_feature_names: raise DatasetMissesFeaturesError(missing_feature_names) + if isinstance(dataset, TaggedTable): + dataset = dataset.features # Cast to Table type, so Python will call the right methods... if dataset.number_of_rows == 0: raise DatasetMissesDataError diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index f7b114a0b..cbd8637a1 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1,4 +1,4 @@ -from ._assertions import assert_that_tables_are_close +from ._assertions import assert_that_tables_are_close, assert_that_tagged_tables_are_equal from ._resources import resolve_resource_path -__all__ = ["assert_that_tables_are_close", "resolve_resource_path"] +__all__ = ["assert_that_tables_are_close", "assert_that_tagged_tables_are_equal", "resolve_resource_path"] diff --git a/tests/helpers/_assertions.py b/tests/helpers/_assertions.py index ecd93c1b4..2dcbcd1e0 100644 --- a/tests/helpers/_assertions.py +++ b/tests/helpers/_assertions.py @@ -1,5 +1,5 @@ import pytest -from safeds.data.tabular.containers import Table +from safeds.data.tabular.containers import Table, TaggedTable def assert_that_tables_are_close(table1: Table, table2: Table) -> None: @@ -22,3 +22,20 @@ def assert_that_tables_are_close(table1: Table, table2: Table) -> None: entry_1 = table1.get_column(column_name).get_value(i) entry_2 = table2.get_column(column_name).get_value(i) assert entry_1 == pytest.approx(entry_2) + + +def assert_that_tagged_tables_are_equal(table1: TaggedTable, table2: TaggedTable) -> None: + """ + Assert that two tagged tables are equal. + + Parameters + ---------- + table1: TaggedTable + The first table. + table2: TaggedTable + The table to compare the first table to. + """ + assert table1.schema == table2.schema + assert table1.features == table2.features + assert table1.target == table2.target + assert table1 == table2 diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/__init__.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_column.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_column.py new file mode 100644 index 000000000..22bd87605 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_column.py @@ -0,0 +1,34 @@ +import pytest +from safeds.data.tabular.containers import Column, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("tagged_table", "column", "expected_tagged_table"), + [ + ( + TaggedTable( + { + "feature_1": [0, 1, 2], + "target": [3, 4, 5], + }, + "target", + None, + ), + Column("other", [6, 7, 8]), + TaggedTable( + { + "feature_1": [0, 1, 2], + "target": [3, 4, 5], + "other": [6, 7, 8], + }, + "target", + ["feature_1"], + ), + ), + ], + ids=["add_column_as_non_feature"], +) +def test_should_add_column(tagged_table: TaggedTable, column: Column, expected_tagged_table: TaggedTable) -> None: + assert_that_tagged_tables_are_equal(tagged_table.add_column(column), expected_tagged_table) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_column_as_feature.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_column_as_feature.py new file mode 100644 index 000000000..325df54db --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_column_as_feature.py @@ -0,0 +1,37 @@ +import pytest +from safeds.data.tabular.containers import Column, Table, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("tagged_table", "column", "tagged_table_with_new_column"), + [ + ( + Table({"f1": [1, 2], "target": [2, 3]}).tag_columns(target_name="target", feature_names=["f1"]), + Column("f2", [4, 5]), + Table({"f1": [1, 2], "target": [2, 3], "f2": [4, 5]}).tag_columns( + target_name="target", + feature_names=["f1", "f2"], + ), + ), + ( + Table({"f1": [1, 2], "target": [2, 3], "other": [0, -1]}).tag_columns( + target_name="target", + feature_names=["f1"], + ), + Column("f2", [4, 5]), + Table({"f1": [1, 2], "target": [2, 3], "other": [0, -1], "f2": [4, 5]}).tag_columns( + target_name="target", + feature_names=["f1", "f2"], + ), + ), + ], + ids=["new column as feature", "table contains a non feature/target column"], +) +def test_add_column_as_feature( + tagged_table: TaggedTable, + column: Column, + tagged_table_with_new_column: TaggedTable, +) -> None: + assert_that_tagged_tables_are_equal(tagged_table.add_column_as_feature(column), tagged_table_with_new_column) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_columns.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_columns.py new file mode 100644 index 000000000..8773e3695 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_columns.py @@ -0,0 +1,39 @@ +import pytest +from safeds.data.tabular.containers import Column, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("tagged_table", "columns", "expected_tagged_table"), + [ + ( + TaggedTable( + { + "feature_1": [0, 1, 2], + "target": [3, 4, 5], + }, + "target", + None, + ), + [Column("other", [6, 7, 8]), Column("other2", [9, 6, 3])], + TaggedTable( + { + "feature_1": [0, 1, 2], + "target": [3, 4, 5], + "other": [6, 7, 8], + "other2": [9, 6, 3], + }, + "target", + ["feature_1"], + ), + ), + ], + ids=["add_columns_as_non_feature"], +) +def test_should_add_columns( + tagged_table: TaggedTable, + columns: list[Column], + expected_tagged_table: TaggedTable, +) -> None: + assert_that_tagged_tables_are_equal(tagged_table.add_columns(columns), expected_tagged_table) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_columns_as_features.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_columns_as_features.py new file mode 100644 index 000000000..f1e7716b8 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_columns_as_features.py @@ -0,0 +1,45 @@ +import pytest +from safeds.data.tabular.containers import Column, Table, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("tagged_table", "columns", "tagged_table_with_new_columns"), + [ + ( + Table({"f1": [1, 2], "target": [2, 3]}).tag_columns(target_name="target", feature_names=["f1"]), + [Column("f2", [4, 5]), Column("f3", [6, 7])], + Table({"f1": [1, 2], "target": [2, 3], "f2": [4, 5], "f3": [6, 7]}).tag_columns( + target_name="target", + feature_names=["f1", "f2", "f3"], + ), + ), + ( + Table({"f1": [1, 2], "target": [2, 3]}).tag_columns(target_name="target", feature_names=["f1"]), + Table.from_columns([Column("f2", [4, 5]), Column("f3", [6, 7])]), + Table({"f1": [1, 2], "target": [2, 3], "f2": [4, 5], "f3": [6, 7]}).tag_columns( + target_name="target", + feature_names=["f1", "f2", "f3"], + ), + ), + ( + Table({"f1": [1, 2], "target": [2, 3], "other": [0, -1]}).tag_columns( + target_name="target", + feature_names=["f1"], + ), + Table.from_columns([Column("f2", [4, 5]), Column("f3", [6, 7])]), + Table({"f1": [1, 2], "target": [2, 3], "other": [0, -1], "f2": [4, 5], "f3": [6, 7]}).tag_columns( + target_name="target", + feature_names=["f1", "f2", "f3"], + ), + ), + ], + ids=["new columns as feature", "table added as features", "table contains a non feature/target column"], +) +def test_add_columns_as_features( + tagged_table: TaggedTable, + columns: list[Column] | Table, + tagged_table_with_new_columns: TaggedTable, +) -> None: + assert_that_tagged_tables_are_equal(tagged_table.add_columns_as_features(columns), tagged_table_with_new_columns) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_row.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_row.py new file mode 100644 index 000000000..2badeec11 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_row.py @@ -0,0 +1,36 @@ +import pytest +from safeds.data.tabular.containers import Row, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "row", "expected"), + [ + ( + TaggedTable( + { + "feature": [0, 1], + "target": [3, 4], + }, + "target", + ), + Row( + { + "feature": 2, + "target": 5, + }, + ), + TaggedTable( + { + "feature": [0, 1, 2], + "target": [3, 4, 5], + }, + "target", + ), + ), + ], + ids=["add_row"], +) +def test_should_add_row(table: TaggedTable, row: Row, expected: TaggedTable) -> None: + assert_that_tagged_tables_are_equal(table.add_row(row), expected) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_rows.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_rows.py new file mode 100644 index 000000000..da8c37a5a --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_add_rows.py @@ -0,0 +1,39 @@ +import pytest +from safeds.data.tabular.containers import Row, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "rows", "expected"), + [ + ( + TaggedTable( + { + "feature": [0, 1], + "target": [4, 5], + }, + "target", + ), + [ + Row( + { + "feature": 2, + "target": 6, + }, + ), + Row({"feature": 3, "target": 7}), + ], + TaggedTable( + { + "feature": [0, 1, 2, 3], + "target": [4, 5, 6, 7], + }, + "target", + ), + ), + ], + ids=["add_rows"], +) +def test_should_add_rows(table: TaggedTable, rows: list[Row], expected: TaggedTable) -> None: + assert_that_tagged_tables_are_equal(table.add_rows(rows), expected) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_as_table.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_as_table.py new file mode 100644 index 000000000..6c9480671 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_as_table.py @@ -0,0 +1,52 @@ +import pytest +from safeds.data.tabular.containers import Table, TaggedTable + + +@pytest.mark.parametrize( + ("tagged_table", "expected"), + [ + ( + TaggedTable( + { + "feature_1": [3, 9, 6], + "feature_2": [6, 12, 9], + "target": [1, 3, 2], + }, + "target", + ["feature_1", "feature_2"], + ), + Table( + { + "feature_1": [3, 9, 6], + "feature_2": [6, 12, 9], + "target": [1, 3, 2], + }, + ), + ), + ( + TaggedTable( + { + "feature_1": [3, 9, 6], + "feature_2": [6, 12, 9], + "other": [3, 9, 12], + "target": [1, 3, 2], + }, + "target", + ["feature_1", "feature_2"], + ), + Table( + { + "feature_1": [3, 9, 6], + "feature_2": [6, 12, 9], + "other": [3, 9, 12], + "target": [1, 3, 2], + }, + ), + ), + ], + ids=["normal", "table_with_column_as_non_feature"], +) +def test_should_return_table(tagged_table: TaggedTable, expected: Table) -> None: + table = tagged_table._as_table() + assert table.schema == expected.schema + assert table == expected diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_copy.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_copy.py new file mode 100644 index 000000000..8819aff05 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_copy.py @@ -0,0 +1,22 @@ +import pytest +from safeds.data.tabular.containers import TaggedTable + + +@pytest.mark.parametrize( + "tagged_table", + [ + TaggedTable({"a": [], "b": []}, target_name="b", feature_names=["a"]), + TaggedTable({"a": ["a", 3, 0.1], "b": [True, False, None]}, target_name="b", feature_names=["a"]), + TaggedTable( + {"a": ["a", 3, 0.1], "b": [True, False, None], "c": ["a", "b", "c"]}, + target_name="b", + feature_names=["a"], + ), + TaggedTable({"a": [], "b": [], "c": []}, target_name="b", feature_names=["a"]), + ], + ids=["empty-rows", "normal", "column_as_non_feature", "column_as_non_feature_with_empty_rows"], +) +def test_should_copy_tagged_table(tagged_table: TaggedTable) -> None: + copied = tagged_table._copy() + assert copied == tagged_table + assert copied is not tagged_table diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_features.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_features.py new file mode 100644 index 000000000..2af13d37c --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_features.py @@ -0,0 +1,37 @@ +import pytest +from safeds.data.tabular.containers import Table, TaggedTable + + +@pytest.mark.parametrize( + ("tagged_table", "features"), + [ + ( + TaggedTable( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + target_name="T", + ), + Table({"A": [1, 4], "B": [2, 5], "C": [3, 6]}), + ), + ( + TaggedTable( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + target_name="T", + feature_names=["A", "C"], + ), + Table({"A": [1, 4], "C": [3, 6]}), + ), + ], + ids=["only_target_and_features", "target_features_and_other"], +) +def test_should_return_features(tagged_table: TaggedTable, features: Table) -> None: + assert tagged_table.features == features diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_filter_rows.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_filter_rows.py new file mode 100644 index 000000000..98ba1bba6 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_filter_rows.py @@ -0,0 +1,105 @@ +from collections.abc import Callable + +import pytest +from safeds.data.tabular.containers import Row, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "expected", "query"), + [ + ( + TaggedTable( + { + "feature_1": [3, 9, 6], + "feature_2": [6, 12, 9], + "target": [1, 3, 2], + }, + "target", + ), + TaggedTable( + { + "feature_1": [3, 6], + "feature_2": [6, 9], + "target": [1, 2], + }, + "target", + ), + lambda row: all(row.get_value(col) < 10 for col in row.column_names), + ), + ( + TaggedTable( + { + "feature_1": [3, 9, 6, 2], + "feature_2": [6, 12, 9, 3], + "other": [1, 2, 3, 10], + "target": [1, 3, 2, 4], + }, + "target", + ["feature_1", "feature_2"], + ), + TaggedTable( + { + "feature_1": [3, 6], + "feature_2": [6, 9], + "other": [1, 3], + "target": [1, 2], + }, + "target", + ["feature_1", "feature_2"], + ), + lambda row: all(row.get_value(col) < 10 for col in row.column_names), + ), + ( + TaggedTable( + { + "feature_1": [3, 9, 6], + "feature_2": [6, 12, 9], + "target": [1, 3, 2], + }, + "target", + ), + TaggedTable( + { + "feature_1": [3, 9, 6], + "feature_2": [6, 12, 9], + "target": [1, 3, 2], + }, + "target", + ), + lambda row: all(row.get_value(col) < 20 for col in row.column_names), + ), + ( + TaggedTable( + { + "feature_1": [3, 9, 6, 2], + "feature_2": [6, 12, 9, 3], + "other": [1, 2, 3, 10], + "target": [1, 3, 2, 4], + }, + "target", + ["feature_1", "feature_2"], + ), + TaggedTable( + { + "feature_1": [3, 9, 6, 2], + "feature_2": [6, 12, 9, 3], + "other": [1, 2, 3, 10], + "target": [1, 3, 2, 4], + }, + "target", + ["feature_1", "feature_2"], + ), + lambda row: all(row.get_value(col) < 20 for col in row.column_names), + ), + ], + ids=[ + "remove_rows_with_values_greater_9", + "remove_rows_with_values_greater_9_non_feature_columns", + "remove_no_rows", + "remove_no_rows_non_feature_columns", + ], +) +def test_should_filter_rows(table: TaggedTable, expected: TaggedTable, query: Callable[[Row], bool]) -> None: + assert_that_tagged_tables_are_equal(table.filter_rows(query), expected) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_from_table.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_from_table.py new file mode 100644 index 000000000..fc99ff58d --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_from_table.py @@ -0,0 +1,145 @@ +import pytest +from safeds.data.tabular.containers import Table, TaggedTable +from safeds.exceptions import UnknownColumnNameError + + +@pytest.mark.parametrize( + ("table", "target_name", "feature_names", "error", "error_msg"), + [ + ( + Table( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + ), + "T", + ["A", "B", "C", "D", "E"], + UnknownColumnNameError, + r"Could not find column\(s\) 'D, E'", + ), + ( + Table( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + ), + "D", + ["A", "B", "C"], + UnknownColumnNameError, + r"Could not find column\(s\) 'D'", + ), + ( + Table( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + ), + "A", + ["A", "B", "C"], + ValueError, + r"Column 'A' cannot be both feature and target.", + ), + ( + Table( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + ), + "A", + [], + ValueError, + r"At least one feature column must be specified.", + ), + ( + Table( + { + "A": [1, 4], + }, + ), + "A", + None, + ValueError, + r"At least one feature column must be specified.", + ), + ], + ids=[ + "feature_does_not_exist", + "target_does_not_exist", + "target_and_feature_overlap", + "features_are_empty-explicitly", + "features_are_empty_implicitly", + ], +) +def test_should_raise_error( + table: Table, + target_name: str, + feature_names: list[str] | None, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises(error, match=error_msg): + TaggedTable._from_table(table, target_name=target_name, feature_names=feature_names) + + +@pytest.mark.parametrize( + ("table", "target_name", "feature_names"), + [ + ( + Table( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + ), + "T", + ["A", "B", "C"], + ), + ( + Table( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + ), + "T", + ["A", "C"], + ), + ( + Table( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + ), + "T", + None, + ), + ], + ids=["create_tagged_table", "tagged_table_not_all_columns_are_features", "tagged_table_with_feature_names_as_None"], +) +def test_should_create_a_tagged_table(table: Table, target_name: str, feature_names: list[str] | None) -> None: + tagged_table = TaggedTable._from_table(table, target_name=target_name, feature_names=feature_names) + feature_names = feature_names if feature_names is not None else table.remove_columns([target_name]).column_names + assert isinstance(tagged_table, TaggedTable) + assert tagged_table._features.column_names == feature_names + assert tagged_table._target.name == target_name + assert tagged_table._features == table.keep_only_columns(feature_names) + assert tagged_table._target == table.get_column(target_name) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_init.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_init.py new file mode 100644 index 000000000..723dfc990 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_init.py @@ -0,0 +1,135 @@ +import pytest +from safeds.data.tabular.containers import Table, TaggedTable +from safeds.exceptions import UnknownColumnNameError + + +@pytest.mark.parametrize( + ("data", "target_name", "feature_names", "error", "error_msg"), + [ + ( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + "T", + ["A", "B", "C", "D", "E"], + UnknownColumnNameError, + r"Could not find column\(s\) 'D, E'", + ), + ( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + "D", + ["A", "B", "C"], + UnknownColumnNameError, + r"Could not find column\(s\) 'D'", + ), + ( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + "A", + ["A", "B", "C"], + ValueError, + r"Column 'A' cannot be both feature and target.", + ), + ( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + "D", + [], + ValueError, + r"At least one feature column must be specified.", + ), + ( + { + "A": [1, 4], + }, + "A", + None, + ValueError, + r"At least one feature column must be specified.", + ), + ], + ids=[ + "feature_does_not_exist", + "target_does_not_exist", + "target_and_feature_overlap", + "features_are_empty-explicitly", + "features_are_empty_implicitly", + ], +) +def test_should_raise_error( + data: dict[str, list[int]], + target_name: str, + feature_names: list[str] | None, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises(error, match=error_msg): + TaggedTable(data, target_name=target_name, feature_names=feature_names) + + +@pytest.mark.parametrize( + ("data", "target_name", "feature_names"), + [ + ( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + "T", + ["A", "B", "C"], + ), + ( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + "T", + ["A", "C"], + ), + ( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + "T", + None, + ), + ], + ids=["create_tagged_table", "tagged_table_not_all_columns_are_features", "tagged_table_with_feature_names_as_None"], +) +def test_should_create_a_tagged_table( + data: dict[str, list[int]], + target_name: str, + feature_names: list[str] | None, +) -> None: + tagged_table = TaggedTable(data, target_name=target_name, feature_names=feature_names) + if feature_names is None: + feature_names = list(data.keys()) + feature_names.remove(target_name) + assert isinstance(tagged_table, TaggedTable) + assert tagged_table._features.column_names == feature_names + assert tagged_table._target.name == target_name + assert tagged_table._features == Table(data).keep_only_columns(feature_names) + assert tagged_table._target == Table(data).get_column(target_name) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_keep_only_columns.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_keep_only_columns.py new file mode 100644 index 000000000..975dca9e9 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_keep_only_columns.py @@ -0,0 +1,131 @@ +import pytest +from safeds.data.tabular.containers import Table, TaggedTable +from safeds.exceptions import IllegalSchemaModificationError + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "column_names", "expected"), + [ + ( + TaggedTable._from_table( + Table( + { + "feat1": [1, 2, 3], + "feat2": [4, 5, 6], + "target": [7, 8, 9], + }, + ), + "target", + ), + ["feat1", "target"], + TaggedTable._from_table( + Table( + { + "feat1": [1, 2, 3], + "target": [7, 8, 9], + }, + ), + "target", + ), + ), + ( + TaggedTable._from_table( + Table( + { + "feat1": [1, 2, 3], + "feat2": [4, 5, 6], + "other": [3, 4, 5], + "target": [7, 8, 9], + }, + ), + "target", + ), + ["feat1", "other", "target"], + TaggedTable._from_table( + Table( + { + "feat1": [1, 2, 3], + "other": [3, 4, 5], + "target": [7, 8, 9], + }, + ), + "target", + ), + ), + ( + TaggedTable._from_table( + Table( + { + "feat1": [1, 2, 3], + "feat2": [4, 5, 6], + "other": [3, 4, 5], + "target": [7, 8, 9], + }, + ), + "target", + ), + ["feat1", "target"], + TaggedTable._from_table( + Table( + { + "feat1": [1, 2, 3], + "target": [7, 8, 9], + }, + ), + "target", + ), + ), + ], + ids=["keep_feature_and_target_column", "keep_non_feature_column", "don't_keep_non_feature_column"], +) +def test_should_return_table(table: TaggedTable, column_names: list[str], expected: TaggedTable) -> None: + new_table = table.keep_only_columns(column_names) + assert_that_tagged_tables_are_equal(new_table, expected) + + +@pytest.mark.parametrize( + ("table", "column_names", "error_msg"), + [ + ( + TaggedTable._from_table( + Table( + { + "feat1": [1, 2, 3], + "feat2": [4, 5, 6], + "other": [3, 5, 7], + "target": [7, 8, 9], + }, + ), + "target", + ["feat1", "feat2"], + ), + ["feat1", "feat2"], + r"Illegal schema modification: Must keep the target column.", + ), + ( + TaggedTable._from_table( + Table( + { + "feat1": [1, 2, 3], + "feat2": [4, 5, 6], + "other": [3, 5, 7], + "target": [7, 8, 9], + }, + ), + "target", + ["feat1", "feat2"], + ), + ["target", "other"], + r"Illegal schema modification: Must keep at least one feature column.", + ), + ], + ids=["table_remove_target", "table_remove_all_features"], +) +def test_should_raise_illegal_schema_modification(table: TaggedTable, column_names: list[str], error_msg: str) -> None: + with pytest.raises( + IllegalSchemaModificationError, + match=error_msg, + ): + table.keep_only_columns(column_names) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns.py new file mode 100644 index 000000000..9e8435885 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns.py @@ -0,0 +1,181 @@ +import pytest +from safeds.data.tabular.containers import Table, TaggedTable +from safeds.exceptions import ColumnIsTargetError, IllegalSchemaModificationError + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "columns", "expected"), + [ + ( + TaggedTable._from_table( + Table( + { + "feat_1": [1, 2, 3], + "feat_2": [4, 5, 6], + "non_feat_1": [2, 4, 6], + "non_feat_2": [3, 6, 9], + "target": [7, 8, 9], + }, + ), + "target", + ["feat_1", "feat_2"], + ), + ["feat_2"], + TaggedTable._from_table( + Table({"feat_1": [1, 2, 3], "non_feat_1": [2, 4, 6], "non_feat_2": [3, 6, 9], "target": [7, 8, 9]}), + "target", + ["feat_1"], + ), + ), + ( + TaggedTable._from_table( + Table( + { + "feat_1": [1, 2, 3], + "feat_2": [4, 5, 6], + "non_feat_1": [2, 4, 6], + "non_feat_2": [3, 6, 9], + "target": [7, 8, 9], + }, + ), + "target", + ["feat_1", "feat_2"], + ), + ["non_feat_2"], + TaggedTable._from_table( + Table({"feat_1": [1, 2, 3], "feat_2": [4, 5, 6], "non_feat_1": [2, 4, 6], "target": [7, 8, 9]}), + "target", + ["feat_1", "feat_2"], + ), + ), + ( + TaggedTable._from_table( + Table( + { + "feat_1": [1, 2, 3], + "feat_2": [4, 5, 6], + "non_feat_1": [2, 4, 6], + "non_feat_2": [3, 6, 9], + "target": [7, 8, 9], + }, + ), + "target", + ["feat_1", "feat_2"], + ), + ["non_feat_1", "non_feat_2"], + TaggedTable._from_table( + Table({"feat_1": [1, 2, 3], "feat_2": [4, 5, 6], "target": [7, 8, 9]}), + "target", + ["feat_1", "feat_2"], + ), + ), + ( + TaggedTable._from_table( + Table( + { + "feat_1": [1, 2, 3], + "feat_2": [4, 5, 6], + "non_feat_1": [2, 4, 6], + "non_feat_2": [3, 6, 9], + "target": [7, 8, 9], + }, + ), + "target", + ["feat_1", "feat_2"], + ), + ["feat_2", "non_feat_2"], + TaggedTable._from_table( + Table({"feat_1": [1, 2, 3], "non_feat_1": [2, 4, 6], "target": [7, 8, 9]}), + "target", + ["feat_1"], + ), + ), + ( + TaggedTable._from_table( + Table( + { + "feat_1": [1, 2, 3], + "non_feat_1": [2, 4, 6], + "target": [7, 8, 9], + }, + ), + "target", + ["feat_1"], + ), + [], + TaggedTable._from_table( + Table({"feat_1": [1, 2, 3], "non_feat_1": [2, 4, 6], "target": [7, 8, 9]}), + "target", + ["feat_1"], + ), + ), + ], + ids=[ + "remove_feature", + "remove_non_feature", + "remove_all_non_features", + "remove_some_feat_and_some_non_feat", + "remove_nothing", + ], +) +def test_should_remove_columns(table: TaggedTable, columns: list[str], expected: TaggedTable) -> None: + new_table = table.remove_columns(columns) + assert_that_tagged_tables_are_equal(new_table, expected) + + +@pytest.mark.parametrize( + ("table", "columns", "error", "error_msg"), + [ + ( + TaggedTable._from_table( + Table({"feat": [1, 2, 3], "non_feat": [1, 2, 3], "target": [4, 5, 6]}), + "target", + ["feat"], + ), + ["target"], + ColumnIsTargetError, + r'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable._from_table( + Table({"feat": [1, 2, 3], "non_feat": [1, 2, 3], "target": [4, 5, 6]}), + "target", + ["feat"], + ), + ["non_feat", "target"], + ColumnIsTargetError, + r'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable._from_table( + Table({"feat": [1, 2, 3], "non_feat": [1, 2, 3], "target": [4, 5, 6]}), + "target", + ["feat"], + ), + ["feat"], + IllegalSchemaModificationError, + r"Illegal schema modification: You cannot remove every feature column.", + ), + ( + TaggedTable._from_table( + Table({"feat": [1, 2, 3], "non_feat": [1, 2, 3], "target": [4, 5, 6]}), + "target", + ["feat"], + ), + ["feat", "non_feat"], + IllegalSchemaModificationError, + r"Illegal schema modification: You cannot remove every feature column.", + ), + ], + ids=["remove_only_target", "remove_non_feat_and_target", "remove_all_features", "remove_non_feat_and_all_features"], +) +def test_should_raise_in_remove_columns( + table: TaggedTable, + columns: list[str], + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises(error, match=error_msg): + table.remove_columns(columns) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns_with_missing_values.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns_with_missing_values.py new file mode 100644 index 000000000..b442fe0d0 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns_with_missing_values.py @@ -0,0 +1,181 @@ +import pytest +from safeds.data.tabular.containers import TaggedTable +from safeds.exceptions import ColumnIsTargetError, IllegalSchemaModificationError + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "expected"), + [ + ( + TaggedTable( + { + "feature_complete": [0, 1, 2], + "feature_incomplete": [3, None, 5], + "non_feature_complete": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_complete", "feature_incomplete"], + ), + TaggedTable( + { + "feature_complete": [0, 1, 2], + "non_feature_complete": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_complete"], + ), + ), + ( + TaggedTable( + { + "feature_complete": [0, 1, 2], + "non_feature_complete": [7, 8, 9], + "non_feature_incomplete": [3, None, 5], + "target": [3, 4, 5], + }, + "target", + ["feature_complete"], + ), + TaggedTable( + { + "feature_complete": [0, 1, 2], + "non_feature_complete": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_complete"], + ), + ), + ( + TaggedTable( + { + "feature_complete": [0, 1, 2], + "non_feature_complete": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_complete"], + ), + TaggedTable( + { + "feature_complete": [0, 1, 2], + "non_feature_complete": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_complete"], + ), + ), + ], + ids=["incomplete_feature", "incomplete_non_feature", "all_complete"], +) +def test_should_remove_columns_with_non_numerical_values(table: TaggedTable, expected: TaggedTable) -> None: + new_table = table.remove_columns_with_missing_values() + assert_that_tagged_tables_are_equal(new_table, expected) + + +@pytest.mark.parametrize( + ("table", "error", "error_msg"), + [ + ( + TaggedTable( + { + "feature": [0, 1, 2], + "non_feature": [1, 2, 3], + "target": [3, None, 5], + }, + "target", + ["feature"], + ), + ColumnIsTargetError, + 'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable( + { + "feature": [0, None, 2], + "non_feature": [1, 2, 3], + "target": [None, 4, 5], + }, + "target", + ["feature"], + ), + ColumnIsTargetError, + 'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable( + { + "feature": [0, 1, 2], + "non_feature": [1, None, 3], + "target": [3, 4, None], + }, + "target", + ["feature"], + ), + ColumnIsTargetError, + 'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable( + { + "feature": [0, None, 2], + "non_feature": [1, None, 3], + "target": [3, None, 5], + }, + "target", + ["feature"], + ), + ColumnIsTargetError, + 'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable( + { + "feature": [0, None, 2], + "non_feature": [1, 2, 3], + "target": [3, 2, 5], + }, + "target", + ["feature"], + ), + IllegalSchemaModificationError, + "Illegal schema modification: You cannot remove every feature column.", + ), + ( + TaggedTable( + { + "feature": [0, None, 2], + "non_feature": [1, None, 3], + "target": [3, 2, 5], + }, + "target", + ["feature"], + ), + IllegalSchemaModificationError, + "Illegal schema modification: You cannot remove every feature column.", + ), + ], + ids=[ + "only_target_incomplete", + "also_feature_incomplete", + "also_non_feature_incomplete", + "all_incomplete", + "all_features_incomplete", + "all_features_and_non_feature_incomplete", + ], +) +def test_should_raise_in_remove_columns_with_missing_values( + table: TaggedTable, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises( + error, + match=error_msg, + ): + table.remove_columns_with_missing_values() diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns_with_non_numerical_values.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns_with_non_numerical_values.py new file mode 100644 index 000000000..5a6251d20 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_columns_with_non_numerical_values.py @@ -0,0 +1,178 @@ +import pytest +from safeds.data.tabular.containers import TaggedTable +from safeds.exceptions import ColumnIsTargetError, IllegalSchemaModificationError + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "expected"), + [ + ( + TaggedTable( + { + "feature_numerical": [0, 1, 2], + "feature_non_numerical": ["a", "b", "c"], + "non_feature_numerical": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_numerical", "feature_non_numerical"], + ), + TaggedTable( + { + "feature_numerical": [0, 1, 2], + "non_feature_numerical": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_numerical"], + ), + ), + ( + TaggedTable( + { + "feature_numerical": [0, 1, 2], + "non_feature_numerical": [7, 8, 9], + "non_feature_non_numerical": ["a", "b", "c"], + "target": [3, 4, 5], + }, + "target", + ["feature_numerical"], + ), + TaggedTable( + { + "feature_numerical": [0, 1, 2], + "non_feature_numerical": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_numerical"], + ), + ), + ( + TaggedTable( + { + "feature_numerical": [0, 1, 2], + "non_feature_numerical": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_numerical"], + ), + TaggedTable( + { + "feature_numerical": [0, 1, 2], + "non_feature_numerical": [7, 8, 9], + "target": [3, 4, 5], + }, + "target", + ["feature_numerical"], + ), + ), + ], + ids=["non_numerical_feature", "non_numerical_non_feature", "all_numerical"], +) +def test_should_remove_columns_with_non_numerical_values(table: TaggedTable, expected: TaggedTable) -> None: + new_table = table.remove_columns_with_non_numerical_values() + assert_that_tagged_tables_are_equal(new_table, expected) + + +@pytest.mark.parametrize( + ("table", "error", "error_msg"), + [ + ( + TaggedTable( + { + "feature": [0, 1, 2], + "non_feature": [1, 2, 3], + "target": ["a", "b", "c"], + }, + "target", + ["feature"], + ), + ColumnIsTargetError, + r'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable( + { + "feature": [0, "x", 2], + "non_feature": [1, 2, 3], + "target": ["a", "b", "c"], + }, + "target", + ["feature"], + ), + ColumnIsTargetError, + r'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable( + { + "feature": [0, 1, 2], + "non_feature": [1, "x", 3], + "target": ["a", "b", "c"], + }, + "target", + ["feature"], + ), + ColumnIsTargetError, + r'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable( + { + "feature": [0, "x", 2], + "non_feature": [1, "x", 3], + "target": ["a", "b", "c"], + }, + "target", + ["feature"], + ), + ColumnIsTargetError, + r'Illegal schema modification: Column "target" is the target column and cannot be removed.', + ), + ( + TaggedTable( + { + "feature": [0, "a", 2], + "non_feature": [1, 2, 3], + "target": [3, 2, 5], + }, + "target", + ["feature"], + ), + IllegalSchemaModificationError, + r"Illegal schema modification: You cannot remove every feature column.", + ), + ( + TaggedTable( + { + "feature": [0, "a", 2], + "non_feature": [1, "b", 3], + "target": [3, 2, 5], + }, + "target", + ["feature"], + ), + IllegalSchemaModificationError, + r"Illegal schema modification: You cannot remove every feature column.", + ), + ], + ids=[ + "only_target_non_numerical", + "also_feature_non_numerical", + "also_non_feature_non_numerical", + "all_non_numerical", + "all_features_non_numerical", + "all_features_and_non_feature_non_numerical", + ], +) +def test_should_raise_in_remove_columns_with_non_numerical_values( + table: TaggedTable, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises(error, match=error_msg): + table.remove_columns_with_non_numerical_values() diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_duplicate_rows.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_duplicate_rows.py new file mode 100644 index 000000000..1cc6936e3 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_duplicate_rows.py @@ -0,0 +1,47 @@ +import pytest +from safeds.data.tabular.containers import TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "expected"), + [ + ( + TaggedTable( + { + "feature": [0, 0, 1], + "target": [2, 2, 3], + }, + "target", + ), + TaggedTable( + { + "feature": [0, 1], + "target": [2, 3], + }, + "target", + ), + ), + ( + TaggedTable( + { + "feature": [0, 1, 2], + "target": [2, 2, 3], + }, + "target", + ), + TaggedTable( + { + "feature": [0, 1, 2], + "target": [2, 2, 3], + }, + "target", + ), + ), + ], + ids=["with_duplicate_rows", "without_duplicate_rows"], +) +def test_should_remove_duplicate_rows(table: TaggedTable, expected: TaggedTable) -> None: + new_table = table.remove_duplicate_rows() + assert_that_tagged_tables_are_equal(new_table, expected) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_rows_with_missing_values.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_rows_with_missing_values.py new file mode 100644 index 000000000..2f22f1489 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_rows_with_missing_values.py @@ -0,0 +1,47 @@ +import pytest +from safeds.data.tabular.containers import TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "expected"), + [ + ( + TaggedTable( + { + "feature": [0.0, None, 2.0], + "target": [3.0, 4.0, 5.0], + }, + "target", + ), + TaggedTable( + { + "feature": [0.0, 2.0], + "target": [3.0, 5.0], + }, + "target", + ), + ), + ( + TaggedTable( + { + "feature": [0.0, 1.0, 2.0], + "target": [3.0, 4.0, 5.0], + }, + "target", + ), + TaggedTable( + { + "feature": [0.0, 1.0, 2.0], + "target": [3.0, 4.0, 5.0], + }, + "target", + ), + ), + ], + ids=["with_missing_values", "without_missing_values"], +) +def test_should_remove_rows_with_missing_values(table: TaggedTable, expected: TaggedTable) -> None: + new_table = table.remove_rows_with_missing_values() + assert_that_tagged_tables_are_equal(new_table, expected) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_rows_with_outliers.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_rows_with_outliers.py new file mode 100644 index 000000000..59a5704f0 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_remove_rows_with_outliers.py @@ -0,0 +1,47 @@ +import pytest +from safeds.data.tabular.containers import TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "expected"), + [ + ( + TaggedTable( + { + "feature": [1.0, 11.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "target": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + }, + "target", + ), + TaggedTable( + { + "feature": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "target": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + }, + "target", + ), + ), + ( + TaggedTable( + { + "feature": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "target": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + }, + "target", + ), + TaggedTable( + { + "feature": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "target": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + }, + "target", + ), + ), + ], + ids=["with_outliers", "no_outliers"], +) +def test_should_remove_rows_with_outliers(table: TaggedTable, expected: TaggedTable) -> None: + new_table = table.remove_rows_with_outliers() + assert_that_tagged_tables_are_equal(new_table, expected) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_rename_column.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_rename_column.py new file mode 100644 index 000000000..051c7fb90 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_rename_column.py @@ -0,0 +1,86 @@ +import pytest +from safeds.data.tabular.containers import TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("original_table", "old_column_name", "new_column_name", "result_table"), + [ + ( + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature": [2, 3, 4], + "target": [3, 4, 5], + }, + target_name="target", + feature_names=["feature_old"], + ), + "feature_old", + "feature_new", + TaggedTable( + { + "feature_new": [0, 1, 2], + "no_feature": [2, 3, 4], + "target": [3, 4, 5], + }, + target_name="target", + feature_names=["feature_new"], + ), + ), + ( + TaggedTable( + { + "feature": [0, 1, 2], + "no_feature": [2, 3, 4], + "target_old": [3, 4, 5], + }, + target_name="target_old", + feature_names=["feature"], + ), + "target_old", + "target_new", + TaggedTable( + { + "feature": [0, 1, 2], + "no_feature": [2, 3, 4], + "target_new": [3, 4, 5], + }, + target_name="target_new", + feature_names=["feature"], + ), + ), + ( + TaggedTable( + { + "feature": [0, 1, 2], + "no_feature_old": [2, 3, 4], + "target": [3, 4, 5], + }, + target_name="target", + feature_names=["feature"], + ), + "no_feature_old", + "no_feature_new", + TaggedTable( + { + "feature": [0, 1, 2], + "no_feature_new": [2, 3, 4], + "target": [3, 4, 5], + }, + target_name="target", + feature_names=["feature"], + ), + ), + ], + ids=["rename_feature_column", "rename_target_column", "rename_non_feature_column"], +) +def test_should_add_column( + original_table: TaggedTable, + old_column_name: str, + new_column_name: str, + result_table: TaggedTable, +) -> None: + new_table = original_table.rename_column(old_column_name, new_column_name) + assert_that_tagged_tables_are_equal(new_table, result_table) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_replace_column.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_replace_column.py new file mode 100644 index 000000000..72b773adc --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_replace_column.py @@ -0,0 +1,179 @@ +import pytest +from safeds.data.tabular.containers import Column, TaggedTable +from safeds.exceptions import IllegalSchemaModificationError + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("original_table", "new_columns", "column_name_to_be_replaced", "result_table"), + [ + ( + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature_old": [2, 3, 4], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_old"], + ), + [Column("feature_new", [2, 1, 0])], + "feature_old", + TaggedTable( + { + "feature_new": [2, 1, 0], + "no_feature_old": [2, 3, 4], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_new"], + ), + ), + ( + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature_old": [2, 3, 4], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_old"], + ), + [Column("feature_new_a", [2, 1, 0]), Column("feature_new_b", [4, 2, 0])], + "feature_old", + TaggedTable( + { + "feature_new_a": [2, 1, 0], + "feature_new_b": [4, 2, 0], + "no_feature_old": [2, 3, 4], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_new_a", "feature_new_b"], + ), + ), + ( + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature_old": [2, 3, 4], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_old"], + ), + [Column("no_feature_new", [2, 1, 0])], + "no_feature_old", + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature_new": [2, 1, 0], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_old"], + ), + ), + ( + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature_old": [2, 3, 4], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_old"], + ), + [Column("no_feature_new_a", [2, 1, 0]), Column("no_feature_new_b", [4, 2, 0])], + "no_feature_old", + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature_new_a": [2, 1, 0], + "no_feature_new_b": [4, 2, 0], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_old"], + ), + ), + ( + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature_old": [2, 3, 4], + "target_old": [3, 4, 5], + }, + "target_old", + ["feature_old"], + ), + [Column("target_new", [2, 1, 0])], + "target_old", + TaggedTable( + { + "feature_old": [0, 1, 2], + "no_feature_old": [2, 3, 4], + "target_new": [2, 1, 0], + }, + "target_new", + ["feature_old"], + ), + ), + ], + ids=[ + "replace_feature_column_with_one", + "replace_feature_column_with_multiple", + "replace_non_feature_column_with_one", + "replace_non_feature_column_with_multiple", + "replace_target_column", + ], +) +def test_should_replace_column( + original_table: TaggedTable, + new_columns: list[Column], + column_name_to_be_replaced: str, + result_table: TaggedTable, +) -> None: + new_table = original_table.replace_column(column_name_to_be_replaced, new_columns) + assert_that_tagged_tables_are_equal(new_table, result_table) + + +@pytest.mark.parametrize( + ("original_table", "new_columns", "column_name_to_be_replaced"), + [ + ( + TaggedTable( + { + "feature_old": [0, 1, 2], + "target_old": [3, 4, 5], + }, + "target_old", + ), + [], + "target_old", + ), + ( + TaggedTable( + { + "feature_old": [0, 1, 2], + "target_old": [3, 4, 5], + }, + "target_old", + ), + [Column("target_new_a", [2, 1, 0]), Column("target_new_b"), [4, 2, 0]], + "target_old", + ), + ], + ids=["zero_columns", "multiple_columns"], +) +def test_should_throw_illegal_schema_modification( + original_table: TaggedTable, + new_columns: list[Column], + column_name_to_be_replaced: str, +) -> None: + with pytest.raises( + IllegalSchemaModificationError, + match='Target column "target_old" can only be replaced by exactly one new column.', + ): + original_table.replace_column(column_name_to_be_replaced, new_columns) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_shuffle_rows.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_shuffle_rows.py new file mode 100644 index 000000000..767a6c14f --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_shuffle_rows.py @@ -0,0 +1,37 @@ +import pytest +from safeds.data.tabular.containers import Row, Table, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("rows", "target_name", "feature_names"), + [ + ( + [ + Row({"feature_a": 0, "feature_b": 3, "no_feature": 6, "target": 9}), + Row({"feature_a": 1, "feature_b": 4, "no_feature": 7, "target": 10}), + Row({"feature_a": 2, "feature_b": 5, "no_feature": 8, "target": 11}), + ], + "target", + ["feature_a", "feature_b"], + ), + ], + ids=["table"], +) +def test_should_shuffle_rows(rows: list[Row], target_name: str, feature_names: list[str]) -> None: + table = TaggedTable._from_table(Table.from_rows(rows), target_name=target_name, feature_names=feature_names) + shuffled = table.shuffle_rows() + assert table.schema == shuffled.schema + assert table.features.column_names == shuffled.features.column_names + assert table.target.name == shuffled.target.name + + # Check that shuffled contains the original rows: + for i in range(table.number_of_rows): + assert shuffled.get_row(i) in rows + + # Assert that table and shuffled are equal again after sorting: + def comparator(r1: Row, r2: Row) -> int: + return 1 if r1.__repr__() < r2.__repr__() else -1 + + assert_that_tagged_tables_are_equal(table.sort_rows(comparator), shuffled.sort_rows(comparator)) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_slice_rows.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_slice_rows.py new file mode 100644 index 000000000..c4dcfeee2 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_slice_rows.py @@ -0,0 +1,55 @@ +import pytest +from _pytest.python_api import raises +from safeds.data.tabular.containers import TaggedTable +from safeds.exceptions import IndexOutOfBoundsError + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "test_table", "second_test_table"), + [ + ( + TaggedTable( + {"feature": [1, 2, 1], "non_feature": [0, 2, 4], "target": [1, 2, 4]}, + target_name="target", + feature_names=["non_feature"], + ), + TaggedTable( + {"feature": [1, 2], "non_feature": [0, 2], "target": [1, 2]}, + target_name="target", + feature_names=["non_feature"], + ), + TaggedTable( + {"feature": [1, 1], "non_feature": [0, 4], "target": [1, 4]}, + target_name="target", + feature_names=["non_feature"], + ), + ), + ], + ids=["Table with three rows"], +) +def test_should_slice_rows(table: TaggedTable, test_table: TaggedTable, second_test_table: TaggedTable) -> None: + new_table = table.slice_rows(0, 2, 1) + second_new_table = table.slice_rows(0, 3, 2) + third_new_table = table.slice_rows() + assert_that_tagged_tables_are_equal(new_table, test_table) + assert_that_tagged_tables_are_equal(second_new_table, second_test_table) + assert_that_tagged_tables_are_equal(third_new_table, table) + + +@pytest.mark.parametrize( + ("start", "end", "step", "error_message"), + [ + (3, 2, 1, r"There is no element in the range \[3, 2\]"), + (4, 0, 1, r"There is no element in the range \[4, 0\]"), + (0, 4, 1, r"There is no element at index '4'"), + (-4, 0, 1, r"There is no element at index '-4'"), + (0, -4, 1, r"There is no element in the range \[0, -4\]"), + ], +) +def test_should_raise_if_index_out_of_bounds(start: int, end: int, step: int, error_message: str) -> None: + table = TaggedTable({"feature": [1, 2, 1], "target": [1, 2, 4]}, "target") + + with raises(IndexOutOfBoundsError, match=error_message): + table.slice_rows(start, end, step) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_sort_columns.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_sort_columns.py new file mode 100644 index 000000000..4ecdb78a7 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_sort_columns.py @@ -0,0 +1,55 @@ +from collections.abc import Callable + +import pytest +from safeds.data.tabular.containers import Column, TaggedTable + + +@pytest.mark.parametrize( + ("query", "col1", "col2", "col3", "col4"), + [ + (None, 0, 1, 2, 3), + ( + lambda col1, col2: (col1.name < col2.name) - (col1.name > col2.name), + 3, + 2, + 1, + 0, + ), + ], + ids=["no query", "with query"], +) +def test_should_return_sorted_table( + query: Callable[[Column, Column], int], + col1: int, + col2: int, + col3: int, + col4: int, +) -> None: + columns = [ + Column("col1", ["A", "B", "C", "A", "D"]), + Column("col2", ["Test1", "Test1", "Test3", "Test1", "Test4"]), + Column("col3", [1, 2, 3, 4, 5]), + Column("col4", [2, 3, 1, 4, 6]), + ] + table1 = TaggedTable( + { + "col2": ["Test1", "Test1", "Test3", "Test1", "Test4"], + "col3": [1, 2, 3, 4, 5], + "col4": [2, 3, 1, 4, 6], + "col1": ["A", "B", "C", "A", "D"], + }, + target_name="col1", + feature_names=["col4", "col3"], + ) + if query is not None: + table_sorted = table1.sort_columns(query) + else: + table_sorted = table1.sort_columns() + table_sorted_columns = table_sorted.to_columns() + assert table_sorted.schema == table1.schema + assert table_sorted_columns[0] == columns[col1] + assert table_sorted_columns[1] == columns[col2] + assert table_sorted_columns[2] == columns[col3] + assert table_sorted_columns[3] == columns[col4] + assert table_sorted.features == table1.features + assert table_sorted.target == table1.target diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_sort_rows.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_sort_rows.py new file mode 100644 index 000000000..1cd86c6b4 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_sort_rows.py @@ -0,0 +1,56 @@ +from collections.abc import Callable + +import pytest +from safeds.data.tabular.containers import Row, TaggedTable + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "comparator", "expected"), + [ + ( + TaggedTable({"feature": [3, 2, 1], "non_feature": [1, 1, 1], "target": [0, 0, 0]}, "target"), + lambda row1, row2: row1["feature"] - row2["feature"], + TaggedTable({"feature": [1, 2, 3], "non_feature": [1, 1, 1], "target": [0, 0, 0]}, "target"), + ), + ( + TaggedTable({"feature": [1, 2, 3], "non_feature": [1, 1, 1], "target": [0, 0, 0]}, "target"), + lambda row1, row2: row1["feature"] - row2["feature"], + TaggedTable({"feature": [1, 2, 3], "non_feature": [1, 1, 1], "target": [0, 0, 0]}, "target"), + ), + ], + ids=["unsorted", "already_sorted"], +) +def test_should_sort_table( + table: TaggedTable, + comparator: Callable[[Row, Row], int], + expected: TaggedTable, +) -> None: + table_sorted = table.sort_rows(comparator) + assert_that_tagged_tables_are_equal(table_sorted, expected) + + +@pytest.mark.parametrize( + ("table", "comparator", "table_copy"), + [ + ( + TaggedTable({"feature": [3, 2, 1], "non_feature": [1, 1, 1], "target": [0, 0, 0]}, "target"), + lambda row1, row2: row1["feature"] - row2["feature"], + TaggedTable({"feature": [3, 2, 1], "non_feature": [1, 1, 1], "target": [0, 0, 0]}, "target"), + ), + ( + TaggedTable({"feature": [1, 2, 3], "non_feature": [1, 1, 1], "target": [0, 0, 0]}, "target"), + lambda row1, row2: row1["feature"] - row2["feature"], + TaggedTable({"feature": [1, 2, 3], "non_feature": [1, 1, 1], "target": [0, 0, 0]}, "target"), + ), + ], + ids=["unsorted", "already_sorted"], +) +def test_should_not_modify_original_table( + table: TaggedTable, + comparator: Callable[[Row, Row], int], + table_copy: TaggedTable, +) -> None: + table.sort_rows(comparator) + assert_that_tagged_tables_are_equal(table, table_copy) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_target.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_target.py new file mode 100644 index 000000000..755721123 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_target.py @@ -0,0 +1,24 @@ +import pytest +from safeds.data.tabular.containers import Column, TaggedTable + + +@pytest.mark.parametrize( + ("tagged_table", "target_column"), + [ + ( + TaggedTable( + { + "A": [1, 4], + "B": [2, 5], + "C": [3, 6], + "T": [0, 1], + }, + target_name="T", + ), + Column("T", [0, 1]), + ), + ], + ids=["target"], +) +def test_should_return_target(tagged_table: TaggedTable, target_column: Column) -> None: + assert tagged_table.target == target_column diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/test_transform_column.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_transform_column.py new file mode 100644 index 000000000..efcc6bc1c --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/test_transform_column.py @@ -0,0 +1,73 @@ +import pytest +from safeds.data.tabular.containers import TaggedTable +from safeds.exceptions import UnknownColumnNameError + +from tests.helpers import assert_that_tagged_tables_are_equal + + +@pytest.mark.parametrize( + ("table", "column_name", "table_transformed"), + [ + ( + TaggedTable({"feature_a": [1, 2, 3], "feature_b": [4, 5, 6], "target": [1, 2, 3]}, "target"), + "feature_a", + TaggedTable({"feature_a": [2, 4, 6], "feature_b": [4, 5, 6], "target": [1, 2, 3]}, "target"), + ), + ( + TaggedTable({"feature_a": [1, 2, 3], "feature_b": [4, 5, 6], "target": [1, 2, 3]}, "target"), + "target", + TaggedTable({"feature_a": [1, 2, 3], "feature_b": [4, 5, 6], "target": [2, 4, 6]}, "target"), + ), + ( + TaggedTable( + {"feature_a": [1, 2, 3], "b": [4, 5, 6], "target": [1, 2, 3]}, + target_name="target", + feature_names=["feature_a"], + ), + "b", + TaggedTable( + {"feature_a": [1, 2, 3], "b": [8, 10, 12], "target": [1, 2, 3]}, + target_name="target", + feature_names=["feature_a"], + ), + ), + ], + ids=["transform_feature_column", "transform_target_column", "transform_column_that_is_neither"], +) +def test_should_transform_column(table: TaggedTable, column_name: str, table_transformed: TaggedTable) -> None: + result = table.transform_column(column_name, lambda row: row.get_value(column_name) * 2) + assert_that_tagged_tables_are_equal(result, table_transformed) + + +@pytest.mark.parametrize( + ("table", "column_name"), + [ + ( + TaggedTable( + { + "A": [1, 2, 3], + "B": [4, 5, 6], + "C": ["a", "b", "c"], + }, + "C", + ), + "D", + ), + ( + TaggedTable( + { + "A": [1, 2, 3], + "B": [4, 5, 6], + "C": ["a", "b", "c"], + }, + target_name="C", + feature_names=["A"], + ), + "D", + ), + ], + ids=["has_only_features_and_target", "has_columns_that_are_neither"], +) +def test_should_raise_if_column_not_found(table: TaggedTable, column_name: str) -> None: + with pytest.raises(UnknownColumnNameError, match=rf"Could not find column\(s\) '{column_name}'"): + table.transform_column(column_name, lambda row: row.get_value("A") * 2) diff --git a/tests/safeds/data/tabular/containers/_table/test_as_table.py b/tests/safeds/data/tabular/containers/_table/test_as_table.py new file mode 100644 index 000000000..7f72e76e8 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/test_as_table.py @@ -0,0 +1,9 @@ +import pytest +from safeds.data.tabular.containers import Table + + +@pytest.mark.parametrize("table", [Table({"col1": [1, 2], "col2:": [3, 4]}), Table()], ids=["table", "empty"]) +def should_return_table(table: Table) -> None: + new_table = table._as_table() + assert table.schema == new_table.schema + assert table == new_table diff --git a/tests/safeds/data/tabular/containers/test_tagged_table.py b/tests/safeds/data/tabular/containers/test_tagged_table.py deleted file mode 100644 index e7636c0f3..000000000 --- a/tests/safeds/data/tabular/containers/test_tagged_table.py +++ /dev/null @@ -1,102 +0,0 @@ -import pytest -from safeds.data.tabular.containers import Column, Table, TaggedTable -from safeds.exceptions import UnknownColumnNameError - - -@pytest.fixture() -def data() -> dict[str, list[int]]: - return { - "A": [1, 4], - "B": [2, 5], - "C": [3, 6], - "T": [0, 1], - } - - -@pytest.fixture() -def table(data: dict[str, list[int]]) -> Table: - return Table(data) - - -@pytest.fixture() -def tagged_table(table: Table) -> TaggedTable: - return table.tag_columns(target_name="T") - - -class TestFromTable: - def test_should_raise_if_a_feature_does_not_exist(self, table: Table) -> None: - with pytest.raises(UnknownColumnNameError): - TaggedTable._from_table(table, target_name="T", feature_names=["A", "B", "C", "D"]) - - def test_should_raise_if_target_does_not_exist(self, table: Table) -> None: - with pytest.raises(UnknownColumnNameError): - TaggedTable._from_table(table, target_name="D") - - def test_should_raise_if_features_and_target_overlap(self, table: Table) -> None: - with pytest.raises(ValueError, match="Column 'A' cannot be both feature and target."): - TaggedTable._from_table(table, target_name="A", feature_names=["A", "B", "C"]) - - def test_should_raise_if_features_are_empty_explicitly(self, table: Table) -> None: - with pytest.raises(ValueError, match="At least one feature column must be specified."): - TaggedTable._from_table(table, target_name="A", feature_names=[]) - - def test_should_raise_if_features_are_empty_implicitly(self) -> None: - table = Table({"A": [1, 4]}) - - with pytest.raises(ValueError, match="At least one feature column must be specified."): - TaggedTable._from_table(table, target_name="A") - - -class TestInit: - def test_should_raise_if_a_feature_does_not_exist(self, data: dict[str, list[int]]) -> None: - with pytest.raises(UnknownColumnNameError): - TaggedTable(data, target_name="T", feature_names=["A", "B", "C", "D"]) - - def test_should_raise_if_target_does_not_exist(self, data: dict[str, list[int]]) -> None: - with pytest.raises(UnknownColumnNameError): - TaggedTable(data, target_name="D") - - def test_should_raise_if_features_and_target_overlap(self, data: dict[str, list[int]]) -> None: - with pytest.raises(ValueError, match="Column 'A' cannot be both feature and target."): - TaggedTable(data, target_name="A", feature_names=["A", "B", "C"]) - - def test_should_raise_if_features_are_empty_explicitly(self, data: dict[str, list[int]]) -> None: - with pytest.raises(ValueError, match="At least one feature column must be specified."): - TaggedTable(data, target_name="A", feature_names=[]) - - def test_should_raise_if_features_are_empty_implicitly(self) -> None: - data = {"A": [1, 4]} - - with pytest.raises(ValueError, match="At least one feature column must be specified."): - TaggedTable(data, target_name="A") - - -class TestFeatures: - def test_should_return_features(self, tagged_table: TaggedTable) -> None: - assert tagged_table.features == Table( - { - "A": [1, 4], - "B": [2, 5], - "C": [3, 6], - }, - ) - - -class TestTarget: - def test_should_return_target(self, tagged_table: TaggedTable) -> None: - assert tagged_table.target == Column("T", [0, 1]) - - -class TestCopy: - @pytest.mark.parametrize( - "tagged_table", - [ - TaggedTable({"a": [], "b": []}, target_name="b", feature_names=["a"]), - TaggedTable({"a": ["a", 3, 0.1], "b": [True, False, None]}, target_name="b", feature_names=["a"]), - ], - ids=["empty-rows", "normal"], - ) - def test_should_copy_tagged_table(self, tagged_table: TaggedTable) -> None: - copied = tagged_table._copy() - assert copied == tagged_table - assert copied is not tagged_table diff --git a/tests/safeds/ml/classical/classification/test_classifier.py b/tests/safeds/ml/classical/classification/test_classifier.py index d568c8cfd..714fc8ed2 100644 --- a/tests/safeds/ml/classical/classification/test_classifier.py +++ b/tests/safeds/ml/classical/classification/test_classifier.py @@ -169,8 +169,8 @@ def test_should_include_features_of_input_table(self, classifier: Classifier, va def test_should_include_complete_input_table(self, classifier: Classifier, valid_data: TaggedTable) -> None: fitted_regressor = classifier.fit(valid_data) - prediction = fitted_regressor.predict(valid_data.remove_columns(["target"])) - assert prediction.remove_columns(["target"]) == valid_data.remove_columns(["target"]) + prediction = fitted_regressor.predict(valid_data.features) + assert prediction.features == valid_data.features def test_should_set_correct_target_name(self, classifier: Classifier, valid_data: TaggedTable) -> None: fitted_classifier = classifier.fit(valid_data) @@ -196,7 +196,7 @@ def test_should_raise_if_dataset_contains_target(self, classifier: Classifier, v def test_should_raise_if_dataset_misses_features(self, classifier: Classifier, valid_data: TaggedTable) -> None: fitted_classifier = classifier.fit(valid_data) with pytest.raises(DatasetMissesFeaturesError, match="[feat1, feat2]"): - fitted_classifier.predict(valid_data.remove_columns(["feat1", "feat2", "target"])) + fitted_classifier.predict(valid_data.features.remove_columns(["feat1", "feat2"])) @pytest.mark.parametrize( ("invalid_data", "expected_error", "expected_error_msg"), diff --git a/tests/safeds/ml/classical/regression/test_regressor.py b/tests/safeds/ml/classical/regression/test_regressor.py index 5f75a399a..0191ffb32 100644 --- a/tests/safeds/ml/classical/regression/test_regressor.py +++ b/tests/safeds/ml/classical/regression/test_regressor.py @@ -170,8 +170,8 @@ def test_should_include_features_of_input_table(self, regressor: Regressor, vali def test_should_include_complete_input_table(self, regressor: Regressor, valid_data: TaggedTable) -> None: fitted_regressor = regressor.fit(valid_data) - prediction = fitted_regressor.predict(valid_data.remove_columns(["target"])) - assert prediction.remove_columns(["target"]) == valid_data.remove_columns(["target"]) + prediction = fitted_regressor.predict(valid_data.features) + assert prediction.features == valid_data.features def test_should_set_correct_target_name(self, regressor: Regressor, valid_data: TaggedTable) -> None: fitted_regressor = regressor.fit(valid_data) @@ -197,7 +197,7 @@ def test_should_raise_if_dataset_contains_target(self, regressor: Regressor, val def test_should_raise_if_dataset_misses_features(self, regressor: Regressor, valid_data: TaggedTable) -> None: fitted_regressor = regressor.fit(valid_data) with pytest.raises(DatasetMissesFeaturesError, match="[feat1, feat2]"): - fitted_regressor.predict(valid_data.remove_columns(["feat1", "feat2", "target"])) + fitted_regressor.predict(valid_data.features.remove_columns(["feat1", "feat2"])) @pytest.mark.parametrize( ("invalid_data", "expected_error", "expected_error_msg"),