From be2f29313945c3cf0132bed34c83d194641bf021 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 11 Aug 2023 01:58:11 +0800 Subject: [PATCH 1/4] normalize path. --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + R-package/tests/testthat/test_dmatrix.R | 1 + src/common/io.cc | 4 +- src/data/data.cc | 124 +++++++++++++----------- src/data/file_iterator.cc | 50 ++++++++++ src/data/file_iterator.h | 60 +++--------- 7 files changed, 140 insertions(+), 101 deletions(-) create mode 100644 src/data/file_iterator.cc diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index a93f773f944d..9e7cbfed4c18 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -47,6 +47,7 @@ OBJECTS= \ $(PKGROOT)/src/data/data.o \ $(PKGROOT)/src/data/sparse_page_raw_format.o \ $(PKGROOT)/src/data/ellpack_page.o \ + $(PKGROOT)/src/data/file_iterator.o \ $(PKGROOT)/src/data/gradient_index.o \ $(PKGROOT)/src/data/gradient_index_page_source.o \ $(PKGROOT)/src/data/gradient_index_format.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index d2f47b2aaaae..7dfa415a43d1 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -47,6 +47,7 @@ OBJECTS= \ $(PKGROOT)/src/data/data.o \ $(PKGROOT)/src/data/sparse_page_raw_format.o \ $(PKGROOT)/src/data/ellpack_page.o \ + $(PKGROOT)/src/data/file_iterator.o \ $(PKGROOT)/src/data/gradient_index.o \ $(PKGROOT)/src/data/gradient_index_page_source.o \ $(PKGROOT)/src/data/gradient_index_format.o \ diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 21d39f255b7b..57cc82c170ed 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -72,6 +72,7 @@ test_that("xgb.DMatrix: saving, loading", { tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1") tmp_file <- tempfile(fileext = ".libsvm") writeLines(tmp, tmp_file) + expect_true(file.exists(tmp_file)) dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE) expect_equal(dim(dtest4), c(3, 4)) expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0)) diff --git a/src/common/io.cc b/src/common/io.cc index 1e15c417388c..d2cf58e282f7 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -28,6 +28,7 @@ #include // for size_t #include // for int32_t, uint32_t #include // for memcpy +#include // for weakly_canonical #include // for filesystem #include // for ifstream #include // for distance @@ -154,7 +155,8 @@ std::string LoadSequentialFile(std::string uri, bool stream) { // Open in binary mode so that correct file size can be computed with // seekg(). This accommodates Windows platform: // https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg - std::ifstream ifs(std::filesystem::u8path(uri), std::ios_base::binary | std::ios_base::in); + auto path = std::filesystem::weakly_canonical(std::filesystem::u8path(uri)); + std::ifstream ifs(path, std::ios_base::binary | std::ios_base::in); if (!ifs) { // https://stackoverflow.com/a/17338934 OpenErr(); diff --git a/src/data/data.cc b/src/data/data.cc index 7c76c6d25335..e8ecccb81d13 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -4,42 +4,57 @@ */ #include "xgboost/data.h" -#include - -#include -#include -#include - -#include "../collective/communicator-inl.h" -#include "../collective/communicator.h" -#include "../common/algorithm.h" // for StableSort -#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry -#include "../common/common.h" -#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize -#include "../common/group_data.h" -#include "../common/io.h" -#include "../common/linalg_op.h" -#include "../common/math.h" -#include "../common/numeric.h" // for Iota -#include "../common/threading_utils.h" -#include "../common/version.h" -#include "../data/adapter.h" -#include "../data/iterative_dmatrix.h" -#include "./sparse_page_dmatrix.h" -#include "./sparse_page_source.h" -#include "dmlc/io.h" -#include "file_iterator.h" -#include "simple_dmatrix.h" -#include "sparse_page_writer.h" -#include "validation.h" -#include "xgboost/c_api.h" -#include "xgboost/context.h" -#include "xgboost/host_device_vector.h" -#include "xgboost/learner.h" -#include "xgboost/linalg.h" // Vector -#include "xgboost/logging.h" -#include "xgboost/string_view.h" -#include "xgboost/version_config.h" +#include // for DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_LINK_TAG + +#include // for copy, max, none_of, min +#include // for atomic +#include // for abs +#include // for uint64_t, int32_t, uint8_t, uint32_t +#include // for size_t, strcmp, memcpy +#include // for exception +#include // for operator<<, basic_ostream, basic_ostream::op... +#include // for map, operator!= +#include // for accumulate, partial_sum +#include // for get, apply +#include // for remove_pointer_t, remove_reference + +#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated +#include "../collective/communicator.h" // for Operation +#include "../common/algorithm.h" // for StableSort +#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry +#include "../common/common.h" // for Split +#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData +#include "../common/group_data.h" // for ParallelGroupBuilder +#include "../common/io.h" // for PeekableInStream +#include "../common/linalg_op.h" // for ElementWiseTransformHost +#include "../common/math.h" // for CheckNAN +#include "../common/numeric.h" // for Iota, RunLengthEncode +#include "../common/threading_utils.h" // for ParallelFor +#include "../common/version.h" // for Version +#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor +#include "../data/iterative_dmatrix.h" // for IterativeDMatrix +#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix +#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa... +#include "dmlc/base.h" // for BeginPtr +#include "dmlc/common.h" // for OMPException +#include "dmlc/data.h" // for Parser +#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP +#include "dmlc/io.h" // for Stream +#include "dmlc/thread_local.h" // for ThreadLocalStore +#include "ellpack_page.h" // for EllpackPage +#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset +#include "gradient_index.h" // for GHistIndexMatrix +#include "simple_dmatrix.h" // for SimpleDMatrix +#include "sparse_page_writer.h" // for SparsePageFormatReg +#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup +#include "xgboost/base.h" // for bst_group_t, bst_row_t, bst_float, bst_ulong +#include "xgboost/context.h" // for Context +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/learner.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte... +#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG +#include "xgboost/span.h" // for Span, operator!=, SpanIterator +#include "xgboost/string_view.h" // for operator==, operator<<, StringView namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>); @@ -811,10 +826,10 @@ DMatrix::~DMatrix() { } } -DMatrix *TryLoadBinary(std::string fname, bool silent) { - int magic; - std::unique_ptr fi( - dmlc::Stream::Create(fname.c_str(), "r", true)); +namespace { +DMatrix* TryLoadBinary(std::string fname, bool silent) { + std::int32_t magic; + std::unique_ptr fi(dmlc::Stream::Create(fname.c_str(), "r", true)); if (fi != nullptr) { common::PeekableInStream is(fi.get()); if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic)) { @@ -822,11 +837,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) { dmlc::ByteSwap(&magic, sizeof(magic), 1); } if (magic == data::SimpleDMatrix::kMagic) { - DMatrix *dmat = new data::SimpleDMatrix(&is); + DMatrix* dmat = new data::SimpleDMatrix(&is); if (!silent) { - LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ - << " matrix with " << dmat->Info().num_nonzero_ - << " entries loaded from " << fname; + LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with " + << dmat->Info().num_nonzero_ << " entries loaded from " << fname; } return dmat; } @@ -834,6 +848,7 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) { } return nullptr; } +} // namespace DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) { auto need_split = false; @@ -845,7 +860,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s } std::string fname, cache_file; - size_t dlm_pos = uri.find('#'); + auto dlm_pos = uri.find('#'); if (dlm_pos != std::string::npos) { cache_file = uri.substr(dlm_pos + 1, uri.length()); fname = uri.substr(0, dlm_pos); @@ -857,14 +872,11 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s for (size_t i = 0; i < cache_shards.size(); ++i) { size_t pos = cache_shards[i].rfind('.'); if (pos == std::string::npos) { - os << cache_shards[i] - << ".r" << collective::GetRank() - << "-" << collective::GetWorldSize(); + os << cache_shards[i] << ".r" << collective::GetRank() << "-" + << collective::GetWorldSize(); } else { - os << cache_shards[i].substr(0, pos) - << ".r" << collective::GetRank() - << "-" << collective::GetWorldSize() - << cache_shards[i].substr(pos, cache_shards[i].length()); + os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-" + << collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length()); } if (i + 1 != cache_shards.size()) { os << ':'; @@ -895,12 +907,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts"; } - data::ValidateFileFormat(fname); - DMatrix* dmat {nullptr}; + DMatrix* dmat{nullptr}; if (cache_file.empty()) { - std::unique_ptr> parser( - dmlc::Parser::Create(fname.c_str(), partid, npart, "auto")); + fname = data::ValidateFileFormat(fname); + std::unique_ptr> parser( + dmlc::Parser::Create(fname.c_str(), partid, npart, "auto")); data::FileAdapter adapter(parser.get()); dmat = DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), Context{}.Threads(), cache_file, data_split_mode); diff --git a/src/data/file_iterator.cc b/src/data/file_iterator.cc new file mode 100644 index 000000000000..58047c7e13e2 --- /dev/null +++ b/src/data/file_iterator.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2021-2023, XGBoost contributors + */ +#include "file_iterator.h" + +#include // for LogCheck_EQ, LogCheck_LE, CHECK_EQ, CHECK_LE, LOG, LOG_... + +#include // for weakly_canonical, path +#include // for map, operator== +#include // for operator<<, basic_ostream, istringstream +#include // for vector + +#include "../common/common.h" // for Split +#include "xgboost/string_view.h" // for operator<<, StringView + +namespace xgboost::data { +std::string ValidateFileFormat(std::string const& uri) { + std::vector name_args_cache = common::Split(uri, '#'); + CHECK_LE(name_args_cache.size(), 2) + << "Only one `#` is allowed in file path for cachefile specification"; + + std::vector name_args = common::Split(name_args_cache[0], '?'); + StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"}; + CHECK_EQ(name_args.size(), 2) << msg; + + std::map args; + std::vector arg_list = common::Split(name_args[1], '&'); + for (size_t i = 0; i < arg_list.size(); ++i) { + std::istringstream is(arg_list[i]); + std::pair kv; + CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format" + << " for key in arg " << i + 1; + CHECK(std::getline(is, kv.second)) << "Invalid uri argument format" + << " for value in arg " << i + 1; + args.insert(kv); + } + if (args.find("format") == args.cend()) { + LOG(FATAL) << msg; + } + + auto path = common::Split(uri, '?')[0]; + namespace fs = std::filesystem; + name_args[0] = fs::weakly_canonical(path); + if (name_args_cache.size() == 1) { + return name_args[0] + "?" + name_args[1]; + } else { + return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1]; + } +} +} // namespace xgboost::data diff --git a/src/data/file_iterator.h b/src/data/file_iterator.h index 4d7239677561..c7f23b478879 100644 --- a/src/data/file_iterator.h +++ b/src/data/file_iterator.h @@ -4,46 +4,20 @@ #ifndef XGBOOST_DATA_FILE_ITERATOR_H_ #define XGBOOST_DATA_FILE_ITERATOR_H_ -#include -#include -#include -#include -#include - -#include "array_interface.h" -#include "dmlc/data.h" -#include "xgboost/c_api.h" -#include "xgboost/json.h" -#include "xgboost/linalg.h" - -namespace xgboost { -namespace data { -inline void ValidateFileFormat(std::string const& uri) { - std::vector name_cache = common::Split(uri, '#'); - CHECK_LE(name_cache.size(), 2) - << "Only one `#` is allowed in file path for cachefile specification"; - - std::vector name_args = common::Split(name_cache[0], '?'); - CHECK_LE(name_args.size(), 2) << "only one `?` is allowed in file path."; - - StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"}; - CHECK_EQ(name_args.size(), 2) << msg; - - std::map args; - std::vector arg_list = common::Split(name_args[1], '&'); - for (size_t i = 0; i < arg_list.size(); ++i) { - std::istringstream is(arg_list[i]); - std::pair kv; - CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format" - << " for key in arg " << i + 1; - CHECK(std::getline(is, kv.second)) << "Invalid uri argument format" - << " for value in arg " << i + 1; - args.insert(kv); - } - if (args.find("format") == args.cend()) { - LOG(FATAL) << msg; - } -} +#include // for max_element +#include // for size_t +#include // for uint32_t +#include // for unique_ptr +#include // for string +#include // for move + +#include "dmlc/data.h" // for RowBlock, Parser +#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate +#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec +#include "xgboost/logging.h" // for CHECK + +namespace xgboost::data { +[[nodiscard]] std::string ValidateFileFormat(std::string const& uri); /** * An iterator for implementing external memory support with file inputs. Users of @@ -72,8 +46,7 @@ class FileIterator { public: FileIterator(std::string uri, unsigned part_index, unsigned num_parts) - : uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts} { - ValidateFileFormat(uri_); + : uri_{ValidateFileFormat(std::move(uri))}, part_idx_{part_index}, n_parts_{num_parts} { XGProxyDMatrixCreate(&proxy_); } ~FileIterator() { @@ -132,6 +105,5 @@ inline int Next(DataIterHandle self) { return static_cast(self)->Next(); } } // namespace fileiter -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_FILE_ITERATOR_H_ From 1fbe3b967681f49599f356951e8314fe2f012ffd Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 11 Aug 2023 02:27:49 +0800 Subject: [PATCH 2/4] windows. --- src/data/file_iterator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/file_iterator.cc b/src/data/file_iterator.cc index 58047c7e13e2..6b8e11748b10 100644 --- a/src/data/file_iterator.cc +++ b/src/data/file_iterator.cc @@ -40,7 +40,7 @@ std::string ValidateFileFormat(std::string const& uri) { auto path = common::Split(uri, '?')[0]; namespace fs = std::filesystem; - name_args[0] = fs::weakly_canonical(path); + name_args[0] = fs::weakly_canonical(path).string(); if (name_args_cache.size() == 1) { return name_args[0] + "?" + name_args[1]; } else { From e06845f1fc6c7726af265c92c1dbc77c9353bddf Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 11 Aug 2023 02:39:39 +0800 Subject: [PATCH 3/4] fix. --- src/common/io.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/common/io.cc b/src/common/io.cc index d2cf58e282f7..8dbeba935689 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -28,8 +28,7 @@ #include // for size_t #include // for int32_t, uint32_t #include // for memcpy -#include // for weakly_canonical -#include // for filesystem +#include // for filesystem, weakly_canonical #include // for ifstream #include // for distance #include // for numeric_limits From 5744a06c68904ec86455ecc1106f6f21f8b001cf Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 11 Aug 2023 15:05:28 +0800 Subject: [PATCH 4/4] u8path. --- src/data/file_iterator.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/data/file_iterator.cc b/src/data/file_iterator.cc index 6b8e11748b10..cebfbdc19f65 100644 --- a/src/data/file_iterator.cc +++ b/src/data/file_iterator.cc @@ -5,7 +5,7 @@ #include // for LogCheck_EQ, LogCheck_LE, CHECK_EQ, CHECK_LE, LOG, LOG_... -#include // for weakly_canonical, path +#include // for weakly_canonical, path, u8path #include // for map, operator== #include // for operator<<, basic_ostream, istringstream #include // for vector @@ -39,8 +39,9 @@ std::string ValidateFileFormat(std::string const& uri) { } auto path = common::Split(uri, '?')[0]; + namespace fs = std::filesystem; - name_args[0] = fs::weakly_canonical(path).string(); + name_args[0] = fs::weakly_canonical(fs::u8path(path)).string(); if (name_args_cache.size() == 1) { return name_args[0] + "?" + name_args[1]; } else {