Skip to content

Commit

Permalink
use text reader to load file content, and update model version
Browse files Browse the repository at this point in the history
  • Loading branch information
chjinche committed Nov 15, 2021
1 parent af70998 commit f622583
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ MLflow (experiment tracking, model monitoring framework): https://github.com/mlf

`{mlr3extralearners}` (R `{mlr3}`-compliant interface): https://github.com/mlr-org/mlr3extralearners

lightgbm-transform (transformation binding): https://github.com/microsoft/lightgbm-transform
lightgbm-transform (feature transformation binding): https://github.com/microsoft/lightgbm-transform

Support
-------
Expand Down
2 changes: 1 addition & 1 deletion docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ Dataset Parameters

- see `lightgbm-transform <https://github.com/microsoft/lightgbm-transform>`__ for usage examples

- **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issue <https://github.com/microsoft/lightgbm-transform/issues>`__
- **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issue page <https://github.com/microsoft/lightgbm-transform/issues>`__

Predict Parameters
~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ struct Config {

// desc = path to a ``.json`` file that specifies customized parser initialized configuration
// desc = see `lightgbm-transform <https://github.com/microsoft/lightgbm-transform>`__ for usage examples
// desc = **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issue <https://github.com/microsoft/lightgbm-transform/issues>`__
// desc = **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issue page <https://github.com/microsoft/lightgbm-transform/issues>`__
std::string parser_config_file = "";

#pragma endregion
Expand Down
16 changes: 0 additions & 16 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <fstream>

#if (!((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__))))
#define FMT_HEADER_ONLY
Expand Down Expand Up @@ -204,21 +203,6 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli
return ret;
}

inline static std::string LoadStringFromFile(const char* filename, int row_num = INT_MAX) {
if (filename == NULL || *filename == '\0') {
return "";
}
std::stringstream ss;
Common::C_stringstream(ss);
std::ifstream fin(filename);
std::string line = "";
int i = 0;
while (std::getline(fin, line) && i++ < row_num) {
ss << line << "\n";
}
return ss.str();
}

inline static std::string GetFromParserConfig(std::string config_str, std::string key) {
// parser config should follow json format.
std::string err;
Expand Down
11 changes: 11 additions & 0 deletions include/LightGBM/utils/text_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ class TextReader {
* \return Text data, store in std::vector by line
*/
inline std::vector<std::string>& Lines() { return lines_; }
/*!
* \brief Get joined text data that read from file
* \return Text data, store in std::string, joined all lines by delimiter "\n"
*/
inline std::string JoinedLines() {
std::stringstream ss;
for (auto line : lines_) {
ss << line << "\n";
}
return ss.str();
}

INDEX_T ReadAllAndProcess(const std::function<void(INDEX_T, const char*, size_t)>& process_fun) {
last_line_ = "";
Expand Down
1 change: 1 addition & 0 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
label_idx_ = train_data_->label_idx();
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
parser_config_str_ = train_data_->parser_config_str();

tree_learner_->ResetTrainingData(train_data, is_constant_hessian_);
ResetBaggingConfig(config_.get(), true);
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt_model_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace LightGBM {

const char* kModelVersion = "v3";
const char* kModelVersion = "v4";

std::string GBDT::DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const {
std::stringstream str_buf;
Expand Down
16 changes: 10 additions & 6 deletions src/io/dataset_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@ void DatasetLoader::SetHeader(const char* filename) {
if (config_.header) {
std::string first_line = text_reader.first_line();
feature_names_ = Common::Split(first_line.c_str(), "\t,");
} else if (!Common::LoadStringFromFile(config_.parser_config_file.c_str()).empty()) {
} else if (!config_.parser_config_file.empty()) {
// support to get header from parser config, so could utilize following label name to id mapping logic.
std::string header_in_parser_config = Common::GetFromParserConfig(
Common::LoadStringFromFile(config_.parser_config_file.c_str()), "header");
if (!header_in_parser_config.empty()) {
Log::Info("Get raw column names from parser config.");
feature_names_ = Common::Split(header_in_parser_config.c_str(), "\t,");
TextReader<data_size_t> parser_config_reader(config_.parser_config_file.c_str(), false);
parser_config_reader.ReadAllLines();
std::string parser_config_str = parser_config_reader.JoinedLines();
if (!parser_config_str.empty()) {
std::string header_in_parser_config = Common::GetFromParserConfig(parser_config_str, "header");
if (!header_in_parser_config.empty()) {
Log::Info("Get raw column names from parser config.");
feature_names_ = Common::Split(header_in_parser_config.c_str(), "\t,");
}
}
}

Expand Down
7 changes: 5 additions & 2 deletions src/io/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,14 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
}

std::string Parser::GenerateParserConfigStr(const char* filename, const char* parser_config_filename, bool header, int label_idx) {
std::string parser_config_str = Common::LoadStringFromFile(parser_config_filename);
TextReader<data_size_t> parser_config_reader(parser_config_filename, false);
parser_config_reader.ReadAllLines();
std::string parser_config_str = parser_config_reader.JoinedLines();
if (!parser_config_str.empty()) {
// save header to parser config in case needed.
if (header && Common::GetFromParserConfig(parser_config_str, "header").empty()) {
parser_config_str = Common::SaveToParserConfig(parser_config_str, "header", Common::LoadStringFromFile(filename, 1));
TextReader<data_size_t> text_reader(filename, header);
parser_config_str = Common::SaveToParserConfig(parser_config_str, "header", text_reader.first_line());
}
// save label id to parser config in case needed.
if (Common::GetFromParserConfig(parser_config_str, "labelId").empty()) {
Expand Down
12 changes: 12 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,15 @@ def test_init_score_for_multiclass_classification(init_score_type):
ds = lgb.Dataset(data, init_score=init_score).construct()
np.testing.assert_equal(ds.get_field('init_score'), init_score)
np.testing.assert_equal(ds.init_score, init_score)


def test_smoke_custom_parser(tmp_path):
data_path = Path(__file__).absolute().parents[2] / 'examples' / 'binary_classification' / 'binary.train'
parser_config_file = tmp_path / 'parser.ini'
with open(parser_config_file, 'w') as fout:
fout.write('{"className": "dummy", "id": "1"}')

data = lgb.Dataset(data_path, params={"parser_config_file": parser_config_file})
with pytest.raises(lgb.basic.LightGBMError,
match="Cannot find parser class 'dummy', please register first or check config format"):
data.construct()
14 changes: 0 additions & 14 deletions tests/python_package_test/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# coding: utf-8
import logging
from pathlib import Path

import numpy as np
import pytest

import lightgbm as lgb

Expand Down Expand Up @@ -99,15 +97,3 @@ def dummy_metric(_, __):
actual_log_wo_gpu_stuff.append(line)

assert "\n".join(actual_log_wo_gpu_stuff) == expected_log


def test_smoke_custom_parser(tmp_path):
data_path = Path(__file__).absolute().parents[2] / 'examples/binary_classification/binary.train'
parser_config_file = tmp_path / 'parser.ini'
with open(parser_config_file, 'w') as fout:
fout.write("{\"className\": \"dummy\", \"id\":\"1\"}")

data = lgb.Dataset(data_path, params={"parser_config_file": parser_config_file})
with pytest.raises(lgb.basic.LightGBMError,
match="Cannot find parser class 'dummy', please register first or check config format"):
lgb.train({}, data)

0 comments on commit f622583

Please sign in to comment.