Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL. Add functional for sycl implementation of RowSetCollection. #10057

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions plugin/sycl/common/row_set.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*!
* Copyright 2017-2023 XGBoost contributors
*/
#ifndef PLUGIN_SYCL_COMMON_ROW_SET_H_
#define PLUGIN_SYCL_COMMON_ROW_SET_H_

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <xgboost/data.h>
#pragma GCC diagnostic pop
#include <algorithm>
#include <vector>
#include <utility>

#include "../data.h"

#include <CL/sycl.hpp>

namespace xgboost {
namespace sycl {
namespace common {


/*! \brief Collection of rowsets stored on device in USM memory */
class RowSetCollection {
public:
/*! \brief data structure to store an instance set, a subset of
* rows (instances) associated with a particular node in a decision
* tree. */
struct Elem {
const size_t* begin{nullptr};
const size_t* end{nullptr};
bst_node_t node_id{-1}; // id of node associated with this instance set; -1 means uninitialized
Elem()
= default;
Elem(const size_t* begin,
const size_t* end,
bst_node_t node_id = -1)
: begin(begin), end(end), node_id(node_id) {}


inline size_t Size() const {
return end - begin;
}
};

inline size_t Size() const {
return elem_of_each_node_.size();
}

/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr)
<< "access element that is not in the set";
return e;
}

/*! \brief return corresponding element set given the node_id */
inline Elem& operator[](unsigned node_id) {
Elem& e = elem_of_each_node_[node_id];
return e;
}

// clear up things
inline void Clear() {
elem_of_each_node_.clear();
}
// initialize node id 0->everything
inline void Init() {
CHECK_EQ(elem_of_each_node_.size(), 0U);

const size_t* begin = row_indices_.Begin();
const size_t* end = row_indices_.End();
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
}

auto& Data() { return row_indices_; }

// split rowset into two
inline void AddSplit(unsigned node_id,
unsigned left_node_id,
unsigned right_node_id,
size_t n_left,
size_t n_right) {
const Elem e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr);
size_t* all_begin = row_indices_.Begin();
size_t* begin = all_begin + (e.begin - all_begin);


CHECK_EQ(n_left + n_right, e.Size());
CHECK_LE(begin + n_left, e.end);
CHECK_EQ(begin + n_left + n_right, e.end);


if (left_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
}
if (right_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
}


elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id);
elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id);
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
}

private:
// stores the row indexes in the set
USMVector<size_t, MemoryType::on_device> row_indices_;
// vector: node_id -> elements
std::vector<Elem> elem_of_each_node_;
};

} // namespace common
} // namespace sycl
} // namespace xgboost


#endif // PLUGIN_SYCL_COMMON_ROW_SET_H_
78 changes: 78 additions & 0 deletions tests/cpp/plugin/test_sycl_row_set_collection.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* Copyright 2020-2023 by XGBoost contributors
*/
#include <gtest/gtest.h>

#include <string>
#include <utility>
#include <vector>

#include "../../../plugin/sycl/common/row_set.h"
#include "../../../plugin/sycl/device_manager.h"
#include "../helpers.h"

namespace xgboost::sycl::common {
TEST(SyclRowSetCollection, AddSplits) {
const size_t num_rows = 16;

DeviceManager device_manager;
auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault());

RowSetCollection row_set_collection;

auto& row_indices = row_set_collection.Data();
row_indices.Resize(&qu, num_rows);
size_t* p_row_indices = row_indices.Data();

qu.submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(num_rows),
[p_row_indices](::sycl::item<1> pid) {
const size_t idx = pid.get_id(0);
p_row_indices[idx] = idx;
});
}).wait_and_throw();
row_set_collection.Init();

CHECK_EQ(row_set_collection.Size(), 1);
{
size_t nid_test = 0;
auto& elem = row_set_collection[nid_test];
CHECK_EQ(elem.begin, row_indices.Begin());
CHECK_EQ(elem.end, row_indices.End());
CHECK_EQ(elem.node_id , 0);
}

size_t nid = 0;
size_t nid_left = 1;
size_t nid_right = 2;
size_t n_left = 4;
size_t n_right = num_rows - n_left;
row_set_collection.AddSplit(nid, nid_left, nid_right, n_left, n_right);
CHECK_EQ(row_set_collection.Size(), 3);

{
size_t nid_test = 0;
auto& elem = row_set_collection[nid_test];
CHECK_EQ(elem.begin, nullptr);
CHECK_EQ(elem.end, nullptr);
CHECK_EQ(elem.node_id , -1);
}

{
size_t nid_test = 1;
auto& elem = row_set_collection[nid_test];
CHECK_EQ(elem.begin, row_indices.Begin());
CHECK_EQ(elem.end, row_indices.Begin() + n_left);
CHECK_EQ(elem.node_id , nid_test);
}

{
size_t nid_test = 2;
auto& elem = row_set_collection[nid_test];
CHECK_EQ(elem.begin, row_indices.Begin() + n_left);
CHECK_EQ(elem.end, row_indices.End());
CHECK_EQ(elem.node_id , nid_test);
}

}
} // namespace xgboost::sycl::common
Loading