Skip to content

Commit

Permalink
FIX Fixes encoders for string dtypes (scikit-learn#15763)
Browse files Browse the repository at this point in the history
Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
3 people committed Oct 28, 2020
1 parent 954b9bc commit 8e22443
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 2 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,10 @@ Changelog
encoded as all zeros. :pr:`14982` by
:user:`Kevin Winata <kwinata>`.

- |Fix| Fix incorrect encoding when using unicode string dtypes in
:class:`preprocessing.OneHotEncoder` and
:class:`preprocessing.OrdinalEncoder`. :pr:`15763` by `Thomas Fan`_.

:mod:`sklearn.svm`
..................

Expand Down
2 changes: 1 addition & 1 deletion sklearn/preprocessing/_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _fit(self, X, handle_unknown='error', force_all_finite=True):
cats = _unique(Xi)
else:
cats = np.array(self.categories[i], dtype=Xi.dtype)
if Xi.dtype != object:
if Xi.dtype.kind not in 'OU':
sorted_cats = np.sort(cats)
error_msg = ("Unsorted categories are not "
"supported for numerical categories")
Expand Down
27 changes: 27 additions & 0 deletions sklearn/preprocessing/tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,33 @@ def test_encoders_has_categorical_tags(Encoder):
assert 'categorical' in Encoder()._get_tags()['X_types']


@pytest.mark.parametrize('input_dtype', ['O', 'U'])
@pytest.mark.parametrize('category_dtype', ['O', 'U'])
@pytest.mark.parametrize('array_type', ['list', 'array', 'dataframe'])
def test_encoders_unicode_categories(input_dtype, category_dtype, array_type):
"""Check that encoding work with string and object dtypes.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/15616
https://github.com/scikit-learn/scikit-learn/issues/15726
"""

X = np.array([['b'], ['a']], dtype=input_dtype)
categories = [np.array(['b', 'a'], dtype=category_dtype)]
ohe = OneHotEncoder(categories=categories, sparse=False).fit(X)

X_test = _convert_container([['a'], ['a'], ['b'], ['a']], array_type)
X_trans = ohe.transform(X_test)

expected = np.array([[0, 1], [0, 1], [1, 0], [0, 1]])
assert_allclose(X_trans, expected)

oe = OrdinalEncoder(categories=categories).fit(X)
X_trans = oe.transform(X_test)

expected = np.array([[1], [1], [0], [1]])
assert_array_equal(X_trans, expected)


@pytest.mark.parametrize("missing_value", [np.nan, None])
def test_ohe_missing_values_get_feature_names(missing_value):
# encoder with missing values with object dtypes
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _encode(values, *, uniques, check_unknown=True):
encoded : ndarray
Encoded values
"""
if values.dtype == object:
if values.dtype.kind in 'OU':
try:
return _map_to_integer(values, uniques)
except KeyError as e:
Expand Down

0 comments on commit 8e22443

Please sign in to comment.