Skip to content

Commit

Permalink
fix: pandas and arrow to_dummies with nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Sep 21, 2024
1 parent dc54e7a commit 0061e9b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 13 deletions.
23 changes: 19 additions & 4 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,16 +655,31 @@ def to_dummies(
from narwhals._arrow.dataframe import ArrowDataFrame

series = self._native_series
da = series.dictionary_encode().combine_chunks()
name = self._name
da = series.dictionary_encode(null_encoding="encode").combine_chunks()

columns = np.zeros((len(da.dictionary), len(da)), np.uint8)
columns[da.indices, np.arange(len(da))] = 1
names = [f"{self._name}{separator}{v}" for v in da.dictionary]
null_col_pa, null_col_pl = f"{name}{separator}None", f"{name}{separator}null"
cols = [
{null_col_pa: null_col_pl}.get(
f"{name}{separator}{v}", f"{name}{separator}{v}"
)
for v in da.dictionary
]

output_order = (
[
null_col_pl,
*sorted([c for c in cols if c != null_col_pl])[int(drop_first) :],
]
if null_col_pl in cols
else sorted(cols)[int(drop_first) :]
)
return ArrowDataFrame(
pa.Table.from_arrays(columns, names=names),
pa.Table.from_arrays(columns, names=cols),
backend_version=self._backend_version,
).select(*sorted(names)[int(drop_first) :])
).select(*output_order)

def quantile(
self: Self,
Expand Down
26 changes: 20 additions & 6 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,13 +630,27 @@ def to_dummies(
plx = self.__native_namespace__()
series = self._native_series
name = str(self._name) if self._name else ""

null_col_pl = f"{name}{separator}null"

result = plx.get_dummies(
series,
prefix=name,
prefix_sep=separator,
drop_first=drop_first,
dummy_na=True, # Adds a null column at the end, even if there aren't any.
dtype=int,
)

*cols, null_col_pd = result.columns
# if there are no nulls, we drop such column, as polars would not add it
drop_null_col = result[null_col_pd].sum() == 0
output_order = [null_col_pd, *cols] if not drop_null_col else cols

return PandasLikeDataFrame(
plx.get_dummies(
series,
prefix=name,
prefix_sep=separator,
drop_first=drop_first,
).astype(int),
result.loc[:, output_order].rename(
columns={null_col_pd: null_col_pl}, errors="ignore"
),
implementation=self._implementation,
backend_version=self._backend_version,
)
Expand Down
31 changes: 28 additions & 3 deletions tests/series_only/to_dummy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import narwhals.stable.v1 as nw
from tests.utils import compare_dicts

data = [1, 2, 3]
data = ["x", "y", "z"]
data_na = ["x", "y", None]


@pytest.mark.parametrize("sep", ["_", "-"])
def test_to_dummies(constructor_eager: Any, sep: str) -> None:
s = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"].alias("a")
result = s.to_dummies(separator=sep)
expected = {f"a{sep}1": [1, 0, 0], f"a{sep}2": [0, 1, 0], f"a{sep}3": [0, 0, 1]}
expected = {f"a{sep}x": [1, 0, 0], f"a{sep}y": [0, 1, 0], f"a{sep}z": [0, 0, 1]}

compare_dicts(result, expected)

Expand All @@ -25,6 +26,30 @@ def test_to_dummies_drop_first(
request.applymarker(pytest.mark.xfail)
s = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"].alias("a")
result = s.to_dummies(drop_first=True, separator=sep)
expected = {f"a{sep}2": [0, 1, 0], f"a{sep}3": [0, 0, 1]}
expected = {f"a{sep}y": [0, 1, 0], f"a{sep}z": [0, 0, 1]}

compare_dicts(result, expected)


@pytest.mark.parametrize("sep", ["_", "-"])
def test_to_dummies_with_nulls(constructor_eager: Any, sep: str) -> None:
if "pandas_nullable_constructor" not in str(constructor_eager):
pytest.skip()
s = nw.from_native(constructor_eager({"a": data_na}), eager_only=True)["a"].alias("a")
result = s.to_dummies(separator=sep)
expected = {f"a{sep}null": [0, 0, 1], f"a{sep}x": [1, 0, 0], f"a{sep}y": [0, 1, 0]}

compare_dicts(result, expected)


@pytest.mark.parametrize("sep", ["_", "-"])
def test_to_dummies_drop_first_na(
request: pytest.FixtureRequest, constructor_eager: Any, sep: str
) -> None:
if "cudf" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)
s = nw.from_native(constructor_eager({"a": data_na}), eager_only=True)["a"].alias("a")
result = s.to_dummies(drop_first=True, separator=sep)
expected = {f"a{sep}null": [0, 0, 1], f"a{sep}y": [0, 1, 0]}

compare_dicts(result, expected)

0 comments on commit 0061e9b

Please sign in to comment.