Skip to content

Commit

Permalink
Support pandas 2.1.0. (#9557)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Sep 11, 2023
1 parent 66a0832 commit 9027686
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
66 changes: 43 additions & 23 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ def pandas_feature_info(
) -> Tuple[Optional[FeatureNames], Optional[FeatureTypes]]:
"""Handle feature info for pandas dataframe."""
import pandas as pd
from pandas.api.types import is_categorical_dtype, is_sparse

# handle feature names
if feature_names is None and meta is None:
Expand All @@ -332,10 +331,10 @@ def pandas_feature_info(
if feature_types is None and meta is None:
feature_types = []
for dtype in data.dtypes:
if is_sparse(dtype):
if is_pd_sparse_dtype(dtype):
feature_types.append(_pandas_dtype_mapper[dtype.subtype.name])
elif (
is_categorical_dtype(dtype) or is_pa_ext_categorical_dtype(dtype)
is_pd_cat_dtype(dtype) or is_pa_ext_categorical_dtype(dtype)
) and enable_categorical:
feature_types.append(CAT_T)
else:
Expand All @@ -345,18 +344,13 @@ def pandas_feature_info(

def is_nullable_dtype(dtype: PandasDType) -> bool:
"""Whether dtype is a pandas nullable type."""
from pandas.api.types import (
is_bool_dtype,
is_categorical_dtype,
is_float_dtype,
is_integer_dtype,
)
from pandas.api.types import is_bool_dtype, is_float_dtype, is_integer_dtype

is_int = is_integer_dtype(dtype) and dtype.name in pandas_nullable_mapper
# np.bool has alias `bool`, while pd.BooleanDtype has `boolean`.
is_bool = is_bool_dtype(dtype) and dtype.name == "boolean"
is_float = is_float_dtype(dtype) and dtype.name in pandas_nullable_mapper
return is_int or is_bool or is_float or is_categorical_dtype(dtype)
return is_int or is_bool or is_float or is_pd_cat_dtype(dtype)


def is_pa_ext_dtype(dtype: Any) -> bool:
Expand All @@ -371,17 +365,48 @@ def is_pa_ext_categorical_dtype(dtype: Any) -> bool:
)


def is_pd_cat_dtype(dtype: PandasDType) -> bool:
"""Wrapper for testing pandas category type."""
import pandas as pd

if hasattr(pd.util, "version") and hasattr(pd.util.version, "Version"):
Version = pd.util.version.Version
if Version(pd.__version__) >= Version("2.1.0"):
from pandas import CategoricalDtype

return isinstance(dtype, CategoricalDtype)

from pandas.api.types import is_categorical_dtype

return is_categorical_dtype(dtype)


def is_pd_sparse_dtype(dtype: PandasDType) -> bool:
"""Wrapper for testing pandas sparse type."""
import pandas as pd

if hasattr(pd.util, "version") and hasattr(pd.util.version, "Version"):
Version = pd.util.version.Version
if Version(pd.__version__) >= Version("2.1.0"):
from pandas import SparseDtype

return isinstance(dtype, SparseDtype)

from pandas.api.types import is_sparse

return is_sparse(dtype)


def pandas_cat_null(data: DataFrame) -> DataFrame:
"""Handle categorical dtype and nullable extension types from pandas."""
import pandas as pd
from pandas.api.types import is_categorical_dtype

# handle category codes and nullable.
cat_columns = []
nul_columns = []
# avoid an unnecessary conversion if possible
for col, dtype in zip(data.columns, data.dtypes):
if is_categorical_dtype(dtype):
if is_pd_cat_dtype(dtype):
cat_columns.append(col)
elif is_pa_ext_categorical_dtype(dtype):
raise ValueError(
Expand All @@ -398,7 +423,7 @@ def pandas_cat_null(data: DataFrame) -> DataFrame:
transformed = data

def cat_codes(ser: pd.Series) -> pd.Series:
if is_categorical_dtype(ser.dtype):
if is_pd_cat_dtype(ser.dtype):
return ser.cat.codes
assert is_pa_ext_categorical_dtype(ser.dtype)
# Not yet supported, the index is not ordered for some reason. Alternately:
Expand Down Expand Up @@ -454,14 +479,12 @@ def _transform_pandas_df(
meta: Optional[str] = None,
meta_type: Optional[NumpyDType] = None,
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
from pandas.api.types import is_categorical_dtype, is_sparse

pyarrow_extension = False
for dtype in data.dtypes:
if not (
(dtype.name in _pandas_dtype_mapper)
or is_sparse(dtype)
or (is_categorical_dtype(dtype) and enable_categorical)
or is_pd_sparse_dtype(dtype)
or (is_pd_cat_dtype(dtype) and enable_categorical)
or is_pa_ext_dtype(dtype)
):
_invalid_dataframe_dtype(data)
Expand Down Expand Up @@ -515,9 +538,8 @@ def _meta_from_pandas_series(
) -> None:
"""Help transform pandas series for meta data like labels"""
data = data.values.astype("float")
from pandas.api.types import is_sparse

if is_sparse(data):
if is_pd_sparse_dtype(getattr(data, "dtype", data)):
data = data.to_dense() # type: ignore
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
_meta_from_numpy(data, name, dtype, handle)
Expand All @@ -539,13 +561,11 @@ def _from_pandas_series(
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
) -> DispatchedDataBackendReturnType:
from pandas.api.types import is_categorical_dtype

if (data.dtype.name not in _pandas_dtype_mapper) and not (
is_categorical_dtype(data.dtype) and enable_categorical
is_pd_cat_dtype(data.dtype) and enable_categorical
):
_invalid_dataframe_dtype(data)
if enable_categorical and is_categorical_dtype(data.dtype):
if enable_categorical and is_pd_cat_dtype(data.dtype):
data = data.cat.codes
return _from_numpy_array(
data.values.reshape(data.shape[0], 1).astype("float"),
Expand Down
8 changes: 4 additions & 4 deletions tests/python/test_with_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_pandas_weight(self):
y = np.random.randn(kRows)
w = np.random.uniform(size=kRows).astype(np.float32)
w_pd = pd.DataFrame(w)
data = xgb.DMatrix(X, y, w_pd)
data = xgb.DMatrix(X, y, weight=w_pd)

assert data.num_row() == kRows
assert data.num_col() == kCols
Expand Down Expand Up @@ -301,14 +301,14 @@ def test_cv_as_pandas(self):

@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
def test_nullable_type(self, DMatrixT) -> None:
from pandas.api.types import is_categorical_dtype
from xgboost.data import is_pd_cat_dtype

for orig, df in pd_dtypes():
if hasattr(df.dtypes, "__iter__"):
enable_categorical = any(is_categorical_dtype for dtype in df.dtypes)
enable_categorical = any(is_pd_cat_dtype(dtype) for dtype in df.dtypes)
else:
# series
enable_categorical = is_categorical_dtype(df.dtype)
enable_categorical = is_pd_cat_dtype(df.dtype)

f0_orig = orig[orig.columns[0]] if isinstance(orig, pd.DataFrame) else orig
f0 = df[df.columns[0]] if isinstance(df, pd.DataFrame) else df
Expand Down

0 comments on commit 9027686

Please sign in to comment.