Skip to content

Commit

Permalink
Fix a bug when target class is a string, but cat encoing via ml algor…
Browse files Browse the repository at this point in the history
…ithm is chosen
  • Loading branch information
ThomasMeissnerDS committed Aug 1, 2024
1 parent 3af3ed5 commit d7f161d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions bluecast/blueprints/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,9 @@ def fit(self, df: pd.DataFrame, target_col: str) -> None:
)
x_test = self.cat_encoder.transform_target_encode_multiclass(x_test.copy())
elif self.conf_training.cat_encoding_via_ml_algorithm:
x_train[self.cat_columns] = x_train[self.cat_columns].astype("category")
x_test[self.cat_columns] = x_test[self.cat_columns].astype("category")
cat_cols = [col for col in self.cat_columns if col != self.target_column]
x_train[cat_cols] = x_train[cat_cols].astype("category")
x_test[cat_cols] = x_test[cat_cols].astype("category")

Check warning on line 402 in bluecast/blueprints/cast.py

View check run for this annotation

Codecov / codecov/patch

bluecast/blueprints/cast.py#L400-L402

Added lines #L400 - L402 were not covered by tests

if self.custom_last_mile_computation:
x_train, y_train = self.custom_last_mile_computation.fit_transform(
Expand Down Expand Up @@ -587,7 +588,8 @@ def transform_new_data(self, df: pd.DataFrame) -> pd.DataFrame:
):
df = self.cat_encoder.transform_target_encode_multiclass(df.copy())
elif self.conf_training.cat_encoding_via_ml_algorithm:
df[self.cat_columns] = df[self.cat_columns].astype("category")
cat_cols = [col for col in self.cat_columns if col != self.target_column]
df[cat_cols] = df[cat_cols].astype("category")

Check warning on line 592 in bluecast/blueprints/cast.py

View check run for this annotation

Codecov / codecov/patch

bluecast/blueprints/cast.py#L591-L592

Added lines #L591 - L592 were not covered by tests

if self.custom_last_mile_computation:
df, _ = self.custom_last_mile_computation.transform(
Expand Down
8 changes: 5 additions & 3 deletions bluecast/blueprints/cast_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,9 @@ def fit(self, df: pd.DataFrame, target_col: str) -> None:
x_test.copy()
)
elif self.conf_training.cat_encoding_via_ml_algorithm:
x_train[self.cat_columns] = x_train[self.cat_columns].astype("category")
x_test[self.cat_columns] = x_test[self.cat_columns].astype("category")
cat_cols = [col for col in self.cat_columns if col != self.target_column]
x_train[cat_cols] = x_train[cat_cols].astype("category")
x_test[cat_cols] = x_test[cat_cols].astype("category")

Check warning on line 380 in bluecast/blueprints/cast_regression.py

View check run for this annotation

Codecov / codecov/patch

bluecast/blueprints/cast_regression.py#L378-L380

Added lines #L378 - L380 were not covered by tests

if self.custom_last_mile_computation:
x_train, y_train = self.custom_last_mile_computation.fit_transform(
Expand Down Expand Up @@ -544,7 +545,8 @@ def transform_new_data(self, df: pd.DataFrame) -> pd.DataFrame:
):
df = self.cat_encoder.transform_target_encode_binary_class(df.copy())
elif self.conf_training.cat_encoding_via_ml_algorithm:
df[self.cat_columns] = df[self.cat_columns].astype("category")
cat_cols = [col for col in self.cat_columns if col != self.target_column]
df[cat_cols] = df[cat_cols].astype("category")

Check warning on line 549 in bluecast/blueprints/cast_regression.py

View check run for this annotation

Codecov / codecov/patch

bluecast/blueprints/cast_regression.py#L548-L549

Added lines #L548 - L549 were not covered by tests

if self.custom_last_mile_computation:
df, _ = self.custom_last_mile_computation.transform(
Expand Down

0 comments on commit d7f161d

Please sign in to comment.