From 23b9c61a686451c9d34b036c41cb6317bdc5cfbb Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Tue, 10 Oct 2023 18:58:51 +0200 Subject: [PATCH] Release 2.0.0 oneapi (#5) * initial, not ready for work * fixes for obj functions * fix some compilation problems * fix some errors * fixes * improve context * plugin compiled and somtimes works * fix the errors. tests passed * fix compilation error wo oneapi * black * README update --------- Co-authored-by: Dmitry Razdoburdin <> --- CMakeLists.txt | 4 + include/xgboost/context.h | 97 +- include/xgboost/linalg.h | 4 +- plugin/CMakeLists.txt | 4 + plugin/updater_oneapi/README.md | 26 +- plugin/updater_oneapi/data_oneapi.h | 273 +++ .../updater_oneapi/device_manager_oneapi.cc | 105 ++ plugin/updater_oneapi/device_manager_oneapi.h | 44 + plugin/updater_oneapi/hist_util_oneapi.cc | 498 ++++++ plugin/updater_oneapi/hist_util_oneapi.h | 377 +++++ .../updater_oneapi/multiclass_obj_oneapi.cc | 285 ++++ plugin/updater_oneapi/param_oneapi.h | 216 +++ plugin/updater_oneapi/predictor_oneapi.cc | 428 +++-- .../updater_oneapi/regression_loss_oneapi.h | 18 +- .../updater_oneapi/regression_obj_oneapi.cc | 93 +- plugin/updater_oneapi/row_set_oneapi.h | 261 +++ .../updater_oneapi/split_evaluator_oneapi.h | 192 +++ .../updater_quantile_hist_oneapi.cc | 1501 +++++++++++++++++ .../updater_quantile_hist_oneapi.h | 611 +++++++ python-package/xgboost/sklearn.py | 7 +- src/CMakeLists.txt | 4 + src/common/linalg_op.cuh | 2 +- src/common/linalg_op.h | 2 +- src/common/numeric.cc | 5 +- src/common/optional_weight.h | 2 +- src/common/ranking_utils.h | 42 +- src/common/stats.cc | 8 +- src/context.cc | 56 +- src/gbm/gbtree.cc | 26 +- src/learner.cc | 5 +- src/metric/elementwise_metric.cu | 53 +- src/objective/adaptive.h | 8 +- src/objective/lambdarank_obj.cc | 10 +- src/objective/objective.cc | 11 +- src/objective/quantile_obj.cu | 48 +- src/objective/regression_obj.cu | 8 +- src/tree/fit_stump.cc | 4 +- tests/python-oneapi/test_oneapi_prediction.py | 151 ++ .../test_oneapi_training_continuation.py | 56 + tests/python-oneapi/test_oneapi_updaters.py | 70 + .../python-oneapi/test_oneapi_with_sklearn.py | 35 + 41 files changed, 5233 insertions(+), 417 deletions(-) create mode 100644 plugin/updater_oneapi/data_oneapi.h create mode 100644 plugin/updater_oneapi/device_manager_oneapi.cc create mode 100644 plugin/updater_oneapi/device_manager_oneapi.h create mode 100644 plugin/updater_oneapi/hist_util_oneapi.cc create mode 100644 plugin/updater_oneapi/hist_util_oneapi.h create mode 100644 plugin/updater_oneapi/multiclass_obj_oneapi.cc create mode 100644 plugin/updater_oneapi/param_oneapi.h create mode 100644 plugin/updater_oneapi/row_set_oneapi.h create mode 100644 plugin/updater_oneapi/split_evaluator_oneapi.h create mode 100644 plugin/updater_oneapi/updater_quantile_hist_oneapi.cc create mode 100644 plugin/updater_oneapi/updater_quantile_hist_oneapi.h create mode 100644 tests/python-oneapi/test_oneapi_prediction.py create mode 100644 tests/python-oneapi/test_oneapi_training_continuation.py create mode 100644 tests/python-oneapi/test_oneapi_updaters.py create mode 100644 tests/python-oneapi/test_oneapi_with_sklearn.py diff --git a/CMakeLists.txt b/CMakeLists.txt index a5eebef2eddd..e524d2aaf7f2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -170,6 +170,10 @@ if (USE_CUDA) find_package(CUDAToolkit REQUIRED) endif (USE_CUDA) +if (PLUGIN_UPDATER_ONEAPI) + target_compile_definitions(xgboost PRIVATE -DXGBOOST_USE_ONEAPI=1) +endif (PLUGIN_UPDATER_ONEAPI) + if (FORCE_COLORED_OUTPUT AND (CMAKE_GENERATOR STREQUAL "Ninja") AND ((CMAKE_CXX_COMPILER_ID STREQUAL "GNU") OR (CMAKE_CXX_COMPILER_ID STREQUAL "Clang"))) diff --git a/include/xgboost/context.h b/include/xgboost/context.h index 262733b220d4..7578b7bfb658 100644 --- a/include/xgboost/context.h +++ b/include/xgboost/context.h @@ -22,19 +22,29 @@ struct CUDAContext; struct DeviceSym { static auto constexpr CPU() { return "cpu"; } static auto constexpr CUDA() { return "cuda"; } + static auto constexpr SYCL_default() { return "sycl"; } + static auto constexpr SYCL_CPU() { return "sycl:cpu"; } + static auto constexpr SYCL_GPU() { return "sycl:gpu"; } }; /** * @brief A type for device ordinal. The type is packed into 32-bit for efficient use in * viewing types like `linalg::TensorView`. */ +constexpr static bst_d_ordinal_t kDefaultOrdinal = -1; struct DeviceOrd { - enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU}; - // CUDA device ordinal. - bst_d_ordinal_t ordinal{-1}; + enum Type : std::int16_t { kCPU = 0, kCUDA = 1, kSyclDefault = 2, kSyclCPU = 3, kSyclGPU = 4} device{kCPU}; + // CUDA or Sycl device ordinal. + bst_d_ordinal_t ordinal{kDefaultOrdinal}; [[nodiscard]] bool IsCUDA() const { return device == kCUDA; } [[nodiscard]] bool IsCPU() const { return device == kCPU; } + [[nodiscard]] bool IsSyclDefault() const { return device == kSyclDefault; } + [[nodiscard]] bool IsSyclCPU() const { return device == kSyclCPU; } + [[nodiscard]] bool IsSyclGPU() const { return device == kSyclGPU; } + [[nodiscard]] bool IsSycl() const { return (IsSyclDefault() || + IsSyclCPU() || + IsSyclGPU()); } DeviceOrd() = default; constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {} @@ -47,7 +57,7 @@ struct DeviceOrd { /** * @brief Constructor for CPU. */ - [[nodiscard]] constexpr static auto CPU() { return DeviceOrd{kCPU, -1}; } + [[nodiscard]] constexpr static auto CPU() { return DeviceOrd{kCPU, kDefaultOrdinal}; } /** * @brief Constructor for CUDA device. * @@ -55,6 +65,27 @@ struct DeviceOrd { */ [[nodiscard]] static auto CUDA(bst_d_ordinal_t ordinal) { return DeviceOrd{kCUDA, ordinal}; } + /** + * @brief Constructor for SYCL. + * + * @param ordinal SYCL device ordinal. + */ + [[nodiscard]] constexpr static auto SYCL_default(bst_d_ordinal_t ordinal = kDefaultOrdinal) { return DeviceOrd{kSyclDefault, ordinal}; } + + /** + * @brief Constructor for SYCL CPU. + * + * @param ordinal SYCL CPU device ordinal. + */ + [[nodiscard]] constexpr static auto SYCL_CPU(bst_d_ordinal_t ordinal = kDefaultOrdinal) { return DeviceOrd{kSyclCPU, ordinal}; } + + /** + * @brief Constructor for SYCL GPU. + * + * @param ordinal SYCL GPU device ordinal. + */ + [[nodiscard]] constexpr static auto SYCL_GPU(bst_d_ordinal_t ordinal = kDefaultOrdinal) { return DeviceOrd{kSyclGPU, ordinal}; } + [[nodiscard]] bool operator==(DeviceOrd const& that) const { return device == that.device && ordinal == that.ordinal; } @@ -68,6 +99,12 @@ struct DeviceOrd { return DeviceSym::CPU(); case DeviceOrd::kCUDA: return DeviceSym::CUDA() + (':' + std::to_string(ordinal)); + case DeviceOrd::kSyclDefault: + return DeviceSym::SYCL_default() + (':' + std::to_string(ordinal)); + case DeviceOrd::kSyclCPU: + return DeviceSym::SYCL_CPU() + (':' + std::to_string(ordinal)); + case DeviceOrd::kSyclGPU: + return DeviceSym::SYCL_GPU() + (':' + std::to_string(ordinal)); default: { LOG(FATAL) << "Unknown device."; return ""; @@ -135,6 +172,25 @@ struct Context : public XGBoostParameter { * @brief Is XGBoost running on a CUDA device? */ [[nodiscard]] bool IsCUDA() const { return Device().IsCUDA(); } + /** + * @brief Is XGBoost running on the default SYCL device? + */ + [[nodiscard]] bool IsSyclDefault() const { return Device().IsSyclDefault(); } + /** + * @brief Is XGBoost running on a SYCL CPU? + */ + [[nodiscard]] bool IsSyclCPU() const { return Device().IsSyclCPU(); } + /** + * @brief Is XGBoost running on a SYCL GPU? + */ + [[nodiscard]] bool IsSyclGPU() const { return Device().IsSyclGPU(); } + /** + * @brief Is XGBoost running on any SYCL device? + */ + [[nodiscard]] bool IsSycl() const { return IsSyclDefault() + || IsSyclCPU() + || IsSyclGPU(); } + /** * @brief Get the current device and ordinal. */ @@ -171,6 +227,29 @@ struct Context : public XGBoostParameter { /** * @brief Call function based on the current device. */ + template + decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const { + static_assert(std::is_same_v, std::invoke_result_t>); + switch (this->Device().device) { + case DeviceOrd::kCPU: + return cpu_fn(); + case DeviceOrd::kCUDA: + return cuda_fn(); + case DeviceOrd::kSyclDefault: + return sycl_fn(); + case DeviceOrd::kSyclCPU: + return sycl_fn(); + case DeviceOrd::kSyclGPU: + return sycl_fn(); + default: + // Do not use the device name as this is likely an internal error, the name + // wouldn't be valid. + LOG(FATAL) << "Unknown device type:" + << static_cast>(this->Device().device); + break; + } + return std::invoke_result_t(); + } template decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn) const { static_assert(std::is_same_v, std::invoke_result_t>); @@ -179,6 +258,12 @@ struct Context : public XGBoostParameter { return cpu_fn(); case DeviceOrd::kCUDA: return cuda_fn(); + case DeviceOrd::kSyclDefault: + LOG(FATAL) << "The requested feature is not implemented for sycl yet"; + case DeviceOrd::kSyclCPU: + LOG(FATAL) << "The requested feature is not implemented for sycl yet"; + case DeviceOrd::kSyclGPU: + LOG(FATAL) << "The requested feature is not implemented for sycl yet"; default: // Do not use the device name as this is likely an internal error, the name // wouldn't be valid. @@ -213,7 +298,9 @@ struct Context : public XGBoostParameter { void SetDeviceOrdinal(Args const& kwargs); Context& SetDevice(DeviceOrd d) { this->device_ = d; - this->gpu_id = d.ordinal; // this can be removed once we move away from `gpu_id`. + if (d.IsCUDA()) { + this->gpu_id = d.ordinal; // this can be removed once we move away from `gpu_id`. + } this->device = d.Name(); return *this; } diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 6d2b54f84c17..20f6f9d6c3b8 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -596,13 +596,13 @@ auto MakeTensorView(Context const *ctx, common::Span data, S &&...shape) { template auto MakeTensorView(Context const *ctx, HostDeviceVector *data, S &&...shape) { - auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan(); + auto span = ctx->IsCUDA() ? data->DeviceSpan() : data->HostSpan(); return MakeTensorView(ctx->gpu_id, span, std::forward(shape)...); } template auto MakeTensorView(Context const *ctx, HostDeviceVector const *data, S &&...shape) { - auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan(); + auto span = ctx->IsCUDA() ? data->ConstDeviceSpan() : data->ConstHostSpan(); return MakeTensorView(ctx->gpu_id, span, std::forward(shape)...); } diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index 7026238e30cf..def1d8b4f0cd 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -4,7 +4,11 @@ endif (PLUGIN_DENSE_PARSER) if (PLUGIN_UPDATER_ONEAPI) add_library(oneapi_plugin OBJECT + ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/hist_util_oneapi.cc ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/regression_obj_oneapi.cc + ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/multiclass_obj_oneapi.cc + ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/updater_quantile_hist_oneapi.cc + ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/device_manager_oneapi.cc ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/predictor_oneapi.cc) target_include_directories(oneapi_plugin PRIVATE diff --git a/plugin/updater_oneapi/README.md b/plugin/updater_oneapi/README.md index c2faf6574015..ddb05e497925 100755 --- a/plugin/updater_oneapi/README.md +++ b/plugin/updater_oneapi/README.md @@ -2,30 +2,20 @@ This plugin adds support of OneAPI programming model for tree construction and prediction algorithms to XGBoost. ## Usage -Specify the 'objective' parameter as one of the following options to offload computation of objective function on OneAPI device. +Specify the 'device' parameter as one of the following options to offload model training and inference on OneAPI device. ### Algorithms -| objective | Description | +| device | Description | | --- | --- | -reg:squarederror_oneapi | regression with squared loss | -reg:squaredlogerror_oneapi | regression with root mean squared logarithmic loss | -reg:logistic_oneapi | logistic regression for probability regression task | -binary:logistic_oneapi | logistic regression for binary classification task | -binary:logitraw_oneapi | logistic regression for classification, output score before logistic transformation | - -Specify the 'predictor' parameter as one of the following options to offload prediction stage on OneAPI device. - -### Algorithms -| predictor | Description | -| --- | --- | -predictor_oneapi | prediction using OneAPI device | - -Please note that parameter names are not finalized and can be changed during further integration of OneAPI support. +sycl | use default sycl device | +sycl:gpu | use default sycl gpu | +sycl:cpu | use default sycl cpu | +sycl:gpu:N | use sycl gpu number N | +sycl:cpu:N | use sycl cpu number N | Python example: ```python -param['predictor'] = 'predictor_oneapi' -param['objective'] = 'reg:squarederror_oneapi' +param['device'] = 'sycl:gpu:0' ``` ## Dependencies diff --git a/plugin/updater_oneapi/data_oneapi.h b/plugin/updater_oneapi/data_oneapi.h new file mode 100644 index 000000000000..a4d6ade88342 --- /dev/null +++ b/plugin/updater_oneapi/data_oneapi.h @@ -0,0 +1,273 @@ +/*! + * Copyright by Contributors 2017-2023 + */ +#ifndef XGBOOST_COMMON_DATA_ONEAPI_H_ +#define XGBOOST_COMMON_DATA_ONEAPI_H_ + +#include +#include +#include + +#include "xgboost/base.h" +#include "xgboost/data.h" +#include "xgboost/logging.h" +#include "xgboost/host_device_vector.h" + +#include "../../src/common/threading_utils.h" + +#include "CL/sycl.hpp" + +namespace xgboost { + +enum class MemoryType { shared, on_device}; + + +template +class USMDeleter { +public: + explicit USMDeleter(sycl::queue qu) : qu_(qu) {} + + void operator()(T* data) const { + sycl::free(data, qu_); + } + +private: + sycl::queue qu_; +}; + +/* OneAPI implementation of a HostDeviceVector, storing both host and device memory in a single USM buffer. + Synchronization between host and device is managed by the compiler runtime. */ +template +class USMVector { + static_assert(std::is_standard_layout::value, "USMVector admits only POD types"); + + std::shared_ptr allocate_memory_(sycl::queue& qu, size_t size) { + if constexpr (memory_type == MemoryType::shared) { + return std::shared_ptr(sycl::malloc_shared(size_, qu), USMDeleter(qu)); + } else { + return std::shared_ptr(sycl::malloc_device(size_, qu), USMDeleter(qu)); + } + } + + void copy_vector_to_memory_(sycl::queue& qu, const std::vector &vec) { + if constexpr (memory_type == MemoryType::shared) { + std::copy(vec.begin (), vec.end (), data_.get()); + } else { + qu.memcpy(data_.get(), vec.data(), size_ * sizeof(T)); + } + } + + +public: + USMVector() : size_(0), capacity_(0), data_(nullptr) {} + + USMVector(sycl::queue& qu, size_t size) : size_(size), capacity_(size) { + data_ = allocate_memory_(qu, size_); + } + + USMVector(sycl::queue& qu, size_t size, T v) : size_(size), capacity_(size) { + data_ = allocate_memory_(qu, size_); + qu.fill(data_.get(), v, size_).wait(); + } + + USMVector(sycl::queue& qu, const std::vector &vec) { + size_ = vec.size(); + capacity_ = size_; + data_ = allocate_memory_(qu, size_); + copy_vector_to_memory_(qu, vec); + } + +// Bug. Copy constructor doesn't copy data. +// USMVector(const USMVector& other) : qu_(other.qu_), size_(other.size_), data_(other.data_) { +// } + + ~USMVector() { + } + + USMVector& operator=(const USMVector& other) { + size_ = other.size_; + capacity_ = other.capacity_; + data_ = other.data_; + return *this; + } + + T* Data() { return data_.get(); } + const T* DataConst() const { return data_.get(); } + + size_t Size() const { return size_; } + + size_t Capacity() const { return capacity_; } + + T& operator[] (size_t i) { return data_.get()[i]; } + const T& operator[] (size_t i) const { return data_.get()[i]; } + + T* Begin () const { return data_.get(); } + T* End () const { return data_.get() + size_; } + + bool Empty() const { return (size_ == 0); } + + void Clear() { + data_.reset(); + size_ = 0; + capacity_ = 0; + } + + void Resize(sycl::queue& qu, size_t size_new) { + if (size_new <= capacity_) { + size_ = size_new; + } else { + size_t size_old = size_; + auto data_old = data_; + size_ = size_new; + capacity_ = size_new; + data_ = allocate_memory_(qu, size_);; + if (size_old > 0) { + qu.memcpy(data_.get(), data_old.get(), sizeof(T) * size_old).wait(); + } + } + } + + // T Get(sycl::queue qu, size_t idx, std::vector* events_ptr) const { + // T val; + // auto event = qu.memcpy(&val, data_.get() + idx, sizeof(T), *events_ptr); + // events_ptr->emplace_back(event); + // return val; + // } + + // T Get(sycl::queue& qu, size_t idx) const { + // T val; + // last_event_ = qu.memcpy(&val, data_.get() + idx, sizeof(T)); + // return val; + // } + + // sycl::event GetLastEvent() const { + // return last_event_; + // } + + void Resize(sycl::queue& qu, size_t size_new, T v) { + if (size_new <= size_) { + size_ = size_new; + } else if (size_new <= capacity_) { + qu.fill(data_.get() + size_, v, size_new - size_).wait(); + size_ = size_new; + } else { + size_t size_old = size_; + auto data_old = data_; + size_ = size_new; + capacity_ = size_new; + data_ = allocate_memory_(qu, size_); + if (size_old > 0) { + qu.memcpy(data_.get(), data_old.get(), sizeof(T) * size_old).wait(); + } + qu.fill(data_.get() + size_old, v, size_new - size_old).wait(); + } + } + + sycl::event ResizeAsync(sycl::queue& qu, size_t size_new, T v) { + if (size_new <= size_) { + size_ = size_new; + return sycl::event(); + } else if (size_new <= capacity_) { + auto event = qu.fill(data_.get() + size_, v, size_new - size_); + size_ = size_new; + return event; + } else { + size_t size_old = size_; + auto data_old = data_; + size_ = size_new; + capacity_ = size_new; + data_ = allocate_memory_(qu, size_); + sycl::event event; + if (size_old > 0) { + event = qu.memcpy(data_.get(), data_old.get(), sizeof(T) * size_old); + } + return qu.fill(data_.get() + size_old, v, size_new - size_old, event); + } + } + + sycl::event ResizeAndFill(sycl::queue& qu, size_t size_new, int v) { + if (size_new <= size_) { + size_ = size_new; + return qu.memset(data_.get(), v, size_new * sizeof(T)); + } else if (size_new <= capacity_) { + size_ = size_new; + return qu.memset(data_.get(), v, size_new * sizeof(T)); + } else { + size_t size_old = size_; + auto data_old = data_; + size_ = size_new; + capacity_ = size_new; + data_ = allocate_memory_(qu, size_); + return qu.memset(data_.get(), v, size_new * sizeof(T)); + } + } + + sycl::event Fill(sycl::queue& qu, T v) { + return qu.fill(data_.get(), v, size_); + } + + void Init(sycl::queue& qu, const std::vector &vec) { + size_ = vec.size(); + capacity_ = size_; + data_ = allocate_memory_(qu, size_); + copy_vector_to_memory_(qu, vec); + } + + using value_type = T; // NOLINT + +private: + size_t size_; + size_t capacity_; + std::shared_ptr data_; + // mutable sycl::event last_event_; +}; + +/* Wrapper for DMatrix which stores all batches in a single USM buffer */ +struct DeviceMatrixOneAPI { + DMatrix* p_mat; // Pointer to the original matrix on the host + sycl::queue qu_; + USMVector row_ptr; + USMVector data; + size_t total_offset; + + DeviceMatrixOneAPI(sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) { + size_t num_row = 0; + size_t num_nonzero = 0; + for (auto &batch : dmat->GetBatches()) { + const auto& data_vec = batch.data.HostVector(); + const auto& offset_vec = batch.offset.HostVector(); + num_nonzero += data_vec.size(); + num_row += batch.Size(); + } + + row_ptr.Resize(qu_, num_row + 1); + data.Resize(qu_, num_nonzero); + + size_t data_offset = 0; + for (auto &batch : dmat->GetBatches()) { + const auto& data_vec = batch.data.HostVector(); + const auto& offset_vec = batch.offset.HostVector(); + size_t batch_size = batch.Size(); + if (batch_size > 0) { + std::copy(offset_vec.data(), offset_vec.data() + batch_size, + row_ptr.Data() + batch.base_rowid); + if (batch.base_rowid > 0) { + for(size_t i = 0; i < batch_size; i++) + row_ptr[i + batch.base_rowid] += batch.base_rowid; + } + std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size], + data.Data() + data_offset); + data_offset += offset_vec[batch_size]; + } + } + row_ptr[num_row] = data_offset; + total_offset = data_offset; + } + + ~DeviceMatrixOneAPI() { + } +}; + +} // namespace xgboost + +#endif diff --git a/plugin/updater_oneapi/device_manager_oneapi.cc b/plugin/updater_oneapi/device_manager_oneapi.cc new file mode 100644 index 000000000000..703db82d199b --- /dev/null +++ b/plugin/updater_oneapi/device_manager_oneapi.cc @@ -0,0 +1,105 @@ +/*! + * Copyright 2017-2022 by Contributors + * \file device_manager_oneapi.cc + */ +#include + +#include "./device_manager_oneapi.h" + +namespace xgboost { + +sycl::device DeviceManagerOneAPI::GetDevice(const DeviceOrd& device_spec) const { + bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) || + (rabit::IsDistributed()); + if (not_use_default_selector) { + DeviceRegister& device_register = GetDevicesRegister(); + const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal; + if (device_spec.IsSyclDefault()) { + auto& devices = device_register.devices; + CHECK_LT(device_idx, devices.size()); + return devices[device_idx]; + } else if (device_spec.IsSyclCPU()) { + auto& cpu_devices = device_register.cpu_devices; + CHECK_LT(device_idx, cpu_devices.size()); + return cpu_devices[device_idx]; + } else { + auto& gpu_devices = device_register.gpu_devices; + CHECK_LT(device_idx, gpu_devices.size()); + return gpu_devices[device_idx]; + } + } else { + if (device_spec.IsSyclDefault()) { + return sycl::device(sycl::default_selector_v); + } else if(device_spec.IsSyclCPU()) { + return sycl::device(sycl::cpu_selector_v); + } else { + return sycl::device(sycl::gpu_selector_v); + } + } +} + +sycl::queue DeviceManagerOneAPI::GetQueue(const DeviceOrd& device_spec) const { + QueueRegister_t& queue_register = GetQueueRegister(); + if (queue_register.count(device_spec.Name()) > 0) { + return queue_register.at(device_spec.Name()); + } + + bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) || + (rabit::IsDistributed()); + std::lock_guard guard(queue_registering_mutex); + if (not_use_default_selector) { + DeviceRegister& device_register = GetDevicesRegister(); + const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal; + if (device_spec.IsSyclDefault()) { + auto& devices = device_register.devices; + CHECK_LT(device_idx, devices.size()); + queue_register[device_spec.Name()] = sycl::queue(devices[device_idx]); + } else if (device_spec.IsSyclCPU()) { + auto& cpu_devices = device_register.cpu_devices; + CHECK_LT(device_idx, cpu_devices.size()); + queue_register[device_spec.Name()] = sycl::queue(cpu_devices[device_idx]);; + } else if (device_spec.IsSyclGPU()) { + auto& gpu_devices = device_register.gpu_devices; + CHECK_LT(device_idx, gpu_devices.size()); + queue_register[device_spec.Name()] = sycl::queue(gpu_devices[device_idx]); + } + } else { + if (device_spec.IsSyclDefault()) { + queue_register[device_spec.Name()] = sycl::queue(sycl::default_selector_v); + } else if (device_spec.IsSyclCPU()) { + queue_register[device_spec.Name()] = sycl::queue(sycl::cpu_selector_v); + } else if (device_spec.IsSyclGPU()) { + queue_register[device_spec.Name()] = sycl::queue(sycl::gpu_selector_v); + } + } + return queue_register.at(device_spec.Name()); +} + +DeviceManagerOneAPI::DeviceRegister& DeviceManagerOneAPI::GetDevicesRegister() const { + static DeviceRegister device_register; + + if (device_register.devices.size() == 0) { + std::lock_guard guard(device_registering_mutex); + std::vector devices = sycl::device::get_devices(); + for (size_t i = 0; i < devices.size(); i++) { + LOG(INFO) << "device_index = " << i << ", name = " << devices[i].get_info(); + } + + for (size_t i = 0; i < devices.size(); i++) { + device_register.devices.push_back(devices[i]); + if (devices[i].is_cpu()) { + device_register.cpu_devices.push_back(devices[i]); + } else if (devices[i].is_gpu()) { + device_register.gpu_devices.push_back(devices[i]); + } + } + } + return device_register; +} + +DeviceManagerOneAPI::QueueRegister_t& DeviceManagerOneAPI::GetQueueRegister() const { + static QueueRegister_t queue_register; + return queue_register; +} + +} // namespace xgboost \ No newline at end of file diff --git a/plugin/updater_oneapi/device_manager_oneapi.h b/plugin/updater_oneapi/device_manager_oneapi.h new file mode 100644 index 000000000000..92f02939ca1d --- /dev/null +++ b/plugin/updater_oneapi/device_manager_oneapi.h @@ -0,0 +1,44 @@ +/*! + * Copyright 2017-2022 by Contributors + * \file device_manager_oneapi.h + */ +#ifndef XGBOOST_DEVICE_MANAGER_ONEAPI_H_ +#define XGBOOST_DEVICE_MANAGER_ONEAPI_H_ + +#include +#include +#include + +#include "CL/sycl.hpp" +#include "xgboost/context.h" + +namespace xgboost { + +class DeviceManagerOneAPI { + public: + // DeviceManagerOneAPI(); + + sycl::queue GetQueue(const DeviceOrd& device_spec) const; + + sycl::device GetDevice(const DeviceOrd& device_spec) const; + + private: + using QueueRegister_t = std::unordered_map; + + struct DeviceRegister { + std::vector devices; + std::vector cpu_devices; + std::vector gpu_devices; + }; + + QueueRegister_t& GetQueueRegister() const; + + DeviceRegister& GetDevicesRegister() const; + + mutable std::mutex queue_registering_mutex; + mutable std::mutex device_registering_mutex; +}; + +} // namespace xgboost + +#endif // XGBOOST_DEVICE_MANAGER_ONEAPI_H_ \ No newline at end of file diff --git a/plugin/updater_oneapi/hist_util_oneapi.cc b/plugin/updater_oneapi/hist_util_oneapi.cc new file mode 100644 index 000000000000..cb4e37513407 --- /dev/null +++ b/plugin/updater_oneapi/hist_util_oneapi.cc @@ -0,0 +1,498 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file hist_util_oneapi.cc + */ +#include +#include + +#include "hist_util_oneapi.h" + +#include "CL/sycl.hpp" + +namespace xgboost { +namespace common { + +uint32_t SearchBin(const bst_float* cut_values, const uint32_t* cut_ptrs, Entry const& e) { + auto beg = cut_ptrs[e.index]; + auto end = cut_ptrs[e.index + 1]; + const auto &values = cut_values; + auto it = std::upper_bound(cut_values + beg, cut_values + end, e.fvalue); + uint32_t idx = it - cut_values; + if (idx == end) { + idx -= 1; + } + return idx; +} + +template +void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) { + const size_t total_len = end - begin; + for (size_t block_len = 1; block_len < total_len; block_len <<= 1) { + for (size_t cur_block = 0; cur_block + block_len < total_len; cur_block += 2 * block_len) { + size_t start = cur_block; + size_t mid = start + block_len; + size_t finish = mid + block_len < total_len ? mid + block_len : total_len; + size_t left_pos = start; + size_t right_pos = mid; + size_t pos = start; + while (left_pos < mid || right_pos < finish) { + if (left_pos < mid && (right_pos == finish || begin[left_pos] < begin[right_pos])) { + buf[pos++] = begin[left_pos++]; + } else { + buf[pos++] = begin[right_pos++]; + } + } + for (size_t i = start; i < finish; i++) begin[i] = buf[i]; + } + } +} + +template +void GHistIndexMatrixOneAPI::SetIndexData(sycl::queue qu, + common::Span index_data_span, + const DeviceMatrixOneAPI &dmat_device, + size_t nbins, + size_t row_stride, + uint32_t* offsets) { + const xgboost::Entry *data_ptr = dmat_device.data.DataConst(); + const bst_row_t *offset_vec = dmat_device.row_ptr.DataConst(); + const size_t num_rows = dmat_device.row_ptr.Size() - 1; + BinIdxType* index_data = index_data_span.data(); + const bst_float* cut_values = cut_device.Values().DataConst(); + const uint32_t* cut_ptrs = cut_device.Ptrs().DataConst(); + sycl::buffer hit_count_buf(hit_count.data(), hit_count.size()); + + USMVector sort_buf(qu, num_rows * row_stride); + BinIdxType* sort_data = sort_buf.Data(); + + qu.submit([&](sycl::handler& cgh) { + auto hit_count_acc = hit_count_buf.template get_access(cgh); + cgh.parallel_for<>(sycl::range<1>(num_rows), [=](sycl::item<1> pid) { + const size_t i = pid.get_id(0); + const size_t ibegin = offset_vec[i]; + const size_t iend = offset_vec[i + 1]; + const size_t size = iend - ibegin; + const size_t start = i * row_stride; + for (bst_uint j = 0; j < size; ++j) { + uint32_t idx = SearchBin(cut_values, cut_ptrs, data_ptr[ibegin + j]); + index_data[start + j] = offsets ? idx - offsets[j] : idx; + sycl::atomic_fetch_add(hit_count_acc[idx], 1); + } + if (!offsets) { + // Sparse case only + mergeSort(index_data + start, index_data + start + size, sort_data + start); + for (bst_uint j = size; j < row_stride; ++j) { + index_data[start + j] = nbins; + } + } + }); + }).wait(); +} + +void GHistIndexMatrixOneAPI::ResizeIndex(const size_t n_offsets, + const size_t n_index, + const bool isDense) { + if ((max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { + index.SetBinTypeSize(kUint8BinsTypeSize); + index.Resize((sizeof(uint8_t)) * n_index); + } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && + max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { + index.SetBinTypeSize(kUint16BinsTypeSize); + index.Resize((sizeof(uint16_t)) * n_index); + } else { + index.SetBinTypeSize(kUint32BinsTypeSize); + index.Resize((sizeof(uint32_t)) * n_index); + } +} + +void GHistIndexMatrixOneAPI::Init(sycl::queue qu, + Context const * ctx, + const DeviceMatrixOneAPI& p_fmat_device, + int max_bins) { + nfeatures = p_fmat_device.p_mat->Info().num_col_; + + cut = SketchOnDMatrix(ctx, p_fmat_device.p_mat, max_bins); + cut_device.Init(qu, cut); + + max_num_bins = max_bins; + const uint32_t nbins = cut.Ptrs().back(); + this->nbins = nbins; + hit_count.resize(nbins, 0); + + this->p_fmat = p_fmat_device.p_mat; + const bool isDense = p_fmat_device.p_mat->IsDense(); + this->isDense_ = isDense; + + row_ptr = std::vector(p_fmat_device.row_ptr.Begin(), p_fmat_device.row_ptr.End()); + row_ptr_device = p_fmat_device.row_ptr; + + index.setQueue(qu); + + row_stride = 0; + for (const auto& batch : p_fmat_device.p_mat->GetBatches()) { + const auto& row_offset = batch.offset.ConstHostVector(); + for (auto i = 1ull; i < row_offset.size(); i++) { + row_stride = std::max(row_stride, static_cast(row_offset[i] - row_offset[i - 1])); + } + } + + const size_t n_offsets = cut.Ptrs().size() - 1; + const size_t n_rows = p_fmat_device.row_ptr.Size() - 1; + const size_t n_index = n_rows * row_stride; + ResizeIndex(n_offsets, n_index, isDense); + + CHECK_GT(cut.Values().size(), 0U); + + uint32_t* offsets = nullptr; + if (isDense) { + index.ResizeOffset(n_offsets); + offsets = index.Offset(); + qu.memcpy(offsets, cut.Ptrs().data(), sizeof(uint32_t) * n_offsets).wait_and_throw(); + // for (size_t i = 0; i < n_offsets; ++i) { + // offsets[i] = cut.Ptrs()[i]; + // } + } + + if (isDense) { + BinTypeSize curent_bin_size = index.GetBinTypeSize(); + if (curent_bin_size == kUint8BinsTypeSize) { + common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); + + } else if (curent_bin_size == kUint16BinsTypeSize) { + common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); + } else { + CHECK_EQ(curent_bin_size, kUint32BinsTypeSize); + common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); + } + /* For sparse DMatrix we have to store index of feature for each bin + in index field to chose right offset. So offset is nullptr and index is not reduced */ + } else { + common::Span index_data_span = {index.data(), n_index}; + SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); + } +} + +/*! + * \brief Fill histogram with zeroes + */ +template +void InitHist(sycl::queue qu, GHistRowOneAPI& hist, size_t size) { + qu.fill(hist.Begin(), xgboost::detail::GradientPairInternal(), size); +} +template void InitHist(sycl::queue qu, GHistRowOneAPI& hist, size_t size); +template void InitHist(sycl::queue qu, GHistRowOneAPI& hist, size_t size); + +/*! + * \brief Copy histogram from src to dst + */ +template +void CopyHist(sycl::queue qu, + GHistRowOneAPI& dst, + const GHistRowOneAPI& src, + size_t size) { + GradientSumT* pdst = reinterpret_cast(dst.Data()); + const GradientSumT* psrc = reinterpret_cast(src.DataConst()); + + qu.submit([&](sycl::handler& cgh) { + cgh.parallel_for<>(sycl::range<1>(2 * size), [=](sycl::item<1> pid) { + const size_t i = pid.get_id(0); + pdst[i] = psrc[i]; + }); + }).wait(); +} +template void CopyHist(sycl::queue qu, + GHistRowOneAPI& dst, + const GHistRowOneAPI& src, + size_t size); +template void CopyHist(sycl::queue qu, + GHistRowOneAPI& dst, + const GHistRowOneAPI& src, + size_t size); + +/*! + * \brief Compute Subtraction: dst = src1 - src2 + */ +template +sycl::event SubtractionHist(sycl::queue qu, + GHistRowOneAPI& dst, + const GHistRowOneAPI& src1, + const GHistRowOneAPI& src2, + size_t size, sycl::event event_priv) { + GradientSumT* pdst = reinterpret_cast(dst.Data()); + const GradientSumT* psrc1 = reinterpret_cast(src1.DataConst()); + const GradientSumT* psrc2 = reinterpret_cast(src2.DataConst()); + + auto event_final = qu.submit([&](sycl::handler& cgh) { + cgh.depends_on(event_priv); + cgh.parallel_for<>(sycl::range<1>(2 * size), [pdst, psrc1, psrc2](sycl::item<1> pid) { + const size_t i = pid.get_id(0); + pdst[i] = psrc1[i] - psrc2[i]; + }); + }); + return event_final; +} +template sycl::event SubtractionHist(sycl::queue qu, + GHistRowOneAPI& dst, + const GHistRowOneAPI& src1, + const GHistRowOneAPI& src2, + size_t size, sycl::event event_priv); +template sycl::event SubtractionHist(sycl::queue qu, + GHistRowOneAPI& dst, + const GHistRowOneAPI& src1, + const GHistRowOneAPI& src2, + size_t size, sycl::event event_priv); + +// Kernel with buffer using +template +sycl::event BuildHistKernel(sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem& row_indices, + const GHistIndexMatrixOneAPI& gmat, + GHistRowOneAPI& hist, + GHistRowOneAPI& hist_buffer, + sycl::event event_priv) { + const size_t size = row_indices.Size(); + const size_t* rid = row_indices.begin; + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const float* pgh = reinterpret_cast(gpair_device.DataConst()); + const BinIdxType* gradient_index = gmat.index.data(); + const uint32_t* offsets = gmat.index.Offset(); + FPType* hist_data = reinterpret_cast(hist.Data()); + const size_t nbins = gmat.nbins; + + const size_t max_feat_local = qu.get_device().get_info(); + const size_t feat_local = n_columns < max_feat_local ? n_columns : max_feat_local; + + const size_t max_nblocks = hist_buffer.Size() / (nbins * 2); + const size_t min_block_size = 128; + size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size)); + const size_t block_size = size / nblocks + !!(size % nblocks); + FPType* hist_buffer_data = reinterpret_cast(hist_buffer.Data()); + + auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv); + auto event_main = qu.submit([&](sycl::handler& cgh) { + cgh.depends_on(event_fill); + cgh.parallel_for<>(sycl::nd_range<2>(sycl::range<2>(nblocks, feat_local), + sycl::range<2>(1, feat_local)), [=](sycl::nd_item<2> pid) { + size_t block = pid.get_global_id(0); + size_t feat = pid.get_global_id(1); + + FPType* hist_local = hist_buffer_data + block * nbins * 2; + for (size_t idx = 0; idx < block_size; ++idx) { + size_t i = block * block_size + idx; + if (i < size) { + const size_t icol_start = n_columns * rid[i]; + const size_t idx_gh = rid[i]; + + pid.barrier(sycl::access::fence_space::local_space); + const BinIdxType* gr_index_local = gradient_index + icol_start; + + for (size_t j = feat; j < n_columns; j += feat_local) { + uint32_t idx_bin = static_cast(gr_index_local[j]); + if constexpr (isDense) { + idx_bin += offsets[j]; + } + if (idx_bin < nbins) { + hist_local[2 * idx_bin] += pgh[2 * idx_gh]; + hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1]; + } + } + } + } + }); + }); + + auto event_save = qu.submit([&](sycl::handler& cgh) { + cgh.depends_on(event_main); + cgh.parallel_for<>(sycl::range<1>(nbins), [=](sycl::item<1> pid) { + size_t idx_bin = pid.get_id(0); + + FPType gsum = 0.0f; + FPType hsum = 0.0f; + + for (size_t j = 0; j < nblocks; ++j) { + gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin]; + hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1]; + } + + hist_data[2 * idx_bin] = gsum; + hist_data[2 * idx_bin + 1] = hsum; + }); + }); + return event_save; +} + +// Kernel with atomic using +template +sycl::event BuildHistKernel(sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem& row_indices, + const GHistIndexMatrixOneAPI& gmat, + GHistRowOneAPI& hist, + sycl::event event_priv) { + const size_t size = row_indices.Size(); + const size_t* rid = row_indices.begin; + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const float* pgh = reinterpret_cast(gpair_device.DataConst()); + const BinIdxType* gradient_index = gmat.index.data(); + const uint32_t* offsets = gmat.index.Offset(); + FPType* hist_data = reinterpret_cast(hist.Data()); + const size_t nbins = gmat.nbins; + + const size_t max_feat_local = qu.get_device().get_info(); + const size_t feat_local = n_columns < max_feat_local ? n_columns : max_feat_local; + + auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv); + auto event_main = qu.submit([&](sycl::handler& cgh) { + cgh.depends_on(event_fill); + cgh.parallel_for<>(sycl::range<2>(size, feat_local), + [=](sycl::item<2> pid) { + size_t i = pid.get_id(0); + size_t feat = pid.get_id(1); + + const size_t icol_start = n_columns * rid[i]; + const size_t idx_gh = rid[i]; + + const BinIdxType* gr_index_local = gradient_index + icol_start; + + for (size_t j = feat; j < n_columns; j += feat_local) { + uint32_t idx_bin = static_cast(gr_index_local[j]); + if constexpr (isDense) { + idx_bin += offsets[j]; + } + if (idx_bin < nbins) { + AtomicRef gsum(hist_data[2 * idx_bin]); + AtomicRef hsum(hist_data[2 * idx_bin + 1]); + gsum.fetch_add(pgh[2 * idx_gh]); + hsum.fetch_add(pgh[2 * idx_gh + 1]); + } + } + }); + }); + return event_main; +} + +template +sycl::event BuildHistDispatchKernel(sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem& row_indices, + const GHistIndexMatrixOneAPI& gmat, + GHistRowOneAPI& hist, + bool isDense, + GHistRowOneAPI& hist_buffer, + sycl::event events_priv) { + const size_t size = row_indices.Size(); + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const size_t nbins = gmat.nbins; + + const size_t max_feat_local = qu.get_device().get_info(); + const size_t feat_local = n_columns < max_feat_local ? n_columns : max_feat_local; + + // max cycle size, while atomics are still effective + const size_t max_cycle_size_atomics = nbins; + const size_t cycle_size = size; + if (cycle_size > max_cycle_size_atomics) { + if (isDense) { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, hist_buffer, + events_priv); + } else { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, hist_buffer, + events_priv); + } + } else { + if (isDense) { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, events_priv); + } else { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, events_priv); + } + } +} + +template +sycl::event BuildHistKernel(sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem& row_indices, + const GHistIndexMatrixOneAPI& gmat, const bool isDense, + GHistRowOneAPI& hist, + GHistRowOneAPI& hist_buffer, + sycl::event event_priv) { + const bool is_dense = isDense; + switch (gmat.index.GetBinTypeSize()) { + case kUint8BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv); + break; + case kUint16BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv); + break; + case kUint32BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv); + break; + default: + CHECK(false); // no default behavior + } +} + +template +sycl::event GHistBuilderOneAPI::BuildHist(const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem& row_indices, + const GHistIndexMatrixOneAPI &gmat, + GHistRowT& hist, + bool isDense, + GHistRowT& hist_buffer, + sycl::event event_priv) { + return BuildHistKernel(qu_, gpair_device, row_indices, gmat, isDense, hist, hist_buffer, event_priv); +} + +template +sycl::event GHistBuilderOneAPI::BuildHist(const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem& row_indices, + const GHistIndexMatrixOneAPI& gmat, + GHistRowOneAPI& hist, + bool isDense, + GHistRowOneAPI& hist_buffer, + sycl::event event_priv); +template +sycl::event GHistBuilderOneAPI::BuildHist(const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem& row_indices, + const GHistIndexMatrixOneAPI& gmat, + GHistRowOneAPI& hist, + bool isDense, + GHistRowOneAPI& hist_buffer, + sycl::event event_priv); + +template +void GHistBuilderOneAPI::SubtractionTrick(GHistRowT& self, + GHistRowT& sibling, + GHistRowT& parent) { + const size_t size = self.Size(); + CHECK_EQ(sibling.Size(), size); + CHECK_EQ(parent.Size(), size); + + SubtractionHist(qu_, self, parent, sibling, size, sycl::event()); +} +template +void GHistBuilderOneAPI::SubtractionTrick(GHistRowOneAPI& self, + GHistRowOneAPI& sibling, + GHistRowOneAPI& parent); +template +void GHistBuilderOneAPI::SubtractionTrick(GHistRowOneAPI& self, + GHistRowOneAPI& sibling, + GHistRowOneAPI& parent); +} // namespace common +} // namespace xgboost diff --git a/plugin/updater_oneapi/hist_util_oneapi.h b/plugin/updater_oneapi/hist_util_oneapi.h new file mode 100644 index 000000000000..95e0fb6801c9 --- /dev/null +++ b/plugin/updater_oneapi/hist_util_oneapi.h @@ -0,0 +1,377 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file hist_util_oneapi.h + */ +#ifndef XGBOOST_COMMON_HIST_UTIL_ONEAPI_H_ +#define XGBOOST_COMMON_HIST_UTIL_ONEAPI_H_ + +#include + +#include "data_oneapi.h" +#include "row_set_oneapi.h" + +#include "../../src/common/hist_util.h" + +#include "CL/sycl.hpp" + +namespace xgboost { +namespace common { + +template +using GHistRowOneAPI = USMVector, memory_type>; + +template +using AtomicRef = sycl::atomic_ref; + +/*! + * \brief OneAPI implementation of HistogramCuts stored in USM buffers to provide access from device kernels + */ +class HistogramCutsOneAPI { +protected: + using BinIdx = uint32_t; + +public: + HistogramCutsOneAPI() {} + + HistogramCutsOneAPI(sycl::queue qu) { + cut_ptrs_.Resize(qu_, 1, 0); + } + + ~HistogramCutsOneAPI() { + } + + void Init(sycl::queue qu, HistogramCuts const& cuts) { + qu_ = qu; + cut_values_.Init(qu_, cuts.cut_values_.HostVector()); + cut_ptrs_.Init(qu_, cuts.cut_ptrs_.HostVector()); + min_vals_.Init(qu_, cuts.min_vals_.HostVector()); + } + + // Getters for USM buffers to pass pointers into device kernels + const USMVector& Ptrs() const { return cut_ptrs_; } + const USMVector& Values() const { return cut_values_; } + const USMVector& MinValues() const { return min_vals_; } + +private: + USMVector cut_values_; + USMVector cut_ptrs_; + USMVector min_vals_; + sycl::queue qu_; +}; + +/*! + * \brief Index data and offsets stored in USM buffers to provide access from device kernels + */ +struct IndexOneAPI { + IndexOneAPI() { + SetBinTypeSize(binTypeSize_); + } + IndexOneAPI(const IndexOneAPI& i) = delete; + IndexOneAPI& operator=(IndexOneAPI i) = delete; + IndexOneAPI(IndexOneAPI&& i) = delete; + IndexOneAPI& operator=(IndexOneAPI&& i) = delete; + uint32_t operator[](size_t i) const { + if (!offset_.Empty()) { + return func_(data_.DataConst(), i) + offset_[i%p_]; + } else { + return func_(data_.DataConst(), i); + } + } + void SetBinTypeSize(BinTypeSize binTypeSize) { + binTypeSize_ = binTypeSize; + switch (binTypeSize) { + case kUint8BinsTypeSize: + func_ = &GetValueFromUint8; + break; + case kUint16BinsTypeSize: + func_ = &GetValueFromUint16; + break; + case kUint32BinsTypeSize: + func_ = &GetValueFromUint32; + break; + default: + CHECK(binTypeSize == kUint8BinsTypeSize || + binTypeSize == kUint16BinsTypeSize || + binTypeSize == kUint32BinsTypeSize); + } + } + BinTypeSize GetBinTypeSize() const { + return binTypeSize_; + } + + template + T* data() { + return reinterpret_cast(data_.Data()); + } + + template + const T* data() const { + return reinterpret_cast(data_.DataConst()); + } + + uint32_t* Offset() { + return offset_.Data(); + } + + const uint32_t* Offset() const { + return offset_.DataConst(); + } + + size_t Size() const { + return data_.Size() / (binTypeSize_); + } + + void Resize(const size_t nBytesData) { + data_.Resize(qu_, nBytesData); + } + + void ResizeOffset(const size_t nDisps) { + offset_.Resize(qu_, nDisps); + p_ = nDisps; + } + + uint8_t* begin() const { + return data_.Begin(); + } + + uint8_t* end() const { + return data_.End(); + } + + void setQueue(sycl::queue qu) { + qu_ = qu; + } + + private: + static uint32_t GetValueFromUint8(const uint8_t* t, size_t i) { + return reinterpret_cast(t)[i]; + } + static uint32_t GetValueFromUint16(const uint8_t* t, size_t i) { + return reinterpret_cast(t)[i]; + } + static uint32_t GetValueFromUint32(const uint8_t* t, size_t i) { + return reinterpret_cast(t)[i]; + } + + using Func = uint32_t (*)(const uint8_t*, size_t); + + USMVector data_; + USMVector offset_; // size of this field is equal to number of features + BinTypeSize binTypeSize_ {kUint8BinsTypeSize}; + size_t p_ {1}; + Func func_; + + sycl::queue qu_; +}; + + +/*! + * \brief Preprocessed global index matrix, in CSR format, stored in USM buffers + * + * Transform floating values to integer index in histogram + */ +struct GHistIndexMatrixOneAPI { + /*! \brief row pointer to rows by element position */ + std::vector row_ptr; + USMVector row_ptr_device; + /*! \brief The index data */ + IndexOneAPI index; + /*! \brief hit count of each index */ + std::vector hit_count; + /*! \brief The corresponding cuts */ + HistogramCuts cut; + HistogramCutsOneAPI cut_device; + DMatrix* p_fmat; + size_t max_num_bins; + size_t nbins; + size_t nfeatures; + size_t row_stride; + + // Create a global histogram matrix based on a given DMatrix device wrapper + void Init(sycl::queue qu, Context const * ctx, const DeviceMatrixOneAPI& p_fmat_device, int max_num_bins); + + template + void SetIndexData(sycl::queue qu, common::Span index_data_span, + const DeviceMatrixOneAPI &dmat_device, + size_t nbins, size_t row_stride, uint32_t* offsets); + + void ResizeIndex(const size_t n_offsets, const size_t n_index, + const bool isDense); + + inline void GetFeatureCounts(std::vector& counts) const { + auto nfeature = cut_device.Ptrs().Size() - 1; + for (unsigned fid = 0; fid < nfeature; ++fid) { + auto ibegin = cut_device.Ptrs()[fid]; + auto iend = cut_device.Ptrs()[fid + 1]; + for (auto i = ibegin; i < iend; ++i) { + counts[fid] += hit_count[i]; + } + } + } + inline bool IsDense() const { + return isDense_; + } + + private: + bool isDense_; +}; + +class ColumnMatrixOneAPI; + +/*! + * \brief Fill histogram with zeroes + */ +template +void InitHist(sycl::queue qu, + GHistRowOneAPI& hist, + size_t size); + +/*! + * \brief Copy histogram from src to dst + */ +template +void CopyHist(sycl::queue qu, + GHistRowOneAPI& dst, + const GHistRowOneAPI& src, + size_t size); + +/*! + * \brief Compute subtraction: dst = src1 - src2 + */ +template +sycl::event SubtractionHist(sycl::queue qu, + GHistRowOneAPI& dst, + const GHistRowOneAPI& src1, + const GHistRowOneAPI& src2, + size_t size, sycl::event event_priv); + +/*! + * \brief Histograms of gradient statistics for multiple nodes + */ +template +class HistCollectionOneAPI { + public: + using GHistRowT = GHistRowOneAPI; + + // Access histogram for i-th node + GHistRowT& operator[](bst_uint nid) { + return data_[nid]; + } + + const GHistRowT& operator[](bst_uint nid) const { + return data_[nid]; + } + + // Initialize histogram collection + void Init(sycl::queue qu, uint32_t nbins) { + qu_ = qu; + if (nbins_ != nbins) { + nbins_ = nbins; + data_.clear(); + } + } + + // Reserve the space for hist rows + void Reserve(bst_uint max_nid) { + data_.reserve(max_nid + 1); + } + + // Create an empty histogram for i-th node + sycl::event AddHistRow(bst_uint nid) { + if (nid >= data_.size()) { + data_.resize(nid + 1); + } + return data_[nid].ResizeAsync(qu_, nbins_, xgboost::detail::GradientPairInternal(0, 0)); + } + + void Wait_and_throw() { + qu_.wait_and_throw(); + } + + private: + /*! \brief Number of all bins over all features */ + uint32_t nbins_ = 0; + + std::vector data_; + + sycl::queue qu_; +}; + +/*! + * \brief Stores temporary histograms to compute them in parallel + */ +template +class ParallelGHistBuilderOneAPI { + public: + using GHistRowT = GHistRowOneAPI; + + void Init(sycl::queue qu, size_t nbins) { + qu_ = qu; + if (nbins != nbins_) { + hist_buffer_.Init(qu_, nbins); + nbins_ = nbins; + } + } + + void Reset(size_t nblocks) { + hist_device_buffer_.Resize(qu_, nblocks * nbins_ * 2); + } + + GHistRowT& GetDeviceBuffer() { + return hist_device_buffer_; + } + + protected: + /*! \brief Number of bins in each histogram */ + size_t nbins_ = 0; + /*! \brief Buffers for histograms for all nodes processed */ + HistCollectionOneAPI hist_buffer_; + + /*! \brief Buffer for additional histograms for Parallel processing */ + GHistRowT hist_device_buffer_; + + sycl::queue qu_; +}; + +/*! + * \brief Builder for histograms of gradient statistics + */ +template +class GHistBuilderOneAPI { + public: + template + using GHistRowT = GHistRowOneAPI; + + GHistBuilderOneAPI() = default; + GHistBuilderOneAPI(sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {} + + // Construct a histogram via histogram aggregation + sycl::event BuildHist(const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem& row_indices, + const GHistIndexMatrixOneAPI& gmat, + GHistRowT& HistCollectionOneAPI, + bool isDense, + GHistRowT& hist_buffer, + sycl::event evens); + + // Construct a histogram via subtraction trick + void SubtractionTrick(GHistRowT& self, + GHistRowT& sibling, + GHistRowT& parent); + + uint32_t GetNumBins() const { + return nbins_; + } + + private: + /*! \brief Number of all bins over all features */ + uint32_t nbins_ { 0 }; + + sycl::queue qu_; +}; +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_HIST_UTIL_ONEAPI_H_ diff --git a/plugin/updater_oneapi/multiclass_obj_oneapi.cc b/plugin/updater_oneapi/multiclass_obj_oneapi.cc new file mode 100644 index 000000000000..89f2a6b0a2a1 --- /dev/null +++ b/plugin/updater_oneapi/multiclass_obj_oneapi.cc @@ -0,0 +1,285 @@ +/*! + * Copyright 2015-2023 by Contributors + * \file multiclass_obj_oneapi.cc + * \brief Definition of multi-class classification objectives. + */ +#include +#include +#include +#include +#include + + +#include "xgboost/parameter.h" +#include "xgboost/data.h" +#include "xgboost/logging.h" +#include "xgboost/objective.h" +#include "xgboost/json.h" + +#include "device_manager_oneapi.h" +#include "CL/sycl.hpp" + + +namespace xgboost { +namespace obj { + + +DMLC_REGISTRY_FILE_TAG(multiclass_obj_oneapi); + + +/*! + * \brief Do inplace softmax transformaton on start to end + * + * \tparam Iterator Input iterator type + * + * \param start Start iterator of input + * \param end end iterator of input + */ +template +inline void SoftmaxOneAPI(Iterator start, Iterator end) { + bst_float wmax = *start; + for (Iterator i = start+1; i != end; ++i) { + wmax = sycl::max(*i, wmax); + } + float wsum = 0.0f; + for (Iterator i = start; i != end; ++i) { + *i = sycl::exp(*i - wmax); + wsum += *i; + } + for (Iterator i = start; i != end; ++i) { + *i /= static_cast(wsum); + } +} + + +/*! + * \brief Find the maximum iterator within the iterators + * \param begin The begining iterator. + * \param end The end iterator. + * \return the iterator point to the maximum value. + * \tparam Iterator The type of the iterator. + */ +template +inline Iterator FindMaxIndexOneAPI(Iterator begin, Iterator end) { + Iterator maxit = begin; + for (Iterator it = begin; it != end; ++it) { + if (*it > *maxit) maxit = it; + } + return maxit; +} + + +struct SoftmaxMultiClassParamOneAPI : public XGBoostParameter { + int num_class; + // declare parameters + DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParamOneAPI) { + DMLC_DECLARE_FIELD(num_class).set_lower_bound(1) + .describe("Number of output class in the multi-class classification."); + } +}; + + +class SoftmaxMultiClassObjOneAPI : public ObjFunction { + public: + explicit SoftmaxMultiClassObjOneAPI(bool output_prob) + : output_prob_(output_prob) {} + + + void Configure(Args const& args) override { + param_.UpdateAllowUnknown(args); + qu_ = device_manager.GetQueue(ctx_->Device()); + } + + + void GetGradient(const HostDeviceVector& preds, + const MetaInfo& info, + int iter, + HostDeviceVector* out_gpair) override { + if (info.labels.Size() == 0) { + return; + } + CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels.Size())) + << "SoftmaxMultiClassObjOneAPI: label size and pred size does not match.\n" + << "label.Size() * num_class: " + << info.labels.Size() * static_cast(param_.num_class) << "\n" + << "num_class: " << param_.num_class << "\n" + << "preds.Size(): " << preds.Size(); + + + const int nclass = param_.num_class; + const auto ndata = static_cast(preds.Size() / nclass); + + + out_gpair->Resize(preds.Size()); + + + const bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } + + + sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); + sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); + sycl::buffer out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size()); + sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), + is_null_weight ? 1 : info.weights_.Size()); + + + sycl::buffer additional_input_buf(1); + { + auto additional_input_acc = additional_input_buf.template get_access(); + additional_input_acc[0] = 1; // Fill the label_correct flag + } + + + qu_.submit([&](sycl::handler& cgh) { + auto preds_acc = preds_buf.template get_access(cgh); + auto labels_acc = labels_buf.template get_access(cgh); + auto weights_acc = weights_buf.template get_access(cgh); + auto out_gpair_acc = out_gpair_buf.template get_access(cgh); + auto additional_input_acc = additional_input_buf.template get_access(cgh); + cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { + int idx = pid[0]; + + + bst_float const * point = &preds_acc[idx * nclass]; + + + // Part of Softmax function + bst_float wmax = std::numeric_limits::min(); + for (int k = 0; k < nclass; k++) { wmax = sycl::max(point[k], wmax); } + float wsum = 0.0f; + for (int k = 0; k < nclass; k++) { wsum += sycl::exp(point[k] - wmax); } + auto label = labels_acc[idx]; + if (label < 0 || label >= nclass) { + additional_input_acc[0] = 0; + label = 0; + } + bst_float wt = is_null_weight ? 1.0f : weights_acc[idx]; + for (int k = 0; k < nclass; ++k) { + bst_float p = expf(point[k] - wmax) / static_cast(wsum); + const float eps = 1e-16f; + const bst_float h = sycl::max(2.0f * p * (1.0f - p) * wt, eps); + p = label == k ? p - 1.0f : p; + out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h); + } + }); + }).wait(); + + + int flag = 1; + { + auto additional_input_acc = additional_input_buf.template get_access(); + flag = additional_input_acc[0]; + } + + + if (flag == 0) { + LOG(FATAL) << "SoftmaxMultiClassObjOneAPI: label must be in [0, num_class)."; + } + } + void PredTransform(HostDeviceVector* io_preds) const override { + this->Transform(io_preds, output_prob_); + } + void EvalTransform(HostDeviceVector* io_preds) override { + this->Transform(io_preds, true); + } + const char* DefaultEvalMetric() const override { + return "mlogloss"; + } + + + inline void Transform(HostDeviceVector *io_preds, bool prob) const { + const int nclass = param_.num_class; + const auto ndata = static_cast(io_preds->Size() / nclass); + max_preds_.Resize(ndata); + + + { + sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); + + + if (prob) { + qu_.submit([&](sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.template get_access(cgh); + cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { + int idx = pid[0]; + bst_float * point = &io_preds_acc[idx * nclass]; + SoftmaxOneAPI(point, point + nclass); + }); + }).wait(); + } else { + sycl::buffer max_preds_buf(max_preds_.HostPointer(), max_preds_.Size()); + + + qu_.submit([&](sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.template get_access(cgh); + auto max_preds_acc = max_preds_buf.template get_access(cgh); + cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { + int idx = pid[0]; + bst_float const * point = &io_preds_acc[idx * nclass]; + max_preds_acc[idx] = FindMaxIndexOneAPI(point, point + nclass) - point; + }); + }).wait(); + } + } + + + if (!prob) { + io_preds->Resize(max_preds_.Size()); + io_preds->Copy(max_preds_); + } + } + + + struct ObjInfo Task() const override {return {ObjInfo::kClassification}; } + + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + if (this->output_prob_) { + out["name"] = String("multi:softprob_oneapi"); + } else { + out["name"] = String("multi:softmax_oneapi"); + } + out["softmax_multiclass_param"] = ToJson(param_); + } + + + void LoadConfig(Json const& in) override { + FromJson(in["softmax_multiclass_param"], ¶m_); + } + + + private: + // output probability + bool output_prob_; + // parameter + SoftmaxMultiClassParamOneAPI param_; + // Cache for max_preds + mutable HostDeviceVector max_preds_; + + DeviceManagerOneAPI device_manager; + + mutable sycl::queue qu_; +}; + + +// register the objective functions +DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParamOneAPI); + + +XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClassOneAPI, "multi:softmax_oneapi") +.describe("Softmax for multi-class classification, output class index.") +.set_body([]() { return new SoftmaxMultiClassObjOneAPI(false); }); + + +XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClassOneAPI, "multi:softprob_oneapi") +.describe("Softmax for multi-class classification, output probability distribution.") +.set_body([]() { return new SoftmaxMultiClassObjOneAPI(true); }); + + +} // namespace obj +} // namespace xgboost diff --git a/plugin/updater_oneapi/param_oneapi.h b/plugin/updater_oneapi/param_oneapi.h new file mode 100644 index 000000000000..27ff1132f0cb --- /dev/null +++ b/plugin/updater_oneapi/param_oneapi.h @@ -0,0 +1,216 @@ +/*! + * Copyright 2014-2023 by Contributors + */ +#ifndef XGBOOST_TREE_PARAM_ONEAPI_H_ +#define XGBOOST_TREE_PARAM_ONEAPI_H_ + + +#include +#include +#include +#include +#include + + +#include "xgboost/parameter.h" +#include "xgboost/data.h" +#include "../src/tree/param.h" + + +namespace xgboost { +namespace tree { + + +/*! \brief Wrapper for necessary training parameters for regression tree to access on device */ +struct TrainParamOneAPI { + float min_child_weight; + float reg_lambda; + float reg_alpha; + float max_delta_step; + + + TrainParamOneAPI() {} + + + TrainParamOneAPI(const TrainParam& param) { + reg_lambda = param.reg_lambda; + reg_alpha = param.reg_alpha; + min_child_weight = param.min_child_weight; + max_delta_step = param.max_delta_step; + } +}; + + +/*! \brief core statistics used for tree construction */ +template +struct GradStatsOneAPI { + /*! \brief sum gradient statistics */ + GradType sum_grad { 0 }; + /*! \brief sum hessian statistics */ + GradType sum_hess { 0 }; + + + public: + GradType GetGrad() const { return sum_grad; } + GradType GetHess() const { return sum_hess; } + + + friend std::ostream& operator<<(std::ostream& os, GradStatsOneAPI s) { + os << s.GetGrad() << "/" << s.GetHess(); + return os; + } + + + GradStatsOneAPI() { + } + + + template + explicit GradStatsOneAPI(const GpairT &sum) + : sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {} + explicit GradStatsOneAPI(const GradType grad, const GradType hess) + : sum_grad(grad), sum_hess(hess) {} + /*! + * \brief accumulate statistics + * \param p the gradient pair + */ + inline void Add(GradientPair p) { this->Add(p.GetGrad(), p.GetHess()); } + + + /*! \brief add statistics to the data */ + inline void Add(const GradStatsOneAPI& b) { + sum_grad += b.sum_grad; + sum_hess += b.sum_hess; + } + /*! \brief same as add, reduce is used in All Reduce */ + inline static void Reduce(GradStatsOneAPI& a, const GradStatsOneAPI& b) { // NOLINT(*) + a.Add(b); + } + /*! \brief set current value to a - b */ + inline void SetSubstract(const GradStatsOneAPI& a, const GradStatsOneAPI& b) { + sum_grad = a.sum_grad - b.sum_grad; + sum_hess = a.sum_hess - b.sum_hess; + } + /*! \return whether the statistics is not used yet */ + inline bool Empty() const { return sum_hess == 0.0; } + /*! \brief add statistics to the data */ + inline void Add(GradType grad, GradType hess) { + sum_grad += grad; + sum_hess += hess; + } +}; + + +/*! + * \brief OneAPI implementation of SplitEntryContainer for device compilation. + * Original structure cannot be used due to std::isinf usage, which is not supported + */ +template +struct SplitEntryContainerOneAPI { + /*! \brief loss change after split this node */ + bst_float loss_chg {0.0f}; + /*! \brief split index */ + bst_feature_t sindex{0}; + bst_float split_value{0.0f}; + + + GradientT left_sum; + GradientT right_sum; + + + SplitEntryContainerOneAPI() = default; + + + friend std::ostream& operator<<(std::ostream& os, SplitEntryContainerOneAPI const& s) { + os << "loss_chg: " << s.loss_chg << ", " + << "split index: " << s.SplitIndex() << ", " + << "split value: " << s.split_value << ", " + << "left_sum: " << s.left_sum << ", " + << "right_sum: " << s.right_sum; + return os; + } + /*!\return feature index to split on */ + bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); } + /*!\return whether missing value goes to left branch */ + bool DefaultLeft() const { return (sindex >> 31) != 0; } + /*! + * \brief decides whether we can replace current entry with the given statistics + * + * This function gives better priority to lower index when loss_chg == new_loss_chg. + * Not the best way, but helps to give consistent result during multi-thread + * execution. + * + * \param new_loss_chg the loss reduction get through the split + * \param split_index the feature index where the split is on + */ + inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { + if (sycl::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf, + // for example when lambda = 0 & min_child_weight = 0 + // skip value in this case + return false; + } else if (this->SplitIndex() <= split_index) { + return new_loss_chg > this->loss_chg; + } else { + return !(this->loss_chg > new_loss_chg); + } + } + /*! + * \brief update the split entry, replace it if e is better + * \param e candidate split solution + * \return whether the proposed split is better and can replace current split + */ + inline bool Update(const SplitEntryContainerOneAPI &e) { + if (this->NeedReplace(e.loss_chg, e.SplitIndex())) { + this->loss_chg = e.loss_chg; + this->sindex = e.sindex; + this->split_value = e.split_value; + this->left_sum = e.left_sum; + this->right_sum = e.right_sum; + return true; + } else { + return false; + } + } + /*! + * \brief update the split entry, replace it if e is better + * \param new_loss_chg loss reduction of new candidate + * \param split_index feature index to split on + * \param new_split_value the split point + * \param default_left whether the missing value goes to left + * \return whether the proposed split is better and can replace current split + */ + bool Update(bst_float new_loss_chg, unsigned split_index, + bst_float new_split_value, bool default_left, + const GradientT &left_sum, + const GradientT &right_sum) { + if (this->NeedReplace(new_loss_chg, split_index)) { + this->loss_chg = new_loss_chg; + if (default_left) { + split_index |= (1U << 31); + } + this->sindex = split_index; + this->split_value = new_split_value; + this->left_sum = left_sum; + this->right_sum = right_sum; + return true; + } else { + return false; + } + } + + + /*! \brief same as update, used by AllReduce*/ + inline static void Reduce(SplitEntryContainerOneAPI &dst, // NOLINT(*) + const SplitEntryContainerOneAPI &src) { // NOLINT(*) + dst.Update(src); + } +}; + + +template +using SplitEntryOneAPI = SplitEntryContainerOneAPI>; + + +} +} +#endif // XGBOOST_TREE_PARAM_H_ diff --git a/plugin/updater_oneapi/predictor_oneapi.cc b/plugin/updater_oneapi/predictor_oneapi.cc index 25a14186c179..86f1877b84c4 100755 --- a/plugin/updater_oneapi/predictor_oneapi.cc +++ b/plugin/updater_oneapi/predictor_oneapi.cc @@ -1,105 +1,111 @@ /*! - * Copyright by Contributors 2017-2020 + * Copyright by Contributors 2017-2023 */ -#include // for any #include #include #include +#include + +#include "data_oneapi.h" + +#include "dmlc/registry.h" + +#include "xgboost/tree_model.h" +#include "xgboost/predictor.h" +#include "xgboost/tree_updater.h" -#include "../../src/common/math.h" #include "../../src/data/adapter.h" +#include "../../src/common/math.h" #include "../../src/gbm/gbtree_model.h" + +#include "./device_manager_oneapi.h" + #include "CL/sycl.hpp" -#include "xgboost/base.h" -#include "xgboost/data.h" -#include "xgboost/host_device_vector.h" -#include "xgboost/logging.h" -#include "xgboost/predictor.h" -#include "xgboost/tree_model.h" -#include "xgboost/tree_updater.h" namespace xgboost { namespace predictor { DMLC_REGISTRY_FILE_TAG(predictor_oneapi); -/*! \brief Element from a sparse vector */ -struct EntryOneAPI { - /*! \brief feature index */ - bst_feature_t index; - /*! \brief feature value */ - bst_float fvalue; - /*! \brief default constructor */ - EntryOneAPI() = default; - /*! - * \brief constructor with index and value - * \param index The feature or row index. - * \param fvalue The feature value. - */ - EntryOneAPI(bst_feature_t index, bst_float fvalue) : index(index), fvalue(fvalue) {} - - EntryOneAPI(const Entry& entry) : index(entry.index), fvalue(entry.fvalue) {} - - /*! \brief reversely compare feature values */ - inline static bool CmpValue(const EntryOneAPI& a, const EntryOneAPI& b) { - return a.fvalue < b.fvalue; - } - inline bool operator==(const EntryOneAPI& other) const { - return (this->index == other.index && this->fvalue == other.fvalue); - } -}; +class PredictorOneAPI : public Predictor { + public: + explicit PredictorOneAPI(Context const* context) : + Predictor::Predictor{context} {} + + void Configure(const std::vector>& args) override { + const DeviceOrd device_spec = ctx_->Device(); -struct DeviceMatrixOneAPI { - DMatrix* p_mat; // Pointer to the original matrix on the host - cl::sycl::queue qu_; - size_t* row_ptr; - size_t row_ptr_size; - EntryOneAPI* data; - - DeviceMatrixOneAPI(DMatrix* dmat, cl::sycl::queue qu) : p_mat(dmat), qu_(qu) { - size_t num_row = 0; - size_t num_nonzero = 0; - for (auto &batch : dmat->GetBatches()) { - const auto& data_vec = batch.data.HostVector(); - const auto& offset_vec = batch.offset.HostVector(); - num_nonzero += data_vec.size(); - num_row += batch.Size(); + bool is_cpu; + if (device_spec.IsSycl()) { + sycl::device device = device_manager.GetDevice(device_spec); + is_cpu = device.is_cpu(); + } else { + is_cpu = true; } - row_ptr = cl::sycl::malloc_shared(num_row + 1, qu_); - data = cl::sycl::malloc_shared(num_nonzero, qu_); - - size_t data_offset = 0; - for (auto &batch : dmat->GetBatches()) { - const auto& data_vec = batch.data.HostVector(); - const auto& offset_vec = batch.offset.HostVector(); - size_t batch_size = batch.Size(); - if (batch_size > 0) { - std::copy(offset_vec.data(), offset_vec.data() + batch_size, - row_ptr + batch.base_rowid); - if (batch.base_rowid > 0) { - for(size_t i = 0; i < batch_size; i++) - row_ptr[i + batch.base_rowid] += batch.base_rowid; - } - std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size], - data + data_offset); - data_offset += offset_vec[batch_size]; - } + LOG(INFO) << "device = " << device_spec.Name() << ", is_cpu = " << int(is_cpu); + + if (is_cpu) { + predictor_backend_.reset(Predictor::Create("cpu_predictor", ctx_)); + } else{ + predictor_backend_.reset(Predictor::Create("oneapi_predictor_backend", ctx_)); } - row_ptr[num_row] = data_offset; - row_ptr_size = num_row + 1; + predictor_backend_->Configure(args); } - ~DeviceMatrixOneAPI() { - if (row_ptr) { - cl::sycl::free(row_ptr, qu_); - } - if (data) { - cl::sycl::free(data, qu_); - } + void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, + const gbm::GBTreeModel &model, uint32_t tree_begin, + uint32_t tree_end = 0) const override { + predictor_backend_->PredictBatch(dmat, predts, model, tree_begin, tree_end); + } + + bool InplacePredict(std::shared_ptr p_m, + const gbm::GBTreeModel &model, float missing, + PredictionCacheEntry *out_preds, uint32_t tree_begin, + unsigned tree_end) const override { + return predictor_backend_->InplacePredict(p_m, model, missing, out_preds, tree_begin, tree_end); + } + + void PredictInstance(const SparsePage::Inst& inst, + std::vector* out_preds, + const gbm::GBTreeModel& model, unsigned ntree_limit, + bool is_column_split) const override { + predictor_backend_->PredictInstance(inst, out_preds, model, ntree_limit, is_column_split); + } + + void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + predictor_backend_->PredictLeaf(p_fmat, out_preds, model, ntree_limit); + } + + void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, + const gbm::GBTreeModel& model, uint32_t ntree_limit, + const std::vector* tree_weights, + bool approximate, int condition, + unsigned condition_feature) const override { + predictor_backend_->PredictContribution(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate, condition, condition_feature); } + + void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, + const gbm::GBTreeModel& model, unsigned ntree_limit, + const std::vector* tree_weights, + bool approximate) const override { + predictor_backend_->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate); + } + + protected: + void InitOutPredictions(const MetaInfo& info, + HostDeviceVector* out_preds, + const gbm::GBTreeModel& model) const { + predictor_backend_->InitOutPredictions(info, out_preds, model); + } + + private: + DeviceManagerOneAPI device_manager; + std::unique_ptr predictor_backend_; }; +/* Wrapper for descriptor of a tree node */ struct DeviceNodeOneAPI { DeviceNodeOneAPI() : fidx(-1), left_child_idx(-1), right_child_idx(-1) {} @@ -114,7 +120,7 @@ struct DeviceNodeOneAPI { int right_child_idx; NodeValue val; - DeviceNodeOneAPI(const RegTree::Node& n) { // NOLINT + DeviceNodeOneAPI(const RegTree::Node& n) { this->left_child_idx = n.LeftChild(); this->right_child_idx = n.RightChild(); this->fidx = n.SplitIndex(); @@ -148,66 +154,54 @@ struct DeviceNodeOneAPI { float GetWeight() const { return val.leaf_weight; } }; +/* OneAPI implementation of a device model, storing tree structure in USM buffers to provide access from device kernels */ class DeviceModelOneAPI { public: - cl::sycl::queue qu_; - DeviceNodeOneAPI* nodes; - size_t* tree_segments; - int* tree_group; + sycl::queue qu_; + USMVector nodes_; + USMVector tree_segments_; + USMVector tree_group_; size_t tree_beg_; size_t tree_end_; - int num_group; + int num_group_; - DeviceModelOneAPI() : nodes(nullptr), tree_segments(nullptr), tree_group(nullptr) {} + DeviceModelOneAPI() {} - ~DeviceModelOneAPI() { - Reset(); - } + ~DeviceModelOneAPI() {} - void Reset() { - if (nodes) - cl::sycl::free(nodes, qu_); - if (tree_segments) - cl::sycl::free(tree_segments, qu_); - if (tree_group) - cl::sycl::free(tree_group, qu_); - } - - void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, cl::sycl::queue qu) { + void Init(sycl::queue qu, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { qu_ = qu; - CHECK_EQ(model.param.size_leaf_vector, 0); - Reset(); - tree_segments = cl::sycl::malloc_shared((tree_end - tree_begin) + 1, qu_); + tree_segments_.Resize(qu_, (tree_end - tree_begin) + 1); int sum = 0; - tree_segments[0] = sum; + tree_segments_[0] = sum; for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { sum += model.trees[tree_idx]->GetNodes().size(); - tree_segments[tree_idx - tree_begin + 1] = sum; + tree_segments_[tree_idx - tree_begin + 1] = sum; } - nodes = cl::sycl::malloc_shared(sum, qu_); + nodes_.Resize(qu_, sum); for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { auto& src_nodes = model.trees[tree_idx]->GetNodes(); for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++) - nodes[node_idx + tree_segments[tree_idx - tree_begin]] = src_nodes[node_idx]; + nodes_[node_idx + tree_segments_[tree_idx - tree_begin]] = src_nodes[node_idx]; } - tree_group = cl::sycl::malloc_shared(model.tree_info.size(), qu_); + tree_group_.Resize(qu_, model.tree_info.size()); for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++) - tree_group[tree_idx] = model.tree_info[tree_idx]; + tree_group_[tree_idx] = model.tree_info[tree_idx]; tree_beg_ = tree_begin; tree_end_ = tree_end; - num_group = model.learner_model_param->num_output_group; + num_group_ = model.learner_model_param->num_output_group; } }; -float GetFvalue(int ridx, int fidx, EntryOneAPI* data, size_t* row_ptr, bool& is_missing) { +float GetFvalue(int ridx, int fidx, Entry* data, size_t* row_ptr, bool& is_missing) { // Binary search auto begin_ptr = data + row_ptr[ridx]; auto end_ptr = data + row_ptr[ridx + 1]; - EntryOneAPI* previous_middle = nullptr; + Entry* previous_middle = nullptr; while (end_ptr != begin_ptr) { auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; if (middle == previous_middle) { @@ -229,7 +223,7 @@ float GetFvalue(int ridx, int fidx, EntryOneAPI* data, size_t* row_ptr, bool& is return 0.0; } -float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, EntryOneAPI* data, size_t* row_ptr) { +float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, Entry* data, size_t* row_ptr) { DeviceNodeOneAPI n = tree[0]; int node_id = 0; bool is_missing; @@ -251,20 +245,68 @@ float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, EntryOneAPI* data, s return n.GetWeight(); } -class PredictorOneAPI : public Predictor { +void DevicePredictInternal(sycl::queue qu, + DeviceMatrixOneAPI* dmat, + HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, + size_t tree_begin, + size_t tree_end) { + if (tree_end - tree_begin == 0) { + return; + } + DeviceModelOneAPI device_model; + device_model.Init(qu, model, tree_begin, tree_end); + + auto& out_preds_vec = out_preds->HostVector(); + + DeviceNodeOneAPI* nodes = device_model.nodes_.Data(); + sycl::buffer out_preds_buf(out_preds_vec.data(), out_preds_vec.size()); + size_t* tree_segments = device_model.tree_segments_.Data(); + int* tree_group = device_model.tree_group_.Data(); + size_t* row_ptr = dmat->row_ptr.Data(); + Entry* data = dmat->data.Data(); + int num_features = dmat->p_mat->Info().num_col_; + int num_rows = dmat->row_ptr.Size() - 1; + int num_group = model.learner_model_param->num_output_group; + + qu.submit([&](sycl::handler& cgh) { + auto out_predictions = out_preds_buf.template get_access(cgh); + cgh.parallel_for<>(sycl::range<1>(num_rows), [=](sycl::id<1> pid) { + int global_idx = pid[0]; + if (global_idx >= num_rows) return; + if (num_group == 1) { + float sum = 0.0; + for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin]; + sum += GetLeafWeight(global_idx, tree, data, row_ptr); + } + out_predictions[global_idx] += sum; + } else { + for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin]; + int out_prediction_idx = global_idx * num_group + tree_group[tree_idx]; + out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr); + } + } + }); + }).wait(); +} + +class PredictorBackendOneAPI : public Predictor { protected: void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { CHECK_NE(model.learner_model_param->num_output_group, 0); size_t n = model.learner_model_param->num_output_group * info.num_row_; - const auto& base_margin = info.base_margin_.HostVector(); + const auto& base_margin = info.base_margin_.Data()->HostVector(); out_preds->Resize(n); std::vector& out_preds_h = out_preds->HostVector(); if (base_margin.size() == n) { CHECK_EQ(out_preds->Size(), n); std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin()); } else { + auto base_score = model.learner_model_param->BaseScore(ctx_)(0); if (!base_margin.empty()) { std::ostringstream oss; oss << "Ignoring the base margin, since it has incorrect length. " @@ -277,171 +319,95 @@ class PredictorOneAPI : public Predictor { oss << "[number of data points], i.e. " << info.num_row_ << ". "; } oss << "Instead, all data points will use " - << "base_score = " << model.learner_model_param->base_score; + << "base_score = " << base_score; LOG(WARNING) << oss.str(); } - std::fill(out_preds_h.begin(), out_preds_h.end(), - model.learner_model_param->base_score); - } - } - - void DevicePredictInternal(DeviceMatrixOneAPI* dmat, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, size_t tree_begin, - size_t tree_end) { - if (tree_end - tree_begin == 0) { - return; + std::fill(out_preds_h.begin(), out_preds_h.end(), base_score); } - model_.Init(model, tree_begin, tree_end, qu_); - - auto& out_preds_vec = out_preds->HostVector(); - - DeviceNodeOneAPI* nodes = model_.nodes; - cl::sycl::buffer out_preds_buf(out_preds_vec.data(), out_preds_vec.size()); - size_t* tree_segments = model_.tree_segments; - int* tree_group = model_.tree_group; - size_t* row_ptr = dmat->row_ptr; - EntryOneAPI* data = dmat->data; - int num_features = dmat->p_mat->Info().num_col_; - int num_rows = dmat->row_ptr_size - 1; - int num_group = model.learner_model_param->num_output_group; - - qu_.submit([&](cl::sycl::handler& cgh) { - auto out_predictions = out_preds_buf.get_access(cgh); - cgh.parallel_for(cl::sycl::range<1>(num_rows), [=](cl::sycl::id<1> pid) { - int global_idx = pid[0]; - if (global_idx >= num_rows) return; - if (num_group == 1) { - float sum = 0.0; - for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin]; - sum += GetLeafWeight(global_idx, tree, data, row_ptr); - } - out_predictions[global_idx] += sum; - } else { - for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin]; - int out_prediction_idx = global_idx * num_group + tree_group[tree_idx]; - out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr); - } - } - }); - }).wait(); } public: - explicit PredictorOneAPI(Context const* generic_param) : - Predictor::Predictor{generic_param}, cpu_predictor(Predictor::Create("cpu_predictor", generic_param)) { - cl::sycl::default_selector selector; - qu_ = cl::sycl::queue(selector); + explicit PredictorBackendOneAPI(Context const* context) : + Predictor::Predictor{context}, cpu_predictor(Predictor::Create("cpu_predictor", context)) { + qu_ = device_manager.GetQueue(context->Device()); } - // ntree_limit is a very problematic parameter, as it's ambiguous in the context of - // multi-output and forest. Same problem exists for tree_begin - void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, - const gbm::GBTreeModel& model, int tree_begin, - uint32_t const ntree_limit = 0) override { + void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, + const gbm::GBTreeModel &model, uint32_t tree_begin, + uint32_t tree_end = 0) const override { + // Existing caching approach is not valid due to the const modifier of the PredictBatch method + /* if (this->device_matrix_cache_.find(dmat) == this->device_matrix_cache_.end()) { this->device_matrix_cache_.emplace( dmat, std::unique_ptr( - new DeviceMatrixOneAPI(dmat, qu_))); + new DeviceMatrixOneAPI(qu_, dmat))); } DeviceMatrixOneAPI* device_matrix = device_matrix_cache_.find(dmat)->second.get(); + */ - // tree_begin is not used, right now we just enforce it to be 0. - CHECK_EQ(tree_begin, 0); - auto* out_preds = &predts->predictions; - CHECK_GE(predts->version, tree_begin); - if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { - CHECK_EQ(predts->version, 0); - } - if (predts->version == 0) { - // out_preds->Size() can be non-zero as it's initialized here before any tree is - // built at the 0^th iterator. - this->InitOutPredictions(dmat->Info(), out_preds, model); - } - - uint32_t const output_groups = model.learner_model_param->num_output_group; - CHECK_NE(output_groups, 0); - // Right now we just assume ntree_limit provided by users means number of tree layers - // in the context of multi-output model - uint32_t real_ntree_limit = ntree_limit * output_groups; - if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { - real_ntree_limit = static_cast(model.trees.size()); - } + DeviceMatrixOneAPI device_matrix(qu_, dmat); // TODO: remove temporary workaround after cache fix - uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups; - // When users have provided ntree_limit, end_version can be lesser, cache is violated - if (predts->version > end_version) { - CHECK_NE(ntree_limit, 0); - this->InitOutPredictions(dmat->Info(), out_preds, model); - predts->version = 0; + auto* out_preds = &predts->predictions; + if (tree_end == 0) { + tree_end = model.trees.size(); } - uint32_t const beg_version = predts->version; - CHECK_LE(beg_version, end_version); - if (beg_version < end_version) { - DevicePredictInternal(device_matrix, out_preds, model, - beg_version * output_groups, - end_version * output_groups); + if (tree_begin < tree_end) { + DevicePredictInternal(qu_, &device_matrix, out_preds, model, tree_begin, tree_end); } - - // delta means {size of forest} * {number of newly accumulated layers} - uint32_t delta = end_version - beg_version; - CHECK_LE(delta, model.trees.size()); - predts->Update(delta); - - CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ || - out_preds->Size() == dmat->Info().num_row_); } - void InplacePredict(std::any const& x, const gbm::GBTreeModel& model, float missing, - PredictionCacheEntry* out_preds, uint32_t tree_begin, + bool InplacePredict(std::shared_ptr p_m, + const gbm::GBTreeModel &model, float missing, + PredictionCacheEntry *out_preds, uint32_t tree_begin, unsigned tree_end) const override { - cpu_predictor->InplacePredict(x, model, missing, out_preds, tree_begin, tree_end); + return cpu_predictor->InplacePredict(p_m, model, missing, out_preds, tree_begin, tree_end); } void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) override { - cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit); + const gbm::GBTreeModel& model, unsigned ntree_limit, + bool is_column_split) const override { + cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit, is_column_split); } - void PredictLeaf(DMatrix* p_fmat, std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) override { + void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, unsigned ntree_limit) const override { cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit); } - void PredictContribution(DMatrix* p_fmat, std::vector* out_contribs, + void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, uint32_t ntree_limit, - std::vector* tree_weights, + const std::vector* tree_weights, bool approximate, int condition, - unsigned condition_feature) override { + unsigned condition_feature) const override { cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate, condition, condition_feature); } - void PredictInteractionContributions(DMatrix* p_fmat, std::vector* out_contribs, + void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, unsigned ntree_limit, - std::vector* tree_weights, - bool approximate) override { + const std::vector* tree_weights, + bool approximate) const override { cpu_predictor->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate); } private: - cl::sycl::queue qu_; - DeviceModelOneAPI model_; + DeviceManagerOneAPI device_manager; + sycl::queue qu_; - std::mutex lock_; std::unique_ptr cpu_predictor; - std::unordered_map> - device_matrix_cache_; + std::unordered_map> device_matrix_cache_; }; XGBOOST_REGISTER_PREDICTOR(PredictorOneAPI, "oneapi_predictor") .describe("Make predictions using DPC++.") -.set_body([](Context const* generic_param) { - return new PredictorOneAPI(generic_param); - }); +.set_body([](Context const *ctx) { return new PredictorOneAPI(ctx); }); + +XGBOOST_REGISTER_PREDICTOR(PredictorBackendOneAPI, "oneapi_predictor_backend") +.describe("Make predictions using DPC++.") +.set_body([](Context const* ctx) { return new PredictorBackendOneAPI(ctx); }); + } // namespace predictor } // namespace xgboost diff --git a/plugin/updater_oneapi/regression_loss_oneapi.h b/plugin/updater_oneapi/regression_loss_oneapi.h index b0299ff7f5a3..beab461e2aff 100755 --- a/plugin/updater_oneapi/regression_loss_oneapi.h +++ b/plugin/updater_oneapi/regression_loss_oneapi.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 XGBoost contributors + * Copyright 2017-2023 XGBoost contributors */ #ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_ #define XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_ @@ -19,7 +19,7 @@ namespace obj { * \return the transformed value. */ inline float SigmoidOneAPI(float x) { - return 1.0f / (1.0f + cl::sycl::exp(-x)); + return 1.0f / (1.0f + sycl::exp(-x)); } // common regressions @@ -38,6 +38,8 @@ struct LinearSquareLossOneAPI { static const char* DefaultEvalMetric() { return "rmse"; } static const char* Name() { return "reg:squarederror_oneapi"; } + + static ObjInfo Info() { return {ObjInfo::kRegression, true, false}; } }; // TODO: DPC++ does not fully support std math inside offloaded kernels @@ -48,12 +50,12 @@ struct SquaredLogErrorOneAPI { } static bst_float FirstOrderGradient(bst_float predt, bst_float label) { predt = std::max(predt, (bst_float)(-1 + 1e-6)); // ensure correct value for log1p - return (cl::sycl::log1p(predt) - cl::sycl::log1p(label)) / (predt + 1); + return (sycl::log1p(predt) - sycl::log1p(label)) / (predt + 1); } static bst_float SecondOrderGradient(bst_float predt, bst_float label) { predt = std::max(predt, (bst_float)(-1 + 1e-6)); - float res = (-cl::sycl::log1p(predt) + cl::sycl::log1p(label) + 1) / - cl::sycl::pow(predt + 1, (bst_float)2); + float res = (-sycl::log1p(predt) + sycl::log1p(label) + 1) / + sycl::pow(predt + 1, (bst_float)2); res = std::max(res, (bst_float)1e-6f); return res; } @@ -64,6 +66,8 @@ struct SquaredLogErrorOneAPI { static const char* DefaultEvalMetric() { return "rmsle"; } static const char* Name() { return "reg:squaredlogerror_oneapi"; } + + static ObjInfo Info() { return ObjInfo::kRegression; } }; // logistic loss for probability regression task @@ -99,6 +103,8 @@ struct LogisticRegressionOneAPI { static const char* DefaultEvalMetric() { return "rmse"; } static const char* Name() { return "reg:logistic_oneapi"; } + + static ObjInfo Info() { return ObjInfo::kRegression; } }; // logistic loss for binary classification task @@ -137,6 +143,8 @@ struct LogisticRawOneAPI : public LogisticRegressionOneAPI { static const char* DefaultEvalMetric() { return "logloss"; } static const char* Name() { return "binary:logitraw_oneapi"; } + + static ObjInfo Info() { return ObjInfo::kRegression; } }; } // namespace obj diff --git a/plugin/updater_oneapi/regression_obj_oneapi.cc b/plugin/updater_oneapi/regression_obj_oneapi.cc index 3ee5741e7c1a..3c157a80e797 100755 --- a/plugin/updater_oneapi/regression_obj_oneapi.cc +++ b/plugin/updater_oneapi/regression_obj_oneapi.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include "xgboost/host_device_vector.h" #include "xgboost/json.h" @@ -11,7 +12,8 @@ #include "../../src/common/transform.h" #include "../../src/common/common.h" -#include "./regression_loss_oneapi.h" +#include "regression_loss_oneapi.h" +#include "device_manager_oneapi.h" #include "CL/sycl.hpp" @@ -39,60 +41,60 @@ class RegLossObjOneAPI : public ObjFunction { void Configure(const std::vector >& args) override { param_.UpdateAllowUnknown(args); - - cl::sycl::default_selector selector; - qu_ = cl::sycl::queue(selector); + qu_ = device_manager.GetQueue(ctx_->Device()); } void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int iter, HostDeviceVector* out_gpair) override { - if (info.labels_.Size() == 0U) { - LOG(WARNING) << "Label set is empty."; - } - CHECK_EQ(preds.Size(), info.labels_.Size()) - << " " << "labels are not correctly provided" - << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size() << ", " - << "Loss: " << Loss::Name(); + if (info.labels.Size() == 0U) { + LOG(WARNING) << "Label set is empty."; + } + CHECK_EQ(preds.Size(), info.labels.Size()) + << " " << "labels are not correctly provided" + << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " + << "Loss: " << Loss::Name(); + + size_t const ndata = preds.Size(); + out_gpair->Resize(ndata); - size_t const ndata = preds.Size(); - out_gpair->Resize(ndata); + // TODO: add label_correct check + label_correct_.Resize(1); + label_correct_.Fill(1); - // TODO: add label_correct check - label_correct_.Resize(1); - label_correct_.Fill(1); + bool is_null_weight = info.weights_.Size() == 0; - bool is_null_weight = info.weights_.Size() == 0; + sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); + sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); + sycl::buffer out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size()); + sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), + is_null_weight ? 1 : info.weights_.Size()); - cl::sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); - cl::sycl::buffer labels_buf(info.labels_.HostPointer(), info.labels_.Size()); - cl::sycl::buffer out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size()); - cl::sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), - is_null_weight ? 1 : info.weights_.Size()); + const size_t n_targets = std::max(info.labels.Shape(1), static_cast(1)); - cl::sycl::buffer additional_input_buf(1); + sycl::buffer additional_input_buf(1); { - auto additional_input_acc = additional_input_buf.get_access(); + auto additional_input_acc = additional_input_buf.get_access(); additional_input_acc[0] = 1; // Fill the label_correct flag } auto scale_pos_weight = param_.scale_pos_weight; if (!is_null_weight) { - CHECK_EQ(info.weights_.Size(), ndata) + CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) << "Number of weights should be equal to number of data points."; } - qu_.submit([&](cl::sycl::handler& cgh) { - auto preds_acc = preds_buf.get_access(cgh); - auto labels_acc = labels_buf.get_access(cgh); - auto weights_acc = weights_buf.get_access(cgh); - auto out_gpair_acc = out_gpair_buf.get_access(cgh); - auto additional_input_acc = additional_input_buf.get_access(cgh); - cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) { + qu_.submit([&](sycl::handler& cgh) { + auto preds_acc = preds_buf.get_access(cgh); + auto labels_acc = labels_buf.get_access(cgh); + auto weights_acc = weights_buf.get_access(cgh); + auto out_gpair_acc = out_gpair_buf.get_access(cgh); + auto additional_input_acc = additional_input_buf.get_access(cgh); + cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { int idx = pid[0]; bst_float p = Loss::PredTransform(preds_acc[idx]); - bst_float w = is_null_weight ? 1.0f : weights_acc[idx]; + bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets]; bst_float label = labels_acc[idx]; if (label == 1.0f) { w *= scale_pos_weight; @@ -108,7 +110,7 @@ class RegLossObjOneAPI : public ObjFunction { int flag = 1; { - auto additional_input_acc = additional_input_buf.get_access(); + auto additional_input_acc = additional_input_buf.get_access(); flag = additional_input_acc[0]; } @@ -123,14 +125,13 @@ class RegLossObjOneAPI : public ObjFunction { return Loss::DefaultEvalMetric(); } - void PredTransform(HostDeviceVector *io_preds) override { + void PredTransform(HostDeviceVector *io_preds) const override { size_t const ndata = io_preds->Size(); + sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); - cl::sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); - - qu_.submit([&](cl::sycl::handler& cgh) { - auto io_preds_acc = io_preds_buf.get_access(cgh); - cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) { + qu_.submit([&](sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.get_access(cgh); + cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { int idx = pid[0]; io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]); }); @@ -141,6 +142,15 @@ class RegLossObjOneAPI : public ObjFunction { return Loss::ProbToMargin(base_score); } + struct ObjInfo Task() const override { + return Loss::Info(); + }; + + uint32_t Targets(MetaInfo const& info) const override { + // Multi-target regression. + return std::max(static_cast(1), info.labels.Shape(1)); + } + void SaveConfig(Json* p_out) const override { auto& out = *p_out; out["name"] = String(Loss::Name()); @@ -153,8 +163,9 @@ class RegLossObjOneAPI : public ObjFunction { protected: RegLossParamOneAPI param_; + DeviceManagerOneAPI device_manager; - cl::sycl::queue qu_; + mutable sycl::queue qu_; }; // register the objective functions diff --git a/plugin/updater_oneapi/row_set_oneapi.h b/plugin/updater_oneapi/row_set_oneapi.h new file mode 100644 index 000000000000..309c85ad432e --- /dev/null +++ b/plugin/updater_oneapi/row_set_oneapi.h @@ -0,0 +1,261 @@ +/*! + * Copyright 2017-2023 XGBoost contributors + */ +#ifndef XGBOOST_COMMON_ROW_SET_ONEAPI_H_ +#define XGBOOST_COMMON_ROW_SET_ONEAPI_H_ + + +#include +#include +#include +#include + + +#include "data_oneapi.h" + + +#include "CL/sycl.hpp" + + +namespace xgboost { +namespace common { + + +/*! \brief Collection of rowsets stored on device in USM memory */ +class RowSetCollectionOneAPI { + public: + /*! \brief data structure to store an instance set, a subset of + * rows (instances) associated with a particular node in a decision + * tree. */ + struct Elem { + const size_t* begin{nullptr}; + const size_t* end{nullptr}; + bst_node_t node_id{-1}; // id of node associated with this instance set; -1 means uninitialized + Elem() + = default; + Elem(const size_t* begin, + const size_t* end, + bst_node_t node_id = -1) + : begin(begin), end(end), node_id(node_id) {} + + + inline size_t Size() const { + return end - begin; + } + }; + + + inline size_t Size() const { + return elem_of_each_node_.size(); + } + + + /*! \brief return corresponding element set given the node_id */ + inline const Elem& operator[](unsigned node_id) const { + const Elem& e = elem_of_each_node_[node_id]; + CHECK(e.begin != nullptr) + << "access element that is not in the set"; + return e; + } + + + /*! \brief return corresponding element set given the node_id */ + inline Elem& operator[](unsigned node_id) { + Elem& e = elem_of_each_node_[node_id]; + return e; + } + + + // clear up things + inline void Clear() { + elem_of_each_node_.clear(); + } + // initialize node id 0->everything + inline void Init() { + CHECK_EQ(elem_of_each_node_.size(), 0U); + + + const size_t* begin = row_indices_.Begin(); + const size_t* end = row_indices_.End(); + elem_of_each_node_.emplace_back(Elem(begin, end, 0)); + } + + + USMVector& Data() { return row_indices_; } + + + // split rowset into two + inline void AddSplit(unsigned node_id, + unsigned left_node_id, + unsigned right_node_id, + size_t n_left, + size_t n_right) { + const Elem e = elem_of_each_node_[node_id]; + CHECK(e.begin != nullptr); + size_t* all_begin = row_indices_.Begin(); + size_t* begin = all_begin + (e.begin - all_begin); + + + CHECK_EQ(n_left + n_right, e.Size()); + CHECK_LE(begin + n_left, e.end); + CHECK_EQ(begin + n_left + n_right, e.end); + + + if (left_node_id >= elem_of_each_node_.size()) { + elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1)); + } + if (right_node_id >= elem_of_each_node_.size()) { + elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1)); + } + + + elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id); + elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id); + elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1); + } + + + private: + // stores the row indexes in the set + USMVector row_indices_; + // vector: node_id -> elements + std::vector elem_of_each_node_; +}; + + +// The builder is required for samples partition to left and rights children for set of nodes +class PartitionBuilderOneAPI { + public: + static constexpr size_t maxLocalSums = 256; + static constexpr size_t subgroupSize = 16; + + + template + void Init(sycl::queue qu, size_t n_nodes, Func funcNTaks) { + qu_ = qu; + nodes_offsets_.resize(n_nodes+1); + result_rows_.resize(2 * n_nodes); + n_nodes_ = n_nodes; + + + nodes_offsets_[0] = 0; + for (size_t i = 1; i < n_nodes+1; ++i) { + nodes_offsets_[i] = nodes_offsets_[i-1] + funcNTaks(i-1); + } + + + if (data_.Size() < nodes_offsets_[n_nodes]) { + data_.Resize(qu_, nodes_offsets_[n_nodes]); + } + prefix_sums_.Resize(qu, maxLocalSums); + } + + + common::Span GetData(int nid) { + return { data_.Data() + nodes_offsets_[nid], nodes_offsets_[nid + 1] - nodes_offsets_[nid] }; + } + + + common::Span GetPrefixSums() { + return { prefix_sums_.Data(), prefix_sums_.Size() }; + } + + + size_t GetLocalSize(const common::Range1d& range) { + size_t range_size = range.end() - range.begin(); + size_t local_subgroups = range_size / (maxLocalSums * subgroupSize) + !!(range_size % (maxLocalSums * subgroupSize)); + return subgroupSize * local_subgroups; + } + + + size_t GetSubgroupSize() { + return subgroupSize; + } + + + // void SetNLeftElems(int nid, size_t n_left) { + // result_left_rows_[nid] = n_left; + // } + + + // void SetNRightElems(int nid, size_t n_right) { + // result_right_rows_[nid] = n_right; + // } + + + // sycl::event SetNLeftRightElems(sycl::queue& qu, const USMVector& parts_size, + // const std::vector& priv_events) { + // auto event = qu.submit([&](sycl::handler& cgh) { + // cgh.depends_on(priv_events); + // cgh.parallel_for<>(sycl::range<1>(n_nodes_), [=](sycl::item<1> nid) { + // const size_t node_in_set = nid.get_id(0); + // result_left_rows_[node_in_set] = parts_size[2 * node_in_set]; + // result_right_rows_[node_in_set] = parts_size[2 * node_in_set + 1]; + // }); + // }); + // return event; + // } + + + size_t* GetResultRowsPtr() { + return result_rows_.data(); + } + + + size_t GetNLeftElems(int nid) const { + // return result_left_rows_[nid]; + return result_rows_[2 * nid]; + } + + + size_t GetNRightElems(int nid) const { + // return result_right_rows_[nid]; + return result_rows_[2 * nid + 1]; + } + + + sycl::event MergeToArray(sycl::queue& qu, size_t node_in_set, + size_t* data_result, + sycl::event priv_event) { + size_t n_nodes_total = GetNLeftElems(node_in_set) + GetNRightElems(node_in_set); + if (n_nodes_total > 0) { + const size_t* data = data_.Data() + nodes_offsets_[node_in_set]; + return qu.memcpy(data_result, data, sizeof(size_t) * n_nodes_total, priv_event); + } else { + return sycl::event(); + } + } + + + // void MergeToArray(int nid, size_t* rows_indexes) { + // size_t* data_result = rows_indexes; + + + // const size_t* data = data_.Data() + nodes_offsets_[nid]; + + + // if (result_left_rows_[nid] + result_right_rows_[nid] > 0) qu_.memcpy(data_result, data, sizeof(size_t) * (result_left_rows_[nid] + result_right_rows_[nid])); + // } + + + protected: + std::vector nodes_offsets_; + std::vector result_rows_; + size_t n_nodes_; + + + USMVector data_; + + + USMVector prefix_sums_; + + + sycl::queue qu_; +}; + + +} // namespace common +} // namespace xgboost + + +#endif // XGBOOST_COMMON_ROW_SET_ONEAPI_H_ diff --git a/plugin/updater_oneapi/split_evaluator_oneapi.h b/plugin/updater_oneapi/split_evaluator_oneapi.h new file mode 100644 index 000000000000..593c7a4910d6 --- /dev/null +++ b/plugin/updater_oneapi/split_evaluator_oneapi.h @@ -0,0 +1,192 @@ +/*! + * Copyright 2018-2023 by Contributors + */ + +#ifndef XGBOOST_TREE_SPLIT_EVALUATOR_ONEAPI_H_ +#define XGBOOST_TREE_SPLIT_EVALUATOR_ONEAPI_H_ + +#include +#include +#include +#include +#include + +#include "param_oneapi.h" + +#include "xgboost/tree_model.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/context.h" +#include "../../src/common/transform.h" +#include "../../src/common/math.h" +#include "../../src/tree/param.h" + +#include "CL/sycl.hpp" + +namespace xgboost { +namespace tree { + +/*! \brief OneAPI implementation of TreeEvaluator, with USM memory for temporary buffer to access on device. + * It also contains own implementation of SplitEvaluator for device compilation, because some of the + functions from the original SplitEvaluator are currently not supported + */ + +template +class TreeEvaluatorOneAPI { + // hist and exact use parent id to calculate constraints. + static constexpr bst_node_t kRootParentId = + (-1 & static_cast((1U << 31) - 1)); + + USMVector lower_bounds_; + USMVector upper_bounds_; + USMVector monotone_; + TrainParamOneAPI param_; + sycl::queue qu_; + bool has_constraint_; + + public: + TreeEvaluatorOneAPI(sycl::queue qu, TrainParam const& p, bst_feature_t n_features) { + qu_ = qu; + if (p.monotone_constraints.empty()) { + monotone_.Resize(qu_, n_features, 0); + has_constraint_ = false; + } else { + monotone_ = USMVector(qu_, p.monotone_constraints); + monotone_.Resize(qu_, n_features, 0); + lower_bounds_.Resize(qu_, p.MaxNodes(), -std::numeric_limits::max()); + upper_bounds_.Resize(qu_, p.MaxNodes(), std::numeric_limits::max()); + has_constraint_ = true; + } + param_ = TrainParamOneAPI(p); + } + + struct SplitEvaluator { + int* constraints; + GradType* lower; + GradType* upper; + bool has_constraint; + TrainParamOneAPI param; + + GradType CalcSplitGain(bst_node_t nidx, + bst_feature_t fidx, + const GradStatsOneAPI& left, + const GradStatsOneAPI& right) const { + int constraint = constraints[fidx]; + const GradType negative_infinity = -std::numeric_limits::infinity(); + GradType wleft = this->CalcWeight(nidx, left); + GradType wright = this->CalcWeight(nidx, right); + + GradType gain = this->CalcGainGivenWeight(nidx, left, wleft) + + this->CalcGainGivenWeight(nidx, right, wright); + if (constraint == 0) { + return gain; + } else if (constraint > 0) { + return wleft <= wright ? gain : negative_infinity; + } else { + return wleft >= wright ? gain : negative_infinity; + } + } + + inline GradType ThresholdL1OneAPI(GradType w, GradType alpha) const { + if (w > + alpha) { + return w - alpha; + } + if (w < - alpha) { + return w + alpha; + } + return 0.0; + } + + inline GradType CalcWeightOneAPI(GradType sum_grad, GradType sum_hess) const { + if (sum_hess < param.min_child_weight || sum_hess <= 0.0) { + return 0.0; + } + GradType dw = -this->ThresholdL1OneAPI(sum_grad, param.reg_alpha) / (sum_hess + param.reg_lambda); + if (param.max_delta_step != 0.0f && std::abs(dw) > param.max_delta_step) { + dw = sycl::copysign((GradType)param.max_delta_step, dw); + } + return dw; + } + + inline GradType CalcWeight(bst_node_t nodeid, const GradStatsOneAPI& stats) const { + GradType w = this->CalcWeightOneAPI(stats.GetGrad(), stats.GetHess()); + if (!has_constraint) { + return w; + } + + if (nodeid == kRootParentId) { + return w; + } else if (w < lower[nodeid]) { + return lower[nodeid]; + } else if (w > upper[nodeid]) { + return upper[nodeid]; + } else { + return w; + } + } + + inline GradType Sqr(GradType a) const { return a * a; } + + inline GradType CalcGainGivenWeight(GradType sum_grad, GradType sum_hess, GradType w) const { + return -(2.0f * sum_grad * w + (sum_hess + param.reg_lambda) * this->Sqr(w)); + } + + inline GradType CalcGainGivenWeight(bst_node_t nid, const GradStatsOneAPI& stats, GradType w) const { + if (stats.GetHess() <= 0) { + return .0f; + } + // Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error. + if (param.max_delta_step == 0.0f && has_constraint == false) { + return this->Sqr(this->ThresholdL1OneAPI(stats.sum_grad, param.reg_alpha)) / + (stats.sum_hess + param.reg_lambda); + } + return this->CalcGainGivenWeight(stats.sum_grad, stats.sum_hess, w); + } + + GradType CalcGain(bst_node_t nid, const GradStatsOneAPI& stats) const { + return this->CalcGainGivenWeight(nid, stats, this->CalcWeight(nid, stats)); + } + }; + + public: + /* Get a view to the evaluator that can be passed down to device. */ + auto GetEvaluator() { + return SplitEvaluator{monotone_.Data(), + lower_bounds_.Data(), + upper_bounds_.Data(), + has_constraint_, + param_}; + } + + void AddSplit(bst_node_t nodeid, bst_node_t leftid, bst_node_t rightid, + bst_feature_t f, GradType left_weight, GradType right_weight) { + if (!has_constraint_) { + return; + } + GradType* lower = lower_bounds_.Data(); + GradType* upper = upper_bounds_.Data(); + int* monotone = monotone_.Data(); + qu_.submit([&](sycl::handler& cgh) { + cgh.parallel_for<>(sycl::range<1>(1), [=](sycl::item<1> pid) { + lower[leftid] = lower[nodeid]; + upper[leftid] = upper[nodeid]; + + lower[rightid] = lower[nodeid]; + upper[rightid] = upper[nodeid]; + int32_t c = monotone[f]; + GradType mid = (left_weight + right_weight) / 2; + + if (c < 0) { + lower[leftid] = mid; + upper[rightid] = mid; + } else if (c > 0) { + upper[leftid] = mid; + lower[rightid] = mid; + } + }); + }).wait(); + } +}; +} // namespace tree +} // namespace xgboost + +#endif // XGBOOST_TREE_SPLIT_EVALUATOR_ONEAPI_H_ diff --git a/plugin/updater_oneapi/updater_quantile_hist_oneapi.cc b/plugin/updater_oneapi/updater_quantile_hist_oneapi.cc new file mode 100644 index 000000000000..579b903cb54c --- /dev/null +++ b/plugin/updater_oneapi/updater_quantile_hist_oneapi.cc @@ -0,0 +1,1501 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file updater_quantile_hist_oneapi.cc + */ +#include +#include + +#include "xgboost/logging.h" +#include "xgboost/tree_updater.h" + +#include "updater_quantile_hist_oneapi.h" +#include "data_oneapi.h" + +namespace xgboost { +namespace tree { + +using sycl::ext::oneapi::plus; +using sycl::ext::oneapi::minimum; +using sycl::ext::oneapi::maximum; + +DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_oneapi); + +DMLC_REGISTER_PARAMETER(OneAPIHistMakerTrainParam); + +void QuantileHistMakerOneAPI::Configure(const Args& args) { + const DeviceOrd device_spec = ctx_->Device(); + + sycl::device device = device_manager.GetDevice(device_spec); + bool is_cpu = device.is_cpu(); + LOG(INFO) << "device = " << device_spec.Name() << ", is_cpu = " << int(is_cpu); + + if (is_cpu) + { + updater_backend_.reset(TreeUpdater::Create("grow_quantile_histmaker", ctx_, task_)); + updater_backend_->Configure(args); + } + else + { + updater_backend_.reset(TreeUpdater::Create("grow_quantile_histmaker_oneapi_backend", ctx_, task_)); + updater_backend_->Configure(args); + } +} + +void QuantileHistMakerOneAPI::Update(TrainParam const *param, + HostDeviceVector *gpair, + DMatrix *dmat, + common::Span> out_position, + const std::vector &trees) { + updater_backend_->Update(param, gpair, dmat, out_position, trees); +} + +bool QuantileHistMakerOneAPI::UpdatePredictionCache( + const DMatrix* data, + linalg::MatrixView out_preds) { + return updater_backend_->UpdatePredictionCache(data, out_preds); +} + +void QuantileHistMakerOneAPIBackend::Configure(const Args& args) { + const DeviceOrd device_spec = ctx_->Device(); + qu_ = device_manager.GetQueue(device_spec); + + // initialize pruner + if (!pruner_) { + pruner_.reset(TreeUpdater::Create("prune", ctx_, task_)); + } + pruner_->Configure(args); + param_.UpdateAllowUnknown(args); + hist_maker_param_.UpdateAllowUnknown(args); +} + +template +void QuantileHistMakerOneAPIBackend::SetBuilder(std::unique_ptr>* builder, + DMatrix *dmat) { + builder->reset(new Builder( + qu_, + param_, + std::move(pruner_), + int_constraint_, dmat)); + if (rabit::IsDistributed()) { + (*builder)->SetHistSynchronizer(new DistributedHistSynchronizerOneAPI()); + (*builder)->SetHistRowsAdder(new DistributedHistRowsAdderOneAPI()); + } else { + (*builder)->SetHistSynchronizer(new BatchHistSynchronizerOneAPI()); + (*builder)->SetHistRowsAdder(new BatchHistRowsAdderOneAPI()); + } +} + +template +void QuantileHistMakerOneAPIBackend::CallBuilderUpdate(const std::unique_ptr>& builder, + TrainParam const *param, + HostDeviceVector *gpair, + DMatrix *dmat, + common::Span> out_position, + const std::vector &trees) { + const std::vector& gpair_h = gpair->ConstHostVector(); + USMVector gpair_device(qu_, gpair_h); + for (auto tree : trees) { + builder->Update(ctx_, param, gmat_, gpair, gpair_device, dmat, out_position, tree); + } +} +void QuantileHistMakerOneAPIBackend::Update(TrainParam const *param, + HostDeviceVector *gpair, + DMatrix *dmat, + common::Span> out_position, + const std::vector &trees) { + if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { + updater_monitor_.Start("GmatInitialization"); + DeviceMatrixOneAPI dmat_device(qu_, dmat); + gmat_.Init(qu_, ctx_, dmat_device, static_cast(param_.max_bin)); + updater_monitor_.Stop("GmatInitialization"); + is_gmat_initialized_ = true; + } + // rescale learning rate according to size of trees + float lr = param_.learning_rate; + param_.learning_rate = lr / trees.size(); + int_constraint_.Configure(param_, dmat->Info().num_col_); + // build tree + bool has_double_support = qu_.get_device().has(sycl::aspect::fp64); + if (hist_maker_param_.single_precision_histogram || !has_double_support) { + if (!hist_maker_param_.single_precision_histogram) { + LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True"; + } + if (!float_builder_) { + SetBuilder(&float_builder_, dmat); + } + CallBuilderUpdate(float_builder_, param, gpair, dmat, out_position, trees); + } else { + if (!double_builder_) { + SetBuilder(&double_builder_, dmat); + } + CallBuilderUpdate(double_builder_, param, gpair, dmat, out_position, trees); + } + + param_.learning_rate = lr; + + p_last_dmat_ = dmat; +} + +bool QuantileHistMakerOneAPIBackend::UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) { + if (param_.subsample < 1.0f) { + return false; + } else { + bool has_double_support = qu_.get_device().has(sycl::aspect::fp64); + if ((hist_maker_param_.single_precision_histogram || !has_double_support) && float_builder_) { + return float_builder_->UpdatePredictionCache(data, out_preds); + } else if (double_builder_) { + return double_builder_->UpdatePredictionCache(data, out_preds); + } else { + return false; + } + } +} + +template +void BatchHistSynchronizerOneAPI::SyncHistograms(BuilderT *builder, + std::vector& sync_ids, + RegTree *p_tree) { + builder->builder_monitor_.Start("SyncHistograms"); + const size_t nbins = builder->hist_builder_.GetNumBins(); + + hist_sync_events_.resize(builder->nodes_for_explicit_hist_build_.size()); + for (int i = 0; i < builder->nodes_for_explicit_hist_build_.size(); i++) { + const auto entry = builder->nodes_for_explicit_hist_build_[i]; + auto this_hist = builder->hist_[entry.nid]; + + if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { + const size_t parent_id = (*p_tree)[entry.nid].Parent(); + auto parent_hist = builder->hist_[parent_id]; + auto sibling_hist = builder->hist_[entry.sibling_nid]; + hist_sync_events_[i] = common::SubtractionHist(builder->qu_, sibling_hist, parent_hist, this_hist, nbins, sycl::event()); + } + } + builder->qu_.wait_and_throw(); + + builder->builder_monitor_.Stop("SyncHistograms"); +} + +template +void DistributedHistSynchronizerOneAPI::SyncHistograms(BuilderT* builder, + std::vector& sync_ids, + RegTree *p_tree) { + builder->builder_monitor_.Start("SyncHistograms"); + const size_t nbins = builder->hist_builder_.GetNumBins(); + for (int node = 0; node < builder->nodes_for_explicit_hist_build_.size(); node++) { + const auto entry = builder->nodes_for_explicit_hist_build_[node]; + auto this_hist = builder->hist_[entry.nid]; + // Store posible parent node + auto this_local = builder->hist_local_worker_[entry.nid]; + common::CopyHist(builder->qu_, this_local, this_hist, nbins); + + if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { + const size_t parent_id = (*p_tree)[entry.nid].Parent(); + auto parent_hist = builder->hist_local_worker_[parent_id]; + auto sibling_hist = builder->hist_[entry.sibling_nid]; + common::SubtractionHist(builder->qu_, sibling_hist, parent_hist, this_hist, nbins, sycl::event()); + // Store posible parent node + auto sibling_local = builder->hist_local_worker_[entry.sibling_nid]; + common::CopyHist(builder->qu_, sibling_local, sibling_hist, nbins); + } + } + builder->ReduceHists(sync_ids, nbins); + + ParallelSubtractionHist(builder, builder->nodes_for_explicit_hist_build_, p_tree); + ParallelSubtractionHist(builder, builder->nodes_for_subtraction_trick_, p_tree); + + builder->builder_monitor_.Stop("SyncHistograms"); +} + +template +void DistributedHistSynchronizerOneAPI::ParallelSubtractionHist( + BuilderT* builder, + const std::vector& nodes, + const RegTree * p_tree) { + const size_t nbins = builder->hist_builder_.GetNumBins(); + for (int node = 0; node < nodes.size(); node++) { + const auto entry = nodes[node]; + if (!((*p_tree)[entry.nid].IsLeftChild())) { + auto this_hist = builder->hist_[entry.nid]; + + if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { + auto parent_hist = builder->hist_[(*p_tree)[entry.nid].Parent()]; + auto sibling_hist = builder->hist_[entry.sibling_nid]; + common::SubtractionHist(builder->qu_, this_hist, parent_hist, sibling_hist, nbins, sycl::event()); + } + } + } +} + +template +void QuantileHistMakerOneAPIBackend::Builder::ReduceHists(std::vector& sync_ids, size_t nbins) { + std::vector reduce_buffer(sync_ids.size() * nbins); + for (size_t i = 0; i < sync_ids.size(); i++) { + auto this_hist = hist_[sync_ids[i]]; + const GradientPairT* psrc = reinterpret_cast(this_hist.DataConst()); + std::copy(psrc, psrc + nbins, reduce_buffer.begin() + i * nbins); + } + collective::Allreduce( + reinterpret_cast(reduce_buffer.data()), + 2 * nbins * sync_ids.size()); + // histred_.Allreduce(reduce_buffer.data(), nbins * sync_ids.size()); + for (size_t i = 0; i < sync_ids.size(); i++) { + auto this_hist = hist_[sync_ids[i]]; + GradientPairT* psrc = reinterpret_cast(this_hist.Data()); + std::copy(reduce_buffer.begin() + i * nbins, reduce_buffer.begin() + (i + 1) * nbins, psrc); + } +} + +template +void BatchHistRowsAdderOneAPI::AddHistRows(BuilderT *builder, + std::vector& sync_ids, + RegTree *p_tree) { + builder->builder_monitor_.Start("AddHistRows"); + + int max_nid = 0; + for (auto const& entry : builder->nodes_for_explicit_hist_build_) { + int nid = entry.nid; + max_nid = nid > max_nid ? nid : max_nid; + } + for (auto const& node : builder->nodes_for_subtraction_trick_) { + max_nid = node.nid > max_nid ? node.nid : max_nid; + } + + builder->hist_.Reserve(max_nid); + for (auto const& entry : builder->nodes_for_explicit_hist_build_) { + int nid = entry.nid; + auto event = builder->hist_.AddHistRow(nid); + } + for (auto const& node : builder->nodes_for_subtraction_trick_) { + auto event = builder->hist_.AddHistRow(node.nid); + } + builder->hist_.Wait_and_throw(); + + builder->builder_monitor_.Stop("AddHistRows"); +} + +template +void DistributedHistRowsAdderOneAPI::AddHistRows(BuilderT *builder, + std::vector& sync_ids, + RegTree *p_tree) { + builder->builder_monitor_.Start("AddHistRows"); + const size_t explicit_size = builder->nodes_for_explicit_hist_build_.size(); + const size_t subtaction_size = builder->nodes_for_subtraction_trick_.size(); + std::vector merged_node_ids(explicit_size + subtaction_size); + for (size_t i = 0; i < explicit_size; ++i) { + merged_node_ids[i] = builder->nodes_for_explicit_hist_build_[i].nid; + } + for (size_t i = 0; i < subtaction_size; ++i) { + merged_node_ids[explicit_size + i] = + builder->nodes_for_subtraction_trick_[i].nid; + } + std::sort(merged_node_ids.begin(), merged_node_ids.end()); + sync_ids.clear(); + for (auto const& nid : merged_node_ids) { + if ((*p_tree)[nid].IsLeftChild()) { + builder->hist_.AddHistRow(nid); + builder->hist_local_worker_.AddHistRow(nid); + sync_ids.push_back(nid); + } + } + for (auto const& nid : merged_node_ids) { + if (!((*p_tree)[nid].IsLeftChild())) { + builder->hist_.AddHistRow(nid); + builder->hist_local_worker_.AddHistRow(nid); + } + } + builder->builder_monitor_.Stop("AddHistRows"); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::SetHistSynchronizer( + HistSynchronizerOneAPI *sync) { + hist_synchronizer_.reset(sync); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::SetHistRowsAdder( + HistRowsAdderOneAPI *adder) { + hist_rows_adder_.reset(adder); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::BuildHistogramsLossGuide( + ExpandEntry entry, + const GHistIndexMatrixOneAPI &gmat, + RegTree *p_tree, + const USMVector &gpair_device) { + nodes_for_explicit_hist_build_.clear(); + nodes_for_subtraction_trick_.clear(); + nodes_for_explicit_hist_build_.push_back(entry); + + if (entry.sibling_nid > -1) { + nodes_for_subtraction_trick_.emplace_back(entry.sibling_nid, entry.nid, + p_tree->GetDepth(entry.sibling_nid), 0.0f, 0); + } + + std::vector sync_ids; + + hist_rows_adder_->AddHistRows(this, sync_ids, p_tree); + BuildLocalHistograms(gmat, p_tree, gpair_device); + hist_synchronizer_->SyncHistograms(this, sync_ids, p_tree); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::BuildLocalHistograms( + const GHistIndexMatrixOneAPI &gmat, + RegTree *p_tree, + const USMVector &gpair_device) { + builder_monitor_.Start("BuildLocalHistogramsOneAPI"); + const size_t n_nodes = nodes_for_explicit_hist_build_.size(); + for (auto& event : hist_build_events_) { + event = sycl::event(); + } + + const size_t event_idx = 0; + for (size_t i = 0; i < n_nodes; i++) { + const int32_t nid = nodes_for_explicit_hist_build_[i].nid; + + if (row_set_collection_[nid].Size() > 0) { + const size_t event_idx = (event_idx + 1) % kNumParallelBuffers; + auto& event = hist_build_events_[event_idx]; + auto& hist_buff = hist_buffers_[event_idx]; + + event = BuildHist(gpair_device, row_set_collection_[nid], gmat, hist_[nid], hist_buff.GetDeviceBuffer(), event); + } else { + common::InitHist(qu_, hist_[nid], hist_[nid].Size()); + } + } + qu_.wait_and_throw(); + builder_monitor_.Stop("BuildLocalHistogramsOneAPI"); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::BuildNodeStats( + const GHistIndexMatrixOneAPI &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair) { + builder_monitor_.Start("BuildNodeStats"); + for (auto const& entry : qexpand_depth_wise_) { + int nid = entry.nid; + this->InitNewNode(nid, gmat, gpair, *p_fmat, *p_tree); + // add constraints + if (!(*p_tree)[nid].IsLeftChild() && !(*p_tree)[nid].IsRoot()) { + // it's a right child + auto parent_id = (*p_tree)[nid].Parent(); + auto left_sibling_id = (*p_tree)[parent_id].LeftChild(); + auto parent_split_feature_id = snode_[parent_id].best.SplitIndex(); + tree_evaluator_.AddSplit( + parent_id, left_sibling_id, nid, parent_split_feature_id, + snode_[left_sibling_id].weight, snode_[nid].weight); + interaction_constraints_.Split(parent_id, parent_split_feature_id, + left_sibling_id, nid); + } + } + builder_monitor_.Stop("BuildNodeStats"); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::AddSplitsToTree( + const GHistIndexMatrixOneAPI &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector* nodes_for_apply_split, + std::vector* temp_qexpand_depth) { + auto evaluator = tree_evaluator_.GetEvaluator(); + for (auto const& entry : qexpand_depth_wise_) { + int nid = entry.nid; + + if (snode_[nid].best.loss_chg < kRtEps || + (param_.max_depth > 0 && depth == param_.max_depth) || + (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { + (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); + } else { + nodes_for_apply_split->push_back(entry); + + NodeEntry& e = snode_[nid]; + bst_float left_leaf_weight = + evaluator.CalcWeight(nid, GradStatsOneAPI{e.best.left_sum}) * param_.learning_rate; + bst_float right_leaf_weight = + evaluator.CalcWeight(nid, GradStatsOneAPI{e.best.right_sum}) * param_.learning_rate; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), + e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); + + int left_id = (*p_tree)[nid].LeftChild(); + int right_id = (*p_tree)[nid].RightChild(); + temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id, + p_tree->GetDepth(left_id), 0.0, (*timestamp)++)); + temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id, + p_tree->GetDepth(right_id), 0.0, (*timestamp)++)); + // - 1 parent + 2 new children + (*num_leaves)++; + } + } +} + +template +void QuantileHistMakerOneAPIBackend::Builder::EvaluateAndApplySplits( + const GHistIndexMatrixOneAPI &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector *temp_qexpand_depth) { + EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree); + + std::vector nodes_for_apply_split; + AddSplitsToTree(gmat, p_tree, num_leaves, depth, timestamp, + &nodes_for_apply_split, temp_qexpand_depth); + ApplySplit(nodes_for_apply_split, gmat, hist_, p_tree); +} + +// Split nodes to 2 sets depending on amount of rows in each node +// Histograms for small nodes will be built explicitly +// Histograms for big nodes will be built by 'Subtraction Trick' +// Exception: in distributed setting, we always build the histogram for the left child node +// and use 'Subtraction Trick' to built the histogram for the right child node. +// This ensures that the workers operate on the same set of tree nodes. +template +void QuantileHistMakerOneAPIBackend::Builder::SplitSiblings( + const std::vector &nodes, + std::vector *small_siblings, + std::vector *big_siblings, + RegTree *p_tree) { + builder_monitor_.Start("SplitSiblings"); + for (auto const& entry : nodes) { + int nid = entry.nid; + RegTree::Node &node = (*p_tree)[nid]; + if (node.IsRoot()) { + small_siblings->push_back(entry); + } else { + const int32_t left_id = (*p_tree)[node.Parent()].LeftChild(); + const int32_t right_id = (*p_tree)[node.Parent()].RightChild(); + + if (nid == left_id && row_set_collection_[left_id ].Size() < + row_set_collection_[right_id].Size()) { + small_siblings->push_back(entry); + } else if (nid == right_id && row_set_collection_[right_id].Size() <= + row_set_collection_[left_id ].Size()) { + small_siblings->push_back(entry); + } else { + big_siblings->push_back(entry); + } + } + } + builder_monitor_.Stop("SplitSiblings"); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::ExpandWithDepthWise( + const GHistIndexMatrixOneAPI &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair, + const USMVector &gpair_device) { + unsigned timestamp = 0; + int num_leaves = 0; + + // in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway + qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid, + p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++)); + ++num_leaves; + for (int depth = 0; depth < param_.max_depth + 1; depth++) { + std::vector sync_ids; + std::vector temp_qexpand_depth; + SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_, + &nodes_for_subtraction_trick_, p_tree); + hist_rows_adder_->AddHistRows(this, sync_ids, p_tree); + BuildLocalHistograms(gmat, p_tree, gpair_device); + hist_synchronizer_->SyncHistograms(this, sync_ids, p_tree); + BuildNodeStats(gmat, p_fmat, p_tree, gpair); + + EvaluateAndApplySplits(gmat, p_tree, &num_leaves, depth, ×tamp, + &temp_qexpand_depth); + + // clean up + qexpand_depth_wise_.clear(); + nodes_for_subtraction_trick_.clear(); + nodes_for_explicit_hist_build_.clear(); + if (temp_qexpand_depth.empty()) { + break; + } else { + qexpand_depth_wise_ = temp_qexpand_depth; + temp_qexpand_depth.clear(); + } + } +} + +template +void QuantileHistMakerOneAPIBackend::Builder::ExpandWithLossGuide( + const GHistIndexMatrixOneAPI& gmat, + DMatrix* p_fmat, + RegTree* p_tree, + const std::vector &gpair, + const USMVector &gpair_device) { + builder_monitor_.Start("ExpandWithLossGuide"); + unsigned timestamp = 0; + int num_leaves = 0; + + ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid, + p_tree->GetDepth(0), 0.0f, timestamp++); + BuildHistogramsLossGuide(node, gmat, p_tree, gpair_device); + + this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, *p_tree); + + this->EvaluateSplits({node}, gmat, hist_, *p_tree); + node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg; + + qexpand_loss_guided_->push(node); + ++num_leaves; + + while (!qexpand_loss_guided_->empty()) { + const ExpandEntry candidate = qexpand_loss_guided_->top(); + const int nid = candidate.nid; + qexpand_loss_guided_->pop(); + if (candidate.IsValid(param_, num_leaves)) { + (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); + } else { + auto evaluator = tree_evaluator_.GetEvaluator(); + NodeEntry& e = snode_[nid]; + bst_float left_leaf_weight = + evaluator.CalcWeight(nid, GradStatsOneAPI{e.best.left_sum}) * param_.learning_rate; + bst_float right_leaf_weight = + evaluator.CalcWeight(nid, GradStatsOneAPI{e.best.right_sum}) * param_.learning_rate; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), + e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); + + this->ApplySplit({candidate}, gmat, hist_, p_tree); + + const int cleft = (*p_tree)[nid].LeftChild(); + const int cright = (*p_tree)[nid].RightChild(); + + ExpandEntry left_node(cleft, cright, p_tree->GetDepth(cleft), + 0.0f, timestamp++); + ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright), + 0.0f, timestamp++); + + if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { + BuildHistogramsLossGuide(left_node, gmat, p_tree, gpair_device); + } else { + BuildHistogramsLossGuide(right_node, gmat, p_tree, gpair_device); + } + + this->InitNewNode(cleft, gmat, gpair, *p_fmat, *p_tree); + this->InitNewNode(cright, gmat, gpair, *p_fmat, *p_tree); + bst_uint featureid = snode_[nid].best.SplitIndex(); + tree_evaluator_.AddSplit(nid, cleft, cright, featureid, + snode_[cleft].weight, snode_[cright].weight); + interaction_constraints_.Split(nid, featureid, cleft, cright); + + this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree); + left_node.loss_chg = snode_[cleft].best.loss_chg; + right_node.loss_chg = snode_[cright].best.loss_chg; + + qexpand_loss_guided_->push(left_node); + qexpand_loss_guided_->push(right_node); + + ++num_leaves; // give two and take one, as parent is no longer a leaf + } + } + builder_monitor_.Stop("ExpandWithLossGuide"); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::Update( + Context const * ctx, + TrainParam const *param, + const GHistIndexMatrixOneAPI &gmat, + HostDeviceVector *gpair, + const USMVector& gpair_device, + DMatrix *p_fmat, + common::Span> out_position, + RegTree *p_tree) { + builder_monitor_.Start("Update"); + + const std::vector& gpair_h = gpair->ConstHostVector(); + tree_evaluator_ = TreeEvaluatorOneAPI(qu_, param_, p_fmat->Info().num_col_); + interaction_constraints_.Reset(); + + this->InitData(ctx, gmat, gpair_h, gpair_device, *p_fmat, *p_tree); + if (param_.grow_policy == TrainParam::kLossGuide) { + ExpandWithLossGuide(gmat, p_fmat, p_tree, gpair_h, gpair_device); + } else { + ExpandWithDepthWise(gmat, p_fmat, p_tree, gpair_h, gpair_device); + } + + for (int nid = 0; nid < p_tree->NumNodes(); ++nid) { + p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; + p_tree->Stat(nid).base_weight = snode_[nid].weight; + p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.GetHess()); + } + pruner_->Update(param, gpair, p_fmat, out_position, std::vector{p_tree}); + + builder_monitor_.Stop("Update"); +} + +template +bool QuantileHistMakerOneAPIBackend::Builder::UpdatePredictionCache( + const DMatrix* data, + linalg::MatrixView out_preds) { + // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in + // conjunction with Update(). + if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { + return false; + } + builder_monitor_.Start("UpdatePredictionCache"); + CHECK_GT(out_preds.Size(), 0U); + + const size_t stride = out_preds.Stride(0); + const int buffer_size = out_preds.Size()*stride - stride + 1; + sycl::buffer out_preds_buf(&out_preds(0), buffer_size); + + size_t n_nodes = row_set_collection_.Size(); + for (size_t node = 0; node < n_nodes; node++) { + const RowSetCollectionOneAPI::Elem& rowset = row_set_collection_[node]; + if (rowset.begin != nullptr && rowset.end != nullptr && rowset.Size() != 0) { + int nid = rowset.node_id; + bst_float leaf_value; + // if a node is marked as deleted by the pruner, traverse upward to locate + // a non-deleted leaf. + if ((*p_last_tree_)[nid].IsDeleted()) { + while ((*p_last_tree_)[nid].IsDeleted()) { + nid = (*p_last_tree_)[nid].Parent(); + } + CHECK((*p_last_tree_)[nid].IsLeaf()); + } + leaf_value = (*p_last_tree_)[nid].LeafValue(); + + const size_t* rid = rowset.begin; + const size_t num_rows = rowset.Size(); + + qu_.submit([&](sycl::handler& cgh) { + auto out_predictions = out_preds_buf.template get_access(cgh); + cgh.parallel_for<>(sycl::range<1>(num_rows), [=](sycl::item<1> pid) { + out_predictions[rid[pid.get_id(0)]*stride] += leaf_value; + }); + }).wait(); + } + } + + builder_monitor_.Stop("UpdatePredictionCache"); + return true; +} +template +void QuantileHistMakerOneAPIBackend::Builder::InitSampling(const std::vector& gpair, + const USMVector &gpair_device, + const DMatrix& fmat, + USMVector& row_indices_device) { + const auto& info = fmat.Info(); + auto& rnd = common::GlobalRandom(); +#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG + std::bernoulli_distribution coin_flip(param_.subsample); + size_t j = 0; + + std::vector row_indices(row_indices_device.Size()); + qu_.memcpy(row_indices.data(), row_indices_device.DataConst(), row_indices.size() * sizeof(size_t)).wait(); + for (size_t i = 0; i < info.num_row_; ++i) { + if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) { + row_indices[j++] = i; + } + } + qu_.memcpy(row_indices_device.Data(), row_indices.data(), row_indices.size() * sizeof(size_t)).wait(); + /* resize row_indices to reduce memory */ + row_indices_device.Resize(qu_, j); +#else + const size_t nthread = this->nthread_; + std::vector row_offsets(nthread, 0); + /* usage of mt19937_64 give 2x speed up for subsampling */ + std::vector rnds(nthread); + /* create engine for each thread */ + for (std::mt19937& r : rnds) { + r = rnd; + } + + std::vector row_indices(row_indices_device.Size()); + qu_.memcpy(row_indices.data(), row_indices_device.DataConst(), row_indices.size() * sizeof(size_t)).wait(); + const size_t discard_size = info.num_row_ / nthread; + #pragma omp parallel num_threads(nthread) + { + const size_t tid = omp_get_thread_num(); + const size_t ibegin = tid * discard_size; + const size_t iend = (tid == (nthread - 1)) ? + info.num_row_ : ibegin + discard_size; + std::bernoulli_distribution coin_flip(param_.subsample); + + rnds[tid].discard(2*discard_size * tid); + for (size_t i = ibegin; i < iend; ++i) { + if (gpair[i].GetHess() >= 0.0f && coin_flip(rnds[tid])) { + row_indices[ibegin + row_offsets[tid]++] = i; + } + } + } + + /* discard global engine */ + rnd = rnds[nthread - 1]; + size_t prefix_sum = row_offsets[0]; + for (size_t i = 1; i < nthread; ++i) { + const size_t ibegin = i * discard_size; + + for (size_t k = 0; k < row_offsets[i]; ++k) { + row_indices[prefix_sum + k] = row_indices[ibegin + k]; + } + prefix_sum += row_offsets[i]; + } + qu_.memcpy(row_indices_device.Data(), row_indices.data(), row_indices.size() * sizeof(size_t)).wait(); + row_indices_device.Resize(qu_, prefix_sum); + + /* + const size_t size = info.num_row_; + const size_t min_block_size = 128; + const size_t nblocks = size / min_block_size + !!(size % min_block_size); + const size_t block_size = size / nblocks + !!(size % nblocks); + + std::vector rnds(nblocks); + std::bernoulli_distribution coin_flip(param_.subsample); + std::vector coin_flips(nblocks); + + #pragma omp parallel for + for (size_t block = 0; block < nblocks; ++block) { + rnds[block] = rnd; + rnds[block].discard(2 * block_size * block); + coin_flips[block] = coin_flip(rnds[block]); + } + rnd = rnds[nblocks - 1]; + + USMVector coin_flips_device(qu_, coin_flips); + USMVector row_offsets(qu_, nblocks, 0); + size_t* offsets_ptr = row_offsets.Data(); + size_t* indices_ptr = row_indices.Data(); + const GradientPair* gpair_ptr = gpair_device.DataConst(); + const uint8_t* coin_flips_ptr = coin_flips_device.DataConst(); + std::vector events; + events.emplace_back(qu_.submit([&](sycl::handler& cgh) { + cgh.parallel_for<>(sycl::range<1>(sycl::range<1>(nblocks)), + [offsets_ptr, indices_ptr, coin_flips_ptr, block_size, size, gpair_ptr](sycl::item<1> pid) { + const size_t block = pid.get_id(0); + + size_t start = block * block_size; + size_t end = (block + 1) * block_size; + if (end > size) { + end = size; + } + for (size_t i = start; i < end; ++i) { + if (gpair_ptr[i].GetHess() >= 0.0f && coin_flips_ptr[block]) { + indices_ptr[start + offsets_ptr[block]++] = i; + } + } + }); + })); + + size_t prefix_sum = row_indices.Get(qu_, 0, &events); + for (size_t i = 1; i < nblocks; ++i) { + const size_t ibegin = i * block_size; + const size_t idx = row_indices.Get(qu_, i, &events); + qu_.submit([&](sycl::handler& cgh) { + cgh.depends_on(events); + cgh.parallel_for<>(sycl::range<1>(sycl::range<1>(idx)), + [indices_ptr, prefix_sum, ibegin](sycl::item<1> pid) { + const size_t k = pid.get_id(0); + indices_ptr[prefix_sum + k] = indices_ptr[ibegin + k]; + }); + prefix_sum += row_indices.Get(qu_, i, &events); + }).wait_and_throw(); + } + */ + /* resize row_indices to reduce memory */ + // row_indices.Resize(qu_, prefix_sum); + +#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG +} +template +void QuantileHistMakerOneAPIBackend::Builder::InitData( + Context const * ctx, + const GHistIndexMatrixOneAPI& gmat, + const std::vector& gpair, + const USMVector &gpair_device, + const DMatrix& fmat, + const RegTree& tree) { + CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) + << "max_depth or max_leaves cannot be both 0 (unlimited); " + << "at least one should be a positive quantity."; + if (param_.grow_policy == TrainParam::kDepthWise) { + CHECK(param_.max_depth > 0) << "max_depth cannot be 0 (unlimited) " + << "when grow_policy is depthwise."; + } + builder_monitor_.Start("InitData"); + const auto& info = fmat.Info(); + + { + // initialize the row set + row_set_collection_.Clear(); + // initialize histogram collection + uint32_t nbins = gmat.cut.Ptrs().back(); + hist_.Init(qu_, nbins); + hist_local_worker_.Init(qu_, nbins); + for (auto& buffer : hist_buffers_) { + buffer.Init(qu_, nbins); + size_t buffer_size = 2048; + if (buffer_size > info.num_row_ / 128 + 1) { + buffer_size = info.num_row_ / 128 + 1; + } + buffer.Reset(buffer_size); + // buffer.Reset(2048); + } + + // initialize histogram builder +#pragma omp parallel + { + this->nthread_ = omp_get_num_threads(); + } + hist_builder_ = GHistBuilderOneAPI(qu_, nbins); + + USMVector& row_indices = row_set_collection_.Data(); + row_indices.Resize(qu_, info.num_row_); + size_t* p_row_indices = row_indices.Data(); + // mark subsample and build list of member rows + + if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; + InitSampling(gpair, gpair_device, fmat, row_indices); + } else { + MemStackAllocatorOneAPI buff(this->nthread_); + bool* p_buff = buff.Get(); + std::fill(p_buff, p_buff + this->nthread_, false); + + const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_); + + #pragma omp parallel num_threads(this->nthread_) + { + const size_t tid = omp_get_thread_num(); + const size_t ibegin = tid * block_size; + const size_t iend = std::min(static_cast(ibegin + block_size), + static_cast(info.num_row_)); + + for (size_t i = ibegin; i < iend; ++i) { + if (gpair[i].GetHess() < 0.0f) { + p_buff[tid] = true; + break; + } + } + } + + bool has_neg_hess = false; + for (int32_t tid = 0; tid < this->nthread_; ++tid) { + if (p_buff[tid]) { + has_neg_hess = true; + } + } + + if (has_neg_hess) { + size_t j = 0; + std::vector row_indices_buff(row_indices.Size()); + for (size_t i = 0; i < info.num_row_; ++i) { + if (gpair[i].GetHess() >= 0.0f) { + row_indices_buff[j++] = i; + } + } + qu_.memcpy(p_row_indices, row_indices_buff.data(), j * sizeof(size_t)).wait(); + row_indices.Resize(qu_, j); + } else { + qu_.submit([&](sycl::handler& cgh) { + cgh.parallel_for<>(sycl::range<1>(sycl::range<1>(info.num_row_)), + [p_row_indices](sycl::item<1> pid) { + const size_t idx = pid.get_id(0); + p_row_indices[idx] = idx; + }); + }).wait_and_throw(); + } + } + } + + row_set_collection_.Init(); + + { + /* determine layout of data */ + const size_t nrow = info.num_row_; + const size_t ncol = info.num_col_; + const size_t nnz = info.num_nonzero_; + // number of discrete bins for feature 0 + const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0]; + if (nrow * ncol == nnz) { + // dense data with zero-based indexing + data_layout_ = kDenseDataZeroBased; + } else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) { + // dense data with one-based indexing + data_layout_ = kDenseDataOneBased; + } else { + // sparse data + data_layout_ = kSparseData; + } + } + // store a pointer to the tree + p_last_tree_ = &tree; + column_sampler_.Init(ctx, info.num_col_, info.feature_weights.ConstHostVector(), + param_.colsample_bynode, param_.colsample_bylevel, + param_.colsample_bytree); + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + /* specialized code for dense data: + choose the column that has a least positive number of discrete bins. + For dense data (with no missing value), + the sum of gradient histogram is equal to snode[nid] */ + const std::vector& row_ptr = gmat.cut.Ptrs(); + const auto nfeature = static_cast(row_ptr.size() - 1); + uint32_t min_nbins_per_feature = 0; + for (bst_uint i = 0; i < nfeature; ++i) { + const uint32_t nbins = row_ptr[i + 1] - row_ptr[i]; + if (nbins > 0) { + if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) { + min_nbins_per_feature = nbins; + fid_least_bins_ = i; + } + } + } + CHECK_GT(min_nbins_per_feature, 0U); + } + { + snode_.Fill(qu_, NodeEntry(param_)); + qu_.wait_and_throw(); + } + { + if (param_.grow_policy == TrainParam::kLossGuide) { + qexpand_loss_guided_.reset(new ExpandQueue(LossGuide)); + } else { + qexpand_depth_wise_.clear(); + } + } + builder_monitor_.Stop("InitData"); +} + +// if sum of statistics for non-missing values in the node +// is equal to sum of statistics for all values: +// then - there are no missing values +// else - there are missing values +template +bool QuantileHistMakerOneAPIBackend::Builder::SplitContainsMissingValues( + const GradStatsOneAPI& e, const NodeEntry& snode) { + if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { + return false; + } else { + return true; + } +} + +// nodes_set - set of nodes to be processed in parallel +template +void QuantileHistMakerOneAPIBackend::Builder::EvaluateSplits( + const std::vector& nodes_set, + const GHistIndexMatrixOneAPI& gmat, + const HistCollectionOneAPI& hist, + const RegTree& tree) { + builder_monitor_.Start("EvaluateSplits"); + + const size_t n_nodes_in_set = nodes_set.size(); + + using FeatureSetType = std::shared_ptr>; + std::vector features_sets(n_nodes_in_set); + + // Generate feature set for each tree node + size_t total_features = 0; + for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) { + const int32_t nid = nodes_set[nid_in_set].nid; + features_sets[nid_in_set] = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); + for (size_t idx_in_feature_set = 0; idx_in_feature_set < features_sets[nid_in_set]->Size(); idx_in_feature_set++) { + const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set]; + if (interaction_constraints_.Query(nid, fid)) { + total_features++; + } + } + } + + split_queries_device_.Clear(); + split_queries_device_.Resize(qu_, total_features); + + size_t pos = 0; + + const size_t local_size = 16; + + for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) { + const size_t nid = nodes_set[nid_in_set].nid; + + for (size_t idx_in_feature_set = 0; idx_in_feature_set < features_sets[nid_in_set]->Size(); idx_in_feature_set++) { + const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set]; + if (interaction_constraints_.Query(nid, fid)) { + split_queries_device_[pos].nid = nid; + split_queries_device_[pos].fid = fid; + split_queries_device_[pos].hist = hist[nid].DataConst(); + split_queries_device_[pos].best = snode_[nid].best; + pos++; + } + } + } + + auto evaluator = tree_evaluator_.GetEvaluator(); + SplitQuery* split_queries_device = split_queries_device_.Data(); + const uint32_t* cut_ptr = gmat.cut_device.Ptrs().DataConst(); + const bst_float* cut_val = gmat.cut_device.Values().DataConst(); + const bst_float* cut_minval = gmat.cut_device.MinValues().DataConst(); + const NodeEntry* snode = snode_.DataConst(); + + TrainParamOneAPI param(param_); + + qu_.submit([&](sycl::handler& cgh) { + cgh.parallel_for<>(sycl::nd_range<2>(sycl::range<2>(total_features, local_size), + sycl::range<2>(1, local_size)), [=](sycl::nd_item<2> pid) [[intel::reqd_sub_group_size(16)]] { + TrainParamOneAPI param_device(param); + typename TreeEvaluatorOneAPI::SplitEvaluator evaluator_device = evaluator; + int i = pid.get_global_id(0); + auto sg = pid.get_sub_group(); + int nid = split_queries_device[i].nid; + int fid = split_queries_device[i].fid; + const GradientPairT* hist_data = split_queries_device[i].hist; + auto grad_stats = EnumerateSplit(sg, cut_ptr, cut_val, hist_data, snode[nid], + split_queries_device[i].best, fid, nid, evaluator_device, param_device); + }); + }).wait(); + + for (size_t i = 0; i < total_features; i++) { + int nid = split_queries_device[i].nid; + snode_[nid].best.Update(split_queries_device[i].best); + } + + builder_monitor_.Stop("EvaluateSplits"); +} + +// Enumerate the split values of specific feature. +// Returns the sum of gradients corresponding to the data points that contains a non-missing value +// for the particular feature fid. +template +template +GradStatsOneAPI QuantileHistMakerOneAPIBackend::Builder::EnumerateSplit( + const uint32_t* cut_ptr, + const bst_float* cut_val, + const bst_float* cut_minval, + const GradientPairT* hist_data, + const NodeEntry& snode, + SplitEntryOneAPI& p_best, + bst_uint fid, + bst_uint nodeID, + typename TreeEvaluatorOneAPI::SplitEvaluator const &evaluator_device, + const TrainParamOneAPI& param) { + GradStatsOneAPI c; + GradStatsOneAPI e; + // best split so far + SplitEntryOneAPI best; + + // bin boundaries + // imin: index (offset) of the minimum value for feature fid + // need this for backward enumeration + const auto imin = static_cast(cut_ptr[fid]); + // ibegin, iend: smallest/largest cut points for feature fid + // use int to allow for value -1 + int32_t ibegin, iend; + if (d_step > 0) { + ibegin = static_cast(cut_ptr[fid]); + iend = static_cast(cut_ptr[fid + 1]); + } else { + ibegin = static_cast(cut_ptr[fid + 1]) - 1; + iend = static_cast(cut_ptr[fid]) - 1; + } + + for (int32_t i = ibegin; i != iend; i += d_step) { + e.Add(hist_data[i].GetGrad(), hist_data[i].GetHess()); + if (e.GetHess() >= param.min_child_weight) { + c.SetSubstract(snode.stats, e); + if (c.GetHess() >= param.min_child_weight) { + bst_float loss_chg; + bst_float split_pt; + if (d_step > 0) { + loss_chg = static_cast( + evaluator_device.CalcSplitGain(nodeID, fid, e, c) - snode.root_gain); + split_pt = cut_val[i]; + best.Update(loss_chg, fid, split_pt, d_step == -1, e, c); + } else { + loss_chg = static_cast( + evaluator_device.CalcSplitGain(nodeID, fid, GradStatsOneAPI{c}, GradStatsOneAPI{e}) - snode.root_gain); + if (i == imin) { + split_pt = cut_minval[fid]; + } else { + split_pt = cut_val[i - 1]; + } + best.Update(loss_chg, fid, split_pt, d_step == -1, c, e); + } + } + } + } + p_best.Update(best); + return e; +} + +// Enumerate the split values of specific feature. +// Returns the sum of gradients corresponding to the data points that contains a non-missing value +// for the particular feature fid. +template +GradStatsOneAPI QuantileHistMakerOneAPIBackend::Builder::EnumerateSplit( + sycl::sub_group& sg, + const uint32_t* cut_ptr, + const bst_float* cut_val, + const GradientPairT* hist_data, + const NodeEntry& snode, + SplitEntryOneAPI& p_best, + bst_uint fid, + bst_uint nodeID, + typename TreeEvaluatorOneAPI::SplitEvaluator const &evaluator_device, + const TrainParamOneAPI& param) { + SplitEntryOneAPI best; + + int32_t ibegin = static_cast(cut_ptr[fid]); + int32_t iend = static_cast(cut_ptr[fid + 1]); + + GradientSumT tot_grad = snode.stats.GetGrad(); + GradientSumT tot_hess = snode.stats.GetHess(); + + GradientSumT sum_grad = 0.0f; + GradientSumT sum_hess = 0.0f; + + int32_t local_size = sg.get_local_range().size(); + + for (int32_t i = ibegin + sg.get_local_id(); i < iend; i += local_size) { + GradientSumT e_grad = sum_grad + sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()); + GradientSumT e_hess = sum_hess + sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>()); + if (e_hess >= param.min_child_weight) { + GradientSumT c_grad = tot_grad - e_grad; + GradientSumT c_hess = tot_hess - e_hess; + if (c_hess >= param.min_child_weight) { + GradStatsOneAPI e(e_grad, e_hess); + GradStatsOneAPI c(c_grad, c_hess); + bst_float loss_chg; + bst_float split_pt; + loss_chg = static_cast( + evaluator_device.CalcSplitGain(nodeID, fid, e, c) - snode.root_gain); + split_pt = cut_val[i]; + best.Update(loss_chg, fid, split_pt, false, e, c); + } + } + sum_grad += sycl::reduce_over_group(sg, hist_data[i].GetGrad(), std::plus<>()); + sum_hess += sycl::reduce_over_group(sg, hist_data[i].GetHess(), std::plus<>()); + } + + bst_float total_loss_chg = sycl::reduce_over_group(sg, best.loss_chg, maximum<>()); + bst_feature_t total_split_index = sycl::reduce_over_group(sg, best.loss_chg == total_loss_chg ? best.SplitIndex() : (1U << 31) - 1U, minimum<>()); + if (best.loss_chg == total_loss_chg && best.SplitIndex() == total_split_index) p_best.Update(best); + return GradStatsOneAPI(sum_grad, sum_hess); +} + +// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending +// on comparison of indexes values (idx_span) and split point (split_cond) +// Handle dense columns +template +inline sycl::event PartitionDenseKernel(sycl::queue& qu, + const GHistIndexMatrixOneAPI& gmat, + const RowSetCollectionOneAPI::Elem& rid_span, + const size_t fid, + const int32_t split_cond, + common::Span& rid_buf, + size_t* parts_size, + sycl::event priv_event) { + const size_t row_stride = gmat.row_stride; + const BinIdxType* gradient_index = gmat.index.data(); + const size_t* rid = rid_span.begin; + const size_t range_size = rid_span.Size(); + const size_t offset = gmat.cut.Ptrs()[fid]; + + size_t* p_rid_buf = rid_buf.data(); + + auto event = qu.submit([&](sycl::handler& cgh) { + cgh.depends_on(priv_event); + cgh.parallel_for<>(sycl::range<1>(range_size), [=](sycl::item<1> nid) { + const size_t id = rid[nid.get_id(0)]; + const int32_t value = static_cast(gradient_index[id * row_stride + fid] + offset); + const bool is_left = value <= split_cond; + if (is_left) { + common::AtomicRef n_left(parts_size[0]); + p_rid_buf[n_left.fetch_add(1)] = id; + } else { + common::AtomicRef n_right(parts_size[1]); + p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id; + } + }); + }); + return event; +} + +// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending +// on comparison of indexes values (idx_span) and split point (split_cond) +// Handle dense columns +template +inline sycl::event PartitionSparseKernel(sycl::queue& qu, + const GHistIndexMatrixOneAPI& gmat, + const RowSetCollectionOneAPI::Elem& rid_span, + const size_t fid, + const int32_t split_cond, + common::Span& rid_buf, + size_t* parts_size, + sycl::event priv_event) { + const size_t row_stride = gmat.row_stride; + const BinIdxType* gradient_index = gmat.index.data(); + const size_t* rid = rid_span.begin; + const size_t range_size = rid_span.Size(); + const uint32_t* cut_ptrs = gmat.cut_device.Ptrs().DataConst(); + const bst_float* cut_vals = gmat.cut_device.Values().DataConst(); + + size_t* p_rid_buf = rid_buf.data(); + auto event = qu.submit([&](sycl::handler& cgh) { + cgh.depends_on(priv_event); + cgh.parallel_for<>(sycl::range<1>(range_size), [=](sycl::item<1> nid) { + const size_t id = rid[nid.get_id(0)]; + + const BinIdxType* gr_index_local = gradient_index + row_stride * id; + const int32_t fid_local = std::lower_bound(gr_index_local, gr_index_local + row_stride, cut_ptrs[fid]) - gr_index_local; + const bool is_left = (fid_local >= row_stride || gr_index_local[fid_local] >= cut_ptrs[fid + 1]) ? default_left : gr_index_local[fid_local] <= split_cond; + if (is_left) { + common::AtomicRef n_left(parts_size[0]); + p_rid_buf[n_left.fetch_add(1)] = id; + } else { + common::AtomicRef n_right(parts_size[1]); + p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id; + } + }); + }); + return event; +} + +template +template +sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( + const size_t nid, + const int32_t split_cond, + const GHistIndexMatrixOneAPI& gmat, + const RegTree::Node& node, + common::Span& rid_buf, + size_t* parts_size, + sycl::event priv_event) { + const bst_uint fid = node.SplitIndex(); + const bool default_left = node.DefaultLeft(); + + if (gmat.IsDense()) { + if (default_left) { + return PartitionDenseKernel(qu_, gmat, row_set_collection_[nid], fid, split_cond, rid_buf, parts_size, priv_event); + } else { + return PartitionDenseKernel(qu_, gmat, row_set_collection_[nid], fid, split_cond, rid_buf, parts_size, priv_event); + } + } else { + if (default_left) { + return PartitionSparseKernel(qu_, gmat, row_set_collection_[nid], fid, split_cond, rid_buf, parts_size, priv_event); + } else { + return PartitionSparseKernel(qu_, gmat, row_set_collection_[nid], fid, split_cond, rid_buf, parts_size, priv_event); + } + } +} + +template +void QuantileHistMakerOneAPIBackend::Builder::FindSplitConditions( + const std::vector& nodes, + const RegTree& tree, + const GHistIndexMatrixOneAPI& gmat, + std::vector* split_conditions) { + const size_t n_nodes = nodes.size(); + split_conditions->resize(n_nodes); + + for (size_t i = 0; i < nodes.size(); ++i) { + const int32_t nid = nodes[i].nid; + const bst_uint fid = tree[nid].SplitIndex(); + const bst_float split_pt = tree[nid].SplitCond(); + const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; + const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; + int32_t split_cond = -1; + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + CHECK_LT(upper_bound, + static_cast(std::numeric_limits::max())); + for (uint32_t i = lower_bound; i < upper_bound; ++i) { + if (split_pt == gmat.cut.Values()[i]) { + split_cond = static_cast(i); + } + } + (*split_conditions)[i] = split_cond; + } +} +template +void QuantileHistMakerOneAPIBackend::Builder::AddSplitsToRowSet(const std::vector& nodes, + RegTree* p_tree) { + const size_t n_nodes = nodes.size(); + for (size_t i = 0; i < n_nodes; ++i) { + const int32_t nid = nodes[i].nid; + const size_t n_left = partition_builder_.GetNLeftElems(i); + const size_t n_right = partition_builder_.GetNRightElems(i); + + row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), + (*p_tree)[nid].RightChild(), n_left, n_right); + } +} + +template +void QuantileHistMakerOneAPIBackend::Builder::ApplySplit(const std::vector nodes, + const GHistIndexMatrixOneAPI& gmat, + const HistCollectionOneAPI& hist, + RegTree* p_tree) { + builder_monitor_.Start("ApplySplit"); + + const size_t n_nodes = nodes.size(); + std::vector split_conditions; + FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); + + partition_builder_.Init(qu_, n_nodes, [&](size_t node_in_set) { + const int32_t nid = nodes[node_in_set].nid; + return row_set_collection_[nid].Size(); + }); + + // Add resize_and_fill method to save one call + auto event = parts_size_.ResizeAndFill(qu_, 2 * n_nodes, 0); + apply_split_events_.resize(n_nodes); + + for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { + const int32_t nid = nodes[node_in_set].nid; + sycl::event& cur_event = apply_split_events_[node_in_set]; + if (row_set_collection_[nid].Size() > 0) { + const RegTree::Node& node = (*p_tree)[nid]; + common::Span rid_buf = partition_builder_.GetData(node_in_set); + size_t* part_size = parts_size_.Data() + 2 * node_in_set; + int32_t split_condition = split_conditions[node_in_set]; + switch (gmat.index.GetBinTypeSize()) { + case common::kUint8BinsTypeSize: + cur_event = PartitionKernel(nid, split_condition, gmat, node, rid_buf, part_size, event); + break; + case common::kUint16BinsTypeSize: + cur_event = PartitionKernel(nid, split_condition, gmat, node, rid_buf, part_size, event); + break; + case common::kUint32BinsTypeSize: + cur_event = PartitionKernel(nid, split_condition, gmat, node, rid_buf, part_size, event); + break; + default: + CHECK(false); // no default behavior + } + } else { + cur_event = sycl::event(); + } + } + + sycl::event event_cpy = qu_.memcpy(partition_builder_.GetResultRowsPtr(), parts_size_.DataConst(), sizeof(size_t) * 2 * n_nodes, apply_split_events_); + qu_.wait_and_throw(); + merge_to_array_events_.resize(n_nodes); + for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { + sycl::event& cur_event = merge_to_array_events_[node_in_set]; + const int32_t nid = nodes[node_in_set].nid; + size_t* data_result = const_cast(row_set_collection_[nid].begin); + cur_event = partition_builder_.MergeToArray(qu_, node_in_set, data_result, event_cpy); + } + qu_.wait_and_throw(); + + AddSplitsToRowSet(nodes, p_tree); + + builder_monitor_.Stop("ApplySplit"); +} + +template +void QuantileHistMakerOneAPIBackend::Builder::InitNewNode(int nid, + const GHistIndexMatrixOneAPI& gmat, + const std::vector& gpair, + const DMatrix& fmat, + const RegTree& tree) { + builder_monitor_.Start("InitNewNode"); + { + snode_.Resize(qu_, tree.NumNodes(), NodeEntry(param_)); + } + + { + auto hist = hist_[nid]; + GradientPairT grad_stat; + if (tree[nid].IsRoot()) { + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + const std::vector& row_ptr = gmat.cut.Ptrs(); + const uint32_t ibegin = row_ptr[fid_least_bins_]; + const uint32_t iend = row_ptr[fid_least_bins_ + 1]; + xgboost::detail::GradientPairInternal* begin = + reinterpret_cast*>(hist.Data()); + + std::vector ets(iend - ibegin); + qu_.memcpy(ets.data(), begin + ibegin, (iend - ibegin) * sizeof(GradientPairT)).wait_and_throw(); + for (const auto& et : ets) { + grad_stat.Add(et.GetGrad(), et.GetHess()); + } + } else { + const RowSetCollectionOneAPI::Elem e = row_set_collection_[nid]; + // for (const size_t* it = e.begin; it < e.end; ++it) { + // grad_stat.Add(gpair[*it].GetGrad(), gpair[*it].GetHess()); + // } + std::vector row_idxs(e.Size()); + qu_.memcpy(row_idxs.data(), e.begin, sizeof(size_t) * e.Size()).wait(); + for (const size_t row_idx : row_idxs) { + grad_stat.Add(gpair[row_idx].GetGrad(), gpair[row_idx].GetHess()); + } + } + collective::Allreduce(reinterpret_cast(&grad_stat), 2); + // histred_.Allreduce(&grad_stat, 1); + snode_[nid].stats = GradStatsOneAPI(grad_stat.GetGrad(), grad_stat.GetHess()); + } else { + int parent_id = tree[nid].Parent(); + if (tree[nid].IsLeftChild()) { + snode_[nid].stats = snode_[parent_id].best.left_sum; + } else { + snode_[nid].stats = snode_[parent_id].best.right_sum; + } + } + } + + // calculating the weights + { + auto evaluator = tree_evaluator_.GetEvaluator(); + bst_uint parentid = tree[nid].Parent(); + snode_[nid].weight = static_cast( + evaluator.CalcWeight(parentid, GradStatsOneAPI{snode_[nid].stats})); + snode_[nid].root_gain = static_cast( + evaluator.CalcGain(parentid, GradStatsOneAPI{snode_[nid].stats})); + } + builder_monitor_.Stop("InitNewNode"); +} + +template struct QuantileHistMakerOneAPIBackend::Builder; +template struct QuantileHistMakerOneAPIBackend::Builder; +template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, + const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); +template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, + const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); +template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, + const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); +template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, + const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); +template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, + const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); +template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, + const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); + +XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMakerOneAPI, "grow_quantile_histmaker_oneapi") +.describe("Grow tree using quantized histogram with dpc++.") +.set_body( + [](Context const* ctx, ObjInfo const * task) { + return new QuantileHistMakerOneAPI(ctx, task); + }); + +XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMakerOneAPIBackend, "grow_quantile_histmaker_oneapi_backend") +.describe("Grow tree using quantized histogram with dpc++ on GPU.") +.set_body( + [](Context const* ctx, ObjInfo const * task) { + return new QuantileHistMakerOneAPIBackend(ctx, task); + }); +} // namespace tree +} // namespace xgboost diff --git a/plugin/updater_oneapi/updater_quantile_hist_oneapi.h b/plugin/updater_oneapi/updater_quantile_hist_oneapi.h new file mode 100644 index 000000000000..70e13c8c9e98 --- /dev/null +++ b/plugin/updater_oneapi/updater_quantile_hist_oneapi.h @@ -0,0 +1,611 @@ +/*! + * Copyright 2017-2021 by Contributors + * \file updater_quantile_hist_oneapi.h + */ +#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_ONEAPI_H_ +#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_ONEAPI_H_ + +#include +#include +#include + +#include + +#include "hist_util_oneapi.h" +#include "row_set_oneapi.h" +#include "split_evaluator_oneapi.h" +#include "device_manager_oneapi.h" + +#include "xgboost/data.h" +#include "xgboost/json.h" +#include "../../src/tree/constraints.h" +#include "../../src/common/random.h" + +namespace xgboost { + +/*! + * \brief A C-style array with in-stack allocation. + As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated. + Temporary copy of implementation to remove dependency on updater_quantile_hist.h + */ +template +class MemStackAllocatorOneAPI { + public: + explicit MemStackAllocatorOneAPI(size_t required_size): required_size_(required_size) { + } + + T* Get() { + if (!ptr_) { + if (MaxStackSize >= required_size_) { + ptr_ = stack_mem_; + } else { + ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); + do_free_ = true; + } + } + + return ptr_; + } + + ~MemStackAllocatorOneAPI() { + if (do_free_) free(ptr_); + } + + + private: + T* ptr_ = nullptr; + bool do_free_ = false; + size_t required_size_; + T stack_mem_[MaxStackSize]; +}; + +namespace tree { + +using xgboost::common::HistCollectionOneAPI; +using xgboost::common::GHistBuilderOneAPI; +using xgboost::common::GHistIndexMatrixOneAPI; +using xgboost::common::GHistRowOneAPI; +using xgboost::common::RowSetCollectionOneAPI; + +template +class HistSynchronizerOneAPI; + +template +class BatchHistSynchronizerOneAPI; + +template +class DistributedHistSynchronizerOneAPI; + +template +class HistRowsAdderOneAPI; + +template +class BatchHistRowsAdderOneAPI; + +template +class DistributedHistRowsAdderOneAPI; + +// training parameters specific to this algorithm +struct OneAPIHistMakerTrainParam + : public XGBoostParameter { + bool single_precision_histogram = false; + // declare parameters + DMLC_DECLARE_PARAMETER(OneAPIHistMakerTrainParam) { + DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( + "Use single precision to build histograms."); + } +}; + +/*! \brief construct a tree using quantized feature values with DPC++ interface */ +class QuantileHistMakerOneAPI: public TreeUpdater { + public: + explicit QuantileHistMakerOneAPI(Context const* ctx, ObjInfo const * task) : TreeUpdater(ctx), ctx_(ctx), task_{task} {} + void Configure(const Args& args) override; + + void Update(TrainParam const *param, + HostDeviceVector* gpair, + DMatrix* dmat, + common::Span> out_position, + const std::vector& trees) override; + + bool UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) override; + + void LoadConfig(Json const& in) override { + if (updater_backend_) { + updater_backend_->LoadConfig(in); + } else { + auto const& config = get(in); + FromJson(config.at("train_param"), &this->param_); + } + } + + void SaveConfig(Json* p_out) const override { + if (updater_backend_) { + updater_backend_->SaveConfig(p_out); + } else { + auto& out = *p_out; + out["train_param"] = ToJson(param_); + } + } + + char const* Name() const override { + if (updater_backend_) { + return updater_backend_->Name(); + } else { + return "grow_quantile_histmaker_oneapi"; + } + } + + protected: + // training parameter + TrainParam param_; + + DeviceManagerOneAPI device_manager; + + ObjInfo const *task_{nullptr}; + Context const* ctx_; + std::unique_ptr updater_backend_; +}; + +// data structure +template +struct NodeEntry { + /*! \brief statics for node entry */ + GradStatsOneAPI stats; + /*! \brief loss of this node, without split */ + GradType root_gain; + /*! \brief weight calculated related to current data */ + GradType weight; + /*! \brief current best solution */ + SplitEntryOneAPI best; + // constructor + explicit NodeEntry(const TrainParam& param) + : root_gain(0.0f), weight(0.0f) {} +}; +// actual builder that runs the algorithm + +/*! \brief construct a tree using quantized feature values with DPC++ backend on GPU*/ +class QuantileHistMakerOneAPIBackend: public TreeUpdater { + public: + explicit QuantileHistMakerOneAPIBackend(Context const* ctx, ObjInfo const * task) : TreeUpdater(ctx), ctx_(ctx), task_{task} { + updater_monitor_.Init("QuantileHistMakerOneAPIBackend"); + } + void Configure(const Args& args) override; + + void Update(TrainParam const *param, + HostDeviceVector* gpair, + DMatrix* dmat, + common::Span> out_position, + const std::vector& trees) override; + + bool UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) override; + + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("train_param"), &this->param_); + try { + FromJson(config.at("oneapi_hist_train_param"), &this->hist_maker_param_); + } catch (std::out_of_range& e) { + // XGBoost model is from 1.1.x, so 'cpu_hist_train_param' is missing. + // We add this compatibility check because it's just recently that we (developers) began + // persuade R users away from using saveRDS() for model serialization. Hopefully, one day, + // everyone will be using xgb.save(). + LOG(WARNING) << "Attempted to load interal configuration for a model file that was generated " + << "by a previous version of XGBoost. A likely cause for this warning is that the model " + << "was saved with saveRDS() in R or pickle.dump() in Python. We strongly ADVISE AGAINST " + << "using saveRDS() or pickle.dump() so that the model remains accessible in current and " + << "upcoming XGBoost releases. Please use xgb.save() instead to preserve models for the " + << "long term. For more details and explanation, see " + << "https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html"; + this->hist_maker_param_.UpdateAllowUnknown(Args{}); + } + } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["train_param"] = ToJson(param_); + out["oneapi_hist_train_param"] = ToJson(hist_maker_param_); + } + + char const* Name() const override { + return "grow_quantile_histmaker_oneapi_backend"; + } + + protected: + template + friend class HistSynchronizerOneAPI; + template + friend class BatchHistSynchronizerOneAPI; + template + friend class DistributedHistSynchronizerOneAPI; + + template + friend class HistRowsAdderOneAPI; + template + friend class BatchHistRowsAdderOneAPI; + template + friend class DistributedHistRowsAdderOneAPI; + + OneAPIHistMakerTrainParam hist_maker_param_; + // training parameter + TrainParam param_; + // quantized data matrix + GHistIndexMatrixOneAPI gmat_; + // (optional) data matrix with feature grouping + // column accessor + DMatrix const* p_last_dmat_ {nullptr}; + bool is_gmat_initialized_ {false}; + + template + struct Builder { + public: + template + using GHistRowT = GHistRowOneAPI; + using GradientPairT = xgboost::detail::GradientPairInternal; + // constructor + explicit Builder(sycl::queue qu, + const TrainParam& param, + std::unique_ptr pruner, + FeatureInteractionConstraintHost int_constraints_, + DMatrix const* fmat) + : qu_(qu), param_(param), + tree_evaluator_(qu, param, fmat->Info().num_col_), + pruner_(std::move(pruner)), + interaction_constraints_{std::move(int_constraints_)}, + p_last_tree_(nullptr), p_last_fmat_(fmat), + snode_(qu, 1u << (param.max_depth + 1), NodeEntry(param)) { + builder_monitor_.Init("QuantileOneAPI::Builder"); + kernel_monitor_.Init("QuantileOneAPI::Kernels"); + } + // update one tree, growing + void Update(Context const * ctx, + TrainParam const *param, + const GHistIndexMatrixOneAPI &gmat, + HostDeviceVector *gpair, + const USMVector& gpair_device, + DMatrix *p_fmat, + common::Span> out_position, + RegTree *p_tree); + + inline sycl::event BuildHist(const USMVector& gpair_device, + const RowSetCollectionOneAPI::Elem row_indices, + const GHistIndexMatrixOneAPI& gmat, + GHistRowT& hist, + GHistRowT& hist_buffer, + sycl::event event_priv) { + return hist_builder_.BuildHist(gpair_device, row_indices, gmat, hist, data_layout_ != kSparseData, hist_buffer, event_priv); + } + + inline void SubtractionTrick(GHistRowT& self, + GHistRowT& sibling, + GHistRowT& parent) { + builder_monitor_.Start("SubtractionTrick"); + hist_builder_.SubtractionTrick(self, sibling, parent); + builder_monitor_.Stop("SubtractionTrick"); + } + + bool UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView p_out_preds); + void SetHistSynchronizer(HistSynchronizerOneAPI* sync); + void SetHistRowsAdder(HistRowsAdderOneAPI* adder); + + // initialize temp data structure + void InitData(Context const * ctx, + const GHistIndexMatrixOneAPI& gmat, + const std::vector& gpair, + const USMVector &gpair_device, + const DMatrix& fmat, + const RegTree& tree); + + protected: + friend class HistSynchronizerOneAPI; + friend class BatchHistSynchronizerOneAPI; + friend class DistributedHistSynchronizerOneAPI; + friend class HistRowsAdderOneAPI; + friend class BatchHistRowsAdderOneAPI; + friend class DistributedHistRowsAdderOneAPI; + + /* tree growing policies */ + struct ExpandEntry { + static const int kRootNid = 0; + static const int kEmptyNid = -1; + int nid; + int sibling_nid; + int depth; + bst_float loss_chg; + unsigned timestamp; + ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg, + unsigned tstmp) + : nid(nid), sibling_nid(sibling_nid), depth(depth), + loss_chg(loss_chg), timestamp(tstmp) {} + + bool IsValid(TrainParam const ¶m, int32_t num_leaves) const { + bool ret = loss_chg <= kRtEps || + (param.max_depth > 0 && this->depth == param.max_depth) || + (param.max_leaves > 0 && num_leaves == param.max_leaves); + return ret; + } + }; + + struct SplitQuery { + int nid; + int fid; + SplitEntryOneAPI best; + const GradientPairT* hist; + }; + + void InitSampling(const std::vector& gpair, + const USMVector &gpair_device, + const DMatrix& fmat, USMVector& row_indices); + + void EvaluateSplits(const std::vector& nodes_set, + const GHistIndexMatrixOneAPI& gmat, + const HistCollectionOneAPI& hist, + const RegTree& tree); + + // Enumerate the split values of specific feature + // Returns the sum of gradients corresponding to the data points that contains a non-missing + // value for the particular feature fid. + template + static GradStatsOneAPI EnumerateSplit( + const uint32_t* cut_ptr,const bst_float* cut_val, const bst_float* cut_minval, const GradientPairT* hist_data, + const NodeEntry &snode, SplitEntryOneAPI& p_best, bst_uint fid, + bst_uint nodeID, + typename TreeEvaluatorOneAPI::SplitEvaluator const &evaluator, const TrainParamOneAPI& param); + + static GradStatsOneAPI EnumerateSplit(sycl::sub_group& sg, + const uint32_t* cut_ptr, const bst_float* cut_val, const GradientPairT* hist_data, + const NodeEntry &snode, SplitEntryOneAPI& p_best, bst_uint fid, + bst_uint nodeID, + typename TreeEvaluatorOneAPI::SplitEvaluator const &evaluator, const TrainParamOneAPI& param); + + void ApplySplit(std::vector nodes, + const GHistIndexMatrixOneAPI& gmat, + const HistCollectionOneAPI& hist, + RegTree* p_tree); + + template + sycl::event PartitionKernel(const size_t nid, + const int32_t split_cond, + const GHistIndexMatrixOneAPI &gmat, + const RegTree::Node& node, + common::Span& rid_buf, + size_t* parts_size, + sycl::event priv_event); + + void AddSplitsToRowSet(const std::vector& nodes, RegTree* p_tree); + + + void FindSplitConditions(const std::vector& nodes, const RegTree& tree, + const GHistIndexMatrixOneAPI& gmat, std::vector* split_conditions); + + void InitNewNode(int nid, + const GHistIndexMatrixOneAPI& gmat, + const std::vector& gpair, + const DMatrix& fmat, + const RegTree& tree); + + // if sum of statistics for non-missing values in the node + // is equal to sum of statistics for all values: + // then - there are no missing values + // else - there are missing values + static bool SplitContainsMissingValues(const GradStatsOneAPI& e, const NodeEntry& snode); + + void ExpandWithDepthWise(const GHistIndexMatrixOneAPI &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair, + const USMVector &gpair_device); + + void BuildLocalHistograms(const GHistIndexMatrixOneAPI &gmat, + RegTree *p_tree, + const USMVector &gpair_device); + + void BuildHistogramsLossGuide( + ExpandEntry entry, + const GHistIndexMatrixOneAPI &gmat, + RegTree *p_tree, + const USMVector &gpair_device); + + // Split nodes to 2 sets depending on amount of rows in each node + // Histograms for small nodes will be built explicitly + // Histograms for big nodes will be built by 'Subtraction Trick' + void SplitSiblings(const std::vector& nodes, + std::vector* small_siblings, + std::vector* big_siblings, + RegTree *p_tree); + + void ParallelSubtractionHist(const common::BlockedSpace2d& space, + const std::vector& nodes, + const RegTree * p_tree); + + void BuildNodeStats(const GHistIndexMatrixOneAPI &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair); + + void EvaluateAndApplySplits(const GHistIndexMatrixOneAPI &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector *temp_qexpand_depth); + + void AddSplitsToTree( + const GHistIndexMatrixOneAPI &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector* nodes_for_apply_split, + std::vector* temp_qexpand_depth); + + void ExpandWithLossGuide(const GHistIndexMatrixOneAPI& gmat, + DMatrix* p_fmat, + RegTree* p_tree, + const std::vector &gpair, + const USMVector& gpair_device); + + void ReduceHists(std::vector& sync_ids, size_t nbins); + + inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) { + if (lhs.loss_chg == rhs.loss_chg) { + return lhs.timestamp > rhs.timestamp; // favor small timestamp + } else { + return lhs.loss_chg < rhs.loss_chg; // favor large loss_chg + } + } + // --data fields-- + const TrainParam& param_; + // number of omp thread used during training + int nthread_; + common::ColumnSampler column_sampler_; + // the internal row sets + RowSetCollectionOneAPI row_set_collection_; + USMVector split_queries_device_; + /*! \brief TreeNode Data: statistics for each constructed node */ + USMVector> snode_; + /*! \brief culmulative histogram of gradients. */ + HistCollectionOneAPI hist_; + /*! \brief culmulative local parent histogram of gradients. */ + HistCollectionOneAPI hist_local_worker_; + TreeEvaluatorOneAPI tree_evaluator_; + /*! \brief feature with least # of bins. to be used for dense specialization + of InitNewNode() */ + uint32_t fid_least_bins_; + + GHistBuilderOneAPI hist_builder_; + std::unique_ptr pruner_; + FeatureInteractionConstraintHost interaction_constraints_; + + common::PartitionBuilderOneAPI partition_builder_; + + // back pointers to tree and data matrix + const RegTree* p_last_tree_; + DMatrix const* const p_last_fmat_; + + using ExpandQueue = + std::priority_queue, + std::function>; + + std::unique_ptr qexpand_loss_guided_; + std::vector qexpand_depth_wise_; + // key is the node id which should be calculated by Subtraction Trick, value is the node which + // provides the evidence for substracts + std::vector nodes_for_subtraction_trick_; + // list of nodes whose histograms would be built explicitly. + std::vector nodes_for_explicit_hist_build_; + + enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; + DataLayout data_layout_; + + common::Monitor builder_monitor_; + common::Monitor kernel_monitor_; + constexpr static size_t kNumParallelBuffers = 1; + std::array, kNumParallelBuffers> hist_buffers_; + std::array hist_build_events_; + USMVector parts_size_; + std::vector parts_size_cpu_; + std::vector apply_split_events_; + std::vector merge_to_array_events_; + // rabit::op::Reducer histred_; + std::unique_ptr> hist_synchronizer_; + std::unique_ptr> hist_rows_adder_; + + sycl::queue qu_; + }; + common::Monitor updater_monitor_; + + template + void SetBuilder(std::unique_ptr>*, DMatrix *dmat); + + template + void CallBuilderUpdate(const std::unique_ptr>& builder, + TrainParam const *param, + HostDeviceVector *gpair, + DMatrix *dmat, + common::Span> out_position, + const std::vector &trees); + + protected: + std::unique_ptr> float_builder_; + std::unique_ptr> double_builder_; + + std::unique_ptr pruner_; + FeatureInteractionConstraintHost int_constraint_; + + sycl::queue qu_; + DeviceManagerOneAPI device_manager; + Context const* ctx_; + ObjInfo const *task_{nullptr}; +}; + +template +class HistSynchronizerOneAPI { + public: + using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + + virtual void SyncHistograms(BuilderT* builder, + std::vector& sync_ids, + RegTree *p_tree) = 0; + virtual ~HistSynchronizerOneAPI() = default; +}; + +template +class BatchHistSynchronizerOneAPI: public HistSynchronizerOneAPI { + public: + using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + void SyncHistograms(BuilderT* builder, + std::vector& sync_ids, + RegTree *p_tree) override; + + std::vector GetEvents() const { + return hist_sync_events_; + } + + private: + std::vector hist_sync_events_; +}; + +template +class DistributedHistSynchronizerOneAPI: public HistSynchronizerOneAPI { + public: + using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + using ExpandEntryT = typename BuilderT::ExpandEntry; + + void SyncHistograms(BuilderT* builder, std::vector& sync_ids, RegTree *p_tree) override; + + void ParallelSubtractionHist(BuilderT* builder, + const std::vector& nodes, + const RegTree * p_tree); +}; + +template +class HistRowsAdderOneAPI { + public: + using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + + virtual void AddHistRows(BuilderT* builder, std::vector& sync_ids, RegTree *p_tree) = 0; + virtual ~HistRowsAdderOneAPI() = default; +}; + +template +class BatchHistRowsAdderOneAPI: public HistRowsAdderOneAPI { + public: + using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + void AddHistRows(BuilderT*, std::vector& sync_ids, RegTree *p_tree) override; +}; + +template +class DistributedHistRowsAdderOneAPI: public HistRowsAdderOneAPI { + public: + using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + void AddHistRows(BuilderT*, std::vector& sync_ids, RegTree *p_tree) override; +}; + + +} // namespace tree +} // namespace xgboost + +#endif // XGBOOST_TREE_UPDATER_QUANTILE_HIST_ONEAPI_H_ diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index e791be51c0e4..6cc3822d760c 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -949,7 +949,12 @@ def _duplicated(parameter: str) -> None: def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix: # Use `QuantileDMatrix` to save memory. - if _can_use_qdm(self.tree_method) and self.booster != "gblinear": + is_sycl = self.device is not None and self.device.startswith("sycl") + if ( + _can_use_qdm(self.tree_method) + and self.booster != "gblinear" + and not is_sycl + ): try: return QuantileDMatrix( **kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4624c643c48c..f6ffc795f2ca 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,6 +16,10 @@ if (USE_CUDA) target_sources(objxgboost PRIVATE ${CUDA_SOURCES}) endif (USE_CUDA) +if (PLUGIN_UPDATER_ONEAPI) + target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_ONEAPI=1) +endif (PLUGIN_UPDATER_ONEAPI) + target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/include diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index 037ad1ff3059..ce9035ba0220 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -42,7 +42,7 @@ void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_ template void ElementWiseKernel(Context const* ctx, linalg::TensorView t, Fn&& fn) { - ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn); + ctx->IsCUDA() ? ElementWiseKernelDevice(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn); } } // namespace linalg } // namespace xgboost diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index f55927402d31..d89e5a736b6e 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -55,7 +55,7 @@ void ElementWiseTransformDevice(linalg::TensorView, Fn&&, void* = nullptr) template void ElementWiseKernel(Context const* ctx, linalg::TensorView t, Fn&& fn) { - if (!ctx->IsCPU()) { + if (ctx->IsCUDA()) { common::AssertGPUSupport(); } ElementWiseKernelHost(t, ctx->Threads(), fn); diff --git a/src/common/numeric.cc b/src/common/numeric.cc index 240e0234ad25..550c533e0e98 100644 --- a/src/common/numeric.cc +++ b/src/common/numeric.cc @@ -11,13 +11,14 @@ namespace xgboost { namespace common { double Reduce(Context const* ctx, HostDeviceVector const& values) { - if (ctx->IsCPU()) { + if (ctx->IsCUDA()) { + return cuda_impl::Reduce(ctx, values); + } else { auto const& h_values = values.ConstHostVector(); auto result = cpu_impl::Reduce(ctx, h_values.cbegin(), h_values.cend(), 0.0); static_assert(std::is_same::value); return result; } - return cuda_impl::Reduce(ctx, values); } } // namespace common } // namespace xgboost diff --git a/src/common/optional_weight.h b/src/common/optional_weight.h index c2844d73f893..40c10a27e4a8 100644 --- a/src/common/optional_weight.h +++ b/src/common/optional_weight.h @@ -26,7 +26,7 @@ inline OptionalWeights MakeOptionalWeights(Context const* ctx, if (ctx->IsCUDA()) { weights.SetDevice(ctx->gpu_id); } - return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()}; + return OptionalWeights{ctx->IsCUDA() ? weights.ConstDeviceSpan() : weights.ConstHostSpan()}; } } // namespace xgboost::common #endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_ diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 75622bd84b60..8ea19e8694d3 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -197,10 +197,10 @@ class RankingCache { CHECK_EQ(info.group_ptr_.back(), info.labels.Size()) << error::GroupSize() << "the size of label."; } - if (ctx->IsCPU()) { - this->InitOnCPU(ctx, info); - } else { + if (ctx->IsCUDA()) { this->InitOnCUDA(ctx, info); + } else { + this->InitOnCPU(ctx, info); } if (!info.weights_.Empty()) { CHECK_EQ(Groups(), info.weights_.Size()) << error::GroupWeight(); @@ -218,7 +218,7 @@ class RankingCache { // Constructed as [1, n_samples] if group ptr is not supplied by the user common::Span DataGroupPtr(Context const* ctx) const { group_ptr_.SetDevice(ctx->gpu_id); - return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan(); + return ctx->IsCUDA() ? group_ptr_.ConstDeviceSpan() : group_ptr_.ConstHostSpan(); } [[nodiscard]] auto const& Param() const { return param_; } @@ -231,10 +231,10 @@ class RankingCache { sorted_idx_cache_.SetDevice(ctx->gpu_id); sorted_idx_cache_.Resize(predt.size()); } - if (ctx->IsCPU()) { - return this->MakeRankOnCPU(ctx, predt); - } else { + if (ctx->IsCUDA()) { return this->MakeRankOnCUDA(ctx, predt); + } else { + return this->MakeRankOnCPU(ctx, predt); } } // The function simply returns a uninitialized buffer as this is only used by the @@ -307,10 +307,10 @@ class NDCGCache : public RankingCache { public: NDCGCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) : RankingCache{ctx, info, p} { - if (ctx->IsCPU()) { - this->InitOnCPU(ctx, info); - } else { + if (ctx->IsCUDA()) { this->InitOnCUDA(ctx, info); + } else { + this->InitOnCPU(ctx, info); } } @@ -318,7 +318,7 @@ class NDCGCache : public RankingCache { return inv_idcg_.View(ctx->gpu_id); } common::Span Discount(Context const* ctx) const { - return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan(); + return ctx->IsCUDA() ? discounts_.ConstDeviceSpan() : discounts_.ConstHostSpan(); } linalg::VectorView Dcg(Context const* ctx) { if (dcg_.Size() == 0) { @@ -387,10 +387,10 @@ class PreCache : public RankingCache { public: PreCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) : RankingCache{ctx, info, p} { - if (ctx->IsCPU()) { - this->InitOnCPU(ctx, info); - } else { + if (ctx->IsCUDA()) { this->InitOnCUDA(ctx, info); + } else { + this->InitOnCPU(ctx, info); } } @@ -399,7 +399,7 @@ class PreCache : public RankingCache { pre_.SetDevice(ctx->gpu_id); pre_.Resize(this->Groups()); } - return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan(); + return ctx->IsCUDA() ? pre_.DeviceSpan() : pre_.HostSpan(); } }; @@ -418,10 +418,10 @@ class MAPCache : public RankingCache { public: MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) : RankingCache{ctx, info, p}, n_samples_{static_cast(info.num_row_)} { - if (ctx->IsCPU()) { - this->InitOnCPU(ctx, info); - } else { + if (ctx->IsCUDA()) { this->InitOnCUDA(ctx, info); + } else { + this->InitOnCPU(ctx, info); } } @@ -430,21 +430,21 @@ class MAPCache : public RankingCache { n_rel_.SetDevice(ctx->gpu_id); n_rel_.Resize(n_samples_); } - return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan(); + return ctx->IsCUDA() ? n_rel_.DeviceSpan() : n_rel_.HostSpan(); } common::Span Acc(Context const* ctx) { if (acc_.Empty()) { acc_.SetDevice(ctx->gpu_id); acc_.Resize(n_samples_); } - return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan(); + return ctx->IsCUDA() ? acc_.DeviceSpan() : acc_.HostSpan(); } common::Span Map(Context const* ctx) { if (map_.Empty()) { map_.SetDevice(ctx->gpu_id); map_.Resize(this->Groups()); } - return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan(); + return ctx->IsCUDA() ? map_.DeviceSpan() : map_.HostSpan(); } }; diff --git a/src/common/stats.cc b/src/common/stats.cc index 80fc2c50d5a5..c34ab2c0f238 100644 --- a/src/common/stats.cc +++ b/src/common/stats.cc @@ -19,7 +19,7 @@ namespace xgboost { namespace common { void Median(Context const* ctx, linalg::Tensor const& t, HostDeviceVector const& weights, linalg::Tensor* out) { - if (!ctx->IsCPU()) { + if (ctx->IsCUDA()) { weights.SetDevice(ctx->gpu_id); auto opt_weights = OptionalWeights(weights.ConstDeviceSpan()); auto t_v = t.View(ctx->gpu_id); @@ -50,7 +50,9 @@ void Mean(Context const* ctx, linalg::Vector const& v, linalg::VectorSetDevice(ctx->gpu_id); out->Reshape(1); - if (ctx->IsCPU()) { + if (ctx->IsCUDA()) { + cuda_impl::Mean(ctx, v.View(ctx->gpu_id), out->View(ctx->gpu_id)); + } else { auto h_v = v.HostView(); float n = v.Size(); MemStackAllocator tloc(ctx->Threads(), 0.0f); @@ -58,8 +60,6 @@ void Mean(Context const* ctx, linalg::Vector const& v, linalg::VectorHostView()(0) = ret; - } else { - cuda_impl::Mean(ctx, v.View(ctx->gpu_id), out->View(ctx->gpu_id)); } } } // namespace common diff --git a/src/context.cc b/src/context.cc index 1acaa6443da1..ff8fcb9b084b 100644 --- a/src/context.cc +++ b/src/context.cc @@ -95,9 +95,13 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { StringView msg{R"(Invalid argument for `device`. Expected to be one of the following: - cpu - cuda -- cuda: # e.g. cuda:0 +- cuda: # e.g. cuda:0 - gpu -- gpu: # e.g. gpu:0 +- gpu: # e.g. gpu:0 +- sycl +- sycl: # e.g. sycl:0 +- sycl: +- sycl:: # e.g. sycl:gpu:0 )"}; auto fatal = [&] { LOG(FATAL) << msg << "Got: `" << input << "`."; }; @@ -105,19 +109,32 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { // mingw hangs on regex using rtools 430. Basic checks only. CHECK_GE(input.size(), 3) << msg; auto substr = input.substr(0, 3); - bool valid = substr == "cpu" || substr == "cud" || substr == "gpu"; + bool valid = substr == "cpu" || substr == "cud" || substr == "gpu" || substr == "syc"; CHECK(valid) << msg; #else - std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu"}; + std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu|sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"}; if (!std::regex_match(input, pattern)) { + LOG(FATAL) << "device doesn't match regex pattern"; fatal(); } #endif // defined(__MINGW32__) // handle alias - std::string s_device = std::regex_replace(input, std::regex{"gpu"}, DeviceSym::CUDA()); + std::string s_device = input; + if (!std::regex_match(s_device, std::regex("sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"))) + s_device = std::regex_replace(s_device, std::regex{"gpu"}, DeviceSym::CUDA()); auto split_it = std::find(s_device.cbegin(), s_device.cend(), ':'); + if (std::regex_match(s_device, std::regex("sycl:(cpu|gpu)?"))) split_it = s_device.cend(); + + // For s_device like "sycl:gpu:1" + if (split_it != s_device.cend()) { + auto second_split_it = std::find(split_it + 1, s_device.cend(), ':'); + if (second_split_it != s_device.cend()) { + split_it = second_split_it; + } + } + DeviceOrd device; device.ordinal = Context::InvalidOrdinal(); // mark it invalid for check. if (split_it == s_device.cend()) { @@ -126,31 +143,52 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { device = DeviceOrd::CPU(); } else if (s_device == DeviceSym::CUDA()) { device = DeviceOrd::CUDA(0); // use 0 as default; + } else if (s_device == DeviceSym::SYCL_default()) { + device = DeviceOrd::SYCL_default(); + } else if (s_device == DeviceSym::SYCL_CPU()) { + device = DeviceOrd::SYCL_CPU(); + } else if (s_device == DeviceSym::SYCL_GPU()) { + device = DeviceOrd::SYCL_GPU(); } else { + LOG(FATAL) << "device doesn't match switch statement"; fatal(); } } else { - // must be CUDA when ordinal is specifed. + // must be CUDA or SYCL when ordinal is specifed. // +1 for colon std::size_t offset = std::distance(s_device.cbegin(), split_it) + 1; // substr StringView s_ordinal = {s_device.data() + offset, s_device.size() - offset}; + StringView s_type = {s_device.data(), offset - 1}; if (s_ordinal.empty()) { + LOG(FATAL) << "Problem" << s_device.cend() - split_it; fatal(); } auto opt_id = ParseInt(s_ordinal); if (!opt_id.has_value()) { + LOG(FATAL) << "Problem2" << s_device.cend() - split_it; fatal(); } CHECK_LE(opt_id.value(), std::numeric_limits::max()) << "Ordinal value too large."; - device = DeviceOrd::CUDA(opt_id.value()); + if (s_type == DeviceSym::SYCL_default()) { + device = DeviceOrd::SYCL_default(opt_id.value()); + } else if (s_type == DeviceSym::SYCL_CPU()) { + device = DeviceOrd::SYCL_CPU(opt_id.value()); + } else if (s_type == DeviceSym::SYCL_GPU()) { + device = DeviceOrd::SYCL_GPU(opt_id.value()); + } else { + device = DeviceOrd::CUDA(opt_id.value()); + } } if (device.ordinal < Context::kCpuId) { + LOG(FATAL) << "Wrong device.ordinal" << device.ordinal; fatal(); } - device = CUDAOrdinal(device, fail_on_invalid_gpu_id); + if (device.IsCUDA()) { + device = CUDAOrdinal(device, fail_on_invalid_gpu_id); + } return device; } @@ -195,7 +233,7 @@ void Context::SetDeviceOrdinal(Args const& kwargs) { if (this->IsCPU()) { CHECK_EQ(this->device_.ordinal, kCpuId); - } else { + } else if (this->IsCUDA()) { CHECK_GT(this->device_.ordinal, kCpuId); } } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 53ff118439b1..cd1567dc0d7f 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -52,7 +52,8 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method) case TreeMethod::kAuto: // Use hist as default in 2.0 case TreeMethod::kHist: { return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; }, - [] { return "grow_gpu_hist"; }); + [] { return "grow_gpu_hist"; }, + [] { return "grow_quantile_histmaker_oneapi"; }); } case TreeMethod::kApprox: { return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; }); @@ -115,10 +116,10 @@ void GBTree::Configure(Args const& cfg) { #if defined(XGBOOST_USE_ONEAPI) if (!oneapi_predictor_) { - oneapi_predictor_ = - std::unique_ptr(Predictor::Create("oneapi_predictor", this->ctx_)); + oneapi_predictor_ = + std::unique_ptr(Predictor::Create("oneapi_predictor", this->ctx_)); } - oneapi_predictor_->Configure(cfg); + oneapi_predictor_->Configure(cfg); #endif // defined(XGBOOST_USE_ONEAPI) // `updater` parameter was manually specified @@ -565,11 +566,18 @@ void GBTree::InplacePredict(std::shared_ptr p_m, float missing, if (f_dmat && !f_dmat->SingleColBlock()) { if (ctx_->IsCPU()) { return cpu_predictor_; - } else { + } else if (ctx_->IsCUDA()) { common::AssertGPUSupport(); CHECK(gpu_predictor_); return gpu_predictor_; + } else { +#if defined(XGBOOST_USE_ONEAPI) + common::AssertOneAPISupport(); + CHECK(oneapi_predictor_); + return oneapi_predictor_; +#endif // defined(XGBOOST_USE_ONEAPI) } + } // Data comes from Device DMatrix. @@ -603,10 +611,16 @@ void GBTree::InplacePredict(std::shared_ptr p_m, float missing, if (ctx_->IsCPU()) { return cpu_predictor_; - } else { + } else if (ctx_->IsCUDA()) { common::AssertGPUSupport(); CHECK(gpu_predictor_); return gpu_predictor_; + } else { +#if defined(XGBOOST_USE_ONEAPI) + common::AssertOneAPISupport(); + CHECK(oneapi_predictor_); + return oneapi_predictor_; +#endif // defined(XGBOOST_USE_ONEAPI) } return cpu_predictor_; diff --git a/src/learner.cc b/src/learner.cc index 51f86aa67013..838c4d01f1b8 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -278,7 +278,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy std::swap(base_score_, base_margin); // Make sure read access everywhere for thread-safe prediction. std::as_const(base_score_).HostView(); - if (!ctx->IsCPU()) { + if (ctx->IsCUDA()) { std::as_const(base_score_).View(ctx->gpu_id); } CHECK(std::as_const(base_score_).Data()->HostCanRead()); @@ -776,7 +776,8 @@ class LearnerConfiguration : public Learner { void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) { // Once binary IO is gone, NONE of these config is useful. if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0" && - tparam_.objective != "multi:softprob") { + (tparam_.objective != "multi:softprob") && + (tparam_.objective != "multi:softprob_oneapi")) { cfg_["num_output_group"] = cfg_["num_class"]; if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) { tparam_.objective = "multi:softmax"; diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index b6888610b586..c0afeca80b3d 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -46,7 +46,26 @@ template PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) { PackedReduceResult result; auto labels = info.labels.View(ctx->gpu_id); - if (ctx->IsCPU()) { + if (ctx->IsCUDA()) { + #if defined(XGBOOST_USE_CUDA) + dh::XGBCachingDeviceAllocator alloc; + thrust::counting_iterator begin(0); + thrust::counting_iterator end = begin + labels.Size(); + result = thrust::transform_reduce( + thrust::cuda::par(alloc), begin, end, + [=] XGBOOST_DEVICE(size_t i) { + auto idx = linalg::UnravelIndex(i, labels.Shape()); + auto sample_id = std::get<0>(idx); + auto target_id = std::get<1>(idx); + auto res = loss(i, sample_id, target_id); + float v{std::get<0>(res)}, wt{std::get<1>(res)}; + return PackedReduceResult{v, wt}; + }, + PackedReduceResult{}, thrust::plus()); + #else + common::AssertGPUSupport(); + #endif // defined(XGBOOST_USE_CUDA) + } else { auto n_threads = ctx->Threads(); std::vector score_tloc(n_threads, 0.0); std::vector weight_tloc(n_threads, 0.0); @@ -69,25 +88,6 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) { double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); result = PackedReduceResult{residue_sum, weights_sum}; - } else { -#if defined(XGBOOST_USE_CUDA) - dh::XGBCachingDeviceAllocator alloc; - thrust::counting_iterator begin(0); - thrust::counting_iterator end = begin + labels.Size(); - result = thrust::transform_reduce( - thrust::cuda::par(alloc), begin, end, - [=] XGBOOST_DEVICE(size_t i) { - auto idx = linalg::UnravelIndex(i, labels.Shape()); - auto sample_id = std::get<0>(idx); - auto target_id = std::get<1>(idx); - auto res = loss(i, sample_id, target_id); - float v{std::get<0>(res)}, wt{std::get<1>(res)}; - return PackedReduceResult{v, wt}; - }, - PackedReduceResult{}, thrust::plus()); -#else - common::AssertGPUSupport(); -#endif // defined(XGBOOST_USE_CUDA) } return result; } @@ -185,10 +185,10 @@ class PseudoErrorLoss : public MetricNoCache { CHECK_EQ(info.labels.Shape(0), info.num_row_); auto labels = info.labels.View(ctx_->gpu_id); preds.SetDevice(ctx_->gpu_id); - auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); + auto predts = ctx_->IsCUDA() ? preds.ConstDeviceSpan() : preds.ConstHostSpan(); info.weights_.SetDevice(ctx_->gpu_id); - common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan() - : info.weights_.ConstDeviceSpan()); + common::OptionalWeights weights(ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() + : info.weights_.ConstHostSpan()); float slope = this->param_.huber_slope; CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0."; PackedReduceResult result = @@ -349,12 +349,13 @@ struct EvalEWiseBase : public MetricNoCache { if (info.labels.Size() != 0) { CHECK_NE(info.labels.Shape(1), 0); } + LOG(INFO) << "EvalEWiseBase::Eval 0"; auto labels = info.labels.View(ctx_->gpu_id); info.weights_.SetDevice(ctx_->gpu_id); - common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan() - : info.weights_.ConstDeviceSpan()); + common::OptionalWeights weights(ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() + : info.weights_.ConstHostSpan()); preds.SetDevice(ctx_->gpu_id); - auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); + auto predts = ctx_->IsCUDA() ? preds.ConstDeviceSpan() : preds.ConstHostSpan(); auto d_policy = policy_; auto result = diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h index ffd3ddec7201..b9fcc1793d3d 100644 --- a/src/objective/adaptive.h +++ b/src/objective/adaptive.h @@ -96,13 +96,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector const& position, std::int32_t group_idx, MetaInfo const& info, float learning_rate, HostDeviceVector const& predt, float alpha, RegTree* p_tree) { - if (ctx->IsCPU()) { - detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate, - predt, alpha, p_tree); - } else { + if (ctx->IsCUDA()) { position.SetDevice(ctx->gpu_id); detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, learning_rate, predt, alpha, p_tree); + } else { + detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate, + predt, alpha, p_tree); } } } // namespace obj diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index d0ff5bda5bde..2bec41c6efc9 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -108,14 +108,14 @@ class LambdaRankObj : public FitIntercept { li_.SetDevice(ctx_->gpu_id); lj_.SetDevice(ctx_->gpu_id); - if (ctx_->IsCPU()) { - cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id), - lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_, - &li_, &lj_, p_cache_); - } else { + if (ctx_->IsCUDA()) { cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_, &li_, &lj_, p_cache_); + } else { + cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id), + lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_, + &li_, &lj_, p_cache_); } li_full_.Data()->Fill(0.0); diff --git a/src/objective/objective.cc b/src/objective/objective.cc index 85cd9803d4ef..88d392ad6309 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -18,13 +18,20 @@ DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg); namespace xgboost { // implement factory functions ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) { - auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name); + std::string replaced_name = name; + if (ctx->IsSycl()) { + auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name + "_oneapi"); + if (e != nullptr) { + replaced_name += "_oneapi"; + } + } + auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(replaced_name); if (e == nullptr) { std::stringstream ss; for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) { ss << "Objective candidate: " << entry->name << "\n"; } - LOG(FATAL) << "Unknown objective function: `" << name << "`\n" + LOG(FATAL) << "Unknown objective function: `" << replaced_name << "`\n" << ss.str(); } auto pobj = (e->body)(); diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index f94b5edf0494..167ed82c27d7 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -71,15 +71,15 @@ class QuantileRegression : public ObjFunction { linalg::MakeTensorView(ctx_, out_gpair, info.num_row_, n_alphas, n_targets / n_alphas); info.weights_.SetDevice(ctx_->gpu_id); - common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() - : info.weights_.ConstDeviceSpan()}; + common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() + : info.weights_.ConstHostSpan()}; preds.SetDevice(ctx_->gpu_id); auto predt = linalg::MakeVec(&preds); auto n_samples = info.num_row_; alpha_.SetDevice(ctx_->gpu_id); - auto alpha = ctx_->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan(); + auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan(); linalg::ElementWiseKernel( ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable { @@ -106,27 +106,7 @@ class QuantileRegression : public ObjFunction { base_score->Reshape(n_targets); double sw{0}; - if (ctx_->IsCPU()) { - auto quantiles = base_score->HostView(); - auto h_weights = info.weights_.ConstHostVector(); - if (info.weights_.Empty()) { - sw = info.num_row_; - } else { - sw = std::accumulate(std::cbegin(h_weights), std::cend(h_weights), 0.0); - } - for (bst_target_t t{0}; t < n_targets; ++t) { - auto alpha = param_.quantile_alpha[t]; - auto h_labels = info.labels.HostView(); - if (h_weights.empty()) { - quantiles(t) = - common::Quantile(ctx_, alpha, linalg::cbegin(h_labels), linalg::cend(h_labels)); - } else { - CHECK_EQ(h_weights.size(), h_labels.Size()); - quantiles(t) = common::WeightedQuantile(ctx_, alpha, linalg::cbegin(h_labels), - linalg::cend(h_labels), std::cbegin(h_weights)); - } - } - } else { + if (ctx_->IsCUDA()) { #if defined(XGBOOST_USE_CUDA) alpha_.SetDevice(ctx_->gpu_id); auto d_alpha = alpha_.ConstDeviceSpan(); @@ -163,6 +143,26 @@ class QuantileRegression : public ObjFunction { #else common::AssertGPUSupport(); #endif // defined(XGBOOST_USE_CUDA) + } else { + auto quantiles = base_score->HostView(); + auto h_weights = info.weights_.ConstHostVector(); + if (info.weights_.Empty()) { + sw = info.num_row_; + } else { + sw = std::accumulate(std::cbegin(h_weights), std::cend(h_weights), 0.0); + } + for (bst_target_t t{0}; t < n_targets; ++t) { + auto alpha = param_.quantile_alpha[t]; + auto h_labels = info.labels.HostView(); + if (h_weights.empty()) { + quantiles(t) = + common::Quantile(ctx_, alpha, linalg::cbegin(h_labels), linalg::cend(h_labels)); + } else { + CHECK_EQ(h_weights.size(), h_labels.Size()); + quantiles(t) = common::WeightedQuantile(ctx_, alpha, linalg::cbegin(h_labels), + linalg::cend(h_labels), std::cbegin(h_weights)); + } + } } // For multiple quantiles, we should extend the base score to a vector instead of diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index a1a773f5340a..3a6a459af60e 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -235,8 +235,8 @@ class PseudoHuberRegression : public FitIntercept { auto predt = linalg::MakeVec(&preds); info.weights_.SetDevice(ctx_->gpu_id); - common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() - : info.weights_.ConstDeviceSpan()}; + common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() + : info.weights_.ConstHostSpan()}; linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); @@ -696,8 +696,8 @@ class MeanAbsoluteError : public ObjFunction { preds.SetDevice(ctx_->gpu_id); auto predt = linalg::MakeVec(&preds); info.weights_.SetDevice(ctx_->gpu_id); - common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() - : info.weights_.ConstDeviceSpan()}; + common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() + : info.weights_.ConstHostSpan()}; linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { auto sign = [](auto x) { diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index 3533de772f59..4fd32823002d 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -74,8 +74,8 @@ void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVectorgpu_id); auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets); - ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()) - : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)); + ctx->IsCUDA() ? cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)) + : cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()); } } // namespace tree } // namespace xgboost diff --git a/tests/python-oneapi/test_oneapi_prediction.py b/tests/python-oneapi/test_oneapi_prediction.py new file mode 100644 index 000000000000..6eee485b618b --- /dev/null +++ b/tests/python-oneapi/test_oneapi_prediction.py @@ -0,0 +1,151 @@ +import sys +import unittest +import pytest + +import numpy as np +import xgboost as xgb +from hypothesis import given, strategies, assume, settings, note + +from xgboost import testing as tm + +rng = np.random.RandomState(1994) + +shap_parameter_strategy = strategies.fixed_dictionaries({ + 'max_depth': strategies.integers(1, 11), + 'max_leaves': strategies.integers(0, 256), + 'num_parallel_tree': strategies.sampled_from([1, 10]), +}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0) + + +class TestOneAPIPredict(unittest.TestCase): + def test_predict(self): + iterations = 10 + np.random.seed(1) + test_num_rows = [10, 1000, 5000] + test_num_cols = [10, 50, 500] + for num_rows in test_num_rows: + for num_cols in test_num_cols: + dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols), + label=[0, 1] * int(num_rows / 2)) + dval = xgb.DMatrix(np.random.randn(num_rows, num_cols), + label=[0, 1] * int(num_rows / 2)) + dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols), + label=[0, 1] * int(num_rows / 2)) + watchlist = [(dtrain, 'train'), (dval, 'validation')] + res = {} + param = { + "objective": "binary:logistic", + 'eval_metric': 'logloss', + 'tree_method': 'hist', + 'device': 'cpu', + 'max_depth': 1, + 'verbosity': 0 + } + bst = xgb.train(param, dtrain, iterations, evals=watchlist, + evals_result=res) + assert self.non_increasing(res["train"]["logloss"]) + cpu_pred_train = bst.predict(dtrain, output_margin=True) + cpu_pred_test = bst.predict(dtest, output_margin=True) + cpu_pred_val = bst.predict(dval, output_margin=True) + + bst.set_param({"device": "sycl:gpu"}) + oneapi_pred_train = bst.predict(dtrain, output_margin=True) + oneapi_pred_test = bst.predict(dtest, output_margin=True) + oneapi_pred_val = bst.predict(dval, output_margin=True) + + np.testing.assert_allclose(cpu_pred_train, oneapi_pred_train, + rtol=1e-6) + np.testing.assert_allclose(cpu_pred_val, oneapi_pred_val, + rtol=1e-6) + np.testing.assert_allclose(cpu_pred_test, oneapi_pred_test, + rtol=1e-6) + + def non_increasing(self, L): + return all((y - x) < 0.001 for x, y in zip(L, L[1:])) + + @pytest.mark.skipif(**tm.no_sklearn()) + def test_multi_predict(self): + from sklearn.datasets import make_regression + from sklearn.model_selection import train_test_split + + n = 1000 + X, y = make_regression(n, random_state=rng) + X_train, X_test, y_train, y_test = train_test_split(X, y, + random_state=123) + dtrain = xgb.DMatrix(X_train, label=y_train) + dtest = xgb.DMatrix(X_test) + + params = {} + params["tree_method"] = "hist" + params["device"] = "cpu" + + bst = xgb.train(params, dtrain) + cpu_predict = bst.predict(dtest) + + bst.set_param({"device": "sycl:gpu"}) + + predict0 = bst.predict(dtest) + predict1 = bst.predict(dtest) + + assert np.allclose(predict0, predict1) + assert np.allclose(predict0, cpu_predict) + + @pytest.mark.skipif(**tm.no_sklearn()) + def test_sklearn(self): + m, n = 15000, 14 + tr_size = 2500 + X = np.random.rand(m, n) + y = 200 * np.matmul(X, np.arange(-3, -3 + n)) + X_train, y_train = X[:tr_size, :], y[:tr_size] + X_test, y_test = X[tr_size:, :], y[tr_size:] + + # First with cpu_predictor + params = {'tree_method': 'hist', + 'device': 'cpu', + 'n_jobs': -1, + 'verbosity' : 0, + 'seed': 123} + m = xgb.XGBRegressor(**params).fit(X_train, y_train) + cpu_train_score = m.score(X_train, y_train) + cpu_test_score = m.score(X_test, y_test) + + # Now with oneapi_predictor + params['device'] = 'sycl:gpu' + m.set_params(**params) + + # m = xgb.XGBRegressor(**params).fit(X_train, y_train) + oneapi_train_score = m.score(X_train, y_train) + # m = xgb.XGBRegressor(**params).fit(X_train, y_train) + oneapi_test_score = m.score(X_test, y_test) + + assert np.allclose(cpu_train_score, oneapi_train_score) + assert np.allclose(cpu_test_score, oneapi_test_score) + + @given(strategies.integers(1, 10), + tm.make_dataset_strategy(), shap_parameter_strategy) + @settings(deadline=None) + def test_shap(self, num_rounds, dataset, param): + param.update({"device": "sycl:gpu"}) + param = dataset.set_params(param) + dmat = dataset.get_dmat() + bst = xgb.train(param, dmat, num_rounds) + test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin) + shap = bst.predict(test_dmat, pred_contribs=True) + margin = bst.predict(test_dmat, output_margin=True) + assume(len(dataset.y) > 0) + assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3) + + @given(strategies.integers(1, 10), + tm.make_dataset_strategy(), shap_parameter_strategy) + @settings(deadline=None, max_examples=20) + def test_shap_interactions(self, num_rounds, dataset, param): + param.update({"device": "sycl:gpu"}) + param = dataset.set_params(param) + dmat = dataset.get_dmat() + bst = xgb.train(param, dmat, num_rounds) + test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin) + shap = bst.predict(test_dmat, pred_interactions=True) + margin = bst.predict(test_dmat, output_margin=True) + assume(len(dataset.y) > 0) + assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin, + 1e-3, 1e-3) diff --git a/tests/python-oneapi/test_oneapi_training_continuation.py b/tests/python-oneapi/test_oneapi_training_continuation.py new file mode 100644 index 000000000000..2ce809b076ce --- /dev/null +++ b/tests/python-oneapi/test_oneapi_training_continuation.py @@ -0,0 +1,56 @@ +import numpy as np +import xgboost as xgb +import json + +rng = np.random.RandomState(1994) + + +class TestOneAPITrainingContinuation: + def run_training_continuation(self, use_json): + kRows = 64 + kCols = 32 + X = np.random.randn(kRows, kCols) + y = np.random.randn(kRows) + dtrain = xgb.DMatrix(X, y) + params = {'device': 'sycl:gpu', 'max_depth': '2', + 'gamma': '0.1', 'alpha': '0.01', + 'enable_experimental_json_serialization': use_json} + bst_0 = xgb.train(params, dtrain, num_boost_round=64) + dump_0 = bst_0.get_dump(dump_format='json') + + bst_1 = xgb.train(params, dtrain, num_boost_round=32) + bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1) + dump_1 = bst_1.get_dump(dump_format='json') + + def recursive_compare(obj_0, obj_1): + if isinstance(obj_0, float): + assert np.isclose(obj_0, obj_1, atol=1e-6) + elif isinstance(obj_0, str): + assert obj_0 == obj_1 + elif isinstance(obj_0, int): + assert obj_0 == obj_1 + elif isinstance(obj_0, dict): + keys_0 = list(obj_0.keys()) + keys_1 = list(obj_1.keys()) + values_0 = list(obj_0.values()) + values_1 = list(obj_1.values()) + for i in range(len(obj_0.items())): + assert keys_0[i] == keys_1[i] + if list(obj_0.keys())[i] != 'missing': + recursive_compare(values_0[i], + values_1[i]) + else: + for i in range(len(obj_0)): + recursive_compare(obj_0[i], obj_1[i]) + + assert len(dump_0) == len(dump_1) + for i in range(len(dump_0)): + obj_0 = json.loads(dump_0[i]) + obj_1 = json.loads(dump_1[i]) + recursive_compare(obj_0, obj_1) + + def test_oneapi_training_continuation_binary(self): + self.run_training_continuation(False) + + def test_oneapi_training_continuation_json(self): + self.run_training_continuation(True) diff --git a/tests/python-oneapi/test_oneapi_updaters.py b/tests/python-oneapi/test_oneapi_updaters.py new file mode 100644 index 000000000000..282f7cfb2264 --- /dev/null +++ b/tests/python-oneapi/test_oneapi_updaters.py @@ -0,0 +1,70 @@ +import numpy as np +import gc +import pytest +import xgboost as xgb +from hypothesis import given, strategies, assume, settings, note + +import sys +import os +# sys.path.append("tests/python") +# import testing as tm +from xgboost import testing as tm + +parameter_strategy = strategies.fixed_dictionaries({ + 'max_depth': strategies.integers(0, 11), + 'max_leaves': strategies.integers(0, 256), + 'max_bin': strategies.integers(2, 1024), + 'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']), + 'single_precision_histogram': strategies.booleans(), + 'min_child_weight': strategies.floats(0.5, 2.0), + 'seed': strategies.integers(0, 10), + # We cannot enable subsampling as the training loss can increase + # 'subsample': strategies.floats(0.5, 1.0), + 'colsample_bytree': strategies.floats(0.5, 1.0), + 'colsample_bylevel': strategies.floats(0.5, 1.0), +}).filter(lambda x: (x['max_depth'] > 0 or x['max_leaves'] > 0) and ( + x['max_depth'] > 0 or x['grow_policy'] == 'lossguide')) + + +def train_result(param, dmat, num_rounds): + result = {} + xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False, + evals_result=result) + return result + + +class TestOneAPIUpdaters: + @given(parameter_strategy, strategies.integers(1, 5), + tm.make_dataset_strategy()) + @settings(deadline=None) + def test_oneapi_hist(self, param, num_rounds, dataset): + param['tree_method'] = 'hist' + param['device'] = 'sycl:gpu' + param['verbosity'] = 0 + param = dataset.set_params(param) + result = train_result(param, dataset.get_dmat(), num_rounds) + note(result) + assert tm.non_increasing(result['train'][dataset.metric]) + + @given(tm.make_dataset_strategy(), strategies.integers(0, 1)) + @settings(deadline=None) + def test_specified_device_id_oneapi_update(self, dataset, device_id): + # Read the list of sycl-devicese + sycl_ls = os.popen('sycl-ls').read() + devices = sycl_ls.split('\n') + + # Test should launch only on gpu + # Find gpus in the list of devices + # and use the id in the list insteard of device_id + target_device_type = "opencl:gpu" + found_devices = 0 + for idx in range(len(devices)): + if len(devices[idx]) >= len(target_device_type): + if devices[idx][1:1+len(target_device_type)] == target_device_type: + if (found_devices == device_id): + param = {'device': f"sycl:gpu:{idx}"} + param = dataset.set_params(param) + result = train_result(param, dataset.get_dmat(), 10) + assert tm.non_increasing(result['train'][dataset.metric]) + else: + found_devices += 1 \ No newline at end of file diff --git a/tests/python-oneapi/test_oneapi_with_sklearn.py b/tests/python-oneapi/test_oneapi_with_sklearn.py new file mode 100644 index 000000000000..0f71efee72de --- /dev/null +++ b/tests/python-oneapi/test_oneapi_with_sklearn.py @@ -0,0 +1,35 @@ +import xgboost as xgb +import pytest +import sys +import numpy as np + +from xgboost import testing as tm +sys.path.append("tests/python") +import test_with_sklearn as twskl # noqa + +pytestmark = pytest.mark.skipif(**tm.no_sklearn()) + +rng = np.random.RandomState(1994) + + +def test_oneapi_binary_classification(): + from sklearn.datasets import load_digits + from sklearn.model_selection import KFold + + digits = load_digits(n_class = 2) + y = digits['target'] + X = digits['data'] + kf = KFold(n_splits=2, shuffle=True, random_state=rng) + for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier): + for train_index, test_index in kf.split(X, y): + xgb_model = cls( + random_state=42, device='sycl:gpu', + n_estimators=4).fit(X[train_index], y[train_index]) + preds = xgb_model.predict(X[test_index]) + labels = y[test_index] + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + print(preds) + print(labels) + print(err) + assert err < 0.1