diff --git a/tests/numeric/test_CutTransformer.py b/tests/numeric/test_CutTransformer.py index 25f7f77..fca7d12 100644 --- a/tests/numeric/test_CutTransformer.py +++ b/tests/numeric/test_CutTransformer.py @@ -1,47 +1,21 @@ -import re - -import pandas import pandas as pd import pytest import test_aide as ta import tests.test_data as d -import tubular +from tests.numeric.test_BaseNumericTransformer import ( + BaseNumericTransformerInitTests, + BaseNumericTransformerTransformTests, +) from tubular.numeric import CutTransformer -class TestInit: +class TestInit(BaseNumericTransformerInitTests): """Tests for CutTransformer.init().""" - def test_super_init_called(self, mocker): - """Test that init calls BaseTransformer.init.""" - expected_call_args = { - 0: { - "args": (), - "kwargs": {"columns": ["a"], "verbose": False}, - }, - } - - with ta.functions.assert_function_call( - mocker, - tubular.base.BaseTransformer, - "__init__", - expected_call_args, - ): - CutTransformer(column="a", new_column_name="b", verbose=False) - - def test_column_type_error(self): - """Test that an exception is raised if column is not a str.""" - with pytest.raises( - TypeError, - match=re.escape( - "CutTransformer: column arg (name of column) should be a single str giving the column to discretise", - ), - ): - CutTransformer( - column=["a"], - new_column_name="a", - ) + @classmethod + def setup_class(cls): + cls.transformer_name = "CutTransformer" def test_new_column_name_type_error(self): """Test that an exception is raised if new_column_name is not a str.""" @@ -71,29 +45,14 @@ def test_cut_kwargs_key_type_error(self): cut_kwargs={"a": 1, 2: "b"}, ) - def test_inputs_set_to_attribute(self): - """Test that the values passed in init are set to attributes.""" - x = CutTransformer( - column="b", - new_column_name="a", - cut_kwargs={"a": 1, "b": 2}, - ) - ta.classes.test_object_attributes( - obj=x, - expected_attributes={ - "column": "b", - "columns": ["b"], - "new_column_name": "a", - "cut_kwargs": {"a": 1, "b": 2}, - }, - msg="Attributes for CutTransformer set in init", - ) - - -class TestTransform: +class TestTransform(BaseNumericTransformerTransformTests): """Tests for CutTransformer.transform().""" + @classmethod + def setup_class(cls): + cls.transformer_name = "CutTransformer" + def expected_df_1(): """Expected output for test_expected_output.""" df = d.create_df_9() @@ -106,49 +65,6 @@ def expected_df_1(): return df - def test_super_transform_call(self, mocker): - """Test the call to BaseTransformer.transform is as expected.""" - df = d.create_df_9() - - x = CutTransformer(column="a", new_column_name="Y", cut_kwargs={"bins": 3}) - - expected_call_args = {0: {"args": (d.create_df_9(),), "kwargs": {}}} - - with ta.functions.assert_function_call( - mocker, - tubular.base.BaseTransformer, - "transform", - expected_call_args, - return_value=d.create_df_9(), - ): - x.transform(df) - - def test_pd_cut_call(self, mocker): - """Test the call to pd.cut is as expected.""" - df = d.create_df_9() - - x = CutTransformer( - column="a", - new_column_name="a_cut", - cut_kwargs={"bins": 3, "right": False, "precision": 2}, - ) - - expected_call_args = { - 0: { - "args": (d.create_df_9()["a"].to_numpy(),), - "kwargs": {"bins": 3, "right": False, "precision": 2}, - }, - } - - with ta.functions.assert_function_call( - mocker, - pandas, - "cut", - expected_call_args, - return_value=[1, 2, 3, 4, 5, 6], - ): - x.transform(df) - def test_output_from_cut_assigned_to_column(self, mocker): """Test that the output from pd.cut is assigned to column with name new_column_name.""" df = d.create_df_9() @@ -187,15 +103,3 @@ def test_expected_output(self, df, expected): actual=df_transformed, msg="CutTransformer.transform output", ) - - def test_non_numeric_column_error(self): - """Test that an exception is raised if the column to discretise is not numeric.""" - df = d.create_df_8() - - x = CutTransformer(column="b", new_column_name="d") - - with pytest.raises( - TypeError, - match="CutTransformer: b should be a numeric dtype but got object", - ): - x.transform(df) diff --git a/tubular/numeric.py b/tubular/numeric.py index e84337c..3883eea 100644 --- a/tubular/numeric.py +++ b/tubular/numeric.py @@ -238,7 +238,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame: return X -class CutTransformer(BaseTransformer): +class CutTransformer(BaseNumericTransformer): """Class to bin a column into discrete intervals. Class simply uses the [pd.cut](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.cut.html) @@ -307,10 +307,6 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame: """ X = super().transform(X) - if not pd.api.types.is_numeric_dtype(X[self.columns[0]]): - msg = f"{self.classname()}: {self.columns[0]} should be a numeric dtype but got {X[self.columns[0]].dtype}" - raise TypeError(msg) - X[self.new_column_name] = pd.cut( X[self.columns[0]].to_numpy(), **self.cut_kwargs,