-
Notifications
You must be signed in to change notification settings - Fork 532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Old and new backends diverge on a 10-row toy example #3188
Comments
Note. Make sure to apply the patch #3186 to obtain the gain values. |
Observations so far:
{
"nodeid": 3,
"split_feature": 1,
"split_threshold": -8.30485821,
"gain": 2.38418579e-07,
"yes": 5,
"no": 6,
"children": [
{
"nodeid": 5,
"leaf_value": 2
},
{
"nodeid": 6,
"leaf_value": 2
}
]
} |
I just found a silly mistake in my toy decision tree implementation and fixed it now. |
Minimal reproducer for the gratuitous split when a node has rows with identical label: from cuml.ensemble import RandomForestRegressor as rfr
import numpy as np
import json
X = np.array([[-1, 0], [0, 1], [2, 0], [0, 3], [-2, 0]], dtype=np.float32)
y = np.array([1, 1, 1, 1, 1], dtype=np.float32)
kwargs = {'max_features': 1.0, 'n_estimators': 1, 'max_depth': 1, 'bootstrap': False, 'n_bins': 5}
for use_experimental in [True, False]:
clf = rfr(use_experimental_backend=use_experimental, **kwargs)
clf.fit(X, y)
json_obj = json.loads(clf.dump_as_json())
print(f'=========use_experimental={use_experimental}=========')
print(json.dumps(json_obj, indent=4)) Output:
|
I found the reason why the new backend was adding the unnecessary split in the example #3188 (comment). It's because of the order at which three floating-point numbers are added together in the computation of the gain value: """catastrophic_cancellation_in_gain.py"""
import numpy as np
X = np.array([[-1, 0], [0, 1], [2, 0], [0, 3], [-2, 0]], dtype=np.float32)
y = np.array([1, 1, 1, 1, 1], dtype=np.float32)
ind = (X[:, 1] <= 0.0)
y_left, y_right = y[ind], y[~ind]
invlen = np.float32(1.0) / np.float32(5.0)
invLeft = np.float32(1.0) / np.float32(3.0)
invRight = np.float32(1.0) / np.float32(2.0)
valL = np.sum(y_left)
valR = np.sum(y_right)
valP = (valL + valR) * invlen
A = -valP * valP
assert isinstance(A, np.float32)
print(f'A = {A}')
B = valL * invlen * valL * invLeft
assert isinstance(B, np.float32)
print(f'B = {B}')
C = valR * invlen * valR * invRight
assert isinstance(C, np.float32)
print(f'C = {C}')
gain = (A + B) + C
print('=' * 60)
print(f'A + B = {A + B}')
print(f'A + C = {A + C}')
print(f'B + C = {B + C}')
print('=' * 60)
print(f'(A + B) + C = {(A + B) + C}')
print(f'(A + C) + B = {(A + C) + B}')
print(f'(B + C) + A = {(B + C) + A}')
print('=' * 60)
print(f'gain = (A + B) + C = {gain}') which prints
|
This is probably a bit of digression from topic of this issue, but I feel like what I'm writing here might be helpful in (atleast partly) fixing this issue in our new backend... @hcho3 template <typename DataT, typename IdxT>
DI void mseGain(DataT* spred, IdxT* scount, DataT* sbins,
Split<DataT, IdxT>& sp, IdxT col, IdxT len, IdxT nbins,
IdxT min_samples_leaf, DataT min_val, DataT min_impurity_decrease) {
auto invlen = DataT(1.0) / len;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto nLeft = scount[i];
auto nRight = len - nLeft;
DataT gain;
// if there aren't enough samples in this split, don't bother!
if (nLeft < min_samples_leaf || nRight < min_samples_leaf) {
gain = min_val; // IOW, -std::numeric_limits<DataT>::max() (but this constant is not accessible from device methods!)
else {
auto invLeft = DataT(1.0) / nLeft;
auto invRight = DataT(1.0) / nRight;
auto valL = spred[i];
auto valR = spred[nbins + i];
// parent sum is basically sum of its left and right children
auto valP = (valL + valR) * invlen;
gain = -valP * valP;
gain += valL * invlen * valL * invLeft;
gain += valR * invlen * valR * invRight;
}
// if the gain is not "enough", don't bother!
if (gain <= min_impurity_decrease) {
gain = min_val;
}
sp.update({sbins[i], col, gain, nLeft});
}
} In short, if we find that certain splits are going to not satisfy the |
I filed #3216 according to your suggestion. While it does not fix the issue in the particular toy example, I believe it should be merged on its own merit, namely ignoring ineligible split candidates. |
…tical(#3243) Closes #3231 Closes #3128 Partially addresses #3188 The degenerate case (labels all identical in a node) is now robustly handled, by computing the MSE metric separately for each of the three nodes (the parent node, the left child node, and the right child node). Doing so ensures that the gain is 0 for the degenerate case. The degenerate case may occur in some real-world regression problems, e.g. house price data where the price label is rounded up to nearest 100k. As a result, the MSE gain is computed very similarly as the MAE gain. Disadvantage: now we always make two passes over data to compute the gain. cc @teju85 @vinaydes @JohnZed Authors: - Hyunsu Cho <[email protected]> - Philip Hyunsu Cho <[email protected]> Approvers: - Thejaswi Rao - John Zedlewski URL: #3243
This issue has been marked stale due to no recent activity in the past 30d. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be marked rotten if there is no activity in the next 60d. |
This issue has been labeled |
Fixed in #3243 |
The old and new backends of RF produces different models for a toy example (given below).
Minimum reproducer: this reproducer consists of a toy example with 10 rows and 2 columns, generated by
make_blobs
.The reproducer also includes a toy implementation of the Decision Tree algorithm:
Here are the results with varying values for max_depth.
max_depth=1
max_depth=2
max_depth=3
The text was updated successfully, but these errors were encountered: