Skip to content

Commit

Permalink
Normalize file system path. (#9463)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 11, 2023
1 parent bdc1a3c commit bb56183
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 102 deletions.
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ OBJECTS= \
$(PKGROOT)/src/data/data.o \
$(PKGROOT)/src/data/sparse_page_raw_format.o \
$(PKGROOT)/src/data/ellpack_page.o \
$(PKGROOT)/src/data/file_iterator.o \
$(PKGROOT)/src/data/gradient_index.o \
$(PKGROOT)/src/data/gradient_index_page_source.o \
$(PKGROOT)/src/data/gradient_index_format.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ OBJECTS= \
$(PKGROOT)/src/data/data.o \
$(PKGROOT)/src/data/sparse_page_raw_format.o \
$(PKGROOT)/src/data/ellpack_page.o \
$(PKGROOT)/src/data/file_iterator.o \
$(PKGROOT)/src/data/gradient_index.o \
$(PKGROOT)/src/data/gradient_index_page_source.o \
$(PKGROOT)/src/data/gradient_index_format.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/tests/testthat/test_dmatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
tmp_file <- tempfile(fileext = ".libsvm")
writeLines(tmp, tmp_file)
expect_true(file.exists(tmp_file))
dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
expect_equal(dim(dtest4), c(3, 4))
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))
Expand Down
5 changes: 3 additions & 2 deletions src/common/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint32_t
#include <cstring> // for memcpy
#include <filesystem> // for filesystem
#include <filesystem> // for filesystem, weakly_canonical
#include <fstream> // for ifstream
#include <iterator> // for distance
#include <limits> // for numeric_limits
Expand Down Expand Up @@ -154,7 +154,8 @@ std::string LoadSequentialFile(std::string uri, bool stream) {
// Open in binary mode so that correct file size can be computed with
// seekg(). This accommodates Windows platform:
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
std::ifstream ifs(std::filesystem::u8path(uri), std::ios_base::binary | std::ios_base::in);
auto path = std::filesystem::weakly_canonical(std::filesystem::u8path(uri));
std::ifstream ifs(path, std::ios_base::binary | std::ios_base::in);
if (!ifs) {
// https://stackoverflow.com/a/17338934
OpenErr();
Expand Down
124 changes: 68 additions & 56 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,57 @@
*/
#include "xgboost/data.h"

#include <dmlc/registry.h>

#include <array>
#include <cstddef>
#include <cstring>

#include "../collective/communicator-inl.h"
#include "../collective/communicator.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/common.h"
#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
#include "../common/math.h"
#include "../common/numeric.h" // for Iota
#include "../common/threading_utils.h"
#include "../common/version.h"
#include "../data/adapter.h"
#include "../data/iterative_dmatrix.h"
#include "./sparse_page_dmatrix.h"
#include "./sparse_page_source.h"
#include "dmlc/io.h"
#include "file_iterator.h"
#include "simple_dmatrix.h"
#include "sparse_page_writer.h"
#include "validation.h"
#include "xgboost/c_api.h"
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "xgboost/linalg.h" // Vector
#include "xgboost/logging.h"
#include "xgboost/string_view.h"
#include "xgboost/version_config.h"
#include <dmlc/registry.h> // for DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_LINK_TAG

#include <algorithm> // for copy, max, none_of, min
#include <atomic> // for atomic
#include <cmath> // for abs
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
#include <cstring> // for size_t, strcmp, memcpy
#include <exception> // for exception
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
#include <map> // for map, operator!=
#include <numeric> // for accumulate, partial_sum
#include <tuple> // for get, apply
#include <type_traits> // for remove_pointer_t, remove_reference

#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
#include "../collective/communicator.h" // for Operation
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/common.h" // for Split
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream
#include "../common/linalg_op.h" // for ElementWiseTransformHost
#include "../common/math.h" // for CheckNAN
#include "../common/numeric.h" // for Iota, RunLengthEncode
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/version.h" // for Version
#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
#include "../data/iterative_dmatrix.h" // for IterativeDMatrix
#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix
#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/common.h" // for OMPException
#include "dmlc/data.h" // for Parser
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
#include "dmlc/io.h" // for Stream
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "ellpack_page.h" // for EllpackPage
#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
#include "gradient_index.h" // for GHistIndexMatrix
#include "simple_dmatrix.h" // for SimpleDMatrix
#include "sparse_page_writer.h" // for SparsePageFormatReg
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
#include "xgboost/base.h" // for bst_group_t, bst_row_t, bst_float, bst_ulong
#include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/learner.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
#include "xgboost/string_view.h" // for operator==, operator<<, StringView

namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
Expand Down Expand Up @@ -811,29 +826,29 @@ DMatrix::~DMatrix() {
}
}

DMatrix *TryLoadBinary(std::string fname, bool silent) {
int magic;
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(fname.c_str(), "r", true));
namespace {
DMatrix* TryLoadBinary(std::string fname, bool silent) {
std::int32_t magic;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi != nullptr) {
common::PeekableInStream is(fi.get());
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic)) {
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&magic, sizeof(magic), 1);
}
if (magic == data::SimpleDMatrix::kMagic) {
DMatrix *dmat = new data::SimpleDMatrix(&is);
DMatrix* dmat = new data::SimpleDMatrix(&is);
if (!silent) {
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_
<< " matrix with " << dmat->Info().num_nonzero_
<< " entries loaded from " << fname;
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with "
<< dmat->Info().num_nonzero_ << " entries loaded from " << fname;
}
return dmat;
}
}
}
return nullptr;
}
} // namespace

DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
auto need_split = false;
Expand All @@ -845,7 +860,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
}

std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
auto dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
cache_file = uri.substr(dlm_pos + 1, uri.length());
fname = uri.substr(0, dlm_pos);
Expand All @@ -857,14 +872,11 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
for (size_t i = 0; i < cache_shards.size(); ++i) {
size_t pos = cache_shards[i].rfind('.');
if (pos == std::string::npos) {
os << cache_shards[i]
<< ".r" << collective::GetRank()
<< "-" << collective::GetWorldSize();
os << cache_shards[i] << ".r" << collective::GetRank() << "-"
<< collective::GetWorldSize();
} else {
os << cache_shards[i].substr(0, pos)
<< ".r" << collective::GetRank()
<< "-" << collective::GetWorldSize()
<< cache_shards[i].substr(pos, cache_shards[i].length());
os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-"
<< collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length());
}
if (i + 1 != cache_shards.size()) {
os << ':';
Expand Down Expand Up @@ -895,12 +907,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
}

data::ValidateFileFormat(fname);
DMatrix* dmat {nullptr};
DMatrix* dmat{nullptr};

if (cache_file.empty()) {
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
fname = data::ValidateFileFormat(fname);
std::unique_ptr<dmlc::Parser<std::uint32_t>> parser(
dmlc::Parser<std::uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
data::FileAdapter adapter(parser.get());
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
cache_file, data_split_mode);
Expand Down
51 changes: 51 additions & 0 deletions src/data/file_iterator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/**
* Copyright 2021-2023, XGBoost contributors
*/
#include "file_iterator.h"

#include <xgboost/logging.h> // for LogCheck_EQ, LogCheck_LE, CHECK_EQ, CHECK_LE, LOG, LOG_...

#include <filesystem> // for weakly_canonical, path, u8path
#include <map> // for map, operator==
#include <ostream> // for operator<<, basic_ostream, istringstream
#include <vector> // for vector

#include "../common/common.h" // for Split
#include "xgboost/string_view.h" // for operator<<, StringView

namespace xgboost::data {
std::string ValidateFileFormat(std::string const& uri) {
std::vector<std::string> name_args_cache = common::Split(uri, '#');
CHECK_LE(name_args_cache.size(), 2)
<< "Only one `#` is allowed in file path for cachefile specification";

std::vector<std::string> name_args = common::Split(name_args_cache[0], '?');
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
CHECK_EQ(name_args.size(), 2) << msg;

std::map<std::string, std::string> args;
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
<< " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
<< " for value in arg " << i + 1;
args.insert(kv);
}
if (args.find("format") == args.cend()) {
LOG(FATAL) << msg;
}

auto path = common::Split(uri, '?')[0];

namespace fs = std::filesystem;
name_args[0] = fs::weakly_canonical(fs::u8path(path)).string();
if (name_args_cache.size() == 1) {
return name_args[0] + "?" + name_args[1];
} else {
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
}
}
} // namespace xgboost::data
60 changes: 16 additions & 44 deletions src/data/file_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,20 @@
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
#define XGBOOST_DATA_FILE_ITERATOR_H_

#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "array_interface.h"
#include "dmlc/data.h"
#include "xgboost/c_api.h"
#include "xgboost/json.h"
#include "xgboost/linalg.h"

namespace xgboost {
namespace data {
inline void ValidateFileFormat(std::string const& uri) {
std::vector<std::string> name_cache = common::Split(uri, '#');
CHECK_LE(name_cache.size(), 2)
<< "Only one `#` is allowed in file path for cachefile specification";

std::vector<std::string> name_args = common::Split(name_cache[0], '?');
CHECK_LE(name_args.size(), 2) << "only one `?` is allowed in file path.";

StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
CHECK_EQ(name_args.size(), 2) << msg;

std::map<std::string, std::string> args;
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
<< " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
<< " for value in arg " << i + 1;
args.insert(kv);
}
if (args.find("format") == args.cend()) {
LOG(FATAL) << msg;
}
}
#include <algorithm> // for max_element
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include <utility> // for move

#include "dmlc/data.h" // for RowBlock, Parser
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
#include "xgboost/logging.h" // for CHECK

namespace xgboost::data {
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);

/**
* An iterator for implementing external memory support with file inputs. Users of
Expand Down Expand Up @@ -72,8 +46,7 @@ class FileIterator {

public:
FileIterator(std::string uri, unsigned part_index, unsigned num_parts)
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts} {
ValidateFileFormat(uri_);
: uri_{ValidateFileFormat(std::move(uri))}, part_idx_{part_index}, n_parts_{num_parts} {
XGProxyDMatrixCreate(&proxy_);
}
~FileIterator() {
Expand Down Expand Up @@ -132,6 +105,5 @@ inline int Next(DataIterHandle self) {
return static_cast<FileIterator*>(self)->Next();
}
} // namespace fileiter
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_FILE_ITERATOR_H_

0 comments on commit bb56183

Please sign in to comment.