-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f002de6
commit bb21af1
Showing
10 changed files
with
197 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" | ||
Infrequent categories may cause overfitting. | ||
This module groups infrequent categories into a common group to reduce the risk of overfitting | ||
""" | ||
|
||
import logging | ||
from typing import Dict, List, Union | ||
|
||
import pandas as pd | ||
|
||
|
||
class InFrequentCategoryEncoder: | ||
"""Group infrequent categories into common group.""" | ||
|
||
def __init__( | ||
self, | ||
cat_columns: List[Union[str, float, int]], | ||
target_col: Union[str, float, int], | ||
infrequent_threshold: int = 5, | ||
): | ||
self.frequencies: Dict[Union[str, float, int], pd.Series] = {} | ||
self.prediction_mode: bool = False | ||
self.cat_columns = cat_columns | ||
self.target_col = target_col | ||
self.infrequent_threshold = infrequent_threshold | ||
|
||
def fit_transform(self, x: pd.DataFrame, y: pd.Series) -> pd.DataFrame: | ||
"""Find infrequent categories and transform column.""" | ||
logging.info("Start fitting binary target encoder.") | ||
if self.target_col in self.cat_columns: | ||
self.cat_columns.remove(self.target_col) | ||
|
||
for col in self.cat_columns: | ||
self.frequencies[col] = x[col].value_counts() | ||
x[col] = x[col].mask( | ||
x[col].map(self.frequencies[col], na_action="ignore") | ||
< self.infrequent_threshold, | ||
"rare categories", | ||
) | ||
return x.copy() # copy against high fragmentation | ||
|
||
def transform(self, x: pd.DataFrame) -> pd.DataFrame: | ||
"""Transform categories based on already explored frequencies.""" | ||
logging.info("Start transforming categories with binary target encoder.") | ||
for col in self.cat_columns: | ||
x[col] = x[col].mask( | ||
x[col].map(self.frequencies[col], na_action="ignore") | ||
< self.infrequent_threshold, | ||
"rare categories", | ||
) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from bluecast.preprocessing.infrequent_categories import InFrequentCategoryEncoder | ||
|
||
|
||
@pytest.fixture | ||
def sample_data(): | ||
data = { | ||
"cat1": ["a", "b", "b", "c", "c", "c", "d", "e"], | ||
"cat2": ["w", "w", "x", "y", "y", "z", "z", "z"], | ||
"target": [0, 1, 0, 1, 0, 1, 0, 1], | ||
} | ||
df = pd.DataFrame(data) | ||
return df | ||
|
||
|
||
def test_fit_transform(sample_data): | ||
encoder = InFrequentCategoryEncoder( | ||
cat_columns=["cat1", "cat2"], target_col="target", infrequent_threshold=2 | ||
) | ||
transformed_df = encoder.fit_transform( | ||
sample_data.drop(columns="target"), sample_data["target"] | ||
) | ||
|
||
assert transformed_df["cat1"].tolist() == [ | ||
"rare categories", | ||
"b", | ||
"b", | ||
"c", | ||
"c", | ||
"c", | ||
"rare categories", | ||
"rare categories", | ||
] | ||
assert transformed_df["cat2"].tolist() == [ | ||
"w", | ||
"w", | ||
"rare categories", | ||
"y", | ||
"y", | ||
"z", | ||
"z", | ||
"z", | ||
] | ||
|
||
|
||
def test_transform(sample_data): | ||
encoder = InFrequentCategoryEncoder( | ||
cat_columns=["cat1", "cat2"], target_col="target", infrequent_threshold=2 | ||
) | ||
encoder.fit_transform(sample_data.drop(columns="target"), sample_data["target"]) | ||
|
||
new_data = pd.DataFrame( | ||
{"cat1": ["a", "b", "c", "d", "f"], "cat2": ["w", "x", "y", "z", "a"]} | ||
) | ||
|
||
transformed_new_data = encoder.transform(new_data) | ||
|
||
assert transformed_new_data["cat1"].tolist() == [ | ||
"rare categories", | ||
"b", | ||
"c", | ||
"rare categories", | ||
"f", | ||
] | ||
assert transformed_new_data["cat2"].tolist() == [ | ||
"w", | ||
"rare categories", | ||
"y", | ||
"z", | ||
"a", | ||
] | ||
|
||
|
||
def test_no_infrequent_categories(sample_data): | ||
encoder = InFrequentCategoryEncoder( | ||
cat_columns=["cat1", "cat2"], target_col="target", infrequent_threshold=1 | ||
) | ||
transformed_df = encoder.fit_transform( | ||
sample_data.drop(columns="target"), sample_data["target"] | ||
) | ||
|
||
assert transformed_df["cat1"].tolist() == sample_data["cat1"].tolist() | ||
assert transformed_df["cat2"].tolist() == sample_data["cat2"].tolist() | ||
|
||
|
||
def test_all_infrequent_categories(sample_data): | ||
encoder = InFrequentCategoryEncoder( | ||
cat_columns=["cat1", "cat2"], target_col="target", infrequent_threshold=10 | ||
) | ||
transformed_df = encoder.fit_transform( | ||
sample_data.drop(columns="target"), sample_data["target"] | ||
) | ||
|
||
assert all(val == "rare categories" for val in transformed_df["cat1"].tolist()) | ||
assert all(val == "rare categories" for val in transformed_df["cat2"].tolist()) |
Binary file not shown.
Binary file renamed
BIN
+148 KB
dist/bluecast-1.5.0-py3-none-any.whl → dist/bluecast-1.5.1-py3-none-any.whl
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "bluecast" | ||
version = "1.5.0" | ||
version = "1.5.1" | ||
description = "A lightweight and fast automl framework" | ||
authors = ["Thomas Meißner <[email protected]>"] | ||
license = "GPL-3.0-only" | ||
|