Skip to content

Commit

Permalink
[option] precise_float_parser: precise float number parsing for text …
Browse files Browse the repository at this point in the history
…input.
  • Loading branch information
cyfdecyf committed Apr 15, 2021
1 parent 2efde4a commit 724872b
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 34 deletions.
6 changes: 6 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,12 @@ Dataset Parameters

- **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent function

- ``precise_float_parser`` :raw-html:`<a id="precise_float_parser" title="Permalink to this parameter" href="#precise_float_parser">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool

- Use precise floating point number parsing for text parser (e.g. CSV, TSV, LibSVM input).

- **Note**: setting this to ``true`` may lead to much slower text parsing.

Predict Parameters
~~~~~~~~~~~~~~~~~~

Expand Down
5 changes: 5 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,11 @@ struct Config {
// desc = **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent function
bool save_binary = false;

// [no-save]
// desc = Use precise floating point number parsing for text parser (e.g. CSV, TSV, LibSVM input).
// desc = **Note**: setting this to ``true`` may lead to much slower text parsing.
bool precise_float_parser = false;

#pragma endregion

#pragma region Predict Parameters
Expand Down
5 changes: 4 additions & 1 deletion include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ class Metadata {
/*! \brief Interface for Parser */
class Parser {
public:
typedef const char* (*AtofFunc)(const char* p, double* out);

/*! \brief virtual destructor */
virtual ~Parser() {}

Expand All @@ -271,9 +273,10 @@ class Parser {
* \param filename One Filename of data
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
* \param label_idx index of label column
* \param precise_float_parser using precise floating point number parsing if true
* \return Object of parser
*/
static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx);
static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser);
};

/*! \brief The main class of data set,
Expand Down
6 changes: 4 additions & 2 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ void Application::Predict() {
if (config_.task == TaskType::KRefitTree) {
// create predictor
Predictor predictor(boosting_.get(), 0, -1, false, true, false, false, 1, 1);
predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check,
config_.precise_float_parser);
TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
Expand Down Expand Up @@ -251,7 +252,8 @@ void Application::Predict() {
config_.pred_early_stop, config_.pred_early_stop_freq,
config_.pred_early_stop_margin);
predictor.Predict(config_.data.c_str(),
config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check,
config_.precise_float_parser);
Log::Info("Finished prediction");
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/application/predictor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,14 @@ class Predictor {
* \param data_filename Filename of data
* \param result_filename Filename of output result
*/
void Predict(const char* data_filename, const char* result_filename, bool header, bool disable_shape_check) {
void Predict(const char* data_filename, const char* result_filename, bool header, bool disable_shape_check, bool precise_float_parser) {
auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be created", result_filename);
}
auto label_idx = header ? -1 : boosting_->LabelIdx();
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx));
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx,
precise_float_parser));

if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename);
Expand Down
3 changes: 2 additions & 1 deletion src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ class Booster {
Predictor predictor(boosting_.get(), start_iteration, num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check,
config.precise_float_parser);
}

void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) const {
Expand Down
3 changes: 3 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"categorical_feature",
"forcedbins_filename",
"save_binary",
"precise_float_parser",
"start_iteration_predict",
"num_iteration_predict",
"predict_raw_score",
Expand Down Expand Up @@ -525,6 +526,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetBool(params, "save_binary", &save_binary);

GetBool(params, "precise_float_parser", &precise_float_parser);

GetInt(params, "start_iteration_predict", &start_iteration_predict);

GetInt(params, "num_iteration_predict", &num_iteration_predict);
Expand Down
6 changes: 4 additions & 2 deletions src/io/dataset_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
auto bin_filename = CheckCanLoadFromBin(filename);
bool is_load_from_binary = false;
if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
config_.precise_float_parser));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
Expand Down Expand Up @@ -267,7 +268,8 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
}
auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
config_.precise_float_parser));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
Expand Down
12 changes: 5 additions & 7 deletions src/io/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

#include <string>
#include <algorithm>
#include <fstream>
#include <functional>
#include <iostream>
#include <memory>

namespace LightGBM {
Expand Down Expand Up @@ -232,7 +229,7 @@ DataType GetDataType(const char* filename, bool header,
return type;
}

Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) {
Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser) {
const int n_read_line = 32;
auto lines = ReadKLineFromFile(filename, header, n_read_line);
int num_col = 0;
Expand All @@ -242,15 +239,16 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
}
std::unique_ptr<Parser> ret;
int output_label_index = -1;
AtofFunc atof = precise_float_parser ? Common::AtofPrecise : Common::Atof;
if (type == DataType::LIBSVM) {
output_label_index = GetLabelIdxForLibsvm(lines[0], num_features, label_idx);
ret.reset(new LibSVMParser(output_label_index, num_col));
ret.reset(new LibSVMParser(output_label_index, num_col, atof));
} else if (type == DataType::TSV) {
output_label_index = GetLabelIdxForTSV(lines[0], num_features, label_idx);
ret.reset(new TSVParser(output_label_index, num_col));
ret.reset(new TSVParser(output_label_index, num_col, atof));
} else if (type == DataType::CSV) {
output_label_index = GetLabelIdxForCSV(lines[0], num_features, label_idx);
ret.reset(new CSVParser(output_label_index, num_col));
ret.reset(new CSVParser(output_label_index, num_col, atof));
}

if (output_label_index < 0 && label_idx >= 0) {
Expand Down
31 changes: 12 additions & 19 deletions src/io/parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,10 @@

namespace LightGBM {

#ifdef USE_PRECISE_TEXT_PARSER
static const char* TextParserAtof(const char* p, double* out) {
return Common::AtofPrecise(p, out);
}
#else
static const char* TextParserAtof(const char* p, double* out) {
return Common::Atof(p, out);
}
#endif

class CSVParser: public Parser {
public:
explicit CSVParser(int label_idx, int total_columns)
:label_idx_(label_idx), total_columns_(total_columns) {
explicit CSVParser(int label_idx, int total_columns, AtofFunc atof)
:label_idx_(label_idx), total_columns_(total_columns), atof_(atof) {
}
inline void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const override {
Expand All @@ -37,7 +27,7 @@ class CSVParser: public Parser {
int offset = 0;
*out_label = 0.0f;
while (*str != '\0') {
str = TextParserAtof(str, &val);
str = atof_(str, &val);
if (idx == label_idx_) {
*out_label = val;
offset = -1;
Expand All @@ -60,20 +50,21 @@ class CSVParser: public Parser {
private:
int label_idx_ = 0;
int total_columns_ = -1;
AtofFunc atof_;
};

class TSVParser: public Parser {
public:
explicit TSVParser(int label_idx, int total_columns)
:label_idx_(label_idx), total_columns_(total_columns) {
explicit TSVParser(int label_idx, int total_columns, AtofFunc atof)
:label_idx_(label_idx), total_columns_(total_columns), atof_(atof) {
}
inline void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const override {
int idx = 0;
double val = 0.0f;
int offset = 0;
while (*str != '\0') {
str = TextParserAtof(str, &val);
str = atof_(str, &val);
if (idx == label_idx_) {
*out_label = val;
offset = -1;
Expand All @@ -96,12 +87,13 @@ class TSVParser: public Parser {
private:
int label_idx_ = 0;
int total_columns_ = -1;
AtofFunc atof_;
};

class LibSVMParser: public Parser {
public:
explicit LibSVMParser(int label_idx, int total_columns)
:label_idx_(label_idx), total_columns_(total_columns) {
explicit LibSVMParser(int label_idx, int total_columns, AtofFunc atof)
:label_idx_(label_idx), total_columns_(total_columns), atof_(atof) {
if (label_idx > 0) {
Log::Fatal("Label should be the first column in a LibSVM file");
}
Expand All @@ -111,7 +103,7 @@ class LibSVMParser: public Parser {
int idx = 0;
double val = 0.0f;
if (label_idx_ == 0) {
str = TextParserAtof(str, &val);
str = atof_(str, &val);
*out_label = val;
str = Common::SkipSpaceAndTab(str);
}
Expand All @@ -136,6 +128,7 @@ class LibSVMParser: public Parser {
private:
int label_idx_ = 0;
int total_columns_ = -1;
AtofFunc atof_;
};

} // namespace LightGBM
Expand Down

0 comments on commit 724872b

Please sign in to comment.