Skip to content

Commit

Permalink
Initial support for one hot categorical split.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 20, 2020
1 parent 20c95be commit cb77e3a
Show file tree
Hide file tree
Showing 42 changed files with 1,038 additions and 249 deletions.
4 changes: 3 additions & 1 deletion include/xgboost/feature_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class FeatureMap {
if (!strcmp("q", tname)) return kQuantitive;
if (!strcmp("int", tname)) return kInteger;
if (!strcmp("float", tname)) return kFloat;
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
if (!strcmp("categorical", tname)) return kInteger;
LOG(FATAL) << "unknown feature type, use i for indicator, q for quantity "
"and categorical for categorical split.";
return kIndicator;
}
/*! \brief name of the feature */
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/version_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#define XGBOOST_VERSION_CONFIG_H_

#define XGBOOST_VER_MAJOR 1
#define XGBOOST_VER_MINOR 2
#define XGBOOST_VER_MINOR 3
#define XGBOOST_VER_PATCH 0

#endif // XGBOOST_VERSION_CONFIG_H_
17 changes: 15 additions & 2 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
silent=False,
feature_names=None,
feature_types=None,
nthread=None):
nthread=None,
enable_categorical=False):
"""Parameters
----------
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
Expand Down Expand Up @@ -417,6 +418,17 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
Number of threads to use for loading data when parallelization is
applicable. If -1, uses maximum threads available on the system.
enable_categorical: boolean, optional
.. versionadded:: 1.3.0
Experimental support of specializing for categorical features. Do
not set to True unless you are interested in development.
Currently it's only available for `gpu_hist` tree method with 1 vs
rest (one hot) categorical split. Also, JSON serialization format,
`enable_experimental_json_serialization`, `gpu_predictor` and
pandas input are required.
"""
if isinstance(data, list):
raise TypeError('Input data can not be a list.')
Expand All @@ -435,7 +447,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
data, missing=self.missing,
threads=self.nthread,
feature_names=feature_names,
feature_types=feature_types)
feature_types=feature_types,
enable_categorical=enable_categorical)
assert handle is not None
self.handle = handle

Expand Down
35 changes: 24 additions & 11 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,24 @@ def _is_pandas_df(data):
}


def _transform_pandas_df(data, feature_names=None, feature_types=None,
def _transform_pandas_df(data, enable_categorical,
feature_names=None, feature_types=None,
meta=None, meta_type=None):
from pandas import MultiIndex, Int64Index
from pandas.api.types import is_sparse
from pandas.api.types import is_sparse, is_categorical

data_dtypes = data.dtypes
if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype)
if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) or
(is_categorical(dtype) and enable_categorical)
for dtype in data_dtypes):
bad_fields = [
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
if dtype.name not in _pandas_dtype_mapper
]

msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """
msg = """DataFrame.dtypes for data must be int, float, bool or categorical. When
categorical type is supplied, DMatrix parameter
`enable_categorical` must be set to `True`."""
raise ValueError(msg + ', '.join(bad_fields))

if feature_names is None and meta is None:
Expand All @@ -200,6 +204,8 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None,
if is_sparse(dtype):
feature_types.append(_pandas_dtype_mapper[
dtype.subtype.name])
elif is_categorical(dtype) and enable_categorical:
feature_types.append('categorical')
else:
feature_types.append(_pandas_dtype_mapper[dtype.name])

Expand All @@ -209,14 +215,19 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None,
meta=meta))

dtype = meta_type if meta_type else 'float'
data = data.values.astype(dtype)
try:
data = data.values.astype(dtype)
except ValueError as e:
raise ValueError('Data must be convertable to float, even ' +
'for categorical data.') from e

return data, feature_names, feature_types


def _from_pandas_df(data, missing, nthread, feature_names, feature_types):
def _from_pandas_df(data, enable_categorical, missing, nthread,
feature_names, feature_types):
data, feature_names, feature_types = _transform_pandas_df(
data, feature_names, feature_types)
data, enable_categorical, feature_names, feature_types)
return _from_numpy_array(data, missing, nthread, feature_names,
feature_types)

Expand Down Expand Up @@ -484,7 +495,8 @@ def _has_array_protocol(data):


def dispatch_data_backend(data, missing, threads,
feature_names, feature_types):
feature_names, feature_types,
enable_categorical=False):
'''Dispatch data for DMatrix.'''
if _is_scipy_csr(data):
return _from_scipy_csr(data, missing, feature_names, feature_types)
Expand All @@ -500,7 +512,7 @@ def dispatch_data_backend(data, missing, threads,
if _is_tuple(data):
return _from_tuple(data, missing, feature_names, feature_types)
if _is_pandas_df(data):
return _from_pandas_df(data, missing, threads,
return _from_pandas_df(data, enable_categorical, missing, threads,
feature_names, feature_types)
if _is_pandas_series(data):
return _from_pandas_series(data, missing, threads, feature_names,
Expand Down Expand Up @@ -624,7 +636,8 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
_meta_from_numpy(data, name, dtype, handle)
return
if _is_pandas_df(data):
data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype)
data, _, _ = _transform_pandas_df(data, False, meta=name,
meta_type=dtype)
_meta_from_numpy(data, name, dtype, handle)
return
if _is_pandas_series(data):
Expand Down
16 changes: 16 additions & 0 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ struct AtomicDispatcher<sizeof(uint64_t)> {
using Type = unsigned long long; // NOLINT
static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits.");
};

template <>
struct AtomicDispatcher<sizeof(uint8_t)> {
using Type = uint8_t; // NOLINT
};
} // namespace detail
} // namespace dh

Expand Down Expand Up @@ -522,6 +527,17 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
cudaMemcpyDeviceToHost));
}

template <class HContainer, class DContainer>
void CopyToD(HContainer const &h, DContainer *d) {
d->resize(h.size());
using HVT = std::remove_cv_t<typename HContainer::value_type>;
using DVT = std::remove_cv_t<typename DContainer::value_type>;
static_assert(std::is_same<HVT, DVT>::value,
"Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT),
cudaMemcpyHostToDevice));
}

// Keep track of pinned memory allocation
struct PinnedMemory {
void *temp_storage{nullptr};
Expand Down
73 changes: 59 additions & 14 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "hist_util.cuh"
#include "math.h" // NOLINT
#include "quantile.h"
#include "categorical.h"
#include "xgboost/host_device_vector.h"


Expand All @@ -36,6 +37,7 @@ namespace detail {

// Count the entries in each column and exclusive scan
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
common::Span<FeatureType const> feature_types,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts) {
Expand All @@ -48,10 +50,16 @@ void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const>
size_t cut_idx = idx - cuts_ptr[column_idx];
Span<Entry const> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
size_t rank = (column_entries.size() * cut_idx) /
static_cast<float>(num_available_cuts);
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
column_entries[rank].fvalue);
if (IsCat(feature_types, column_idx)) {
size_t rank = cut_idx;
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
column_entries[rank].fvalue);
} else {
size_t rank = (column_entries.size() * cut_idx) /
static_cast<float>(num_available_cuts);
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
column_entries[rank].fvalue);
}
});
}

Expand Down Expand Up @@ -196,13 +204,13 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
}
} // namespace detail

void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
SketchContainer *sketch_container, int num_cuts_per_feature,
size_t num_columns) {
void ProcessBatch(int device, DMatrix const *m, const SparsePage &page,
size_t begin, size_t end, SketchContainer *sketch_container,
int num_cuts_per_feature, size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());

Expand All @@ -219,13 +227,48 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);

// Removing duplicated entries in categorical features.
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
auto d_feature_types = m->Info().feature_types.ConstDeviceSpan();
auto n_uniques = dh::SegmentedUnique(
column_sizes_scan.data().get(),
column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(),
new_column_scan.data().get(), sorted_entries.begin(),
[=] __device__(Entry const &l, Entry const &r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});

// Renew the column scan and cut scan based on categorical data.
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(num_columns + 1);
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
dh::LaunchN(device, new_column_scan.size() - 1, [=] __device__(size_t idx) {
idx += 1;
if (IsCat(d_feature_types, idx - 1)) {
d_new_cuts_size[idx - 1] =
d_new_columns_ptr[idx] - d_new_columns_ptr[idx - 1];
} else {
d_new_cuts_size[idx - 1] = d_cuts_ptr[idx] - d_cuts_ptr[idx - 1];
}
});
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
new_cuts_size.cend(), d_cuts_ptr.data());
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
sorted_entries.resize(n_uniques);
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();

CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts));

detail::ExtractCutsSparse(device, d_cuts_ptr, d_feature_types,
dh::ToSpan(sorted_entries),
dh::ToSpan(new_column_scan), dh::ToSpan(cuts));

// add cuts into sketches
sorted_entries.clear();
Expand Down Expand Up @@ -313,7 +356,9 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
device, num_cuts_per_feature, has_weights);

HistogramCuts cuts;
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,

dmat->Info().feature_types.SetDevice(device);
SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_,
dmat->Info().num_row_, device);

dmat->Info().weights_.SetDevice(device);
Expand All @@ -333,7 +378,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
dmat->Info().num_col_,
is_ranking, dh::ToSpan(groups));
} else {
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts_per_feature,
ProcessBatch(device, dmat, batch, begin, end, &sketch_container, num_cuts_per_feature,
dmat->Info().num_col_);
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/common/hist_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct EntryCompareOp {
* \param out_cuts Output cut values
*/
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
common::Span<FeatureType const> feature_types,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts);
Expand Down Expand Up @@ -189,6 +190,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
// Extract the cuts from all columns concurrently
detail::ExtractCutsSparse(device, d_cuts_ptr,
{},
dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
Expand Down
2 changes: 2 additions & 0 deletions src/common/host_device_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <cstdint>
#include <memory>
#include <utility>
#include "xgboost/tree_model.h"
#include "xgboost/host_device_vector.h"

namespace xgboost {
Expand Down Expand Up @@ -176,6 +177,7 @@ template class HostDeviceVector<FeatureType>;
template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Segment>;

#if defined(__APPLE__)
/*
Expand Down
1 change: 1 addition & 0 deletions src/common/host_device_vector.cu
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Node>;
template class HostDeviceVector<RegTree::Segment>;

#if defined(__APPLE__)
/*
Expand Down
2 changes: 1 addition & 1 deletion src/common/observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class TrainingObserver {

for (size_t i = 0; i < h_vec.size(); ++i) {
OBSERVER_PRINT << h_vec[i] << ", ";
if (i % 8 == 0) {
if (i % 8 == 0 && i != 0) {
OBSERVER_PRINT << OBSERVER_NEWLINE;
}
if ((i + 1) == n) {
Expand Down
Loading

0 comments on commit cb77e3a

Please sign in to comment.