Skip to content

Commit

Permalink
[python] Faster categorical column names selection (#4787)
Browse files Browse the repository at this point in the history
* Faster categorical column names selection (#1)

* Faster categorical column names selection

Change slow and redundant dataframe query by select_dtypes into a dataframe.dtypes list comprehension

* Update compat with CategoricalDtype

* sort imports

* import CategoricalDtype from pandas.api.types

* add categorical import try/except
  • Loading branch information
Neronuser authored Nov 12, 2021
1 parent 3b6ebd7 commit 6cbb358
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import numpy as np
import scipy.sparse

from .compat import PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_DataFrame, pd_Series
from .compat import (PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_CategoricalDtype, pd_DataFrame,
pd_Series)
from .libpath import find_lib_path

ZERO_THRESHOLD = 1e-35
Expand Down Expand Up @@ -567,7 +568,7 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
raise ValueError('Input data must be 2 dimensional and non empty.')
if feature_name == 'auto' or feature_name is None:
data = data.rename(columns=str)
cat_cols = list(data.select_dtypes(include=['category']).columns)
cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)]
cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered]
if pandas_categorical is None: # train dataset
pandas_categorical = [list(data[col].cat.categories) for col in cat_cols]
Expand Down
10 changes: 10 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from pandas import Series as pd_Series
from pandas import concat
from pandas.api.types import is_sparse as is_dtype_sparse
try:
from pandas import CategoricalDtype as pd_CategoricalDtype
except ImportError:
from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
PANDAS_INSTALLED = True
except ImportError:
PANDAS_INSTALLED = False
Expand All @@ -23,6 +27,12 @@ class pd_DataFrame: # type: ignore
def __init__(self, *args, **kwargs):
pass

class pd_CategoricalDtype:
"""Dummy class for pandas.CategoricalDtype."""

def __init__(self, *args, **kwargs):
pass

concat = None
is_dtype_sparse = None

Expand Down

0 comments on commit 6cbb358

Please sign in to comment.