-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
prediction_early_stop.cpp
91 lines (74 loc) · 2.49 KB
/
prediction_early_stop.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
/*!
* Copyright (c) 2017 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/log.h>
#include <limits>
#include <algorithm>
#include <cmath>
#include <vector>
namespace LightGBM {
PredictionEarlyStopInstance CreateNone(const PredictionEarlyStopConfig&) {
return PredictionEarlyStopInstance{
[](const double*, int) {
return false;
},
std::numeric_limits<int>::max() // make sure the lambda is almost never called
};
}
PredictionEarlyStopInstance CreateMulticlass(const PredictionEarlyStopConfig& config) {
// margin_threshold will be captured by value
const double margin_threshold = config.margin_threshold;
return PredictionEarlyStopInstance{
[margin_threshold](const double* pred, int sz) {
if (sz < 2) {
Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
}
// copy and sort
std::vector<double> votes(static_cast<size_t>(sz));
for (int i = 0; i < sz; ++i) {
votes[i] = pred[i];
}
std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());
const auto margin = votes[0] - votes[1];
if (margin > margin_threshold) {
return true;
}
return false;
},
config.round_period
};
}
PredictionEarlyStopInstance CreateBinary(const PredictionEarlyStopConfig& config) {
// margin_threshold will be captured by value
const double margin_threshold = config.margin_threshold;
return PredictionEarlyStopInstance{
[margin_threshold](const double* pred, int sz) {
if (sz != 1) {
Log::Fatal("Binary early stopping needs predictions to be of length one");
}
const auto margin = 2.0 * fabs(pred[0]);
if (margin > margin_threshold) {
return true;
}
return false;
},
config.round_period
};
}
PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config) {
if (type == "none") {
return CreateNone(config);
} else if (type == "multiclass") {
return CreateMulticlass(config);
} else if (type == "binary") {
return CreateBinary(config);
} else {
Log::Fatal("Unknown early stopping type: %s", type.c_str());
}
// Fix for compiler warnings about reaching end of control
return CreateNone(config);
}
} // namespace LightGBM