Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for one hot split. #5949

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/xgboost/feature_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class FeatureMap {
if (!strcmp("q", tname)) return kQuantitive;
if (!strcmp("int", tname)) return kInteger;
if (!strcmp("float", tname)) return kFloat;
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
if (!strcmp("categorical", tname)) return kInteger;
LOG(FATAL) << "unknown feature type, use i for indicator, q for quantity "
"and categorical for categorical split.";
return kIndicator;
}
/*! \brief name of the feature */
Expand Down
17 changes: 15 additions & 2 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
silent=False,
feature_names=None,
feature_types=None,
nthread=None):
nthread=None,
enable_categorical=False):
"""Parameters
----------
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
Expand Down Expand Up @@ -419,6 +420,17 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
Number of threads to use for loading data when parallelization is
applicable. If -1, uses maximum threads available on the system.

enable_categorical: boolean, optional

.. versionadded:: 1.3.0

Experimental support of specializing for categorical features. Do
not set to True unless you are interested in development.
Currently it's only available for `gpu_hist` tree method with 1 vs
rest (one hot) categorical split. Also, JSON serialization format,
`enable_experimental_json_serialization`, `gpu_predictor` and
pandas input are required.

"""
if isinstance(data, list):
raise TypeError('Input data can not be a list.')
Expand All @@ -437,7 +449,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
data, missing=self.missing,
threads=self.nthread,
feature_names=feature_names,
feature_types=feature_types)
feature_types=feature_types,
enable_categorical=enable_categorical)
assert handle is not None
self.handle = handle

Expand Down
36 changes: 25 additions & 11 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,24 @@ def _is_modin_df(data):
}


def _transform_pandas_df(data, feature_names=None, feature_types=None,
def _transform_pandas_df(data, enable_categorical,
feature_names=None, feature_types=None,
meta=None, meta_type=None):
from pandas import MultiIndex, Int64Index
from pandas.api.types import is_sparse
from pandas.api.types import is_sparse, is_categorical

data_dtypes = data.dtypes
if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype)
if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) or
(is_categorical(dtype) and enable_categorical)
for dtype in data_dtypes):
bad_fields = [
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
if dtype.name not in _pandas_dtype_mapper
]

msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """
msg = """DataFrame.dtypes for data must be int, float, bool or categorical. When
categorical type is supplied, DMatrix parameter
`enable_categorical` must be set to `True`."""
raise ValueError(msg + ', '.join(bad_fields))

if feature_names is None and meta is None:
Expand All @@ -207,6 +211,8 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None,
if is_sparse(dtype):
feature_types.append(_pandas_dtype_mapper[
dtype.subtype.name])
elif is_categorical(dtype) and enable_categorical:
feature_types.append('categorical')
else:
feature_types.append(_pandas_dtype_mapper[dtype.name])

Expand All @@ -215,15 +221,21 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None,
'DataFrame for {meta} cannot have multiple columns'.format(
meta=meta))

dtype = meta_type if meta_type else np.float32
data = np.ascontiguousarray(data.values, dtype=dtype)
dtype = meta_type if meta_type else np.float32
try:
data = data.values.astype(dtype)
except ValueError as e:
raise ValueError('Data must be convertable to float, even ' +
'for categorical data.') from e

return data, feature_names, feature_types


def _from_pandas_df(data, missing, nthread, feature_names, feature_types):
def _from_pandas_df(data, enable_categorical, missing, nthread,
feature_names, feature_types):
data, feature_names, feature_types = _transform_pandas_df(
data, feature_names, feature_types)
data, enable_categorical, feature_names, feature_types)
return _from_numpy_array(data, missing, nthread, feature_names,
feature_types)

Expand Down Expand Up @@ -498,7 +510,8 @@ def _has_array_protocol(data):


def dispatch_data_backend(data, missing, threads,
feature_names, feature_types):
feature_names, feature_types,
enable_categorical=False):
'''Dispatch data for DMatrix.'''
if _is_scipy_csr(data):
return _from_scipy_csr(data, missing, feature_names, feature_types)
Expand All @@ -514,7 +527,7 @@ def dispatch_data_backend(data, missing, threads,
if _is_tuple(data):
return _from_tuple(data, missing, feature_names, feature_types)
if _is_pandas_df(data):
return _from_pandas_df(data, missing, threads,
return _from_pandas_df(data, enable_categorical, missing, threads,
feature_names, feature_types)
if _is_pandas_series(data):
return _from_pandas_series(data, missing, threads, feature_names,
Expand Down Expand Up @@ -644,7 +657,8 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
_meta_from_numpy(data, name, dtype, handle)
return
if _is_pandas_df(data):
data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype)
data, _, _ = _transform_pandas_df(data, False, meta=name,
meta_type=dtype)
_meta_from_numpy(data, name, dtype, handle)
return
if _is_pandas_series(data):
Expand Down
16 changes: 16 additions & 0 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ struct AtomicDispatcher<sizeof(uint64_t)> {
using Type = unsigned long long; // NOLINT
static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits.");
};

template <>
struct AtomicDispatcher<sizeof(uint8_t)> {
using Type = uint8_t; // NOLINT
};
} // namespace detail
} // namespace dh

Expand Down Expand Up @@ -536,6 +541,17 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
cudaMemcpyDeviceToHost));
}

template <class HContainer, class DContainer>
void CopyToD(HContainer const &h, DContainer *d) {
d->resize(h.size());
using HVT = std::remove_cv_t<typename HContainer::value_type>;
using DVT = std::remove_cv_t<typename DContainer::value_type>;
static_assert(std::is_same<HVT, DVT>::value,
"Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT),
cudaMemcpyHostToDevice));
}

// Keep track of pinned memory allocation
struct PinnedMemory {
void *temp_storage{nullptr};
Expand Down
2 changes: 1 addition & 1 deletion src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());

Expand Down
2 changes: 2 additions & 0 deletions src/common/host_device_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <cstdint>
#include <memory>
#include <utility>
#include "xgboost/tree_model.h"
#include "xgboost/host_device_vector.h"

namespace xgboost {
Expand Down Expand Up @@ -176,6 +177,7 @@ template class HostDeviceVector<FeatureType>;
template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Segment>;

#if defined(__APPLE__)
/*
Expand Down
1 change: 1 addition & 0 deletions src/common/host_device_vector.cu
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Node>;
template class HostDeviceVector<RegTree::Segment>;
template class HostDeviceVector<RTreeNodeStat>;

#if defined(__APPLE__)
Expand Down
2 changes: 1 addition & 1 deletion src/data/adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ namespace data {
/** \brief An adapter can return this value for number of rows or columns
* indicating that this value is currently unknown and should be inferred while
* passing over the data. */
constexpr size_t kAdapterUnknownSize = std::numeric_limits<size_t >::max();
constexpr size_t kAdapterUnknownSize = std::numeric_limits<bst_row_t>::max();

struct COOTuple {
COOTuple() = default;
Expand Down
3 changes: 3 additions & 0 deletions src/data/iterative_device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
}));
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
row_counts.end());

this->Info().feature_types.Resize(proxy->Info().feature_types.Size());
this->Info().feature_types.Copy(proxy->Info().feature_types);
batches++;
}
iter.Reset();
Expand Down
Loading