Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ScalingTransformer to Inherit from BaseNumericTransformer and Update Tests #284

Merged
merged 12 commits into from
Sep 19, 2024
16 changes: 8 additions & 8 deletions tests/numeric/test_ScalingTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_super_init_called(self, mocker):


class TestCheckNumericColumns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B: we now set up our test classes to inherit a set of generic tests from other test modules. See the guidance in the issue #275 #268 for an example in this numeric module. You need to set up the test classes for init, fit and transform to inherit from a parent test class and add

@classmethod def setup_class(cls): cls.transformer_name = "ScalingTransformer"

to each so that the inherited tests can function (they look up initialised transformers and minimal dataframes from conftest.py using this)

After doing this check_numeric will be covered by tests in BaseNumericTransformerFitTests and BaseNumericTransformerTransformTests so these tests can be deleted form here once those tests have been inherited.

You can also delete any implementation tests so we are just testing functionality. There should be far fewer tests in this module when you are done!

"""Tests for the check_numeric_columns method."""
"""Tests for the _check_numeric method."""

def test_exception_raised(self):
"""Test an exception is raised if non numeric columns are passed in X."""
Expand All @@ -135,15 +135,15 @@ def test_exception_raised(self):
TypeError,
match=r"""ScalingTransformer: The following columns are not numeric in X; \['b', 'c'\]""",
):
x.check_numeric_columns(df)
x._check_numeric(df)

def test_X_returned(self):
"""Test that the input X is returned from the method."""
df = d.create_df_2()

x = ScalingTransformer(columns=["a"], scaler_type="standard")

df_returned = x.check_numeric_columns(df)
df_returned = x._check_numeric(df)

ta.equality.assert_equal_dispatch(
expected=df,
Expand Down Expand Up @@ -171,8 +171,8 @@ def test_super_fit_call(self, mocker):
):
x.fit(df)

def test_check_numeric_columns_call(self, mocker):
"""Test the call to ScalingTransformer.check_numeric_columns."""
def test_check_numeric_call(self, mocker):
"""Test the call to ScalingTransformer._check_numeric."""
df = d.create_df_2()

x = ScalingTransformer(columns=["a"], scaler_type="standard")
Expand All @@ -182,7 +182,7 @@ def test_check_numeric_columns_call(self, mocker):
with ta.functions.assert_function_call(
mocker,
tubular.numeric.ScalingTransformer,
"check_numeric_columns",
"_check_numeric",
expected_call_args,
return_value=d.create_df_2(),
):
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_super_transform_called(self, mocker):
x.transform(df)

def test_check_numeric_columns_call(self, mocker):
"""Test the call to ScalingTransformer.check_numeric_columns."""
"""Test the call to ScalingTransformer._check_numeric."""
df = d.create_df_2()

x = ScalingTransformer(columns=["a"], scaler_type="standard")
Expand All @@ -276,7 +276,7 @@ def test_check_numeric_columns_call(self, mocker):

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
tubular.numeric.BaseNumericTransformer,
"transform",
expected_call_args,
return_value=d.create_df_2(),
Expand Down
71 changes: 22 additions & 49 deletions tubular/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def _check_numeric(self, X: pd.DataFrame) -> None:
msg = f"{self.classname()}: The following columns are not numeric in X; {non_numeric_columns}"
raise TypeError(msg)

return X

def fit(
self,
X: pd.DataFrame,
Expand Down Expand Up @@ -427,7 +429,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
return X


class ScalingTransformer(BaseTransformer):
class ScalingTransformer(BaseNumericTransformer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P: thanks for the clean up here it is much nicer now!

"""Transformer to perform scaling of numeric columns.

Transformer can apply min max scaling, max absolute scaling or standardisation (subtract mean and divide by std).
Expand All @@ -450,6 +452,13 @@ class ScalingTransformer(BaseTransformer):

"""

# Dictionary mapping scaler types to their corresponding sklearn classes
scaler_options = {
"min_max": MinMaxScaler,
"max_abs": MaxAbsScaler,
"standard": StandardScaler,
}

def __init__(
self,
columns: str | list[str] | None,
Expand All @@ -459,63 +468,33 @@ def __init__(
) -> None:
if scaler_kwargs is None:
scaler_kwargs = {}
else:
if type(scaler_kwargs) is not dict:
msg = f"{self.classname()}: scaler_kwargs should be a dict but got type {type(scaler_kwargs)}"
raise TypeError(msg)

# Validate scaler_kwargs type
if not isinstance(scaler_kwargs, dict):
msg = f"{self.classname()}: scaler_kwargs should be a dict but got type {type(scaler_kwargs)}"
raise TypeError(msg)

for i, k in enumerate(scaler_kwargs.keys()):
if type(k) is not str:
if not isinstance(k, str):
msg = f"{self.classname()}: unexpected type ({type(k)}) for scaler_kwargs key in position {i}, must be str"
raise TypeError(msg)

allowed_scaler_values = ["min_max", "max_abs", "standard"]

if scaler_type not in allowed_scaler_values:
# Validate scaler_type
if scaler_type not in self.scaler_options:
allowed_scaler_values = list(self.scaler_options.keys())
msg = f"{self.classname()}: scaler_type should be one of; {allowed_scaler_values}"
raise ValueError(msg)

if scaler_type == "min_max":
self.scaler = MinMaxScaler(**scaler_kwargs)

elif scaler_type == "max_abs":
self.scaler = MaxAbsScaler(**scaler_kwargs)

elif scaler_type == "standard":
self.scaler = StandardScaler(**scaler_kwargs)

# Initialize scaler using the dictionary
self.scaler = self.scaler_options[scaler_type](**scaler_kwargs)
# This attribute is not for use in any method
# Here only as a fix to allow string representation of transformer.
self.scaler_kwargs = scaler_kwargs
self.scaler_type = scaler_type

super().__init__(columns=columns, **kwargs)

def check_numeric_columns(self, X: pd.DataFrame) -> pd.DataFrame:
"""Method to check all columns (specicifed in self.columns) in X are all numeric.

Parameters
----------
X : pd.DataFrame
Data containing columns to check.

"""
numeric_column_types = X[self.columns].apply(
pd.api.types.is_numeric_dtype,
axis=0,
)

if not numeric_column_types.all():
non_numeric_columns = list(
numeric_column_types.loc[~numeric_column_types].index,
)

msg = f"{self.classname()}: The following columns are not numeric in X; {non_numeric_columns}"
raise TypeError(msg)

return X

def fit(self, X: pd.DataFrame, y: pd.Series | None = None) -> pd.DataFrame:
def fit(self, X: pd.DataFrame, y: pd.Series | None = None) -> ScalingTransformer:
"""Fit scaler to input data.

Parameters
Expand All @@ -528,11 +507,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series | None = None) -> pd.DataFrame:

"""
super().fit(X, y)

X = self.check_numeric_columns(X)

self.scaler.fit(X[self.columns])

return self

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -551,8 +526,6 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
"""
X = super().transform(X)

X = self.check_numeric_columns(X)

X[self.columns] = self.scaler.transform(X[self.columns])

return X
Expand Down
Loading