Skip to content

Commit

Permalink
adjusted changes to match main
Browse files Browse the repository at this point in the history
  • Loading branch information
limlam96 committed Oct 16, 2024
1 parent 104dd07 commit e88c6a5
Showing 1 changed file with 36 additions and 25 deletions.
61 changes: 36 additions & 25 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,15 @@ def test_fit_returns_self_weighted(
minimal_attribute_dict,
uninitialized_transformers,
minimal_dataframe_lookup,
narwhalified_transformers_dict,
):
"""Test fit returns self?."""
df = minimal_dataframe_lookup[self.transformer_name]

# skip polars test if not narwhalified
narwhalified = narwhalified_transformers_dict[self.transformer_name]
if not narwhalified and isinstance(df, pl.DataFrame):
uninitialized_transformer = uninitialized_transformers[self.transformer_name]
if not uninitialized_transformer.polars_compatible and isinstance(
df,
pl.DataFrame,
):
return

df = nw.from_native(df)
Expand All @@ -519,7 +520,7 @@ def test_fit_returns_self_weighted(

df = nw.to_native(df)

transformer = uninitialized_transformers[self.transformer_name](**args)
transformer = uninitialized_transformer(**args)

x_fitted = transformer.fit(df, df["a"])

Expand All @@ -537,14 +538,16 @@ def test_fit_not_changing_data_weighted(
minimal_attribute_dict,
uninitialized_transformers,
minimal_dataframe_lookup,
narwhalified_transformers_dict,
):
"""Test fit does not change X - when weights are used."""
df = minimal_dataframe_lookup[self.transformer_name]
uninitialized_transformer = uninitialized_transformers[self.transformer_name]

# skip polars test if not narwhalified
narwhalified = narwhalified_transformers_dict[self.transformer_name]
if not narwhalified and isinstance(df, pl.DataFrame):
if not uninitialized_transformer.polars_compatible and isinstance(
df,
pl.DataFrame,
):
return

df = nw.from_native(df)
Expand All @@ -564,7 +567,7 @@ def test_fit_not_changing_data_weighted(
args = minimal_attribute_dict[self.transformer_name].copy()
args["weights_column"] = weight_column

transformer = uninitialized_transformers[self.transformer_name](**args)
transformer = uninitialized_transformer(**args)

df = nw.to_native(df)
original_df = nw.to_native(original_df)
Expand Down Expand Up @@ -594,15 +597,17 @@ def test_bad_values_in_weights_error(
minimal_attribute_dict,
uninitialized_transformers,
minimal_dataframe_lookup,
narwhalified_transformers_dict,
):
"""Test that an exception is raised if there are negative/nan/inf values in sample_weight."""

df = minimal_dataframe_lookup[self.transformer_name]
uninitialized_transformer = uninitialized_transformers[self.transformer_name]

# skip polars test if not narwhalified
narwhalified = narwhalified_transformers_dict[self.transformer_name]
if not narwhalified and isinstance(df, pl.DataFrame):
if not uninitialized_transformer.polars_compatible and isinstance(
df,
pl.DataFrame,
):
return

df = nw.from_native(df)
Expand All @@ -622,7 +627,7 @@ def test_bad_values_in_weights_error(

df = nw.to_native(df)

transformer = uninitialized_transformers[self.transformer_name](**args)
transformer = uninitialized_transformer(**args)

with pytest.raises(ValueError, match=expected_message):
transformer.fit(df, df["a"])
Expand Down Expand Up @@ -651,15 +656,17 @@ def test_weight_col_non_numeric(
uninitialized_transformers,
minimal_attribute_dict,
minimal_dataframe_lookup,
narwhalified_transformers_dict,
):
"""Test an error is raised if weight is not numeric."""

df = minimal_dataframe_lookup[self.transformer_name]
uninitialized_transformer = uninitialized_transformers[self.transformer_name]

# skip polars test if not narwhalified
narwhalified = narwhalified_transformers_dict[self.transformer_name]
if not narwhalified and isinstance(df, pl.DataFrame):
if not uninitialized_transformer.polars_compatible and isinstance(
df,
pl.DataFrame,
):
return

df = nw.from_native(df)
Expand All @@ -678,7 +685,7 @@ def test_weight_col_non_numeric(
args = minimal_attribute_dict[self.transformer_name].copy()
args["weights_column"] = weight_column

transformer = uninitialized_transformers[self.transformer_name](**args)
transformer = uninitialized_transformer(**args)
transformer.fit(df, df["a"])

@pytest.mark.parametrize(
Expand All @@ -691,15 +698,17 @@ def test_weight_not_in_X_error(
uninitialized_transformers,
minimal_attribute_dict,
minimal_dataframe_lookup,
narwhalified_transformers_dict,
):
"""Test an error is raised if weight is not in X"""

df = minimal_dataframe_lookup[self.transformer_name]
uninitialized_transformer = uninitialized_transformers[self.transformer_name]

# skip polars test if not narwhalified
narwhalified = narwhalified_transformers_dict[self.transformer_name]
if not narwhalified and isinstance(df, pl.DataFrame):
if not uninitialized_transformer.polars_compatible and isinstance(
df,
pl.DataFrame,
):
return

weight_column = "weight_column"
Expand All @@ -714,7 +723,7 @@ def test_weight_not_in_X_error(
args = minimal_attribute_dict[self.transformer_name].copy()
args["weights_column"] = weight_column

transformer = uninitialized_transformers[self.transformer_name](**args)
transformer = uninitialized_transformer(**args)
transformer.fit(df, df["a"])

@pytest.mark.parametrize(
Expand All @@ -727,7 +736,6 @@ def test_zero_total_weight_error(
minimal_attribute_dict,
uninitialized_transformers,
minimal_dataframe_lookup,
narwhalified_transformers_dict,
):
"""Test that an exception is raised if the total sample weights are 0."""

Expand All @@ -736,17 +744,20 @@ def test_zero_total_weight_error(
args["weights_column"] = weight_column

df = minimal_dataframe_lookup[self.transformer_name]
uninitialized_transformer = uninitialized_transformers[self.transformer_name]

# skip polars test if not narwhalified
narwhalified = narwhalified_transformers_dict[self.transformer_name]
if not narwhalified and isinstance(df, pl.DataFrame):
if not uninitialized_transformer.polars_compatible and isinstance(
df,
pl.DataFrame,
):
return

df = nw.from_native(df)
df = df.with_columns(nw.lit(0).alias("weight_column"))
df = nw.to_native(df)

transformer = uninitialized_transformers[self.transformer_name](**args)
transformer = uninitialized_transformer(**args)
with pytest.raises(
ValueError,
match="total sample weights are not greater than 0",
Expand Down

0 comments on commit e88c6a5

Please sign in to comment.