From d4eb488636a42c3b91aa9f1023220076b71e57dd Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 16 Mar 2023 10:08:11 -0700 Subject: [PATCH 1/7] handle num_col and get gradients --- include/xgboost/data.h | 9 +++ src/data/data.cc | 15 ++-- src/data/iterative_dmatrix.cc | 2 +- src/data/iterative_dmatrix.cu | 2 +- src/data/simple_dmatrix.cc | 6 +- src/data/simple_dmatrix.cu | 2 +- src/data/sparse_page_dmatrix.cc | 2 +- src/learner.cc | 24 ++++++- tests/cpp/plugin/helpers.cc | 6 ++ tests/cpp/plugin/helpers.h | 47 ++++++++++++- tests/cpp/plugin/test_federated_adapter.cu | 65 ++++------------- .../cpp/plugin/test_federated_communicator.cc | 60 +++------------- tests/cpp/plugin/test_federated_data.cc | 70 +++++++++++++++++++ tests/cpp/plugin/test_federated_server.cc | 45 ++---------- 14 files changed, 197 insertions(+), 158 deletions(-) create mode 100644 tests/cpp/plugin/test_federated_data.cc diff --git a/include/xgboost/data.h b/include/xgboost/data.h index ec78c588d5d9..4c96ceebfa8c 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -171,6 +171,15 @@ class MetaInfo { */ void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column); + /** + * @brief Synchronize the number of columns across all workers. + * + * Normally we just need to find the maximum number of columns across all workers, but + * in vertical federated learning, since each worker loads its own list of columns, + * we need to sum them. + */ + void SynchronizeNumberOfColumns(); + private: void SetInfoFromHost(Context const& ctx, StringView key, Json arr); void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr); diff --git a/src/data/data.cc b/src/data/data.cc index d24048a2ab23..c20197145d3c 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -700,6 +700,14 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col } } +void MetaInfo::SynchronizeNumberOfColumns() { + if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) { + collective::Allreduce(&num_col_, 1); + } else { + collective::Allreduce(&num_col_, 1); + } +} + void MetaInfo::Validate(std::int32_t device) const { if (group_ptr_.size() != 0 && weights_.Size() != 0) { CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) @@ -903,10 +911,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s LOG(FATAL) << "Encountered parser error:\n" << e.what(); } - /* sync up number of features after matrix loaded. - * partitioned data will fail the train/val validation check - * since partitioned data not knowing the real number of features. */ - collective::Allreduce(&dmat->Info().num_col_, 1); + dmat->Info().data_split_mode = data_split_mode; + dmat->Info().SynchronizeNumberOfColumns(); if (need_split && data_split_mode == DataSplitMode::kCol) { if (!cache_file.empty()) { @@ -917,7 +923,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s delete dmat; return sliced; } else { - dmat->Info().data_split_mode = data_split_mode; return dmat; } } diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index ae0cfc4a48fa..c7ac492c9a14 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -190,7 +190,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, // From here on Info() has the correct data shape Info().num_row_ = accumulated_rows; Info().num_nonzero_ = nnz; - collective::Allreduce(&info_.num_col_, 1); + Info().SynchronizeNumberOfColumns(); CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) { return f > accumulated_rows; })) << "Something went wrong during iteration."; diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index 2d4a0bb0b123..5e7fc8d4f565 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -166,7 +166,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing, iter.Reset(); // Synchronise worker columns - collective::Allreduce(&info_.num_col_, 1); + info_.SynchronizeNumberOfColumns(); } BatchSet IterativeDMatrix::GetEllpackBatches(BatchParam const& param) { diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 014b57282d0e..dbdc79e6d558 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -215,10 +215,6 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { info_.num_col_ = adapter->NumColumns(); } - - // Synchronise worker columns - collective::Allreduce(&info_.num_col_, 1); - if (adapter->NumRows() == kAdapterUnknownSize) { using IteratorAdapterT = IteratorAdapter; @@ -346,7 +342,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i } // Synchronise worker columns info_.num_col_ = adapter->NumColumns(); - collective::Allreduce(&info_.num_col_, 1); + info_.SynchronizeNumberOfColumns(); info_.num_row_ = total_batch_size; info_.num_nonzero_ = data_vec.size(); CHECK_EQ(offset_vec.back(), info_.num_nonzero_); diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 64f308b8c2bd..8295718ba221 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -35,7 +35,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); // Synchronise worker columns - collective::Allreduce(&info_.num_col_, 1); + info_.SynchronizeNumberOfColumns(); } template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing, diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 698e1e5b2967..5e5b622af264 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -96,7 +96,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p this->info_.num_col_ = n_features; this->info_.num_nonzero_ = nnz; - collective::Allreduce(&info_.num_col_, 1); + info_.SynchronizeNumberOfColumns(); CHECK_NE(info_.num_col_, 0); } diff --git a/src/learner.cc b/src/learner.cc index 62875ead6cb8..85086b9d7ad1 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1303,7 +1303,7 @@ class LearnerImpl : public LearnerIO { monitor_.Stop("PredictRaw"); monitor_.Start("GetGradient"); - obj_->GetGradient(predt.predictions, train->Info(), iter, &gpair_); + GetGradient(predt.predictions, train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); TrainingObserver::Instance().Observe(gpair_, "Gradients"); @@ -1482,6 +1482,28 @@ class LearnerImpl : public LearnerIO { } private: + void GetGradient(HostDeviceVector const& preds, MetaInfo const& info, int iteration, + HostDeviceVector* out_gpair) { + // Special handling for vertical federated learning. + if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) { + // We assume labels are only available on worker 0, so the gradients are calculated there + // and broadcast to other workers. + if (collective::GetRank() == 0) { + obj_->GetGradient(preds, info, iteration, out_gpair); + collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair), + 0); + } else { + CHECK_EQ(info.labels.Size(), 0) + << "In vertical federated learning, labels should only be on the first worker"; + out_gpair->Resize(preds.Size()); + collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair), + 0); + } + } else { + obj_->GetGradient(preds, info, iteration, out_gpair); + } + } + /*! \brief random number transformation seed. */ static int32_t constexpr kRandSeedMagic = 127; // gradient pairs diff --git a/tests/cpp/plugin/helpers.cc b/tests/cpp/plugin/helpers.cc index a70479b1bb1c..722c350dd9ad 100644 --- a/tests/cpp/plugin/helpers.cc +++ b/tests/cpp/plugin/helpers.cc @@ -17,3 +17,9 @@ int GenerateRandomPort(int low, int high) { int port = dist(rng); return port; } + +std::string GetServerAddress() { + int port = GenerateRandomPort(50000, 60000); + std::string address = std::string("localhost:") + std::to_string(port); + return address; +} diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index ea72f1538af6..ffa68946cd83 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -1,10 +1,53 @@ /*! * Copyright 2022 XGBoost contributors */ - #ifndef XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ #define XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ -int GenerateRandomPort(int low, int high); +#include +#include +#include + +#include "../../../plugin/federated/federated_server.h" +#include "../../../src/collective/communicator-inl.h" + +std::string GetServerAddress(); + +namespace xgboost { + +class BaseFederatedTest : public ::testing::Test { + protected: + void SetUp() override { + server_address_ = GetServerAddress(); + server_thread_.reset(new std::thread([this] { + grpc::ServerBuilder builder; + xgboost::federated::FederatedService service{kWorldSize}; + builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + server_->Wait(); + })); + } + + void TearDown() override { + server_->Shutdown(); + server_thread_->join(); + } + + void InitCommunicator(int rank) { + Json config{JsonObject()}; + config["xgboost_communicator"] = String("federated"); + config["federated_server_address"] = String(server_address_); + config["federated_world_size"] = kWorldSize; + config["federated_rank"] = rank; + xgboost::collective::Init(config); + } + + static int const kWorldSize{3}; + std::string server_address_; + std::unique_ptr server_thread_; + std::unique_ptr server_; +}; +} // namespace xgboost #endif // XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index 794c60909e76..c4816ff18dd3 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -1,56 +1,20 @@ /*! * Copyright 2022 XGBoost contributors */ -#include #include #include +#include #include #include -#include -#include "./helpers.h" #include "../../../plugin/federated/federated_communicator.h" -#include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/device_communicator_adapter.cuh" +#include "./helpers.h" -namespace { - -std::string GetServerAddress() { - int port = GenerateRandomPort(50000, 60000); - std::string address = std::string("localhost:") + std::to_string(port); - return address; -} - -} // anonymous namespace - -namespace xgboost { -namespace collective { - -class FederatedAdapterTest : public ::testing::Test { - protected: - void SetUp() override { - server_address_ = GetServerAddress(); - server_thread_.reset(new std::thread([this] { - grpc::ServerBuilder builder; - federated::FederatedService service{kWorldSize}; - builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); - builder.RegisterService(&service); - server_ = builder.BuildAndStart(); - server_->Wait(); - })); - } - - void TearDown() override { - server_->Shutdown(); - server_thread_->join(); - } +namespace xgboost::collective { - static int const kWorldSize{2}; - std::string server_address_; - std::unique_ptr server_thread_; - std::unique_ptr server_; -}; +class FederatedAdapterTest : public BaseFederatedTest {}; TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; }; @@ -65,20 +29,20 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) { TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread([rank, server_address=server_address_] { + threads.emplace_back([rank, server_address = server_address_] { FederatedCommunicator comm{kWorldSize, rank, server_address}; // Assign device 0 to all workers, since we run gtest in a single-GPU machine DeviceCommunicatorAdapter adapter{0, &comm}; - int const count = 3; + int count = 3; thrust::device_vector buffer(count, 0); thrust::sequence(buffer.begin(), buffer.end()); adapter.AllReduceSum(buffer.data().get(), count); thrust::host_vector host_buffer = buffer; EXPECT_EQ(host_buffer.size(), count); for (auto i = 0; i < count; i++) { - EXPECT_EQ(host_buffer[i], i * 2); + EXPECT_EQ(host_buffer[i], i * kWorldSize); } - })); + }); } for (auto& thread : threads) { thread.join(); @@ -88,7 +52,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { TEST_F(FederatedAdapterTest, DeviceAllGatherV) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread([rank, server_address=server_address_] { + threads.emplace_back([rank, server_address = server_address_] { FederatedCommunicator comm{kWorldSize, rank, server_address}; // Assign device 0 to all workers, since we run gtest in a single-GPU machine DeviceCommunicatorAdapter adapter{0, &comm}; @@ -104,17 +68,16 @@ TEST_F(FederatedAdapterTest, DeviceAllGatherV) { EXPECT_EQ(segments[0], 2); EXPECT_EQ(segments[1], 3); thrust::host_vector host_buffer = receive_buffer; - EXPECT_EQ(host_buffer.size(), 5); - int expected[] = {0, 1, 0, 1, 2}; - for (auto i = 0; i < 5; i++) { + EXPECT_EQ(host_buffer.size(), 9); + int expected[] = {0, 1, 0, 1, 2, 0, 1, 2, 3}; + for (auto i = 0; i < 9; i++) { EXPECT_EQ(host_buffer[i], expected[i]); } - })); + }); } for (auto& thread : threads) { thread.join(); } } -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index f5d72e5f4972..5177187c5d50 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -2,65 +2,34 @@ * Copyright 2022 XGBoost contributors */ #include -#include #include #include #include -#include -#include "helpers.h" #include "../../../plugin/federated/federated_communicator.h" -#include "../../../plugin/federated/federated_server.h" - -namespace { - -std::string GetServerAddress() { - int port = GenerateRandomPort(50000, 60000); - std::string address = std::string("localhost:") + std::to_string(port); - return address; -} - -} // anonymous namespace +#include "helpers.h" -namespace xgboost { -namespace collective { +namespace xgboost::collective { -class FederatedCommunicatorTest : public ::testing::Test { +class FederatedCommunicatorTest : public BaseFederatedTest { public: - static void VerifyAllgather(int rank, const std::string& server_address) { + static void VerifyAllgather(int rank, const std::string &server_address) { FederatedCommunicator comm{kWorldSize, rank, server_address}; CheckAllgather(comm, rank); } - static void VerifyAllreduce(int rank, const std::string& server_address) { + static void VerifyAllreduce(int rank, const std::string &server_address) { FederatedCommunicator comm{kWorldSize, rank, server_address}; CheckAllreduce(comm); } - static void VerifyBroadcast(int rank, const std::string& server_address) { + static void VerifyBroadcast(int rank, const std::string &server_address) { FederatedCommunicator comm{kWorldSize, rank, server_address}; CheckBroadcast(comm, rank); } protected: - void SetUp() override { - server_address_ = GetServerAddress(); - server_thread_.reset(new std::thread([this] { - grpc::ServerBuilder builder; - federated::FederatedService service{kWorldSize}; - builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); - builder.RegisterService(&service); - server_ = builder.BuildAndStart(); - server_->Wait(); - })); - } - - void TearDown() override { - server_->Shutdown(); - server_thread_->join(); - } - static void CheckAllgather(FederatedCommunicator &comm, int rank) { int buffer[kWorldSize] = {0, 0, 0}; buffer[rank] = rank; @@ -90,11 +59,6 @@ class FederatedCommunicatorTest : public ::testing::Test { EXPECT_EQ(buffer, "hello"); } } - - static int const kWorldSize{3}; - std::string server_address_; - std::unique_ptr server_thread_; - std::unique_ptr server_; }; TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { @@ -161,8 +125,7 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) { TEST_F(FederatedCommunicatorTest, Allgather) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back( - std::thread(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_)); + threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_); } for (auto &thread : threads) { thread.join(); @@ -172,8 +135,7 @@ TEST_F(FederatedCommunicatorTest, Allgather) { TEST_F(FederatedCommunicatorTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back( - std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_)); + threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_); } for (auto &thread : threads) { thread.join(); @@ -183,12 +145,10 @@ TEST_F(FederatedCommunicatorTest, Allreduce) { TEST_F(FederatedCommunicatorTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back( - std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_)); + threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_); } for (auto &thread : threads) { thread.join(); } } -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/test_federated_data.cc b/tests/cpp/plugin/test_federated_data.cc new file mode 100644 index 000000000000..3789c34a7c17 --- /dev/null +++ b/tests/cpp/plugin/test_federated_data.cc @@ -0,0 +1,70 @@ +/*! + * Copyright 2023 XGBoost contributors + */ +#include +#include +#include + +#include +#include +#include + +#include "../../../plugin/federated/federated_server.h" +#include "../../../src/collective/communicator-inl.h" +#include "../filesystem.h" +#include "helpers.h" + +namespace xgboost { + +class FederatedDataTest : public BaseFederatedTest { + public: + void VerifyLoadUri(int rank) { + InitCommunicator(rank); + + size_t constexpr kRows{16}; + size_t const kCols = 8 + rank; + std::vector data(kRows * kCols); + + for (size_t i = 0; i < kRows * kCols; ++i) { + data[i] = i; + } + + dmlc::TemporaryDirectory tmpdir; + std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv"; + + std::ofstream fout(path); + size_t i = 0; + for (size_t r = 0; r < kRows; ++r) { + for (size_t c = 0; c < kCols; ++c) { + fout << data[i]; + i++; + if (c != kCols - 1) { + fout << ","; + } + } + fout << "\n"; + } + fout.flush(); + fout.close(); + + std::unique_ptr dmat; + std::string uri = path + "?format=csv"; + dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol)); + + ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3); + ASSERT_EQ(dmat->Info().num_row_, kRows); + + xgboost::collective::Finalize(); + } +}; + +TEST_F(FederatedDataTest, LoadUri) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(&FederatedDataTest_LoadUri_Test::VerifyLoadUri, this, rank); + } + for (auto& thread : threads) { + thread.join(); + } +} +} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index fa9c272d2903..79e06bf5f07c 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -1,30 +1,17 @@ /*! * Copyright 2017-2020 XGBoost contributors */ -#include #include -#include #include #include #include "federated_client.h" -#include "federated_server.h" #include "helpers.h" -namespace { - -std::string GetServerAddress() { - int port = GenerateRandomPort(50000, 60000); - std::string address = std::string("localhost:") + std::to_string(port); - return address; -} - -} // anonymous namespace - namespace xgboost { -class FederatedServerTest : public ::testing::Test { +class FederatedServerTest : public BaseFederatedTest { public: static void VerifyAllgather(int rank, const std::string& server_address) { federated::FederatedClient client{server_address, rank}; @@ -51,23 +38,6 @@ class FederatedServerTest : public ::testing::Test { } protected: - void SetUp() override { - server_address_ = GetServerAddress(); - server_thread_.reset(new std::thread([this] { - grpc::ServerBuilder builder; - federated::FederatedService service{kWorldSize}; - builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); - builder.RegisterService(&service); - server_ = builder.BuildAndStart(); - server_->Wait(); - })); - } - - void TearDown() override { - server_->Shutdown(); - server_thread_->join(); - } - static void CheckAllgather(federated::FederatedClient& client, int rank) { int data[kWorldSize] = {0, 0, 0}; data[rank] = rank; @@ -98,17 +68,12 @@ class FederatedServerTest : public ::testing::Test { auto reply = client.Broadcast(send_buffer, 0); EXPECT_EQ(reply, "hello broadcast") << "rank " << rank; } - - static int const kWorldSize{3}; - std::string server_address_; - std::unique_ptr server_thread_; - std::unique_ptr server_; }; TEST_F(FederatedServerTest, Allgather) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_)); + threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_address_); } for (auto& thread : threads) { thread.join(); @@ -118,7 +83,7 @@ TEST_F(FederatedServerTest, Allgather) { TEST_F(FederatedServerTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank, server_address_)); + threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_address_); } for (auto& thread : threads) { thread.join(); @@ -128,7 +93,7 @@ TEST_F(FederatedServerTest, Allreduce) { TEST_F(FederatedServerTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank, server_address_)); + threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_address_); } for (auto& thread : threads) { thread.join(); @@ -138,7 +103,7 @@ TEST_F(FederatedServerTest, Broadcast) { TEST_F(FederatedServerTest, Mixture) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank, server_address_)); + threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_address_); } for (auto& thread : threads) { thread.join(); From 858bdedfe84f8f0d51ed862c41d827ae2d70fbaf Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 16 Mar 2023 12:13:39 -0700 Subject: [PATCH 2/7] fix base score estimation --- include/xgboost/data.h | 15 +++++++++++++++ src/data/data.cc | 16 ++++++++++++++++ src/data/iterative_dmatrix.h | 3 +++ src/data/proxy_dmatrix.h | 3 +++ src/data/simple_dmatrix.cc | 7 +++++++ src/data/simple_dmatrix.h | 1 + src/data/sparse_page_dmatrix.h | 3 +++ src/learner.cc | 21 ++++++++++++++++++++- src/objective/init_estimation.cc | 2 +- src/tree/fit_stump.cc | 15 ++++++++++----- src/tree/fit_stump.h | 3 ++- tests/cpp/tree/test_fit_stump.cc | 3 ++- 12 files changed, 83 insertions(+), 9 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 4c96ceebfa8c..dc7007760260 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -334,6 +334,10 @@ class SparsePage { * \brief Check wether the column index is sorted. */ bool IsIndicesSorted(int32_t n_threads) const; + /** + * \brief Reindex the column index with an offset. + */ + void Reindex(uint64_t feature_offset, int32_t n_threads); void SortRows(int32_t n_threads); @@ -641,6 +645,17 @@ class DMatrix { */ virtual DMatrix *SliceCol(int num_slices, int slice_id) = 0; + /** + * \brief Reindex the features based on a global view. + * + * In some cases (e.g. vertical federated learning), features are loaded locally with indices + * starting from 0. However, all the algorithms assume the features are globally indexed, so we + * reindex the features based on the offset needed to obtain the global view. + * + * \param offset The offset to be added to the feature index + */ + virtual void ReindexFeatures(uint64_t offset) = 0; + protected: virtual BatchSet GetRowBatches() = 0; virtual BatchSet GetColumnBatches() = 0; diff --git a/src/data/data.cc b/src/data/data.cc index c20197145d3c..7267f0c78f40 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -911,6 +911,15 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s LOG(FATAL) << "Encountered parser error:\n" << e.what(); } + + if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) { + std::vector buffer(collective::GetWorldSize()); + buffer[collective::GetRank()] = dmat->Info().num_col_; + collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t)); + auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0); + dmat->ReindexFeatures(offset); + } + dmat->Info().data_split_mode = data_split_mode; dmat->Info().SynchronizeNumberOfColumns(); @@ -1053,6 +1062,13 @@ void SparsePage::SortIndices(int32_t n_threads) { }); } +void SparsePage::Reindex(uint64_t feature_offset, int32_t n_threads) { + auto& h_data = this->data.HostVector(); + common::ParallelFor(this->Size(), n_threads, [&](auto i) { + h_data[i].index += feature_offset; + }); +} + void SparsePage::SortRows(int32_t n_threads) { auto& h_offset = this->offset.HostVector(); auto& h_data = this->data.HostVector(); diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 28c4087c419a..f70be62595c9 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -90,6 +90,9 @@ class IterativeDMatrix : public DMatrix { LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix."; return nullptr; } + void ReindexFeatures(uint64_t offset) override { + LOG(FATAL) << "Reindexing features is not supported for Quantile DMatrix."; + } BatchSet GetRowBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 7a15d6498745..ec550cce3984 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -89,6 +89,9 @@ class DMatrixProxy : public DMatrix { LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix."; return nullptr; } + void ReindexFeatures(uint64_t offset) override { + LOG(FATAL) << "Reindexing features is not supported for Proxy DMatrix."; + } BatchSet GetRowBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index dbdc79e6d558..27c4bdce721e 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -73,6 +73,13 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { return out; } +void SimpleDMatrix::ReindexFeatures(uint64_t offset) { + if (offset == 0) { + return; + } + sparse_page_->Reindex(offset, Ctx()->Threads()); +} + BatchSet SimpleDMatrix::GetRowBatches() { // since csr is the default data structure so `source_` is always available. auto begin_iter = BatchIterator( diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 897abfcf0a9d..fb60ebc95ff1 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -36,6 +36,7 @@ class SimpleDMatrix : public DMatrix { bool SingleColBlock() const override { return true; } DMatrix* Slice(common::Span ridxs) override; DMatrix* SliceCol(int num_slices, int slice_id) override; + void ReindexFeatures(uint64_t offset) override; /*! \brief magic number used to identify SimpleDMatrix binary files */ static const int kMagic = 0xffffab01; diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index aa0be69845aa..03d43d9a9a20 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -111,6 +111,9 @@ class SparsePageDMatrix : public DMatrix { LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory."; return nullptr; } + void ReindexFeatures(uint64_t offset) override { + LOG(FATAL) << "Reindexing features is not supported for external memory."; + } private: BatchSet GetRowBatches() override; diff --git a/src/learner.cc b/src/learner.cc index 85086b9d7ad1..eca8d5762683 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -440,7 +440,7 @@ class LearnerConfiguration : public Learner { info.Validate(Ctx()->gpu_id); // We estimate it from input data. linalg::Tensor base_score; - UsePtr(obj_)->InitEstimation(info, &base_score); + InitEstimation(info, &base_score); CHECK_EQ(base_score.Size(), 1); mparam_.base_score = base_score(0); CHECK(!std::isnan(mparam_.base_score)); @@ -857,6 +857,25 @@ class LearnerConfiguration : public Learner { mparam_.num_target = n_targets; } } + + void InitEstimation(MetaInfo const& info, linalg::Tensor* base_score) { + // Special handling for vertical federated learning. + if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) { + // We assume labels are only available on worker 0, so the estimation is calculated there + // and added to other workers. + if (collective::GetRank() == 0) { + UsePtr(obj_)->InitEstimation(info, base_score); + collective::Broadcast(base_score->Data()->HostPointer(), + sizeof(bst_float) * base_score->Size(), 0); + } else { + base_score->Reshape(1); + collective::Broadcast(base_score->Data()->HostPointer(), + sizeof(bst_float) * base_score->Size(), 0); + } + } else { + UsePtr(obj_)->InitEstimation(info, base_score); + } + } }; std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT diff --git a/src/objective/init_estimation.cc b/src/objective/init_estimation.cc index 96fd5d65379c..938ceb59d973 100644 --- a/src/objective/init_estimation.cc +++ b/src/objective/init_estimation.cc @@ -33,7 +33,7 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* b new_obj->GetGradient(dummy_predt, info, 0, &gpair); bst_target_t n_targets = this->Targets(info); linalg::Vector leaf_weight; - tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight); + tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight); // workaround, we don't support multi-target due to binary model serialization for // base margin. diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index ad0253d22be4..5131f9284268 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -21,7 +21,8 @@ namespace xgboost { namespace tree { namespace cpu_impl { -void FitStump(Context const* ctx, linalg::TensorView gpair, +void FitStump(Context const* ctx, MetaInfo const& info, + linalg::TensorView gpair, linalg::VectorView out) { auto n_targets = out.Size(); CHECK_EQ(n_targets, gpair.Shape(1)); @@ -43,8 +44,12 @@ void FitStump(Context const* ctx, linalg::TensorView gpai } } CHECK(h_sum.CContiguous()); - collective::Allreduce( - reinterpret_cast(h_sum.Values().data()), h_sum.Size() * 2); + + // In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce. + if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) { + collective::Allreduce( + reinterpret_cast(h_sum.Values().data()), h_sum.Size() * 2); + } for (std::size_t i = 0; i < h_sum.Size(); ++i) { out(i) = static_cast(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess())); @@ -64,7 +69,7 @@ inline void FitStump(Context const*, linalg::TensorView, #endif // !defined(XGBOOST_USE_CUDA) } // namespace cuda_impl -void FitStump(Context const* ctx, HostDeviceVector const& gpair, +void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector const& gpair, bst_target_t n_targets, linalg::Vector* out) { out->SetDevice(ctx->gpu_id); out->Reshape(n_targets); @@ -72,7 +77,7 @@ void FitStump(Context const* ctx, HostDeviceVector const& gpair, gpair.SetDevice(ctx->gpu_id); auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets); - ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView()) + ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()) : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)); } } // namespace tree diff --git a/src/tree/fit_stump.h b/src/tree/fit_stump.h index 1f5cd60b4928..4778ecfc5dec 100644 --- a/src/tree/fit_stump.h +++ b/src/tree/fit_stump.h @@ -16,6 +16,7 @@ #include "../common/common.h" // AssertGPUSupport #include "xgboost/base.h" // GradientPair #include "xgboost/context.h" // Context +#include "xgboost/data.h" // MetaInfo #include "xgboost/host_device_vector.h" // HostDeviceVector #include "xgboost/linalg.h" // TensorView @@ -30,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) { /** * @brief Fit a tree stump as an estimation of base_score. */ -void FitStump(Context const* ctx, HostDeviceVector const& gpair, +void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector const& gpair, bst_target_t n_targets, linalg::Vector* out); } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_fit_stump.cc b/tests/cpp/tree/test_fit_stump.cc index ef608e5757d9..35a6af9943c0 100644 --- a/tests/cpp/tree/test_fit_stump.cc +++ b/tests/cpp/tree/test_fit_stump.cc @@ -21,7 +21,8 @@ void TestFitStump(Context const *ctx) { } } linalg::Vector out; - FitStump(ctx, gpair, kTargets, &out); + MetaInfo info; + FitStump(ctx, info, gpair, kTargets, &out); auto h_out = out.HostView(); for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) { // sum_hess == kRows From e1f67f25848617b9922653935600ed4934be5466 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 16 Mar 2023 17:37:56 -0700 Subject: [PATCH 3/7] fix reindexing --- src/data/data.cc | 2 +- tests/cpp/plugin/test_federated_data.cc | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/data/data.cc b/src/data/data.cc index 7267f0c78f40..9e669f5a7c3e 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1064,7 +1064,7 @@ void SparsePage::SortIndices(int32_t n_threads) { void SparsePage::Reindex(uint64_t feature_offset, int32_t n_threads) { auto& h_data = this->data.HostVector(); - common::ParallelFor(this->Size(), n_threads, [&](auto i) { + common::ParallelFor(h_data.size(), n_threads, [&](auto i) { h_data[i].index += feature_offset; }); } diff --git a/tests/cpp/plugin/test_federated_data.cc b/tests/cpp/plugin/test_federated_data.cc index 3789c34a7c17..67194fcde320 100644 --- a/tests/cpp/plugin/test_federated_data.cc +++ b/tests/cpp/plugin/test_federated_data.cc @@ -54,6 +54,19 @@ class FederatedDataTest : public BaseFederatedTest { ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3); ASSERT_EQ(dmat->Info().num_row_, kRows); + for (auto const& page : dmat->GetBatches()) { + auto entries = page.GetView().data; + auto index = 0; + int offsets[] = {0, 8, 17}; + int offset = offsets[rank]; + for (auto row = 0; row < kRows; row++) { + for (auto col = 0; col < kCols; col++) { + EXPECT_EQ(entries[index].index, col + offset); + index++; + } + } + } + xgboost::collective::Finalize(); } }; From 1e0c070ec4932d27a726af3508b62e773966f6f2 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 17 Mar 2023 11:07:10 -0700 Subject: [PATCH 4/7] remove plugin helpers.cc --- tests/cpp/plugin/helpers.cc | 25 ------------------------- tests/cpp/plugin/helpers.h | 28 ++++++++++++++++++++++------ 2 files changed, 22 insertions(+), 31 deletions(-) delete mode 100644 tests/cpp/plugin/helpers.cc diff --git a/tests/cpp/plugin/helpers.cc b/tests/cpp/plugin/helpers.cc deleted file mode 100644 index 722c350dd9ad..000000000000 --- a/tests/cpp/plugin/helpers.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include -#include -#include - -#include "helpers.h" - -using namespace std::chrono_literals; - -int GenerateRandomPort(int low, int high) { - // Ensure unique timestamp by introducing a small artificial delay - std::this_thread::sleep_for(100ms); - auto timestamp = static_cast(std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()).count()); - std::mt19937_64 rng(timestamp); - std::uniform_int_distribution dist(low, high); - int port = dist(rng); - return port; -} - -std::string GetServerAddress() { - int port = GenerateRandomPort(50000, 60000); - std::string address = std::string("localhost:") + std::to_string(port); - return address; -} diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index ffa68946cd83..0ac6746f8916 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -1,17 +1,35 @@ /*! - * Copyright 2022 XGBoost contributors + * Copyright 2022-2023 XGBoost contributors */ -#ifndef XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ -#define XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ +#pragma once #include #include #include +#include + #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" -std::string GetServerAddress(); +inline int GenerateRandomPort(int low, int high) { + using namespace std::chrono_literals; + // Ensure unique timestamp by introducing a small artificial delay + std::this_thread::sleep_for(100ms); + auto timestamp = static_cast(std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + std::mt19937_64 rng(timestamp); + std::uniform_int_distribution dist(low, high); + int port = dist(rng); + return port; +} + +inline std::string GetServerAddress() { + int port = GenerateRandomPort(50000, 60000); + std::string address = std::string("localhost:") + std::to_string(port); + return address; +} namespace xgboost { @@ -49,5 +67,3 @@ class BaseFederatedTest : public ::testing::Test { std::unique_ptr server_; }; } // namespace xgboost - -#endif // XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ From 6ac73e3dbc2c6d305e7a91cfc4027c30a073ae57 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 17 Mar 2023 19:01:26 -0700 Subject: [PATCH 5/7] extract method to create test csv --- tests/cpp/data/test_data.cc | 27 ++++--------------------- tests/cpp/helpers.cc | 23 +++++++++++++++++++++ tests/cpp/helpers.h | 2 ++ tests/cpp/plugin/test_federated_data.cc | 22 ++------------------ 4 files changed, 31 insertions(+), 43 deletions(-) diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index c3732819287c..99cd72cc09a0 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -112,31 +112,12 @@ TEST(SparsePage, SortIndices) { } TEST(DMatrix, Uri) { - size_t constexpr kRows {16}; - size_t constexpr kCols {8}; - std::vector data (kRows * kCols); - - for (size_t i = 0; i < kRows * kCols; ++i) { - data[i] = i; - } + auto constexpr kRows {16}; + auto constexpr kCols {8}; dmlc::TemporaryDirectory tmpdir; - std::string path = tmpdir.path + "/small.csv"; - - std::ofstream fout(path); - size_t i = 0; - for (size_t r = 0; r < kRows; ++r) { - for (size_t c = 0; c < kCols; ++c) { - fout << data[i]; - i++; - if (c != kCols - 1) { - fout << ","; - } - } - fout << "\n"; - } - fout.flush(); - fout.close(); + auto const path = tmpdir.path + "/small.csv"; + CreateTestCSV(path, kRows, kCols); std::unique_ptr dmat; // FIXME(trivialfis): Enable the following test by restricting csv parser in dmlc-core. diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 9236f569fb2c..49813f1d04de 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -65,6 +65,29 @@ void CreateBigTestData(const std::string& filename, size_t n_entries, bool zero_ } } +void CreateTestCSV(std::string const& path, size_t rows, size_t cols) { + std::vector data(rows * cols); + + for (size_t i = 0; i < rows * cols; ++i) { + data[i] = i; + } + + std::ofstream fout(path); + size_t i = 0; + for (size_t r = 0; r < rows; ++r) { + for (size_t c = 0; c < cols; ++c) { + fout << data[i]; + i++; + if (c != cols - 1) { + fout << ","; + } + } + fout << "\n"; + } + fout.flush(); + fout.close(); +} + void CheckObjFunctionImpl(std::unique_ptr const& obj, std::vector preds, std::vector labels, diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 279e3f75951e..a059f0436117 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -59,6 +59,8 @@ void CreateSimpleTestData(const std::string& filename); // 0-based indexing. void CreateBigTestData(const std::string& filename, size_t n_entries, bool zero_based = true); +void CreateTestCSV(std::string const& path, size_t rows, size_t cols); + void CheckObjFunction(std::unique_ptr const& obj, std::vector preds, std::vector labels, diff --git a/tests/cpp/plugin/test_federated_data.cc b/tests/cpp/plugin/test_federated_data.cc index 67194fcde320..8ac89e887a7a 100644 --- a/tests/cpp/plugin/test_federated_data.cc +++ b/tests/cpp/plugin/test_federated_data.cc @@ -12,6 +12,7 @@ #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" #include "../filesystem.h" +#include "../helpers.h" #include "helpers.h" namespace xgboost { @@ -23,29 +24,10 @@ class FederatedDataTest : public BaseFederatedTest { size_t constexpr kRows{16}; size_t const kCols = 8 + rank; - std::vector data(kRows * kCols); - - for (size_t i = 0; i < kRows * kCols; ++i) { - data[i] = i; - } dmlc::TemporaryDirectory tmpdir; std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv"; - - std::ofstream fout(path); - size_t i = 0; - for (size_t r = 0; r < kRows; ++r) { - for (size_t c = 0; c < kCols; ++c) { - fout << data[i]; - i++; - if (c != kCols - 1) { - fout << ","; - } - } - fout << "\n"; - } - fout.flush(); - fout.close(); + CreateTestCSV(path, kRows, kCols); std::unique_ptr dmat; std::string uri = path + "?format=csv"; From 5f8c7e28f05adf47ab21f5675fac45c481008888 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 20 Mar 2023 11:55:22 -0700 Subject: [PATCH 6/7] fix distributed tests --- include/xgboost/data.h | 24 +++++--------- src/data/data.cc | 48 ++++++++++++++-------------- src/data/data.cu | 8 ++--- src/data/iterative_dmatrix.h | 3 -- src/data/proxy_dmatrix.h | 3 -- src/data/simple_dmatrix.cc | 57 ++++++++++++++++++++++++---------- src/data/simple_dmatrix.cu | 9 ++++-- src/data/simple_dmatrix.h | 13 ++++++-- src/data/sparse_page_dmatrix.h | 3 -- 9 files changed, 92 insertions(+), 76 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index dc7007760260..57f8a0e36594 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -572,17 +572,18 @@ class DMatrix { * \brief Creates a new DMatrix from an external data adapter. * * \tparam AdapterT Type of the adapter. - * \param [in,out] adapter View onto an external data. - * \param missing Values to count as missing. - * \param nthread Number of threads for construction. - * \param cache_prefix (Optional) The cache prefix for external memory. - * \param page_size (Optional) Size of the page. + * \param [in,out] adapter View onto an external data. + * \param missing Values to count as missing. + * \param nthread Number of threads for construction. + * \param cache_prefix (Optional) The cache prefix for external memory. + * \param data_split_mode (Optional) Data split mode. * * \return a Created DMatrix. */ template static DMatrix* Create(AdapterT* adapter, float missing, int nthread, - const std::string& cache_prefix = ""); + const std::string& cache_prefix = "", + DataSplitMode data_split_mode = DataSplitMode::kRow); /** * \brief Create a new Quantile based DMatrix used for histogram based algorithm. @@ -645,17 +646,6 @@ class DMatrix { */ virtual DMatrix *SliceCol(int num_slices, int slice_id) = 0; - /** - * \brief Reindex the features based on a global view. - * - * In some cases (e.g. vertical federated learning), features are loaded locally with indices - * starting from 0. However, all the algorithms assume the features are globally indexed, so we - * reindex the features based on the offset needed to obtain the global view. - * - * \param offset The offset to be added to the feature index - */ - virtual void ReindexFeatures(uint64_t offset) = 0; - protected: virtual BatchSet GetRowBatches() = 0; virtual BatchSet GetColumnBatches() = 0; diff --git a/src/data/data.cc b/src/data/data.cc index a0ffd75248be..6f5d52817c50 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -878,7 +878,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s dmlc::Parser::Create(fname.c_str(), partid, npart, file_format.c_str())); data::FileAdapter adapter(parser.get()); dmat = DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), Context{}.Threads(), - cache_file); + cache_file, data_split_mode); } else { data::FileIterator iter{fname, static_cast(partid), static_cast(npart), file_format}; @@ -914,18 +914,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s LOG(FATAL) << "Encountered parser error:\n" << e.what(); } - - if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) { - std::vector buffer(collective::GetWorldSize()); - buffer[collective::GetRank()] = dmat->Info().num_col_; - collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t)); - auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0); - dmat->ReindexFeatures(offset); - } - - dmat->Info().data_split_mode = data_split_mode; - dmat->Info().SynchronizeNumberOfColumns(); - if (need_split && data_split_mode == DataSplitMode::kCol) { if (!cache_file.empty()) { LOG(FATAL) << "Column-wise data split is not support for external memory."; @@ -971,39 +959,49 @@ template DMatrix *DMatrix::Create -DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&) { - return new data::SimpleDMatrix(adapter, missing, nthread); +DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&, + DataSplitMode data_split_mode) { + return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode); } template DMatrix* DMatrix::Create(data::DenseAdapter* adapter, float missing, std::int32_t nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, + DataSplitMode data_split_mode); template DMatrix* DMatrix::Create(data::ArrayAdapter* adapter, float missing, std::int32_t nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, + DataSplitMode data_split_mode); template DMatrix* DMatrix::Create(data::CSRAdapter* adapter, float missing, std::int32_t nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, + DataSplitMode data_split_mode); template DMatrix* DMatrix::Create(data::CSCAdapter* adapter, float missing, std::int32_t nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, + DataSplitMode data_split_mode); template DMatrix* DMatrix::Create(data::DataTableAdapter* adapter, float missing, std::int32_t nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, + DataSplitMode data_split_mode); template DMatrix* DMatrix::Create(data::FileAdapter* adapter, float missing, std::int32_t nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, + DataSplitMode data_split_mode); template DMatrix* DMatrix::Create(data::CSRArrayAdapter* adapter, float missing, std::int32_t nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, + DataSplitMode data_split_mode); template DMatrix* DMatrix::Create(data::CSCArrayAdapter* adapter, float missing, std::int32_t nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, + DataSplitMode data_split_mode); template DMatrix* DMatrix::Create( data::IteratorAdapter* adapter, - float missing, int nthread, const std::string& cache_prefix); + float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode); template DMatrix* DMatrix::Create( - data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&); + data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&, + DataSplitMode data_split_mode); SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { SparsePage transpose; diff --git a/src/data/data.cu b/src/data/data.cu index 4dedc7d24c4e..eccbe7567193 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -170,17 +170,17 @@ void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) { template DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, - const std::string& cache_prefix) { + const std::string& cache_prefix, DataSplitMode data_split_mode) { CHECK_EQ(cache_prefix.size(), 0) << "Device memory construction is not currently supported with external " "memory."; - return new data::SimpleDMatrix(adapter, missing, nthread); + return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode); } template DMatrix* DMatrix::Create( data::CudfAdapter* adapter, float missing, int nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, DataSplitMode data_split_mode); template DMatrix* DMatrix::Create( data::CupyAdapter* adapter, float missing, int nthread, - const std::string& cache_prefix); + const std::string& cache_prefix, DataSplitMode data_split_mode); } // namespace xgboost diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index f70be62595c9..28c4087c419a 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -90,9 +90,6 @@ class IterativeDMatrix : public DMatrix { LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix."; return nullptr; } - void ReindexFeatures(uint64_t offset) override { - LOG(FATAL) << "Reindexing features is not supported for Quantile DMatrix."; - } BatchSet GetRowBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index ec550cce3984..7a15d6498745 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -89,9 +89,6 @@ class DMatrixProxy : public DMatrix { LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix."; return nullptr; } - void ReindexFeatures(uint64_t offset) override { - LOG(FATAL) << "Reindexing features is not supported for Proxy DMatrix."; - } BatchSet GetRowBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 27c4bdce721e..098c3c4f2585 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -73,11 +73,17 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { return out; } -void SimpleDMatrix::ReindexFeatures(uint64_t offset) { - if (offset == 0) { - return; +void SimpleDMatrix::ReindexFeatures() { + if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) { + std::vector buffer(collective::GetWorldSize()); + buffer[collective::GetRank()] = info_.num_col_; + collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t)); + auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0); + if (offset == 0) { + return; + } + sparse_page_->Reindex(offset, ctx_.Threads()); } - sparse_page_->Reindex(offset, Ctx()->Threads()); } BatchSet SimpleDMatrix::GetRowBatches() { @@ -158,7 +164,8 @@ BatchSet SimpleDMatrix::GetExtBatches(BatchParam const&) { } template -SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { +SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread, + DataSplitMode data_split_mode) { this->ctx_.nthread = nthread; std::vector qids; @@ -222,6 +229,12 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { info_.num_col_ = adapter->NumColumns(); } + + // Synchronise worker columns + info_.data_split_mode = data_split_mode; + ReindexFeatures(); + info_.SynchronizeNumberOfColumns(); + if (adapter->NumRows() == kAdapterUnknownSize) { using IteratorAdapterT = IteratorAdapter; @@ -275,22 +288,31 @@ void SimpleDMatrix::SaveToLocalFile(const std::string& fname) { fo->Write(sparse_page_->data.HostVector()); } -template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread); -template SimpleDMatrix::SimpleDMatrix(ArrayAdapter* adapter, float missing, int nthread); -template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, int nthread); -template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing, int nthread); -template SimpleDMatrix::SimpleDMatrix(CSCArrayAdapter* adapter, float missing, int nthread); -template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing, int nthread); -template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, int nthread); -template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread); +template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode); +template SimpleDMatrix::SimpleDMatrix(ArrayAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode); +template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode); +template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode); +template SimpleDMatrix::SimpleDMatrix(CSCArrayAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode); +template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode); +template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode); +template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode); template SimpleDMatrix::SimpleDMatrix( IteratorAdapter *adapter, - float missing, int nthread); + float missing, int nthread, DataSplitMode data_split_mode); template <> -SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread) { - ctx_.nthread = nthread; +SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread, + DataSplitMode data_split_mode) { + ctx_.nthread = nthread; auto& offset_vec = sparse_page_->offset.HostVector(); auto& data_vec = sparse_page_->data.HostVector(); @@ -349,7 +371,10 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i } // Synchronise worker columns info_.num_col_ = adapter->NumColumns(); + info_.data_split_mode = data_split_mode; + ReindexFeatures(); info_.SynchronizeNumberOfColumns(); + info_.num_row_ = total_batch_size; info_.num_nonzero_ = data_vec.size(); CHECK_EQ(offset_vec.back(), info_.num_nonzero_); diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 8295718ba221..89a91ab57bb3 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -15,7 +15,8 @@ namespace data { // Current implementation assumes a single batch. More batches can // be supported in future. Does not currently support inferring row/column size template -SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/) { +SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/, + DataSplitMode data_split_mode) { auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice() : adapter->DeviceIdx(); CHECK_GE(device, 0); @@ -35,12 +36,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); // Synchronise worker columns + info_.data_split_mode = data_split_mode; + ReindexFeatures(); info_.SynchronizeNumberOfColumns(); } template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing, - int nthread); + int nthread, DataSplitMode data_split_mode); template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing, - int nthread); + int nthread, DataSplitMode data_split_mode); } // namespace data } // namespace xgboost diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index fb60ebc95ff1..853e765af72e 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -22,7 +22,8 @@ class SimpleDMatrix : public DMatrix { public: SimpleDMatrix() = default; template - explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread); + explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread, + DataSplitMode data_split_mode = DataSplitMode::kRow); explicit SimpleDMatrix(dmlc::Stream* in_stream); ~SimpleDMatrix() override = default; @@ -36,7 +37,6 @@ class SimpleDMatrix : public DMatrix { bool SingleColBlock() const override { return true; } DMatrix* Slice(common::Span ridxs) override; DMatrix* SliceCol(int num_slices, int slice_id) override; - void ReindexFeatures(uint64_t offset) override; /*! \brief magic number used to identify SimpleDMatrix binary files */ static const int kMagic = 0xffffab01; @@ -62,6 +62,15 @@ class SimpleDMatrix : public DMatrix { bool GHistIndexExists() const override { return static_cast(gradient_index_); } bool SparsePageExists() const override { return true; } + /** + * \brief Reindex the features based on a global view. + * + * In some cases (e.g. vertical federated learning), features are loaded locally with indices + * starting from 0. However, all the algorithms assume the features are globally indexed, so we + * reindex the features based on the offset needed to obtain the global view. + */ + void ReindexFeatures(); + private: Context ctx_; }; diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 03d43d9a9a20..aa0be69845aa 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -111,9 +111,6 @@ class SparsePageDMatrix : public DMatrix { LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory."; return nullptr; } - void ReindexFeatures(uint64_t offset) override { - LOG(FATAL) << "Reindexing features is not supported for external memory."; - } private: BatchSet GetRowBatches() override; From 919dadb250c1ce3452b2f3e41e7a57624cb099be Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 21 Mar 2023 17:46:10 -0700 Subject: [PATCH 7/7] no reindexing in gpu code --- src/data/simple_dmatrix.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 89a91ab57bb3..fc09f52c457d 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -17,6 +17,8 @@ namespace data { template SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/, DataSplitMode data_split_mode) { + CHECK(data_split_mode != DataSplitMode::kCol) + << "Column-wise data split is currently not supported on the GPU."; auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice() : adapter->DeviceIdx(); CHECK_GE(device, 0); @@ -37,7 +39,6 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread info_.num_row_ = adapter->NumRows(); // Synchronise worker columns info_.data_split_mode = data_split_mode; - ReindexFeatures(); info_.SynchronizeNumberOfColumns(); }