From cb77e3a56abbf77a26d3f215b0884ba92455082e Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 24 Jul 2020 17:11:17 +0800 Subject: [PATCH] Initial support for one hot categorical split. --- include/xgboost/feature_map.h | 4 +- include/xgboost/version_config.h | 2 +- python-package/xgboost/core.py | 17 +- python-package/xgboost/data.py | 35 ++-- src/common/device_helpers.cuh | 16 ++ src/common/hist_util.cu | 73 ++++++-- src/common/hist_util.cuh | 2 + src/common/host_device_vector.cc | 2 + src/common/host_device_vector.cu | 1 + src/common/observer.h | 2 +- src/common/quantile.cu | 36 +++- src/common/quantile.cuh | 13 +- src/data/adapter.h | 2 +- src/data/ellpack_page.cu | 35 ++-- src/data/ellpack_page.cuh | 4 +- src/data/iterative_device_dmatrix.cu | 10 +- src/predictor/gpu_predictor.cu | 171 ++++++++++++++---- src/tree/gpu_hist/evaluate_splits.cu | 129 ++++++++++--- src/tree/gpu_hist/evaluate_splits.cuh | 1 + src/tree/gpu_hist/feature_groups.cu | 6 +- src/tree/gpu_hist/feature_groups.cuh | 22 +-- src/tree/tree_model.cc | 2 +- src/tree/updater_gpu_common.cuh | 4 + src/tree/updater_gpu_hist.cu | 148 +++++++++++---- tests/cpp/common/test_hist_util.cu | 43 ++++- tests/cpp/common/test_hist_util.h | 39 ++-- tests/cpp/common/test_quantile.cc | 15 +- tests/cpp/common/test_quantile.cu | 46 +++-- tests/cpp/common/test_quantile.h | 33 ++++ tests/cpp/data/test_ellpack_page.cu | 63 ++++++- tests/cpp/helpers.cc | 9 +- tests/cpp/helpers.h | 25 +++ tests/cpp/predictor/test_gpu_predictor.cu | 6 + tests/cpp/predictor/test_predictor.cc | 55 +++++- tests/cpp/predictor/test_predictor.h | 2 + .../cpp/tree/gpu_hist/test_evaluate_splits.cu | 25 ++- tests/cpp/tree/gpu_hist/test_histogram.cu | 80 +++++++- tests/cpp/tree/test_gpu_hist.cu | 30 ++- tests/python-gpu/test_gpu_updaters.py | 46 +++++ tests/python-gpu/test_gpu_with_dask.py | 3 +- tests/python/test_with_pandas.py | 19 +- tests/python/testing.py | 11 ++ 42 files changed, 1038 insertions(+), 249 deletions(-) diff --git a/include/xgboost/feature_map.h b/include/xgboost/feature_map.h index a48e28ba1bfa..d5ff431d64eb 100644 --- a/include/xgboost/feature_map.h +++ b/include/xgboost/feature_map.h @@ -82,7 +82,9 @@ class FeatureMap { if (!strcmp("q", tname)) return kQuantitive; if (!strcmp("int", tname)) return kInteger; if (!strcmp("float", tname)) return kFloat; - LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity"; + if (!strcmp("categorical", tname)) return kInteger; + LOG(FATAL) << "unknown feature type, use i for indicator, q for quantity " + "and categorical for categorical split."; return kIndicator; } /*! \brief name of the feature */ diff --git a/include/xgboost/version_config.h b/include/xgboost/version_config.h index efab14e17d92..9c8585369b80 100644 --- a/include/xgboost/version_config.h +++ b/include/xgboost/version_config.h @@ -5,7 +5,7 @@ #define XGBOOST_VERSION_CONFIG_H_ #define XGBOOST_VER_MAJOR 1 -#define XGBOOST_VER_MINOR 2 +#define XGBOOST_VER_MINOR 3 #define XGBOOST_VER_PATCH 0 #endif // XGBOOST_VERSION_CONFIG_H_ diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index c8d0460825e5..02c12b9d9ea7 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -382,7 +382,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None, silent=False, feature_names=None, feature_types=None, - nthread=None): + nthread=None, + enable_categorical=False): """Parameters ---------- data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/ @@ -417,6 +418,17 @@ def __init__(self, data, label=None, weight=None, base_margin=None, Number of threads to use for loading data when parallelization is applicable. If -1, uses maximum threads available on the system. + enable_categorical: boolean, optional + + .. versionadded:: 1.3.0 + + Experimental support of specializing for categorical features. Do + not set to True unless you are interested in development. + Currently it's only available for `gpu_hist` tree method with 1 vs + rest (one hot) categorical split. Also, JSON serialization format, + `enable_experimental_json_serialization`, `gpu_predictor` and + pandas input are required. + """ if isinstance(data, list): raise TypeError('Input data can not be a list.') @@ -435,7 +447,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None, data, missing=self.missing, threads=self.nthread, feature_names=feature_names, - feature_types=feature_types) + feature_types=feature_types, + enable_categorical=enable_categorical) assert handle is not None self.handle = handle diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index e4c05dcc244e..2384bd97f441 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -168,20 +168,24 @@ def _is_pandas_df(data): } -def _transform_pandas_df(data, feature_names=None, feature_types=None, +def _transform_pandas_df(data, enable_categorical, + feature_names=None, feature_types=None, meta=None, meta_type=None): from pandas import MultiIndex, Int64Index - from pandas.api.types import is_sparse + from pandas.api.types import is_sparse, is_categorical + data_dtypes = data.dtypes - if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) + if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) or + (is_categorical(dtype) and enable_categorical) for dtype in data_dtypes): bad_fields = [ str(data.columns[i]) for i, dtype in enumerate(data_dtypes) if dtype.name not in _pandas_dtype_mapper ] - msg = """DataFrame.dtypes for data must be int, float or bool. - Did not expect the data types in fields """ + msg = """DataFrame.dtypes for data must be int, float, bool or categorical. When + categorical type is supplied, DMatrix parameter + `enable_categorical` must be set to `True`.""" raise ValueError(msg + ', '.join(bad_fields)) if feature_names is None and meta is None: @@ -200,6 +204,8 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None, if is_sparse(dtype): feature_types.append(_pandas_dtype_mapper[ dtype.subtype.name]) + elif is_categorical(dtype) and enable_categorical: + feature_types.append('categorical') else: feature_types.append(_pandas_dtype_mapper[dtype.name]) @@ -209,14 +215,19 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None, meta=meta)) dtype = meta_type if meta_type else 'float' - data = data.values.astype(dtype) + try: + data = data.values.astype(dtype) + except ValueError as e: + raise ValueError('Data must be convertable to float, even ' + + 'for categorical data.') from e return data, feature_names, feature_types -def _from_pandas_df(data, missing, nthread, feature_names, feature_types): +def _from_pandas_df(data, enable_categorical, missing, nthread, + feature_names, feature_types): data, feature_names, feature_types = _transform_pandas_df( - data, feature_names, feature_types) + data, enable_categorical, feature_names, feature_types) return _from_numpy_array(data, missing, nthread, feature_names, feature_types) @@ -484,7 +495,8 @@ def _has_array_protocol(data): def dispatch_data_backend(data, missing, threads, - feature_names, feature_types): + feature_names, feature_types, + enable_categorical=False): '''Dispatch data for DMatrix.''' if _is_scipy_csr(data): return _from_scipy_csr(data, missing, feature_names, feature_types) @@ -500,7 +512,7 @@ def dispatch_data_backend(data, missing, threads, if _is_tuple(data): return _from_tuple(data, missing, feature_names, feature_types) if _is_pandas_df(data): - return _from_pandas_df(data, missing, threads, + return _from_pandas_df(data, enable_categorical, missing, threads, feature_names, feature_types) if _is_pandas_series(data): return _from_pandas_series(data, missing, threads, feature_names, @@ -624,7 +636,8 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): _meta_from_numpy(data, name, dtype, handle) return if _is_pandas_df(data): - data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype) + data, _, _ = _transform_pandas_df(data, False, meta=name, + meta_type=dtype) _meta_from_numpy(data, name, dtype, handle) return if _is_pandas_series(data): diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 5e4f1eae0d6b..52e906972531 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -80,6 +80,11 @@ struct AtomicDispatcher { using Type = unsigned long long; // NOLINT static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits."); }; + +template <> +struct AtomicDispatcher { + using Type = uint8_t; // NOLINT +}; } // namespace detail } // namespace dh @@ -522,6 +527,17 @@ void CopyDeviceSpanToVector(std::vector *dst, xgboost::common::Span cudaMemcpyDeviceToHost)); } +template +void CopyToD(HContainer const &h, DContainer *d) { + d->resize(h.size()); + using HVT = std::remove_cv_t; + using DVT = std::remove_cv_t; + static_assert(std::is_same::value, + "Host and device containers must have same value type."); + dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT), + cudaMemcpyHostToDevice)); +} + // Keep track of pinned memory allocation struct PinnedMemory { void *temp_storage{nullptr}; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index ebd38b7aecd9..4cf5511d9d98 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -24,6 +24,7 @@ #include "hist_util.cuh" #include "math.h" // NOLINT #include "quantile.h" +#include "categorical.h" #include "xgboost/host_device_vector.h" @@ -36,6 +37,7 @@ namespace detail { // Count the entries in each column and exclusive scan void ExtractCutsSparse(int device, common::Span cuts_ptr, + common::Span feature_types, Span sorted_data, Span column_sizes_scan, Span out_cuts) { @@ -48,10 +50,16 @@ void ExtractCutsSparse(int device, common::Span size_t cut_idx = idx - cuts_ptr[column_idx]; Span column_entries = sorted_data.subspan(column_sizes_scan[column_idx], column_size); - size_t rank = (column_entries.size() * cut_idx) / - static_cast(num_available_cuts); - out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1, - column_entries[rank].fvalue); + if (IsCat(feature_types, column_idx)) { + size_t rank = cut_idx; + out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1, + column_entries[rank].fvalue); + } else { + size_t rank = (column_entries.size() * cut_idx) / + static_cast(num_available_cuts); + out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1, + column_entries[rank].fvalue); + } }); } @@ -196,13 +204,13 @@ void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, } } // namespace detail -void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end, - SketchContainer *sketch_container, int num_cuts_per_feature, - size_t num_columns) { +void ProcessBatch(int device, DMatrix const *m, const SparsePage &page, + size_t begin, size_t end, SketchContainer *sketch_container, + int num_cuts_per_feature, size_t num_columns) { dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); dh::device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); + host_data.begin() + end); thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); @@ -219,13 +227,48 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end, 0, sorted_entries.size(), &cuts_ptr, &column_sizes_scan); + // Removing duplicated entries in categorical features. + dh::caching_device_vector new_column_scan(column_sizes_scan.size()); + auto d_feature_types = m->Info().feature_types.ConstDeviceSpan(); + auto n_uniques = dh::SegmentedUnique( + column_sizes_scan.data().get(), + column_sizes_scan.data().get() + column_sizes_scan.size(), + sorted_entries.begin(), sorted_entries.end(), + new_column_scan.data().get(), sorted_entries.begin(), + [=] __device__(Entry const &l, Entry const &r) { + if (l.index == r.index) { + if (IsCat(d_feature_types, l.index)) { + return l.fvalue == r.fvalue; + } + } + return false; + }); + + // Renew the column scan and cut scan based on categorical data. + dh::caching_device_vector new_cuts_size(num_columns + 1); + auto d_new_cuts_size = dh::ToSpan(new_cuts_size); + auto d_new_columns_ptr = dh::ToSpan(new_column_scan); + auto d_cuts_ptr = cuts_ptr.DeviceSpan(); + CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); + dh::LaunchN(device, new_column_scan.size() - 1, [=] __device__(size_t idx) { + idx += 1; + if (IsCat(d_feature_types, idx - 1)) { + d_new_cuts_size[idx - 1] = + d_new_columns_ptr[idx] - d_new_columns_ptr[idx - 1]; + } else { + d_new_cuts_size[idx - 1] = d_cuts_ptr[idx] - d_cuts_ptr[idx - 1]; + } + }); + thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), + new_cuts_size.cend(), d_cuts_ptr.data()); auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); + sorted_entries.resize(n_uniques); dh::caching_device_vector cuts(h_cuts_ptr.back()); - auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); - CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); - detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts)); + + detail::ExtractCutsSparse(device, d_cuts_ptr, d_feature_types, + dh::ToSpan(sorted_entries), + dh::ToSpan(new_column_scan), dh::ToSpan(cuts)); // add cuts into sketches sorted_entries.clear(); @@ -313,7 +356,9 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, device, num_cuts_per_feature, has_weights); HistogramCuts cuts; - SketchContainer sketch_container(max_bins, dmat->Info().num_col_, + + dmat->Info().feature_types.SetDevice(device); + SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_, dmat->Info().num_row_, device); dmat->Info().weights_.SetDevice(device); @@ -333,7 +378,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, dmat->Info().num_col_, is_ranking, dh::ToSpan(groups)); } else { - ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts_per_feature, + ProcessBatch(device, dmat, batch, begin, end, &sketch_container, num_cuts_per_feature, dmat->Info().num_col_); } } diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index f1034040c1ab..7adabab6978c 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -38,6 +38,7 @@ struct EntryCompareOp { * \param out_cuts Output cut values */ void ExtractCutsSparse(int device, common::Span cuts_ptr, + common::Span feature_types, Span sorted_data, Span column_sizes_scan, Span out_cuts); @@ -189,6 +190,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, dh::caching_device_vector cuts(h_cuts_ptr.back()); // Extract the cuts from all columns concurrently detail::ExtractCutsSparse(device, d_cuts_ptr, + {}, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts)); diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index f9974f8ecfaf..a16154966588 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -10,6 +10,7 @@ #include #include #include +#include "xgboost/tree_model.h" #include "xgboost/host_device_vector.h" namespace xgboost { @@ -176,6 +177,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t +template class HostDeviceVector; #if defined(__APPLE__) /* diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 39a0fbe9efb0..e8ed1d5e6b83 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -404,6 +404,7 @@ template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t template class HostDeviceVector; +template class HostDeviceVector; #if defined(__APPLE__) /* diff --git a/src/common/observer.h b/src/common/observer.h index 1af16d45dbd4..397b565ed271 100644 --- a/src/common/observer.h +++ b/src/common/observer.h @@ -71,7 +71,7 @@ class TrainingObserver { for (size_t i = 0; i < h_vec.size(); ++i) { OBSERVER_PRINT << h_vec[i] << ", "; - if (i % 8 == 0) { + if (i % 8 == 0 && i != 0) { OBSERVER_PRINT << OBSERVER_NEWLINE; } if ((i + 1) == n) { diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 52d0e37e97be..6c9ec1bf6c60 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -16,6 +16,7 @@ #include "hist_util.h" #include "device_helpers.cuh" #include "common.h" +#include "categorical.h" namespace xgboost { namespace common { @@ -304,9 +305,13 @@ void SketchContainer::Prune(size_t to) { this->Unique(); OffsetT to_total = 0; HostDeviceVector new_columns_ptr{to_total}; + auto const& h_feature_types = feature_types_.ConstHostSpan(); for (bst_feature_t i = 0; i < num_columns_; ++i) { size_t length = this->Column(i).size(); length = std::min(length, to); + if (IsCat(h_feature_types, i)) { + length = this->Column(i).size(); + } to_total += length; new_columns_ptr.HostVector().emplace_back(to_total); } @@ -317,6 +322,7 @@ void SketchContainer::Prune(size_t to) { auto d_columns_ptr_out = new_columns_ptr.ConstDeviceSpan(); auto out = dh::ToSpan(this->Other()); auto in = dh::ToSpan(this->Current()); + auto ft = this->feature_types_.ConstDeviceSpan(); dh::LaunchN(0, to_total, [=] __device__(size_t idx) { size_t column_id = dh::SegmentId(d_columns_ptr_out, idx); auto out_column = out.subspan(d_columns_ptr_out[column_id], @@ -326,10 +332,11 @@ void SketchContainer::Prune(size_t to) { d_columns_ptr_in[column_id + 1] - d_columns_ptr_in[column_id]); idx -= d_columns_ptr_out[column_id]; + auto is_cat = IsCat(ft, column_id); // Input has lesser columns than `to`, just copy them to the output. This is correct // as the new output size is calculated based on both the size of `to` and current // column. - if (in_column.size() <= to) { + if (in_column.size() <= to || is_cat) { out_column[idx] = in_column[idx]; return; } @@ -473,7 +480,8 @@ void SketchContainer::AllReduce() { } // Merge them into a new sketch. - SketchContainer new_sketch(num_bins_, this->num_columns_, global_sum_rows, + SketchContainer new_sketch(this->feature_types_, num_bins_, + this->num_columns_, global_sum_rows, this->device_); for (size_t i = 0; i < allworkers.size(); ++i) { auto worker = allworkers[i]; @@ -513,11 +521,16 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector(); h_out_columns_ptr.clear(); h_out_columns_ptr.push_back(0); + auto const& h_feature_types = this->feature_types_.ConstHostSpan(); for (bst_feature_t i = 0; i < num_columns_; ++i) { - h_out_columns_ptr.push_back( - std::min(static_cast(std::max(static_cast(1ul), - this->Column(i).size())), - static_cast(num_bins_))); + size_t column_size = std::max(static_cast(1ul), + this->Column(i).size()); + if (IsCat(h_feature_types, i)) { + h_out_columns_ptr.push_back(static_cast(column_size)); + } else { + h_out_columns_ptr.push_back(std::min(static_cast(column_size), + static_cast(num_bins_))); + } } std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), h_out_columns_ptr.begin()); @@ -528,6 +541,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { p_cuts->cut_values_.SetDevice(device_); p_cuts->cut_values_.Resize(total_bins); auto out_cut_values = p_cuts->cut_values_.DeviceSpan(); + auto d_ft = feature_types_.ConstDeviceSpan(); dh::LaunchN(0, total_bins, [=] __device__(size_t idx) { auto column_id = dh::SegmentId(d_out_columns_ptr, idx); @@ -550,11 +564,17 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { return; } - // First thread is responsible for setting min values. - if (idx == 0) { + if (idx == 0 && !IsCat(d_ft, column_id)) { auto mval = in_column[idx].value; d_min_values[column_id] = mval - (fabs(mval) + 1e-5); } + + if (IsCat(d_ft, column_id)) { + assert(out_column.size() == in_column.size()); + out_column[idx] = in_column[idx].value; + return; + } + // Last thread is responsible for setting a value that's greater than other cuts. if (idx == out_column.size() - 1) { const bst_float cpt = in_column.back().value; diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index cd5833914db5..bfe706cc25d8 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -4,6 +4,7 @@ #include #include "xgboost/span.h" +#include "xgboost/data.h" #include "device_helpers.cuh" #include "quantile.h" #include "timer.h" @@ -28,6 +29,7 @@ class SketchContainer { private: Monitor timer_; std::unique_ptr reducer_; + HostDeviceVector feature_types_; bst_row_t num_rows_; bst_feature_t num_columns_; int32_t num_bins_; @@ -80,12 +82,19 @@ class SketchContainer { * \param num_rows Total number of rows in known dataset (typically the rows in current worker). * \param device GPU ID. */ - SketchContainer(int32_t max_bin, bst_feature_t num_columns, bst_row_t num_rows, int32_t device) : - num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { + SketchContainer(HostDeviceVector const& feature_types, + int32_t max_bin, + bst_feature_t num_columns, bst_row_t num_rows, + int32_t device) + : num_rows_{num_rows}, + num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { // Initialize Sketches for this dmatrix this->columns_ptr_.SetDevice(device_); this->columns_ptr_.Resize(num_columns + 1); CHECK_GE(device, 0); + this->feature_types_.Resize(feature_types.Size()); + this->feature_types_.Copy(feature_types); + this->feature_types_.SetDevice(device); timer_.Init(__func__); } /* \brief Return GPU ID for this container. */ diff --git a/src/data/adapter.h b/src/data/adapter.h index c3981c24fffd..aa7d44f2e609 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -68,7 +68,7 @@ namespace data { /** \brief An adapter can return this value for number of rows or columns * indicating that this value is currently unknown and should be inferred while * passing over the data. */ -constexpr size_t kAdapterUnknownSize = std::numeric_limits::max(); +constexpr size_t kAdapterUnknownSize = std::numeric_limits::max(); struct COOTuple { COOTuple() = default; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 39e845f2d765..942f35d9fbe6 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -5,6 +5,7 @@ #include #include #include +#include "../common/categorical.h" #include "../common/hist_util.cuh" #include "../common/random.h" #include "./ellpack_page.cuh" @@ -33,6 +34,7 @@ __global__ void CompressBinEllpackKernel( const Entry* __restrict__ entries, // One batch of input data const float* __restrict__ cuts, // HistogramCuts::cut_values_ const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_ + common::Span feature_types, size_t base_row, // batch_row_begin size_t n_rows, size_t row_stride, @@ -51,11 +53,19 @@ __global__ void CompressBinEllpackKernel( // {feature_cuts, ncuts} forms the array of cuts of `feature'. const float* feature_cuts = &cuts[cut_rows[feature]]; int ncuts = cut_rows[feature + 1] - cut_rows[feature]; + bool is_cat = common::IsCat(feature_types, ifeature); // Assigning the bin in current entry. // S.t.: fvalue < feature_cuts[bin] - bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts, - fvalue) - - feature_cuts; + if (is_cat) { + auto it = dh::MakeTransformIterator( + feature_cuts, [](float v) { return common::AsCat(v); }); + bin = thrust::lower_bound(thrust::seq, it, it + ncuts, common::AsCat(fvalue)) - it; + } else { + bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts, + fvalue) - + feature_cuts; + } + if (bin >= ncuts) { bin = ncuts - 1; } @@ -90,7 +100,7 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts, n_rows(page.Size()), row_stride(row_stride) { this->InitCompressedData(device); - this->CreateHistIndices(device, page); + this->CreateHistIndices(device, page, {}); } // Construct an ELLPACK matrix in memory. @@ -108,12 +118,14 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) monitor_.Stop("Quantiles"); monitor_.Start("InitCompressedData"); - InitCompressedData(param.gpu_id); + this->InitCompressedData(param.gpu_id); monitor_.Stop("InitCompressedData"); + dmat->Info().feature_types.SetDevice(param.gpu_id); + auto ft = dmat->Info().feature_types.ConstDeviceSpan(); monitor_.Start("BinningCompression"); for (const auto& batch : dmat->GetBatches()) { - CreateHistIndices(param.gpu_id, batch); + CreateHistIndices(param.gpu_id, batch, ft); } monitor_.Stop("BinningCompression"); } @@ -365,7 +377,8 @@ void EllpackPageImpl::InitCompressedData(int device) { // Compress a CSR page into ELLPACK. void EllpackPageImpl::CreateHistIndices(int device, - const SparsePage& row_batch) { + const SparsePage& row_batch, + common::Span feature_types) { if (row_batch.Size() == 0) return; unsigned int null_gidx_value = NumSymbols() - 1; @@ -397,9 +410,9 @@ void EllpackPageImpl::CreateHistIndices(int device, size_t n_entries = ent_cnt_end - ent_cnt_begin; dh::device_vector entries_d(n_entries); // copy data entries to device. - dh::safe_cuda(cudaMemcpy(entries_d.data().get(), - data_vec.data() + ent_cnt_begin, - n_entries * sizeof(Entry), cudaMemcpyDefault)); + dh::safe_cuda(cudaMemcpyAsync(entries_d.data().get(), + data_vec.data() + ent_cnt_begin, + n_entries * sizeof(Entry), cudaMemcpyDefault)); const dim3 block3(32, 8, 1); // 256 threads const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), common::DivRoundUp(row_stride, block3.y), 1); @@ -408,7 +421,7 @@ void EllpackPageImpl::CreateHistIndices(int device, CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()), gidx_buffer.DevicePointer(), row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), - device_accessor.feature_segments.data(), + device_accessor.feature_segments.data(), feature_types, row_batch.base_rowid + batch_row_begin, batch_nrows, row_stride, null_gidx_value); } diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 8cb0162fb2c7..0e83f7e6bb9d 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -212,8 +212,8 @@ class EllpackPageImpl { * @param row_batch The CSR page. */ void CreateHistIndices(int device, - const SparsePage& row_batch - ); + const SparsePage& row_batch, + common::Span feature_types); /*! * \brief Initialize the buffer to store compressed features. */ diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index b99f99590bc0..9ff944925a20 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -79,7 +79,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin } else { CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; } - sketch_containers.emplace_back(batch_param_.max_bin, cols, num_rows(), get_device()); + sketch_containers.emplace_back(proxy->Info().feature_types, batch_param_.max_bin, cols, + num_rows(), get_device()); auto* p_sketch = &sketch_containers.back(); proxy->Info().weights_.SetDevice(get_device()); Dispatch(proxy, [&](auto const &value) { @@ -97,11 +98,16 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin })); nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end()); + + this->Info().feature_types.Resize(proxy->Info().feature_types.Size()); + this->Info().feature_types.Copy(proxy->Info().feature_types); batches++; } iter.Reset(); dh::safe_cuda(cudaSetDevice(get_device())); - common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, get_device()); + common::SketchContainer final_sketch(this->Info().feature_types, + batch_param_.max_bin, cols, accumulated_rows, get_device()); + for (auto const& sketch : sketch_containers) { final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); final_sketch.FixError(); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index a36a131fa3bb..f3a6dea9bb8c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -18,6 +18,8 @@ #include "../data/ellpack_page.cuh" #include "../data/device_adapter.cuh" #include "../common/common.h" +#include "../common/bitfield.h" +#include "../common/categorical.h" #include "../common/device_helpers.cuh" namespace xgboost { @@ -168,33 +170,49 @@ struct DeviceAdapterLoader { template __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, + common::Span split_types, + common::Span d_cat_ptrs, + common::Span d_categories, Loader* loader) { - RegTree::Node n = tree[0]; + bst_node_t nidx = 0; + RegTree::Node n = tree[nidx]; while (!n.IsLeaf()) { float fvalue = loader->GetElement(ridx, n.SplitIndex()); // Missing value - if (isnan(fvalue)) { - n = tree[n.DefaultChild()]; + if (common::CheckNAN(fvalue)) { + nidx = n.DefaultChild(); } else { - if (fvalue < n.SplitCond()) { - n = tree[n.LeftChild()]; + bool go_left = true; + if (common::IsCat(split_types, nidx)) { + auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg, + d_cat_ptrs[nidx].size); + go_left = Decision(categories, common::AsCat(fvalue)); } else { - n = tree[n.RightChild()]; + go_left = fvalue < n.SplitCond(); + } + if (go_left) { + nidx = n.LeftChild(); + } else { + nidx = n.RightChild(); } } + n = tree[nidx]; } - return n.LeafValue(); + return tree[nidx].LeafValue(); } template -__global__ void PredictKernel(Data data, - common::Span d_nodes, - common::Span d_out_predictions, - common::Span d_tree_segments, - common::Span d_tree_group, - size_t tree_begin, size_t tree_end, size_t num_features, - size_t num_rows, size_t entry_start, - bool use_shared, int num_group) { +__global__ void +PredictKernel(Data data, common::Span d_nodes, + common::Span d_out_predictions, + common::Span d_tree_segments, + common::Span d_tree_group, + common::Span d_tree_split_types, + common::Span d_cat_tree_segments, + common::Span d_cat_node_segments, + common::Span d_categories, size_t tree_begin, + size_t tree_end, size_t num_features, size_t num_rows, + size_t entry_start, bool use_shared, int num_group) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; Loader loader(data, use_shared, num_features, num_rows, entry_start); if (global_idx >= num_rows) return; @@ -203,7 +221,18 @@ __global__ void PredictKernel(Data data, for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; - float leaf = GetLeafWeight(global_idx, d_tree, &loader); + auto tree_cat_ptrs = d_cat_node_segments.subspan( + d_tree_segments[tree_idx - tree_begin], + d_tree_segments[tree_idx - tree_begin + 1] - + d_tree_segments[tree_idx - tree_begin]); + auto tree_categories = + d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], + d_cat_tree_segments[tree_idx - tree_begin + 1] - + d_cat_tree_segments[tree_idx - tree_begin]); + float leaf = GetLeafWeight(global_idx, d_tree, d_tree_split_types, + tree_cat_ptrs, + tree_categories, + &loader); sum += leaf; } d_out_predictions[global_idx] += sum; @@ -213,8 +242,19 @@ __global__ void PredictKernel(Data data, const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; bst_uint out_prediction_idx = global_idx * num_group + tree_group; + auto tree_cat_ptrs = d_cat_node_segments.subspan( + d_tree_segments[tree_idx - tree_begin], + d_tree_segments[tree_idx - tree_begin + 1] - + d_tree_segments[tree_idx - tree_begin]); + auto tree_categories = + d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], + d_cat_tree_segments[tree_idx - tree_begin + 1] - + d_cat_tree_segments[tree_idx - tree_begin]); d_out_predictions[out_prediction_idx] += - GetLeafWeight(global_idx, d_tree, &loader); + GetLeafWeight(global_idx, d_tree, d_tree_split_types, + tree_cat_ptrs, + tree_categories, + &loader); } } } @@ -222,9 +262,15 @@ __global__ void PredictKernel(Data data, class DeviceModel { public: // Need to lazily construct the vectors because GPU id is only known at runtime - HostDeviceVector nodes; HostDeviceVector tree_segments; + HostDeviceVector nodes; HostDeviceVector tree_group; + HostDeviceVector split_types; + + HostDeviceVector categories; + HostDeviceVector categories_tree_segments; + HostDeviceVector categories_node_segments; + size_t tree_beg_; // NOLINT size_t tree_end_; // NOLINT int num_group; @@ -256,6 +302,41 @@ class DeviceModel { tree_group = std::move(HostDeviceVector(model.tree_info.size(), 0, gpu_id)); auto& h_tree_group = tree_group.HostVector(); std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size()); + + // Initialize categorical splits. + split_types.SetDevice(gpu_id); + std::vector& h_split_types = split_types.HostVector(); + h_split_types.resize(h_tree_segments.back()); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const& src_st = model.trees.at(tree_idx)->GetSplitTypes(); + std::copy(src_st.cbegin(), src_st.cend(), + h_split_types.begin() + h_tree_segments[tree_idx - tree_begin]); + } + + categories = HostDeviceVector({}, gpu_id); + categories_tree_segments = HostDeviceVector(1, 0, gpu_id); + std::vector &h_categories = categories.HostVector(); + std::vector& h_split_cat_segments = categories_tree_segments.HostVector(); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const& src_cats = model.trees.at(tree_idx)->GetSplitCategories(); + size_t orig_size = h_categories.size(); + h_categories.resize(orig_size + src_cats.size()); + std::copy(src_cats.cbegin(), src_cats.cend(), + h_categories.begin() + orig_size); + h_split_cat_segments.push_back(h_categories.size()); + } + + categories_node_segments = + HostDeviceVector(h_tree_segments.back(), {}, gpu_id); + std::vector &h_categories_node_segments = + categories_node_segments.HostVector(); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr(); + std::copy(src_cats_ptr.cbegin(), src_cats_ptr.cend(), + h_categories_node_segments.begin() + + h_tree_segments[tree_idx - tree_begin]); + } + this->tree_beg_ = tree_begin; this->tree_end_ = tree_end; this->num_group = model.learner_model_param->num_output_group; @@ -264,7 +345,8 @@ class DeviceModel { class GPUPredictor : public xgboost::Predictor { private: - void PredictInternal(const SparsePage& batch, size_t num_features, + void PredictInternal(const SparsePage& batch, + size_t num_features, HostDeviceVector* predictions, size_t batch_offset) { batch.offset.SetDevice(generic_param_->gpu_id); @@ -284,14 +366,18 @@ class GPUPredictor : public xgboost::Predictor { SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, - data, - model_.nodes.DeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), - model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(), - model_.tree_beg_, model_.tree_end_, num_features, num_rows, - entry_start, use_shared, model_.num_group); + PredictKernel, data, + model_.nodes.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), + model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(), + model_.split_types.ConstDeviceSpan(), + model_.categories_tree_segments.ConstDeviceSpan(), + model_.categories_node_segments.ConstDeviceSpan(), + model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_, + num_features, num_rows, entry_start, use_shared, model_.num_group); } - void PredictInternal(EllpackDeviceAccessor const& batch, HostDeviceVector* out_preds, + void PredictInternal(EllpackDeviceAccessor const& batch, + HostDeviceVector* out_preds, size_t batch_offset) { const uint32_t BLOCK_THREADS = 256; size_t num_rows = batch.n_rows; @@ -300,12 +386,15 @@ class GPUPredictor : public xgboost::Predictor { bool use_shared = false; size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( - PredictKernel, - batch, - model_.nodes.DeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), - model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(), - model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows, - entry_start, use_shared, model_.num_group); + PredictKernel, batch, + model_.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), + model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(), + model_.split_types.ConstDeviceSpan(), + model_.categories_tree_segments.ConstDeviceSpan(), + model_.categories_node_segments.ConstDeviceSpan(), + model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_, + batch.NumFeatures(), num_rows, entry_start, use_shared, + model_.num_group); } void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, @@ -317,6 +406,7 @@ class GPUPredictor : public xgboost::Predictor { } model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id); out_preds->SetDevice(generic_param_->gpu_id); + auto const& info = dmat->Info(); if (dmat->PageExists()) { size_t batch_offset = 0; @@ -329,7 +419,8 @@ class GPUPredictor : public xgboost::Predictor { size_t batch_offset = 0; for (auto const& page : dmat->GetBatches()) { this->PredictInternal( - page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), out_preds, + page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), + out_preds, batch_offset); batch_offset += page.Impl()->n_rows; } @@ -432,12 +523,14 @@ class GPUPredictor : public xgboost::Predictor { size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, - m->Value(), - d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(), - d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(), - tree_begin, tree_end, m->NumColumns(), info.num_row_, - entry_start, use_shared, output_groups); + PredictKernel, m->Value(), + d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(), + d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(), + d_model.split_types.ConstDeviceSpan(), + d_model.categories_tree_segments.ConstDeviceSpan(), + d_model.categories_node_segments.ConstDeviceSpan(), + d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(), + info.num_row_, entry_start, use_shared, output_groups); } void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 7a843c2bfb7c..4006068b3af0 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -1,8 +1,9 @@ /*! * Copyright 2020 by XGBoost Contributors */ -#include "evaluate_splits.cuh" #include +#include "evaluate_splits.cuh" +#include "../../common/categorical.h" namespace xgboost { namespace tree { @@ -65,15 +66,86 @@ ReduceFeature(common::Span feature_histogram, if (threadIdx.x == 0) { shared_sum = local_sum; } - __syncthreads(); + cub::CTA_SYNC(); return shared_sum; } +template struct OneHotBin { + GradientSumT __device__ operator()( + bool thread_active, uint32_t scan_begin, + SumCallbackOp*, + GradientSumT const &missing, + EvaluateSplitInputs const &inputs, TempStorageT *) { + GradientSumT bin = thread_active + ? inputs.gradient_histogram[scan_begin + threadIdx.x] + : GradientSumT(); + auto rest = inputs.parent_sum - bin - missing; + return rest; + } +}; + +template +struct UpdateOneHot { + void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain, + bst_feature_t fidx, GradientSumT const &missing, + GradientSumT const &bin, + EvaluateSplitInputs const &inputs, + DeviceSplitCandidate *best_split) { + int split_gidx = (scan_begin + threadIdx.x); + float fvalue = inputs.feature_values[split_gidx]; + GradientSumT left = missing_left ? bin + missing : bin; + GradientSumT right = inputs.parent_sum - left; + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, + GradientPair(left), GradientPair(right), true, + inputs.param); + } +}; + +template +struct NumericBin { + GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin, + SumCallbackOp* prefix_callback, + GradientSumT const &missing, + EvaluateSplitInputs inputs, + TempStorageT *temp_storage) { + GradientSumT bin = thread_active + ? inputs.gradient_histogram[scan_begin + threadIdx.x] + : GradientSumT(); + ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback); + return bin; + } +}; + +template +struct UpdateNumeric { + void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain, + bst_feature_t fidx, GradientSumT const &missing, + GradientSumT const &bin, + EvaluateSplitInputs const &inputs, + DeviceSplitCandidate *best_split) { + // Use pointer from cut to indicate begin and end of bins for each feature. + uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin + int split_gidx = (scan_begin + threadIdx.x) - 1; + float fvalue; + if (split_gidx < static_cast(gidx_begin)) { + fvalue = inputs.min_fvalue[fidx]; + } else { + fvalue = inputs.feature_values[split_gidx]; + } + GradientSumT left = missing_left ? bin + missing : bin; + GradientSumT right = inputs.parent_sum - left; + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, + fidx, GradientPair(left), GradientPair(right), + false, inputs.param); + } +}; + /*! \brief Find the thread with best gain. */ template + typename MaxReduceT, typename TempStorageT, typename GradientSumT, + typename BinFn, typename UpdateFn> __device__ void EvaluateFeature( - int fidx, EvaluateSplitInputs inputs, + bst_feature_t fidx, EvaluateSplitInputs inputs, DeviceSplitCandidate* best_split, // shared memory storing best split TempStorageT* temp_storage // temp memory for cub operations ) { @@ -81,12 +153,14 @@ __device__ void EvaluateFeature( uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin uint32_t gidx_end = inputs.feature_segments[fidx + 1]; // end bin for i^th feature + auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin); + auto bin_fn = BinFn(); + auto update_fn = UpdateFn(); // Sum histogram bins for current feature GradientSumT const feature_sum = ReduceFeature( - inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin), - temp_storage); + feature_hist, temp_storage); GradientSumT const missing = inputs.parent_sum - feature_sum; float const null_gain = -std::numeric_limits::infinity(); @@ -95,12 +169,7 @@ __device__ void EvaluateFeature( for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) { bool thread_active = (scan_begin + threadIdx.x) < gidx_end; - - // Gradient value for current bin. - GradientSumT bin = thread_active - ? inputs.gradient_histogram[scan_begin + threadIdx.x] - : GradientSumT(); - ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); + auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage); // Whether the gradient of missing values is put to the left side. bool missing_left = true; @@ -123,24 +192,14 @@ __device__ void EvaluateFeature( block_max = best; } - __syncthreads(); + cub::CTA_SYNC(); // Best thread updates split if (threadIdx.x == block_max.key) { - int split_gidx = (scan_begin + threadIdx.x) - 1; - float fvalue; - if (split_gidx < static_cast(gidx_begin)) { - fvalue = inputs.min_fvalue[fidx]; - } else { - fvalue = inputs.feature_values[split_gidx]; - } - GradientSumT left = missing_left ? bin + missing : bin; - GradientSumT right = inputs.parent_sum - left; - best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, - fidx, GradientPair(left), GradientPair(right), - inputs.param); + update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs, + best_split); } - __syncthreads(); + cub::CTA_SYNC(); } } @@ -181,11 +240,21 @@ __global__ void EvaluateSplitsKernel( // One block for each feature. Features are sampled, so fidx != blockIdx.x int fidx = inputs.feature_set[is_left ? blockIdx.x : blockIdx.x - left.feature_set.size()]; + if (common::IsCat(inputs.feature_types, fidx)) { + EvaluateFeature, + UpdateOneHot>(fidx, inputs, &best_split, + &temp_storage); + } else { + EvaluateFeature, + UpdateNumeric>(fidx, inputs, &best_split, + &temp_storage); + } - EvaluateFeature( - fidx, inputs, &best_split, &temp_storage); - - __syncthreads(); + cub::CTA_SYNC(); if (threadIdx.x == 0) { // Record best loss for each feature diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index ed175ae721e2..8ba177d8acdd 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -17,6 +17,7 @@ struct EvaluateSplitInputs { GradientSumT parent_sum; GPUTrainingParam param; common::Span feature_set; + common::Span feature_types; common::Span feature_segments; common::Span feature_values; common::Span min_fvalue; diff --git a/src/tree/gpu_hist/feature_groups.cu b/src/tree/gpu_hist/feature_groups.cu index 5a2c8ee6cbd8..9bb9d816283a 100644 --- a/src/tree/gpu_hist/feature_groups.cu +++ b/src/tree/gpu_hist/feature_groups.cu @@ -23,13 +23,13 @@ FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense, return; } - std::vector& feature_segments_h = feature_segments.HostVector(); + std::vector& feature_segments_h = feature_segments.HostVector(); std::vector& bin_segments_h = bin_segments.HostVector(); feature_segments_h.push_back(0); bin_segments_h.push_back(0); const std::vector& cut_ptrs = cuts.Ptrs(); - int max_shmem_bins = shm_size / bin_size; + size_t max_shmem_bins = shm_size / bin_size; max_group_bins = 0; for (size_t i = 2; i < cut_ptrs.size(); ++i) { @@ -49,7 +49,7 @@ FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense, } void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) { - std::vector& feature_segments_h = feature_segments.HostVector(); + std::vector& feature_segments_h = feature_segments.HostVector(); feature_segments_h.push_back(0); feature_segments_h.push_back(cuts.Ptrs().size() - 1); diff --git a/src/tree/gpu_hist/feature_groups.cuh b/src/tree/gpu_hist/feature_groups.cuh index 3af230c2ccf6..a0fc765a6b4a 100644 --- a/src/tree/gpu_hist/feature_groups.cuh +++ b/src/tree/gpu_hist/feature_groups.cuh @@ -20,7 +20,7 @@ namespace tree { consecutive feature indices, and also contains a range of all bin indices associated with those features. */ struct FeatureGroup { - __host__ __device__ FeatureGroup(int start_feature_, int num_features_, + __host__ __device__ FeatureGroup(size_t start_feature_, size_t num_features_, int start_bin_, int num_bins_) : start_feature(start_feature_), num_features(num_features_), start_bin(start_bin_), num_bins(num_bins_) {} @@ -36,24 +36,24 @@ struct FeatureGroup { /** \brief FeatureGroupsAccessor is a non-owning accessor for FeatureGroups. */ struct FeatureGroupsAccessor { - FeatureGroupsAccessor(common::Span feature_segments_, + FeatureGroupsAccessor(common::Span feature_segments_, common::Span bin_segments_, int max_group_bins_) : feature_segments(feature_segments_), bin_segments(bin_segments_), max_group_bins(max_group_bins_) {} - - common::Span feature_segments; + + common::Span feature_segments; common::Span bin_segments; int max_group_bins; - + /** \brief Gets the number of feature groups. */ - __host__ __device__ int NumGroups() const { + __host__ __device__ size_t NumGroups() const { return feature_segments.size() - 1; } /** \brief Gets the information about a feature group with index i. */ __host__ __device__ FeatureGroup operator[](int i) const { return {feature_segments[i], feature_segments[i + 1] - feature_segments[i], - bin_segments[i], bin_segments[i + 1] - bin_segments[i]}; + bin_segments[i], bin_segments[i + 1] - bin_segments[i]}; } }; @@ -78,13 +78,13 @@ struct FeatureGroupsAccessor { */ struct FeatureGroups { /** Group cuts for features. Size equals to (number of groups + 1). */ - HostDeviceVector feature_segments; + HostDeviceVector feature_segments; /** Group cuts for bins. Size equals to (number of groups + 1) */ HostDeviceVector bin_segments; /** Maximum number of bins in a group. Useful to compute the amount of dynamic shared memory when launching a kernel. */ int max_group_bins; - + /** Creates feature groups by splitting features into groups. \param cuts Histogram cuts that given the number of bins per feature. \param is_dense Whether the data matrix is dense. @@ -106,12 +106,12 @@ struct FeatureGroups { feature_segments.SetDevice(device); bin_segments.SetDevice(device); return {feature_segments.ConstDeviceSpan(), bin_segments.ConstDeviceSpan(), - max_group_bins}; + max_group_bins}; } private: void InitSingle(const common::HistogramCuts& cuts); -}; +}; } // namespace tree } // namespace xgboost diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 27521c68c7fb..b45b9534addb 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -842,6 +842,7 @@ void RegTree::LoadModel(Json const& in) { auto cat = common::AsCat(get(j_cat)); max_cat = std::max(max_cat, cat); } + max_cat = max_cat == 0 ? 1 : max_cat; size_t size = max_cat == std::numeric_limits::min() ? 0 : common::KCatBitField::ComputeStorageSize(max_cat); @@ -900,7 +901,6 @@ void RegTree::SaveModel(Json* p_out) const { std::vector split_type(n_nodes); std::vector categories(n_nodes); - for (bst_node_t i = 0; i < n_nodes; ++i) { auto const& s = stats_[i]; loss_changes[i] = s.loss_chg; diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 63da94ada5cb..a7dbfb237536 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -57,6 +57,7 @@ struct DeviceSplitCandidate { DefaultDirection dir {kLeftDir}; int findex {-1}; float fvalue {0}; + bool is_cat { false }; GradientPair left_sum; GradientPair right_sum; @@ -77,6 +78,7 @@ struct DeviceSplitCandidate { float fvalue_in, int findex_in, GradientPair left_sum_in, GradientPair right_sum_in, + bool cat, const GPUTrainingParam& param) { if (loss_chg_in > loss_chg && left_sum_in.GetHess() >= param.min_child_weight && @@ -84,6 +86,7 @@ struct DeviceSplitCandidate { loss_chg = loss_chg_in; dir = dir_in; fvalue = fvalue_in; + is_cat = cat; left_sum = left_sum_in; right_sum = right_sum_in; findex = findex_in; @@ -96,6 +99,7 @@ struct DeviceSplitCandidate { << "dir: " << c.dir << ", " << "findex: " << c.findex << ", " << "fvalue: " << c.fvalue << ", " + << "is_cat: " << c.is_cat << ", " << "left sum: " << c.left_sum << ", " << "right sum: " << c.right_sum << std::endl; return os; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 3535a59d6f85..3e884bf4ac1e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -20,7 +20,9 @@ #include "../common/io.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" +#include "../common/bitfield.h" #include "../common/timer.h" +#include "../common/categorical.h" #include "../data/ellpack_page.cuh" #include "param.h" @@ -175,6 +177,7 @@ template struct GPUHistMakerDevice { int device_id; EllpackPageImpl* page; + common::Span feature_types; BatchParam batch_param; std::unique_ptr row_partitioner; @@ -205,9 +208,12 @@ struct GPUHistMakerDevice { std::unique_ptr sampler; std::unique_ptr feature_groups; + // Storing split categories for 1 node. + dh::caching_device_vector node_categories; GPUHistMakerDevice(int _device_id, EllpackPageImpl* _page, + common::Span _feature_types, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed, @@ -216,6 +222,7 @@ struct GPUHistMakerDevice { BatchParam _batch_param) : device_id(_device_id), page(_page), + feature_types{_feature_types}, param(std::move(_param)), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), @@ -300,11 +307,17 @@ struct GPUHistMakerDevice { common::Span feature_set = interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); auto matrix = page->GetDeviceAccessor(device_id); + + auto root_hist = hist.GetNodeHistogram(nidx); + std::vector h_hist(root_hist.size()); + dh::CopyDeviceSpanToVector(&h_hist, root_hist); + EvaluateSplitInputs inputs{ nidx, {root_sum.GetGrad(), root_sum.GetHess()}, gpu_param, feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -343,6 +356,7 @@ struct GPUHistMakerDevice { candidate.split.left_sum.GetHess()}, gpu_param, left_feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -355,6 +369,7 @@ struct GPUHistMakerDevice { candidate.split.right_sum.GetHess()}, gpu_param, right_feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -405,8 +420,11 @@ struct GPUHistMakerDevice { hist.HistogramExists(nidx_parent); } - void UpdatePosition(int nidx, RegTree::Node split_node) { + void UpdatePosition(int nidx, RegTree* p_tree) { + RegTree::Node split_node = (*p_tree)[nidx]; + auto split_type = p_tree->NodeSplitType(nidx); auto d_matrix = page->GetDeviceAccessor(device_id); + auto node_cats = dh::ToSpan(node_categories); row_partitioner->UpdatePosition( nidx, split_node.LeftChild(), split_node.RightChild(), @@ -415,11 +433,17 @@ struct GPUHistMakerDevice { bst_float cut_value = d_matrix.GetFvalue(ridx, split_node.SplitIndex()); // Missing value - int new_position = 0; + bst_node_t new_position = 0; if (isnan(cut_value)) { new_position = split_node.DefaultChild(); } else { - if (cut_value <= split_node.SplitCond()) { + bool go_left = true; + if (split_type == FeatureType::kCategorical) { + go_left = common::Decision(node_cats, common::AsCat(cut_value)); + } else { + go_left = cut_value <= split_node.SplitCond(); + } + if (go_left) { new_position = split_node.LeftChild(); } else { new_position = split_node.RightChild(); @@ -434,48 +458,77 @@ struct GPUHistMakerDevice { // prediction cache void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) { dh::TemporaryArray d_nodes(p_tree->GetNodes().size()); - dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(), - d_nodes.size() * sizeof(RegTree::Node), - cudaMemcpyHostToDevice)); + dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(), + d_nodes.size() * sizeof(RegTree::Node), + cudaMemcpyHostToDevice)); + auto const& h_split_types = p_tree->GetSplitTypes(); + auto const& categories = p_tree->GetSplitCategories(); + auto const& categories_segments = p_tree->GetSplitCategoriesPtr(); + + dh::device_vector d_split_types; + dh::device_vector d_categories; + dh::device_vector d_categories_segments; + + dh::CopyToD(h_split_types, &d_split_types); + dh::CopyToD(categories, &d_categories); + dh::CopyToD(categories_segments, &d_categories_segments); if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) { row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); } if (page->n_rows == p_fmat->Info().num_row_) { - FinalisePositionInPage(page, dh::ToSpan(d_nodes)); + FinalisePositionInPage(page, dh::ToSpan(d_nodes), + dh::ToSpan(d_split_types), dh::ToSpan(d_categories), + dh::ToSpan(d_categories_segments)); } else { for (auto& batch : p_fmat->GetBatches(batch_param)) { - FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes)); + FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), + dh::ToSpan(d_split_types), dh::ToSpan(d_categories), + dh::ToSpan(d_categories_segments)); } } } - void FinalisePositionInPage(EllpackPageImpl* page, const common::Span d_nodes) { + void FinalisePositionInPage(EllpackPageImpl *page, + const common::Span d_nodes, + common::Span d_feature_types, + common::Span categories, + common::Span categories_segments) { auto d_matrix = page->GetDeviceAccessor(device_id); row_partitioner->FinalisePosition( [=] __device__(size_t row_id, int position) { - if (!d_matrix.IsInRange(row_id)) { - return RowPartitioner::kIgnoredTreePosition; - } - auto node = d_nodes[position]; + // What happens if user prune the tree? + if (!d_matrix.IsInRange(row_id)) { + return RowPartitioner::kIgnoredTreePosition; + } + auto node = d_nodes[position]; - while (!node.IsLeaf()) { - bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); - // Missing value - if (isnan(element)) { - position = node.DefaultChild(); - } else { - if (element <= node.SplitCond()) { - position = node.LeftChild(); - } else { - position = node.RightChild(); + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + bool go_left = true; + if (common::IsCat(d_feature_types, position)) { + auto node_cats = + categories.subspan(categories_segments[position].beg, + categories_segments[position].size); + go_left = common::Decision(node_cats, common::AsCat(element)); + } else { + go_left = element <= node.SplitCond(); + } + if (go_left) { + position = node.LeftChild(); + } else { + position = node.RightChild(); + } + } + node = d_nodes[position]; } - } - node = d_nodes[position]; - } - return position; - }); + return position; + }); } void UpdatePredictionCache(bst_float* out_preds_d) { @@ -507,7 +560,7 @@ struct GPUHistMakerDevice { weight * param_d.learning_rate; }); - dh::safe_cuda(cudaMemcpy( + dh::safe_cuda(cudaMemcpyAsync( out_preds_d, prediction_cache.data().get(), prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault)); row_partitioner.reset(); @@ -570,11 +623,32 @@ struct GPUHistMakerDevice { auto right_weight = node_value_constraints[candidate.nid].CalcWeight( param, candidate.split.right_sum) * param.learning_rate; - tree.ExpandNode(candidate.nid, candidate.split.findex, - candidate.split.fvalue, candidate.split.dir == kLeftDir, - base_weight, left_weight, right_weight, - candidate.split.loss_chg, parent_sum.GetHess(), - candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); + + auto is_cat = candidate.split.is_cat; + if (is_cat) { + auto cat = common::AsCat(candidate.split.fvalue); + std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1))); + LBitField32 cats_bits(split_cats); + cats_bits.Set(cat); + node_categories.resize(split_cats.size()); + dh::safe_cuda(cudaMemcpyAsync( + node_categories.data().get(), split_cats.data(), + split_cats.size() * sizeof(uint32_t), cudaMemcpyHostToDevice)); + tree.ExpandCategorical( + candidate.nid, candidate.split.findex, split_cats, + candidate.split.dir == kLeftDir, base_weight, left_weight, + right_weight, candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), + candidate.split.right_sum.GetHess()); + } else { + tree.ExpandNode(candidate.nid, candidate.split.findex, + candidate.split.fvalue, candidate.split.dir == kLeftDir, + base_weight, left_weight, right_weight, + candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), + candidate.split.right_sum.GetHess()); + } + // Set up child constraints node_value_constraints.resize(tree.GetNodes().size()); node_value_constraints[candidate.nid].SetChild( @@ -655,9 +729,9 @@ struct GPUHistMakerDevice { int right_child_nidx = tree[candidate.nid].RightChild(); // Only create child entries if needed if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { + num_leaves)) { monitor.Start("UpdatePosition"); - this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); + this->UpdatePosition(candidate.nid, p_tree); monitor.Stop("UpdatePosition"); monitor.Start("BuildHist"); @@ -746,8 +820,10 @@ class GPUHistMakerSpecialised { }; auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); dh::safe_cuda(cudaSetDevice(device_)); + info_->feature_types.SetDevice(device_); maker.reset(new GPUHistMakerDevice(device_, page, + info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_, column_sampling_seed, diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index b225acb2039d..5b35c537dc71 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -108,7 +108,7 @@ TEST(HistUtil, DeviceSketchDeterminism) { } } -TEST(HistUtil, DeviceSketchCategorical) { +TEST(HistUtil, DeviceSketchCategoricalAsNumeric) { int categorical_sizes[] = {2, 6, 8, 12}; int num_bins = 256; int sizes[] = {25, 100, 1000}; @@ -122,6 +122,33 @@ TEST(HistUtil, DeviceSketchCategorical) { } } +void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) { + auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); + auto dmat = GetDMatrixFromData(x, n, 1); + dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); + ASSERT_EQ(dmat->Info().feature_types.Size(), 1); + auto cuts = DeviceSketch(0, dmat.get(), num_bins); + std::sort(x.begin(), x.end()); + auto n_uniques = std::unique(x.begin(), x.end()) - x.begin(); + ASSERT_NE(n_uniques, x.size()); + ASSERT_EQ(cuts.TotalBins(), n_uniques); + ASSERT_EQ(n_uniques, num_categories); + + auto& values = cuts.cut_values_.HostVector(); + ASSERT_TRUE(std::is_sorted(values.cbegin(), values.cend())); + auto is_unique = (std::unique(values.begin(), values.end()) - values.begin()) == n_uniques; + ASSERT_TRUE(is_unique); + + x.resize(n_uniques); + for (size_t i = 0; i < n_uniques; ++i) { + ASSERT_EQ(x[i], values[i]); + } +} + +TEST(HistUtil, DeviceSketchCategoricalFeatures) { + TestCategoricalSketch(1000, 256, 32); +} + TEST(HistUtil, DeviceSketchMultipleColumns) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {100, 1000, 1500}; @@ -237,7 +264,8 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) { template auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing, size_t batch_size = 0) { common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, adapter.NumColumns(), adapter.NumRows(), 0); + HostDeviceVector ft; + SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(), 0); MetaInfo info; AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); @@ -305,7 +333,8 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); + HostDeviceVector ft; + SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, 0); AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); HistogramCuts cuts; @@ -332,10 +361,12 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); + HostDeviceVector ft; + SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, 0); AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); + HistogramCuts cuts; sketch_container.MakeCuts(&cuts); ConsoleLogger::Configure({{"verbosity", "0"}}); @@ -477,9 +508,11 @@ void TestAdapterSketchFromWeights(bool with_group) { data::CupyAdapter adapter(m); auto const& batch = adapter.Value(); - SketchContainer sketch_container(kBins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_container(ft, kBins, kCols, kRows, 0); AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), &sketch_container); + common::HistogramCuts cuts; sketch_container.MakeCuts(&cuts); diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index d025e5ea60bf..c2c9cdb296ec 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -6,6 +6,7 @@ #include #include #include "../../../src/common/hist_util.h" +#include "../../../src/common/categorical.h" #include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/adapter.h" @@ -60,26 +61,6 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector &x, } #endif -inline std::vector GenerateRandomCategoricalSingleColumn(int n, - int num_categories) { - std::vector x(n); - std::mt19937 rng(0); - std::uniform_int_distribution dist(0, num_categories - 1); - std::generate(x.begin(), x.end(), [&]() { return dist(rng); }); - // Make sure each category is present - for(auto i = 0; i < num_categories; i++) { - x[i] = i; - } - return x; -} - -inline std::shared_ptr -GetDMatrixFromData(const std::vector &x, int num_rows, int num_columns) { - data::DenseAdapter adapter(x.data(), num_rows, num_columns); - return std::shared_ptr(new data::SimpleDMatrix( - &adapter, std::numeric_limits::quiet_NaN(), 1)); -} - inline std::shared_ptr GetExternalMemoryDMatrixFromData( const std::vector& x, int num_rows, int num_columns, size_t page_size, const dmlc::TemporaryDirectory& tempdir) { @@ -147,12 +128,14 @@ inline void TestRank(const std::vector &column_cuts, inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, const std::vector& sorted_column, const std::vector& sorted_weights, - size_t num_bins) { - + size_t num_bins, bool is_cat = false) { // Check the endpoints are correct CHECK_GT(sorted_column.size(), 0); - EXPECT_LT(cuts.MinValues().at(column_idx), sorted_column.front()); - EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); + if (is_cat) { + EXPECT_EQ(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); + } else { + EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); + } EXPECT_GE(cuts.Values()[cuts.Ptrs()[column_idx+1]-1], sorted_column.back()); // Check the cuts are sorted @@ -187,7 +170,9 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, int num_bins) { // Collect data into columns - std::vector> columns(dmat->Info().num_col_); + auto const& info = dmat->Info(); + auto const& ft = info.feature_types.ConstHostSpan(); + std::vector> columns(info.num_col_); for (auto& batch : dmat->GetBatches()) { ASSERT_GT(batch.Size(), 0ul); for (auto i = 0ull; i < batch.Size(); i++) { @@ -197,7 +182,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, } } // Sort - for (auto i = 0ull; i < columns.size(); i++) { + for (auto i = 0ul; i < columns.size(); i++) { auto& col = columns.at(i); const auto& w = dmat->Info().weights_.HostVector(); std::vector index(col.size()); @@ -214,7 +199,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, } } - ValidateColumn(cuts, i, sorted_column, sorted_weights, num_bins); + ValidateColumn(cuts, i, sorted_column, sorted_weights, num_bins, IsCat(ft, i)); } } diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index fa748de1cc6c..345bfe5d4c4c 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -1,11 +1,12 @@ #include +#include "test_hist_util.h" #include "test_quantile.h" + #include "../../../src/common/quantile.h" #include "../../../src/common/hist_util.h" namespace xgboost { namespace common { - TEST(Quantile, LoadBalance) { size_t constexpr kRows = 1000, kCols = 100; auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); @@ -183,5 +184,17 @@ TEST(Quantile, SameOnAllWorkers) { rabit::Finalize(); #endif // defined(__unix__) } + +TEST(CPUQuantile, FromOneHot) { + std::vector x = BasicOneHotEncodedData(); + auto m = GetDMatrixFromData(x, 5, 3); + + int32_t max_bins = 16; + HistogramCuts cuts = SketchOnDMatrix(m.get(), max_bins); + + std::vector const& h_cuts_ptr = cuts.Ptrs(); + std::vector h_cuts_values = cuts.Values(); + ValidateBasicOneHot(h_cuts_ptr, h_cuts_values); +} } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index f7c7e22e3650..65d61d532ec6 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -1,6 +1,7 @@ #include #include "test_quantile.h" #include "../helpers.h" +#include "test_quantile.h" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/quantile.cuh" @@ -8,7 +9,8 @@ namespace xgboost { namespace common { TEST(GPUQuantile, Basic) { constexpr size_t kRows = 1000, kCols = 100, kBins = 256; - SketchContainer sketch(kBins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch(ft, kBins, kCols, kRows, 0); dh::caching_device_vector entries; dh::device_vector cuts_ptr(kCols+1); thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0); @@ -20,7 +22,8 @@ TEST(GPUQuantile, Basic) { void TestSketchUnique(float sparsity) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) { - SketchContainer sketch(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity} @@ -94,7 +97,8 @@ void TestQuantileElemRank(int32_t device, Span in, TEST(GPUQuantile, Prune) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { - SketchContainer sketch(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} @@ -127,7 +131,8 @@ TEST(GPUQuantile, Prune) { TEST(GPUQuantile, MergeEmpty) { constexpr size_t kRows = 1000, kCols = 100; size_t n_bins = 10; - SketchContainer sketch_0(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_0(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage_0; std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface( @@ -166,7 +171,8 @@ TEST(GPUQuantile, MergeEmpty) { TEST(GPUQuantile, MergeBasic) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { - SketchContainer sketch_0(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_0(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage_0; std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0} .Device(0) @@ -176,7 +182,7 @@ TEST(GPUQuantile, MergeBasic) { AdapterDeviceSketch(adapter_0.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_0); - SketchContainer sketch_1(n_bins, kCols, kRows * kRows, 0); + SketchContainer sketch_1(ft, n_bins, kCols, kRows * kRows, 0); HostDeviceVector storage_1; std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0} .Device(0) @@ -212,7 +218,8 @@ TEST(GPUQuantile, MergeBasic) { void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { MetaInfo info; int32_t seed = 0; - SketchContainer sketch_0(n_bins, cols, rows, 0); + HostDeviceVector ft; + SketchContainer sketch_0(ft, n_bins, cols, rows, 0); HostDeviceVector storage_0; std::string interface_str_0 = RandomDataGenerator{rows, cols, 0} .Device(0) @@ -224,7 +231,7 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { &sketch_0); size_t f_rows = rows * frac; - SketchContainer sketch_1(n_bins, cols, f_rows, 0); + SketchContainer sketch_1(ft, n_bins, cols, f_rows, 0); HostDeviceVector storage_1; std::string interface_str_1 = RandomDataGenerator{f_rows, cols, 0} .Device(0) @@ -288,7 +295,8 @@ TEST(GPUQuantile, AllReduceBasic) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { // Set up single node version; - SketchContainer sketch_on_single_node(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, 0); size_t intermediate_num_cuts = std::min(kRows * world, static_cast(n_bins * WQSketch::kFactor)); @@ -300,7 +308,8 @@ TEST(GPUQuantile, AllReduceBasic) { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - containers.emplace_back(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + containers.emplace_back(ft, n_bins, kCols, kRows, 0); AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &containers.back()); @@ -317,7 +326,7 @@ TEST(GPUQuantile, AllReduceBasic) { // Set up distributed version. We rely on using rank as seed to generate // the exact same copy of data. auto rank = rabit::GetRank(); - SketchContainer sketch_distributed(n_bins, kCols, kRows, 0); + SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} .Device(0) @@ -376,7 +385,8 @@ TEST(GPUQuantile, SameOnAllWorkers) { RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { auto rank = rabit::GetRank(); - SketchContainer sketch_distributed(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} .Device(0) @@ -433,5 +443,17 @@ TEST(GPUQuantile, SameOnAllWorkers) { return; #endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL) } + +TEST(GPUQuantile, FromOneHot) { + std::vector x = BasicOneHotEncodedData(); + auto m = GetDMatrixFromData(x, 5, 3); + int32_t max_bins = 16; + auto cuts = DeviceSketch(0, m.get(), max_bins); + + std::vector const& h_cuts_ptr = cuts.Ptrs(); + std::vector h_cuts_values = cuts.Values(); + + ValidateBasicOneHot(h_cuts_ptr, h_cuts_values); +} } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index e91f19ef84a8..e17465752803 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -1,4 +1,9 @@ +#ifndef XGBOOST_TEST_QUANTILE_H_ +#define XGBOOST_TEST_QUANTILE_H_ + #include +#include + #include #include #include @@ -50,5 +55,33 @@ template void RunWithSeedsAndBins(size_t rows, Fn fn) { } } } +inline auto BasicOneHotEncodedData() { + std::vector x { + 0, 1, 0, + 0, 1, 0, + 0, 1, 0, + 0, 0, 1, + 1, 0, 0 + }; + return x; +} + +inline void ValidateBasicOneHot(std::vector const &h_cuts_ptr, + std::vector const &h_cuts_values) { + size_t const cols = 3; + ASSERT_EQ(h_cuts_ptr.size(), cols + 1); + ASSERT_EQ(h_cuts_values.size(), cols * 2); + + for (size_t i = 1; i < h_cuts_ptr.size(); ++i) { + auto feature = + common::Span(h_cuts_values) + .subspan(h_cuts_ptr[i - 1], h_cuts_ptr[i] - h_cuts_ptr[i - 1]); + EXPECT_EQ(feature.size(), 2); + // 0 is discarded as min value. + EXPECT_EQ(feature[0], 1.0f); + EXPECT_GT(feature[1], 1.0f); + } +} } // namespace common } // namespace xgboost +#endif // XGBOOST_TEST_QUANTILE_H_ diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index f92e6a7f9a4d..5d8b80c0b5a9 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -7,10 +7,14 @@ #include "../helpers.h" #include "../histogram_helpers.h" +#include "../common/test_quantile.h" #include "gtest/gtest.h" +#include "../../../src/common/categorical.h" #include "../../../src/common/hist_util.h" #include "../../../src/data/ellpack_page.cuh" +#include "../../../src/data/adapter.h" +#include "../../../src/data/simple_dmatrix.h" namespace xgboost { @@ -77,6 +81,64 @@ TEST(EllpackPage, BuildGidxSparse) { } } +TEST(EllpackPage, FromCategoricalBasic) { + using common::AsCat; + size_t constexpr kRows = 1000, kCats = 13, kCols = 1; + size_t max_bins = 8; + auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats); + auto m = GetDMatrixFromData(x, kRows, 1); + auto& h_ft = m->Info().feature_types.HostVector(); + h_ft.resize(kCols, FeatureType::kCategorical); + + BatchParam p(0, max_bins); + auto ellpack = EllpackPage(m.get(), p); + auto accessor = ellpack.Impl()->GetDeviceAccessor(0); + ASSERT_EQ(kCats, accessor.NumBins()); + + auto x_copy = x; + std::sort(x_copy.begin(), x_copy.end()); + auto n_uniques = std::unique(x_copy.begin(), x_copy.end()) - x_copy.begin(); + ASSERT_EQ(n_uniques, kCats); + + std::vector h_cuts_ptr(accessor.feature_segments.size()); + dh::CopyDeviceSpanToVector(&h_cuts_ptr, accessor.feature_segments); + std::vector h_cuts_values(accessor.gidx_fvalue_map.size()); + dh::CopyDeviceSpanToVector(&h_cuts_values, accessor.gidx_fvalue_map); + + ASSERT_EQ(h_cuts_ptr.size(), 2); + ASSERT_EQ(h_cuts_values.size(), kCats); + + std::vector const &h_gidx_buffer = + ellpack.Impl()->gidx_buffer.HostVector(); + auto h_gidx_iter = common::CompressedIterator( + h_gidx_buffer.data(), accessor.NumSymbols()); + + for (size_t i = 0; i < x.size(); ++i) { + auto bin = h_gidx_iter[i]; + auto bin_value = h_cuts_values.at(bin); + ASSERT_EQ(AsCat(x[i]), AsCat(bin_value)); + } +} + +TEST(EllpackPage, FromOneHot) { + std::vector x = common::BasicOneHotEncodedData(); + auto m = GetDMatrixFromData(x, 5, 3); + int32_t max_bins = 16; + BatchParam p(0, max_bins); + auto ellpack = EllpackPage(m.get(), p); + auto accessor = ellpack.Impl()->GetDeviceAccessor(0); + + std::vector h_cuts_ptr(accessor.feature_segments.size()); + dh::CopyDeviceSpanToVector(&h_cuts_ptr, accessor.feature_segments); + std::vector h_cuts_values(accessor.gidx_fvalue_map.size()); + dh::CopyDeviceSpanToVector(&h_cuts_values, accessor.gidx_fvalue_map); + + size_t const cols = 3; + ASSERT_EQ(h_cuts_ptr.size(), cols + 1); + ASSERT_EQ(h_cuts_values.size(), cols * 2); + common::ValidateBasicOneHot(h_cuts_ptr, h_cuts_values); +} + struct ReadRowFunction { EllpackDeviceAccessor matrix; int row; @@ -194,5 +256,4 @@ TEST(EllpackPage, Compact) { } } } - } // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 858b651981fb..1b319a8873ff 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -17,6 +17,7 @@ #include "helpers.h" #include "xgboost/c_api.h" #include "../../src/data/adapter.h" +#include "../../src/data/simple_dmatrix.h" #include "../../src/gbm/gbtree_model.h" #include "xgboost/predictor.h" @@ -350,6 +351,13 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label, return out; } +std::shared_ptr +GetDMatrixFromData(const std::vector &x, int num_rows, int num_columns){ + data::DenseAdapter adapter(x.data(), num_rows, num_columns); + return std::shared_ptr(new data::SimpleDMatrix( + &adapter, std::numeric_limits::quiet_NaN(), 1)); +} + std::unique_ptr CreateSparsePageDMatrix( size_t n_entries, size_t page_size, std::string tmp_file) { // Create sufficiently large data to make two row pages @@ -539,5 +547,4 @@ RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) { return RMMAllocatorPtr(nullptr, DeleteRMMResource); } #endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1 - } // namespace xgboost diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 5d4ce6cefa68..be8356c86bac 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -42,6 +42,12 @@ struct LearnerModelParam; class GradientBooster; } +template +Float RelError(Float l, Float r) { + static_assert(std::is_floating_point::value, ""); + return std::abs(1.0f - l / r); +} + bool FileExists(const std::string& filename); int64_t GetFileSize(const std::string& filename); @@ -254,6 +260,22 @@ class RandomDataGenerator { #endif }; +inline std::vector +GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) { + std::vector x(n); + std::mt19937 rng(0); + std::uniform_int_distribution dist(0, num_categories - 1); + std::generate(x.begin(), x.end(), [&]() { return dist(rng); }); + // Make sure each category is present + for(size_t i = 0; i < num_categories; i++) { + x[i] = i; + } + return x; +} + +std::shared_ptr +GetDMatrixFromData(const std::vector &x, int num_rows, int num_columns); + std::unique_ptr CreateSparsePageDMatrix( size_t n_entries, size_t page_size, std::string tmp_file); @@ -308,6 +330,9 @@ inline HostDeviceVector GenerateRandomGradients(const size_t n_row return gpair; } +std::shared_ptr GetDMatrixFromData(const std::vector &x, + int num_rows, int num_columns); + typedef void *DMatrixHandle; // NOLINT(*); class CudaArrayIterForTest { diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 585acf1790b6..c928c2886d91 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -160,9 +160,11 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT dmlc::Error); } + TEST(GpuPredictor, LesserFeatures) { TestPredictionWithLesserFeatures("gpu_predictor"); } + // Very basic test of empty model TEST(GPUPredictor, ShapStump) { cudaSetDevice(0); @@ -189,6 +191,7 @@ TEST(GPUPredictor, ShapStump) { EXPECT_EQ(phis[4], 0.0); EXPECT_EQ(phis[5], param.base_score); } + TEST(GPUPredictor, Shap) { LearnerModelParam param; param.num_feature = 1; @@ -219,5 +222,8 @@ TEST(GPUPredictor, Shap) { } } +TEST(GPUPredictor, CategoricalPrediction) { + TestCategoricalPrediction("gpu_predictor"); +} } // namespace predictor } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 3005d585f5a6..7f6de563125b 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -12,6 +12,8 @@ #include "../helpers.h" #include "../../../src/common/io.h" +#include "../../../src/common/categorical.h" +#include "../../../src/common/bitfield.h" namespace xgboost { TEST(Predictor, PredictionCache) { @@ -27,7 +29,7 @@ TEST(Predictor, PredictionCache) { }; add_cache(); - ASSERT_EQ(container.Container().size(), 0); + ASSERT_EQ(container.Container().size(), 0ul); add_cache(); EXPECT_ANY_THROW(container.Entry(m)); } @@ -174,4 +176,55 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) { } #endif // defined(XGBOOST_USE_CUDA) } + +void TestCategoricalPrediction(std::string name) { + size_t constexpr kCols = 10; + PredictionCacheEntry out_predictions; + + LearnerModelParam param; + param.num_feature = kCols; + param.num_output_group = 1; + param.base_score = 0.5; + + gbm::GBTreeModel model(¶m); + + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + auto& p_tree = trees.front(); + + uint32_t split_ind = 3; + bst_cat_t split_cat = 4; + float left_weight = 1.3f; + float right_weight = 1.7f; + + std::vector split_cats(LBitField32::ComputeStorageSize(split_cat)); + LBitField32 cats_bits(split_cats); + cats_bits.Set(split_cat); + + p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f, + left_weight, right_weight, + 3.0f, 2.2f, 7.0f, 9.0f); + model.CommitModel(std::move(trees), 0); + + GenericParameter runtime; + runtime.gpu_id = 0; + std::unique_ptr predictor{ + Predictor::Create(name.c_str(), &runtime)}; + + std::vector row(kCols); + row[split_ind] = split_cat; + auto m = GetDMatrixFromData(row, 1, kCols); + + predictor->PredictBatch(m.get(), &out_predictions, model, 0); + ASSERT_EQ(out_predictions.predictions.Size(), 1ul); + ASSERT_EQ(out_predictions.predictions.HostVector()[0], + right_weight + param.base_score); + + row[split_ind] = split_cat + 1; + m = GetDMatrixFromData(row, 1, kCols); + out_predictions.version = 0; + predictor->PredictBatch(m.get(), &out_predictions, model, 0); + ASSERT_EQ(out_predictions.predictions.HostVector()[0], + left_weight + param.base_score); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index b6a3180111f2..68e034e0a581 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -61,6 +61,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor, int32_t device = -1); void TestPredictionWithLesserFeatures(std::string preditor_name); + +void TestCategoricalPrediction(std::string name); } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_ diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index 7ec925f185a3..4b9670d83ee6 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -6,7 +6,7 @@ namespace xgboost { namespace tree { -TEST(GpuHist, EvaluateSingleSplit) { +void TestEvaluateSingleSplit(bool is_categorical) { thrust::device_vector out_splits(1); GradientPair parent_sum(0.0, 1.0); GPUTrainingParam param{}; @@ -23,17 +23,26 @@ TEST(GpuHist, EvaluateSingleSplit) { thrust::device_vector feature_histogram = std::vector{ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}}; + thrust::device_vector monotonic_constraints(feature_set.size(), 0); + dh::device_vector feature_types(feature_set.size(), + FeatureType::kCategorical); + common::Span d_feature_types; + if (is_categorical) { + d_feature_types = dh::ToSpan(feature_types); + } EvaluateSplitInputs input{1, parent_sum, param, dh::ToSpan(feature_set), + d_feature_types, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram), ValueConstraint(), dh::ToSpan(monotonic_constraints)}; + EvaluateSingleSplit(dh::ToSpan(out_splits), input); DeviceSplitCandidate result = out_splits[0]; @@ -45,6 +54,14 @@ TEST(GpuHist, EvaluateSingleSplit) { parent_sum.GetHess()); } +TEST(GpuHist, EvaluateSingleSplit) { + TestEvaluateSingleSplit(false); +} + +TEST(GpuHist, EvaluateCategoricalSplit) { + TestEvaluateSingleSplit(true); +} + TEST(GpuHist, EvaluateSingleSplitMissing) { thrust::device_vector out_splits(1); GradientPair parent_sum(1.0, 1.5); @@ -63,6 +80,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -115,6 +133,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -152,6 +171,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -190,6 +210,7 @@ TEST(GpuHist, EvaluateSplits) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -201,6 +222,7 @@ TEST(GpuHist, EvaluateSplits) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -217,6 +239,5 @@ TEST(GpuHist, EvaluateSplits) { EXPECT_EQ(result_right.findex, 0); EXPECT_EQ(result_right.fvalue, 1.0); } - } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 99cc4b835fec..1f88f09aa086 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -3,6 +3,7 @@ #include "../../helpers.h" #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../../../src/tree/gpu_hist/histogram.cuh" +#include "../../../../src/common/categorical.h" namespace xgboost { namespace tree { @@ -30,7 +31,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(Gradient)); - + auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); BuildGradientHistogram(page->GetDeviceAccessor(0), feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), @@ -67,7 +68,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { // Use a single feature group to compute the baseline. FeatureGroups single_group(page->Cuts()); - + dh::device_vector baseline(num_bins); BuildGradientHistogram(page->GetDeviceAccessor(0), single_group.DeviceAccessor(0), @@ -97,5 +98,80 @@ TEST(Histogram, GPUDeterministic) { } } } + +std::vector OneHotEncodeFeature(std::vector x, size_t num_cat) { + std::vector ret(x.size() * num_cat, 0); + size_t n_rows = x.size(); + for (size_t r = 0; r < n_rows; ++r) { + bst_cat_t cat = common::AsCat(x[r]); + ret.at(num_cat * r + cat) = 1; + } + return ret; +} + +// Test 1 vs rest categorical histogram is equivalent to one hot encoded data. +void TestGPUHistogramCategorical(size_t num_categories) { + size_t constexpr kRows = 340; + size_t constexpr kBins = 256; + auto x = GenerateRandomCategoricalSingleColumn(kRows, num_categories); + auto cat_m = GetDMatrixFromData(x, kRows, 1); + cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); + BatchParam batch_param{0, static_cast(kBins), 0}; + tree::RowPartitioner row_partitioner(0, kRows); + auto ridx = row_partitioner.GetRows(0); + dh::device_vector cat_hist(num_categories); + auto gpair = GenerateRandomGradients(kRows, 0, 2); + gpair.SetDevice(0); + auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); + // Generate hist with cat data. + for (auto const &batch : cat_m->GetBatches(batch_param)) { + auto* page = batch.Impl(); + FeatureGroups single_group(page->Cuts()); + BuildGradientHistogram(page->GetDeviceAccessor(0), + single_group.DeviceAccessor(0), + gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist), + rounding); + } + + // Generate hist with one hot encoded data. + auto x_encoded = OneHotEncodeFeature(x, num_categories); + auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories); + dh::device_vector encode_hist(2 * num_categories); + for (auto const &batch : encode_m->GetBatches(batch_param)) { + auto* page = batch.Impl(); + FeatureGroups single_group(page->Cuts()); + BuildGradientHistogram(page->GetDeviceAccessor(0), + single_group.DeviceAccessor(0), + gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist), + rounding); + } + + std::vector h_cat_hist(cat_hist.size()); + thrust::copy(cat_hist.begin(), cat_hist.end(), h_cat_hist.begin()); + auto cat_sum = std::accumulate(h_cat_hist.begin(), h_cat_hist.end(), GradientPairPrecise{}); + + std::vector h_encode_hist(encode_hist.size()); + thrust::copy(encode_hist.begin(), encode_hist.end(), h_encode_hist.begin()); + + for (size_t c = 0; c < num_categories; ++c) { + auto zero = h_encode_hist[c * 2]; + auto one = h_encode_hist[c * 2 + 1]; + + auto chosen = h_cat_hist[c]; + auto not_chosen = cat_sum - chosen; + + ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps); + ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps); + + ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps); + ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps); + } +} + +TEST(Histogram, GPUHistCategorical) { + for (size_t num_categories = 2; num_categories < 8; ++num_categories) { + TestGPUHistogramCategorical(num_categories); + } +} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 5199a27d26e8..b969e79bf339 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -80,7 +80,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols, + GPUHistMakerDevice maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); @@ -154,19 +154,18 @@ TEST(GpuHist, EvaluateRootSplit) { TrainParam param; - std::vector> args { - {"max_depth", "1"}, - {"max_leaves", "0"}, - - // Disable all other parameters. - {"colsample_bynode", "1"}, - {"colsample_bylevel", "1"}, - {"colsample_bytree", "1"}, - {"min_child_weight", "0.01"}, - {"reg_alpha", "0"}, - {"reg_lambda", "0"}, - {"max_delta_step", "0"} - }; + std::vector> args{ + {"max_depth", "1"}, + {"max_leaves", "0"}, + + // Disable all other parameters. + {"colsample_bynode", "1"}, + {"colsample_bylevel", "1"}, + {"colsample_bytree", "1"}, + {"min_child_weight", "0.01"}, + {"reg_alpha", "0"}, + {"reg_lambda", "0"}, + {"max_delta_step", "0"}}; param.Init(args); for (size_t i = 0; i < kNCols; ++i) { param.monotone_constraints.emplace_back(0); @@ -178,7 +177,7 @@ TEST(GpuHist, EvaluateRootSplit) { auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; GPUHistMakerDevice - maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param); + maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param); // Initialize GPUHistMakerDevice::node_sum_gradients maker.node_sum_gradients = {}; @@ -261,7 +260,6 @@ void TestHistogramIndexImpl() { ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins()); ASSERT_EQ(maker->page->gidx_buffer.Size(), maker_ext->page->gidx_buffer.Size()); - } TEST(GpuHist, TestHistogramIndex) { diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index ce555bd6a5a6..a1b607865305 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -41,6 +41,52 @@ def test_gpu_hist(self, param, num_rounds, dataset): note(result) assert tm.non_increasing(result['train'][dataset.metric]) + def run_categorical_basic(self, cat, onehot, label, rounds): + by_etl_results = {} + by_builtin_results = {} + + parameters = {'tree_method': 'gpu_hist', + 'predictor': 'gpu_predictor', + 'enable_experimental_json_serialization': True} + + m = xgb.DMatrix(onehot, label, enable_categorical=True) + xgb.train(parameters, m, + num_boost_round=rounds, + evals=[(m, 'Train')], evals_result=by_etl_results) + + m = xgb.DMatrix(cat, label, enable_categorical=True) + xgb.train(parameters, m, + num_boost_round=rounds, + evals=[(m, 'Train')], evals_result=by_builtin_results) + np.testing.assert_allclose( + np.array(by_etl_results['Train']['rmse']), + np.array(by_builtin_results['Train']['rmse']), + rtol=1e-4) + assert tm.non_increasing(by_builtin_results['Train']['rmse']) + + @given(strategies.integers(10, 400), strategies.integers(5, 10), + strategies.integers(1, 6), strategies.integers(4, 8)) + @settings(deadline=None) + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical(self, rows, cols, rounds, cats): + import pandas as pd + rng = np.random.RandomState(1994) + + pd_dict = {} + for i in range(cols): + c = rng.randint(low=0, high=cats+1, size=rows) + pd_dict[str(i)] = pd.Series(c, dtype=np.int64) + + df = pd.DataFrame(pd_dict) + label = df.iloc[:, 0] + for i in range(0, cols-1): + label += df.iloc[:, i] + label += 1 + df = df.astype('category') + x = pd.get_dummies(df) + + self.run_categorical_basic(df, x, label, rounds) + @pytest.mark.skipif(**tm.no_cupy()) @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index a06bfc28361f..0b9d68491974 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -165,7 +165,8 @@ def test_dask_dataframe(self, local_cuda_cluster): @settings(deadline=duration(seconds=120)) @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) - @pytest.mark.parametrize('local_cuda_cluster', [{'n_workers': 2}], indirect=['local_cuda_cluster']) + @pytest.mark.parametrize('local_cuda_cluster', [{'n_workers': 2}], + indirect=['local_cuda_cluster']) @pytest.mark.mgpu def test_gpu_hist(self, params, num_rounds, dataset, local_cuda_cluster): with Client(local_cuda_cluster) as client: diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 04f8c9510629..bd0c36280598 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -67,7 +67,8 @@ def test_pandas(self): # 0 1 1 0 0 # 1 2 0 1 0 # 2 3 0 0 1 - result, _, _ = xgb.data._transform_pandas_df(dummies) + result, _, _ = xgb.data._transform_pandas_df(dummies, + enable_categorical=False) exp = np.array([[1., 1., 0., 0.], [2., 0., 1., 0.], [3., 0., 0., 1.]]) @@ -109,6 +110,16 @@ def test_pandas(self): assert dm.num_row() == 2 assert dm.num_col() == 6 + def test_pandas_categorical(self): + rng = np.random.RandomState(1994) + rows = 100 + X = rng.randint(3, 7, size=rows) + X = pd.Series(X, dtype="category") + X = pd.DataFrame({'f0': X}) + y = rng.randn(rows) + m = xgb.DMatrix(X, y, enable_categorical=True) + assert m.feature_types[0] == 'categorical' + def test_pandas_sparse(self): import pandas as pd rows = 100 @@ -129,15 +140,15 @@ def test_pandas_label(self): # label must be a single column df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) self.assertRaises(ValueError, xgb.data._transform_pandas_df, df, - None, None, 'label', 'float') + False, None, None, 'label', 'float') # label must be supported dtype df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)}) self.assertRaises(ValueError, xgb.data._transform_pandas_df, df, - None, None, 'label', 'float') + False, None, None, 'label', 'float') df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)}) - result, _, _ = xgb.data._transform_pandas_df(df, None, None, + result, _, _ = xgb.data._transform_pandas_df(df, False, None, None, 'label', 'float') np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], dtype=float)) diff --git a/tests/python/testing.py b/tests/python/testing.py index a81e7ea87048..f24b0f357842 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -44,6 +44,17 @@ def no_dt(): 'reason': 'Datatable is not installed.'} +def no_cupy(): + reason = 'cupy is not installed.' + try: + import cupy # noqa + return {'condition': False, + 'reason': reason} + except ImportError: + return {'condition': True, + 'reason': reason} + + def no_matplotlib(): reason = 'Matplotlib is not installed.' try: