diff --git a/python/ray/data/preprocessors/encoder.py b/python/ray/data/preprocessors/encoder.py index f9a0990e3989..dc3a156f77c5 100644 --- a/python/ray/data/preprocessors/encoder.py +++ b/python/ray/data/preprocessors/encoder.py @@ -2,6 +2,7 @@ from typing import List, Dict, Optional from collections import Counter, OrderedDict +import numpy as np import pandas as pd import pandas.api.types @@ -324,7 +325,9 @@ def _transform_pandas(self, df: pd.DataFrame): _validate_df(df, *self.columns) def encode_list(element: list, *, name: str): - if not isinstance(element, list): + if isinstance(element, np.ndarray): + element = element.tolist() + elif not isinstance(element, list): element = [element] stats = self.stats_[f"unique_values({name})"] counter = Counter(element) @@ -509,6 +512,7 @@ def _get_unique_value_indices( encode_lists: bool = True, ) -> Dict[str, Dict[str, int]]: """If drop_na_values is True, will silently drop NA values.""" + if max_categories is None: max_categories = {} for column in max_categories: @@ -601,5 +605,5 @@ def _is_series_composed_of_lists(series: pd.Series) -> bool: (element for element in series if element is not None), None ) return pandas.api.types.is_object_dtype(series.dtype) and isinstance( - first_not_none_element, list + first_not_none_element, (list, np.ndarray) ) diff --git a/python/ray/data/tests/preprocessors/test_encoder.py b/python/ray/data/tests/preprocessors/test_encoder.py index eacc586612c8..a9e1d374b1f4 100644 --- a/python/ray/data/tests/preprocessors/test_encoder.py +++ b/python/ray/data/tests/preprocessors/test_encoder.py @@ -416,6 +416,14 @@ def test_multi_hot_encoder(): null_encoder.transform_batch(null_df) null_encoder.transform_batch(nonnull_df) + # Verify that `fit` and `transform` work with ndarrays. + df = pd.DataFrame({"column": [np.array(["A"]), np.array(["A", "B"])]}) + ds = ray.data.from_pandas(df) + encoder = MultiHotEncoder(["column"]) + transformed = encoder.fit_transform(ds) + encodings = [record["column"] for record in transformed.take_all()] + assert encodings == [[1, 0], [1, 1]] + def test_multi_hot_encoder_with_max_categories(): """Tests basic MultiHotEncoder functionality with limit."""