diff --git a/src/safeds/_validation/_check_columns_are_numeric.py b/src/safeds/_validation/_check_columns_are_numeric.py index 0dae84220..84ef66fec 100644 --- a/src/safeds/_validation/_check_columns_are_numeric.py +++ b/src/safeds/_validation/_check_columns_are_numeric.py @@ -7,10 +7,35 @@ if TYPE_CHECKING: from collections.abc import Container - from safeds.data.tabular.containers import Table + from safeds.data.tabular.containers import Column, Table from safeds.data.tabular.typing import Schema +def _check_column_is_numeric( + column: Column, + *, + operation: str = "do a numeric operation", +) -> None: + """ + Check if the column is numeric and raise an error if it is not. + + Parameters + ---------- + column: + The column to check. + operation: + The operation that is performed on the column. This is used in the error message. + + Raises + ------ + ColumnTypeError + If the column is not numeric. + """ + if not column.type.is_numeric: + message = _build_error_message([column.name], operation) + raise ColumnTypeError(message) + + def _check_columns_are_numeric( table_or_schema: Table | Schema, column_names: str | list[str], diff --git a/src/safeds/data/tabular/containers/_column.py b/src/safeds/data/tabular/containers/_column.py index ec514bea2..9a013cbf7 100644 --- a/src/safeds/data/tabular/containers/_column.py +++ b/src/safeds/data/tabular/containers/_column.py @@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload from safeds._utils import _structural_hash +from safeds._validation._check_columns_are_numeric import _check_column_is_numeric from safeds.data.tabular.plotting import ColumnPlotter from safeds.data.tabular.typing._polars_data_type import _PolarsDataType from safeds.exceptions import ( ColumnLengthMismatchError, IndexOutOfBoundsError, MissingValuesColumnError, - NonNumericColumnError, ) from ._lazy_cell import _LazyCell @@ -223,7 +223,7 @@ def get_distinct_values( def get_value(self, index: int) -> T_co: """ - Return the column value at specified index. + Return the column value at specified index. Equivalent to the `[]` operator (indexed access). Nonnegative indices are counted from the beginning (starting at 0), negative indices from the end (starting at -1). @@ -249,6 +249,9 @@ def get_value(self, index: int) -> T_co: >>> column = Column("test", [1, 2, 3]) >>> column.get_value(1) 2 + + >>> column[1] + 2 """ if index < -self.row_count or index >= self.row_count: raise IndexOutOfBoundsError(index) @@ -434,7 +437,7 @@ def count_if( """ Return how many values in the column satisfy the predicate. - The predicate can return one of three values: + The predicate can return one of three results: * True, if the value satisfies the predicate. * False, if the value does not satisfy the predicate. @@ -458,11 +461,6 @@ def count_if( count: The number of values in the column that satisfy the predicate. - Raises - ------ - TypeError - If the predicate does not return a boolean cell. - Examples -------- >>> from safeds.data.tabular.containers import Column @@ -764,8 +762,9 @@ def correlation_with(self, other: Column) -> float: """ import polars as pl - if not self.is_numeric or not other.is_numeric: - raise NonNumericColumnError("") # TODO: Add column names to error message + _check_column_is_numeric(self, operation="calculate the correlation") + _check_column_is_numeric(other, operation="calculate the correlation") + if self.row_count != other.row_count: raise ColumnLengthMismatchError("") # TODO: Add column names to error message if self.missing_value_count() > 0 or other.missing_value_count() > 0: @@ -881,8 +880,7 @@ def mean(self) -> T_co: >>> column.mean() 2.0 """ - if not self.is_numeric: - raise NonNumericColumnError("") # TODO: Add column name to error message + _check_column_is_numeric(self, operation="calculate the mean") return self._series.mean() @@ -910,8 +908,7 @@ def median(self) -> T_co: >>> column.median() 2.0 """ - if not self.is_numeric: - raise NonNumericColumnError("") # TODO: Add column name to error message + _check_column_is_numeric(self, operation="calculate the median") return self._series.median() @@ -1087,8 +1084,7 @@ def standard_deviation(self) -> float: >>> column.standard_deviation() 1.0 """ - if not self.is_numeric: - raise NonNumericColumnError("") # TODO: Add column name to error message + _check_column_is_numeric(self, operation="calculate the standard deviation") return self._series.std() @@ -1116,8 +1112,7 @@ def variance(self) -> float: >>> column.variance() 1.0 """ - if not self.is_numeric: - raise NonNumericColumnError("") # TODO: Add column name to error message + _check_column_is_numeric(self, operation="calculate the variance") return self._series.var() diff --git a/src/safeds/data/tabular/containers/_row.py b/src/safeds/data/tabular/containers/_row.py index 3d818c010..b21aed44e 100644 --- a/src/safeds/data/tabular/containers/_row.py +++ b/src/safeds/data/tabular/containers/_row.py @@ -68,7 +68,7 @@ def schema(self) -> Schema: @abstractmethod def get_value(self, name: str) -> Cell: """ - Get the value of the specified column. + Get the value of the specified column. This is equivalent to using the `[]` operator (indexed access). Parameters ---------- @@ -84,6 +84,29 @@ def get_value(self, name: str) -> Cell: ------ ColumnNotFoundError If the column name does not exist. + + Examples + -------- + >>> from safeds.data.tabular.containers import Table + >>> table = Table({"col1": [1, 2], "col2": [3, 4]}) + >>> table.remove_rows(lambda row: row.get_value("col1") == 1) + +------+------+ + | col1 | col2 | + | --- | --- | + | i64 | i64 | + +=============+ + | 2 | 4 | + +------+------+ + + + >>> table.remove_rows(lambda row: row["col1"] == 1) + +------+------+ + | col1 | col2 | + | --- | --- | + | i64 | i64 | + +=============+ + | 2 | 4 | + +------+------+ """ @abstractmethod diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index bf6998d71..476791915 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, overload from safeds._config import _get_device, _init_default_device from safeds._config._polars import _get_polars_config @@ -1008,6 +1008,73 @@ def transform_column( # Row operations # ------------------------------------------------------------------------------------------------------------------ + @overload + def count_row_if( + self, + predicate: Callable[[Row], Cell[bool | None]], + *, + ignore_unknown: Literal[True] = ..., + ) -> int: ... + + @overload + def count_row_if( + self, + predicate: Callable[[Row], Cell[bool | None]], + *, + ignore_unknown: bool, + ) -> int | None: ... + + def count_row_if( + self, + predicate: Callable[[Row], Cell[bool | None]], + *, + ignore_unknown: bool = True, + ) -> int | None: + """ + Return how many rows in the table satisfy the predicate. + + The predicate can return one of three results: + + * True, if the row satisfies the predicate. + * False, if the row does not satisfy the predicate. + * None, if the truthiness of the predicate is unknown, e.g. due to missing values. + + By default, cases where the truthiness of the predicate is unknown are ignored and this method returns how + often the predicate returns True. + + You can instead enable Kleene logic by setting `ignore_unknown=False`. In this case, this method returns None if + the predicate returns None at least once. Otherwise, it still returns how often the predicate returns True. + + Parameters + ---------- + predicate: + The predicate to apply to each row. + ignore_unknown: + Whether to ignore cases where the truthiness of the predicate is unknown. + + Returns + ------- + count: + The number of rows in the table that satisfy the predicate. + + Examples + -------- + >>> from safeds.data.tabular.containers import Table + >>> table = Table({"col1": [1, 2, 3], "col2": [1, 3, 3]}) + >>> table.count_row_if(lambda row: row["col1"] == row["col2"]) + 2 + + >>> table.count_row_if(lambda row: row["col1"] > row["col2"]) + 0 + """ + expression = predicate(_LazyVectorizedRow(self))._polars_expression + series = self._lazy_frame.select(expression.alias("count")).collect().get_column("count") + + if ignore_unknown or series.null_count() == 0: + return series.sum() + else: + return None + # TODO: Rethink group_rows/group_rows_by_column. They should not return a dict. def remove_duplicate_rows(self) -> Table: diff --git a/src/safeds/data/tabular/plotting/_column_plotter.py b/src/safeds/data/tabular/plotting/_column_plotter.py index fd3ea5f10..e9692c7da 100644 --- a/src/safeds/data/tabular/plotting/_column_plotter.py +++ b/src/safeds/data/tabular/plotting/_column_plotter.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from safeds._utils import _figure_to_image -from safeds.exceptions import NonNumericColumnError +from safeds._validation._check_columns_are_numeric import _check_column_is_numeric if TYPE_CHECKING: from safeds.data.image.containers import Image @@ -49,9 +49,8 @@ def box_plot(self) -> Image: >>> column = Column("test", [1, 2, 3]) >>> boxplot = column.plot.box_plot() """ - if self._column.row_count > 0 and not self._column.is_numeric: - # TODO better error message - raise NonNumericColumnError(f"{self._column.name} is of type {self._column.type}.") + if self._column.row_count > 0: + _check_column_is_numeric(self._column, operation="create a box plot") import matplotlib.pyplot as plt @@ -115,9 +114,8 @@ def lag_plot(self, lag: int) -> Image: >>> column = Column("values", [1, 2, 3, 4]) >>> image = column.plot.lag_plot(2) """ - if self._column.row_count > 0 and not self._column.is_numeric: - # TODO better error message - raise NonNumericColumnError("This time series target contains non-numerical columns.") + if self._column.row_count > 0: + _check_column_is_numeric(self._column, operation="create a lag plot") import matplotlib.pyplot as plt diff --git a/tests/safeds/data/tabular/containers/_column/test_correlation_with.py b/tests/safeds/data/tabular/containers/_column/test_correlation_with.py index 656956c59..e62fe92d9 100644 --- a/tests/safeds/data/tabular/containers/_column/test_correlation_with.py +++ b/tests/safeds/data/tabular/containers/_column/test_correlation_with.py @@ -1,6 +1,10 @@ import pytest from safeds.data.tabular.containers import Column -from safeds.exceptions import ColumnLengthMismatchError, MissingValuesColumnError, NonNumericColumnError +from safeds.exceptions import ( + ColumnLengthMismatchError, + ColumnTypeError, + MissingValuesColumnError, +) @pytest.mark.parametrize( @@ -38,7 +42,7 @@ def test_should_return_correlation_between_two_columns(values1: list, values2: l def test_should_raise_if_columns_are_not_numeric(values1: list, values2: list) -> None: column1 = Column("A", values1) column2 = Column("B", values2) - with pytest.raises(NonNumericColumnError): + with pytest.raises(ColumnTypeError): column1.correlation_with(column2) diff --git a/tests/safeds/data/tabular/containers/_column/test_mean.py b/tests/safeds/data/tabular/containers/_column/test_mean.py index ee9eb5e73..382de2380 100644 --- a/tests/safeds/data/tabular/containers/_column/test_mean.py +++ b/tests/safeds/data/tabular/containers/_column/test_mean.py @@ -1,6 +1,6 @@ import pytest from safeds.data.tabular.containers import Column -from safeds.exceptions import NonNumericColumnError +from safeds.exceptions import ColumnTypeError @pytest.mark.parametrize( @@ -36,5 +36,5 @@ def test_should_return_mean_value(values: list, expected: int) -> None: ) def test_should_raise_if_column_is_not_numeric(values: list) -> None: column = Column("col", values) - with pytest.raises(NonNumericColumnError): + with pytest.raises(ColumnTypeError): column.mean() diff --git a/tests/safeds/data/tabular/containers/_column/test_median.py b/tests/safeds/data/tabular/containers/_column/test_median.py index efc5c8b75..9a0ceb637 100644 --- a/tests/safeds/data/tabular/containers/_column/test_median.py +++ b/tests/safeds/data/tabular/containers/_column/test_median.py @@ -1,6 +1,6 @@ import pytest from safeds.data.tabular.containers import Column -from safeds.exceptions import NonNumericColumnError +from safeds.exceptions import ColumnTypeError @pytest.mark.parametrize( @@ -36,5 +36,5 @@ def test_should_return_median_value(values: list, expected: int) -> None: ) def test_should_raise_if_column_is_not_numeric(values: list) -> None: column = Column("A", values) - with pytest.raises(NonNumericColumnError): + with pytest.raises(ColumnTypeError): column.median() diff --git a/tests/safeds/data/tabular/containers/_column/test_plot_box_plot.py b/tests/safeds/data/tabular/containers/_column/test_plot_box_plot.py index cbbe51791..fbfbd21ad 100644 --- a/tests/safeds/data/tabular/containers/_column/test_plot_box_plot.py +++ b/tests/safeds/data/tabular/containers/_column/test_plot_box_plot.py @@ -1,6 +1,6 @@ import pytest from safeds.data.tabular.containers import Column -from safeds.exceptions import NonNumericColumnError +from safeds.exceptions import ColumnTypeError from syrupy import SnapshotAssertion @@ -24,5 +24,5 @@ def test_should_match_snapshot(column: Column, snapshot_png_image: SnapshotAsser def test_should_raise_if_column_contains_non_numerical_values() -> None: column = Column("a", ["A", "B", "C"]) - with pytest.raises(NonNumericColumnError): + with pytest.raises(ColumnTypeError): column.plot.box_plot() diff --git a/tests/safeds/data/tabular/containers/_column/test_plot_lag_plot.py b/tests/safeds/data/tabular/containers/_column/test_plot_lag_plot.py index 7dd27f969..aa5abbfa0 100644 --- a/tests/safeds/data/tabular/containers/_column/test_plot_lag_plot.py +++ b/tests/safeds/data/tabular/containers/_column/test_plot_lag_plot.py @@ -1,6 +1,6 @@ import pytest from safeds.data.tabular.containers import Column -from safeds.exceptions import NonNumericColumnError +from safeds.exceptions import ColumnTypeError from syrupy import SnapshotAssertion @@ -24,5 +24,5 @@ def test_should_match_snapshot(column: Column, snapshot_png_image: SnapshotAsser def test_should_raise_if_column_contains_non_numerical_values() -> None: column = Column("a", ["A", "B", "C"]) - with pytest.raises(NonNumericColumnError): + with pytest.raises(ColumnTypeError): column.plot.lag_plot(1) diff --git a/tests/safeds/data/tabular/containers/_column/test_standard_deviation.py b/tests/safeds/data/tabular/containers/_column/test_standard_deviation.py index 48aaf0e4a..924614286 100644 --- a/tests/safeds/data/tabular/containers/_column/test_standard_deviation.py +++ b/tests/safeds/data/tabular/containers/_column/test_standard_deviation.py @@ -1,6 +1,6 @@ import pytest from safeds.data.tabular.containers import Column -from safeds.exceptions import NonNumericColumnError +from safeds.exceptions import ColumnTypeError @pytest.mark.parametrize( @@ -34,5 +34,5 @@ def test_should_return_standard_deviation(values: list, expected: int) -> None: ) def test_should_raise_if_column_is_not_numeric(values: list) -> None: column = Column("A", values) - with pytest.raises(NonNumericColumnError): + with pytest.raises(ColumnTypeError): column.standard_deviation() diff --git a/tests/safeds/data/tabular/containers/_column/test_variance.py b/tests/safeds/data/tabular/containers/_column/test_variance.py index 9bddbe400..10b19c217 100644 --- a/tests/safeds/data/tabular/containers/_column/test_variance.py +++ b/tests/safeds/data/tabular/containers/_column/test_variance.py @@ -1,6 +1,6 @@ import pytest from safeds.data.tabular.containers import Column -from safeds.exceptions import NonNumericColumnError +from safeds.exceptions import ColumnTypeError @pytest.mark.parametrize( @@ -34,5 +34,5 @@ def test_should_return_variance(values: list, expected: int) -> None: ) def test_should_raise_if_column_is_not_numeric(values: list) -> None: column = Column("A", values) - with pytest.raises(NonNumericColumnError): + with pytest.raises(ColumnTypeError): column.variance() diff --git a/tests/safeds/data/tabular/containers/_table/test_count_row_if.py b/tests/safeds/data/tabular/containers/_table/test_count_row_if.py new file mode 100644 index 000000000..b4ab7f289 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/test_count_row_if.py @@ -0,0 +1,64 @@ +import pytest +from safeds.data.tabular.containers import Table + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ([], 0), + ([1], 1), + ([2], 0), + ([None], 0), + ([1, None], 1), + ([2, None], 0), + ([1, 2], 1), + ([1, 2, None], 1), + ], + ids=[ + "empty", + "always true", + "always false", + "always unknown", + "true and unknown", + "false and unknown", + "true and false", + "true and false and unknown", + ], +) +def test_should_handle_boolean_logic( + values: list, + expected: int, +) -> None: + table = Table({"a": values}) + assert table.count_row_if(lambda row: row["a"] < 2) == expected + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ([], 0), + ([1], 1), + ([2], 0), + ([None], None), + ([1, None], None), + ([2, None], None), + ([1, 2], 1), + ([1, 2, None], None), + ], + ids=[ + "empty", + "always true", + "always false", + "always unknown", + "true and unknown", + "false and unknown", + "true and false", + "true and false and unknown", + ], +) +def test_should_handle_kleene_logic( + values: list, + expected: int | None, +) -> None: + table = Table({"a": values}) + assert table.count_row_if(lambda row: row["a"] < 2, ignore_unknown=False) == expected