-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor ScalingTransformer to Inherit from BaseNumericTransformer and Update Tests #284
Merged
davidhopkinson26
merged 12 commits into
main
from
feature/refactor_ScalingTransformer_tests
Sep 19, 2024
Merged
Changes from 2 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
dd0fb45
refactor scalingtransformer class
favouribude1 a04936c
refactor scalingtransformer class
favouribude1 878fee8
Refactored ScalingTransformer Test
favouribude1 84c2393
Refactored ScalingTransformer Test
favouribude1 8e6e815
Refactored ScalingTransformer Test
favouribude1 abc889a
Updated changelog
favouribude1 c2d18dc
Updated changelog
favouribude1 4a177a7
Updating test for scalingtransformer
favouribude1 b937ebb
update test
favouribude1 124bd5b
Added test for scalingtransformer
favouribude1 4f61a07
update changelog
favouribude1 b71b72a
remove unnecessary changelog entry for ScalingTransformer column update
favouribude1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,8 @@ def _check_numeric(self, X: pd.DataFrame) -> None: | |
msg = f"{self.classname()}: The following columns are not numeric in X; {non_numeric_columns}" | ||
raise TypeError(msg) | ||
|
||
return X | ||
|
||
def fit( | ||
self, | ||
X: pd.DataFrame, | ||
|
@@ -427,7 +429,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame: | |
return X | ||
|
||
|
||
class ScalingTransformer(BaseTransformer): | ||
class ScalingTransformer(BaseNumericTransformer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P: thanks for the clean up here it is much nicer now! |
||
"""Transformer to perform scaling of numeric columns. | ||
|
||
Transformer can apply min max scaling, max absolute scaling or standardisation (subtract mean and divide by std). | ||
|
@@ -450,6 +452,13 @@ class ScalingTransformer(BaseTransformer): | |
|
||
""" | ||
|
||
# Dictionary mapping scaler types to their corresponding sklearn classes | ||
scaler_options = { | ||
"min_max": MinMaxScaler, | ||
"max_abs": MaxAbsScaler, | ||
"standard": StandardScaler, | ||
} | ||
|
||
def __init__( | ||
self, | ||
columns: str | list[str] | None, | ||
|
@@ -459,63 +468,33 @@ def __init__( | |
) -> None: | ||
if scaler_kwargs is None: | ||
scaler_kwargs = {} | ||
else: | ||
if type(scaler_kwargs) is not dict: | ||
msg = f"{self.classname()}: scaler_kwargs should be a dict but got type {type(scaler_kwargs)}" | ||
raise TypeError(msg) | ||
|
||
# Validate scaler_kwargs type | ||
if not isinstance(scaler_kwargs, dict): | ||
msg = f"{self.classname()}: scaler_kwargs should be a dict but got type {type(scaler_kwargs)}" | ||
raise TypeError(msg) | ||
|
||
for i, k in enumerate(scaler_kwargs.keys()): | ||
if type(k) is not str: | ||
if not isinstance(k, str): | ||
msg = f"{self.classname()}: unexpected type ({type(k)}) for scaler_kwargs key in position {i}, must be str" | ||
raise TypeError(msg) | ||
|
||
allowed_scaler_values = ["min_max", "max_abs", "standard"] | ||
|
||
if scaler_type not in allowed_scaler_values: | ||
# Validate scaler_type | ||
if scaler_type not in self.scaler_options: | ||
allowed_scaler_values = list(self.scaler_options.keys()) | ||
msg = f"{self.classname()}: scaler_type should be one of; {allowed_scaler_values}" | ||
raise ValueError(msg) | ||
|
||
if scaler_type == "min_max": | ||
self.scaler = MinMaxScaler(**scaler_kwargs) | ||
|
||
elif scaler_type == "max_abs": | ||
self.scaler = MaxAbsScaler(**scaler_kwargs) | ||
|
||
elif scaler_type == "standard": | ||
self.scaler = StandardScaler(**scaler_kwargs) | ||
|
||
# Initialize scaler using the dictionary | ||
self.scaler = self.scaler_options[scaler_type](**scaler_kwargs) | ||
# This attribute is not for use in any method | ||
# Here only as a fix to allow string representation of transformer. | ||
self.scaler_kwargs = scaler_kwargs | ||
self.scaler_type = scaler_type | ||
|
||
super().__init__(columns=columns, **kwargs) | ||
|
||
def check_numeric_columns(self, X: pd.DataFrame) -> pd.DataFrame: | ||
"""Method to check all columns (specicifed in self.columns) in X are all numeric. | ||
|
||
Parameters | ||
---------- | ||
X : pd.DataFrame | ||
Data containing columns to check. | ||
|
||
""" | ||
numeric_column_types = X[self.columns].apply( | ||
pd.api.types.is_numeric_dtype, | ||
axis=0, | ||
) | ||
|
||
if not numeric_column_types.all(): | ||
non_numeric_columns = list( | ||
numeric_column_types.loc[~numeric_column_types].index, | ||
) | ||
|
||
msg = f"{self.classname()}: The following columns are not numeric in X; {non_numeric_columns}" | ||
raise TypeError(msg) | ||
|
||
return X | ||
|
||
def fit(self, X: pd.DataFrame, y: pd.Series | None = None) -> pd.DataFrame: | ||
def fit(self, X: pd.DataFrame, y: pd.Series | None = None) -> ScalingTransformer: | ||
"""Fit scaler to input data. | ||
|
||
Parameters | ||
|
@@ -528,11 +507,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series | None = None) -> pd.DataFrame: | |
|
||
""" | ||
super().fit(X, y) | ||
|
||
X = self.check_numeric_columns(X) | ||
|
||
self.scaler.fit(X[self.columns]) | ||
|
||
return self | ||
|
||
def transform(self, X: pd.DataFrame) -> pd.DataFrame: | ||
|
@@ -551,8 +526,6 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame: | |
""" | ||
X = super().transform(X) | ||
|
||
X = self.check_numeric_columns(X) | ||
|
||
X[self.columns] = self.scaler.transform(X[self.columns]) | ||
|
||
return X | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
B: we now set up our test classes to inherit a set of generic tests from other test modules. See the guidance in the issue #275 #268 for an example in this numeric module. You need to set up the test classes for init, fit and transform to inherit from a parent test class and add
@classmethod def setup_class(cls): cls.transformer_name = "ScalingTransformer"
to each so that the inherited tests can function (they look up initialised transformers and minimal dataframes from conftest.py using this)
After doing this check_numeric will be covered by tests in BaseNumericTransformerFitTests and BaseNumericTransformerTransformTests so these tests can be deleted form here once those tests have been inherited.
You can also delete any implementation tests so we are just testing functionality. There should be far fewer tests in this module when you are done!