Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated tests for ColumnDtypeSetter (also updated name to fit convent… #219

Merged
merged 3 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ Changed
- Update DataFrameMethodTransformer tests to have inheritable init class that can be used by othe test files.
- Moved BaseTransformer, DataFrameMethodTransformer, BaseMappingTransformer, BaseMappingTransformerMixin, CrossColumnMappingTransformer and Mapping Transformer over to the new testing framework.
- Refactored MappingTransformer by removing redundant init method.
- Updated tests for
- Refactored tests for ColumnDtypeSetter, and renamed (from SetColumnDtype)
- Refactored tests for SetValueTransformer
- Refactored ArbitraryImputer by removing redundant fillna call in transform method. This should increase tubular's efficiency and maintainability.
- Refactored ArbitraryImputer and BaseImputer tests in new format.
- Refactored MedianImputer tests in new format.
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def minimal_attribute_dict():
"value": 1,
"columns": ["a"],
},
"SetColumnDtype": {
"ColumnDtypeSetter": {
"columns": ["a"],
"dtype": str,
},
Expand Down
93 changes: 39 additions & 54 deletions tests/misc/test_SetColumnDtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,46 @@
import pytest
import test_aide as ta

import tests.test_data as d
import tubular
from tubular.misc import SetColumnDtype
from tests.base_tests import (
ColumnStrListInitTests,
GenericFitTests,
GenericTransformTests,
OtherBaseBehaviourTests,
)
from tubular.misc import ColumnDtypeSetter


class TestSetColumnDtypeInit:
"""Tests for SetColumnDtype custom transformer."""
class TestInit(ColumnStrListInitTests):
"""Generic tests for ColumnDtypeSetter.init()."""

def test_tubular_base_transformer_super_init_called(self, mocker):
"""Test that init calls tubular BaseTransformer.init."""
expected_call_args = {
0: {
"args": (["a"],),
"kwargs": {},
},
}
with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"__init__",
expected_call_args,
):
SetColumnDtype(columns=["a"], dtype=float)

def test_dtype_attribute_set(self):
"""Test that the value passed in the value arg is set as an attribute of the same name."""
x = SetColumnDtype(columns=["a"], dtype=str)

assert x.dtype == str, "unexpected value set to dtype atttribute"
@classmethod
def setup_class(cls):
cls.transformer_name = "ColumnDtypeSetter"

@pytest.mark.parametrize(
"invalid_dtype",
["STRING", "misc_invalid", "np.int", 0],
)
def test_invalid_dtype_error(self, invalid_dtype):
msg = f"SetColumnDtype: data type '{invalid_dtype}' not understood as a valid dtype"
msg = f"ColumnDtypeSetter: data type '{invalid_dtype}' not understood as a valid dtype"
with pytest.raises(TypeError, match=msg):
SetColumnDtype(columns=["a"], dtype=invalid_dtype)


class TestSetColumnDtypeTransform:
def test_transform_arguments(self):
"""Test that transform has expected arguments."""
ta.functions.test_function_arguments(
func=SetColumnDtype.transform,
expected_arguments=[
"self",
"X",
],
)
ColumnDtypeSetter(columns=["a"], dtype=invalid_dtype)


def test_super_transform_called(self, mocker):
"""Test that BaseTransformer.transform called."""
df = d.create_df_3()
class TestFit(GenericFitTests):
"""Generic tests for ColumnDtypeSetter.fit()"""

x = SetColumnDtype(columns=["a"], dtype=float)
@classmethod
def setup_class(cls):
cls.transformer_name = "ColumnDtypeSetter"

expected_call_args = {0: {"args": (d.create_df_3(),), "kwargs": {}}}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"transform",
expected_call_args,
return_value=d.create_df_3(),
):
x.transform(df)
class TestTransform(GenericTransformTests):
"""Tests for ColumnDtypeSetter.transform."""

@classmethod
def setup_class(cls):
cls.transformer_name = "ColumnDtypeSetter"

def base_df():
"""Input dataframe from test_expected_output."""
Expand Down Expand Up @@ -106,7 +79,7 @@ def test_expected_output(self, df, expected, dtype):
df["c"] = df["c"].astype(int)
df["d"] = df["d"].astype(str)

x = SetColumnDtype(columns=["a", "b", "c", "d"], dtype=dtype)
x = ColumnDtypeSetter(columns=["a", "b", "c", "d"], dtype=dtype)

df_transformed = x.transform(df)

Expand All @@ -115,3 +88,15 @@ def test_expected_output(self, df, expected, dtype):
actual=df_transformed,
msg="Check values correctly converted to float",
)


class TestOtherBaseBehaviour(OtherBaseBehaviourTests):
"""
Class to run tests for ColumnDtypeSetter behaviour outside the three standard methods.

May need to overwite specific tests in this class if the tested transformer modifies this behaviour.
"""

@classmethod
def setup_class(cls):
cls.transformer_name = "ColumnDtypeSetter"
11 changes: 8 additions & 3 deletions tubular/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
return X


class SetColumnDtype(BaseTransformer):
class ColumnDtypeSetter(BaseTransformer):
"""Transformer to set transform columns in a dataframe to a dtype.

Parameters
Expand All @@ -68,8 +68,13 @@ class SetColumnDtype(BaseTransformer):
e.g. float or 'float'
"""

def __init__(self, columns: str | list[str], dtype: type | str) -> None:
super().__init__(columns)
def __init__(
self,
columns: str | list[str],
dtype: type | str,
**kwargs: dict[str, bool],
) -> None:
super().__init__(columns, **kwargs)

self.__validate_dtype(dtype)

Expand Down
Loading