Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Avoid unnecessary split for degenerate case where all labels are identical #3243

Merged
merged 5 commits into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
- PR #3214: Correct flaky silhouette score test by setting atol
- PR #3216: Ignore splits that do not satisfy constraints
- PR #3239: Fix intermittent dask random forest failure
- PR #3243: Avoid unnecessary split for degenerate case where all labels are identical
- PR #3245: Rename `rows_sample` -> `max_samples` to be consistent with sklearn's RF

# cuML 0.16.0 (23 Oct 2020)
Expand Down
11 changes: 4 additions & 7 deletions cpp/src/decisiontree/batched-levelalgo/builder_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ struct Builder {
maxNodes = 8191;
}

if (isRegression() && params.split_criterion == CRITERION::MAE) {
if (isRegression()) {
dim3 grid(n_blks_for_rows, n_col_blks, max_batch);
block_sync_size = MLCommon::GridSync::computeWorkspaceSize(
grid, MLCommon::SyncType::ACROSS_X, false);
Expand Down Expand Up @@ -469,12 +469,9 @@ struct RegTraits {

CUDA_CHECK(
cudaMemsetAsync(b.pred, 0, sizeof(DataT) * b.nPredCounts * 2, s));
if (splitType == CRITERION::MAE) {
CUDA_CHECK(
cudaMemsetAsync(b.pred2, 0, sizeof(DataT) * b.nPredCounts * 2, s));
CUDA_CHECK(
cudaMemsetAsync(b.pred2P, 0, sizeof(DataT) * b.nPredCounts, s));
}
CUDA_CHECK(
cudaMemsetAsync(b.pred2, 0, sizeof(DataT) * b.nPredCounts * 2, s));
CUDA_CHECK(cudaMemsetAsync(b.pred2P, 0, sizeof(DataT) * b.nPredCounts, s));
CUDA_CHECK(
cudaMemsetAsync(b.pred_count, 0, sizeof(IdxT) * b.nPredCounts, s));
computeSplitRegressionKernel<DataT, DataT, IdxT, TPB_DEFAULT>
Expand Down
122 changes: 61 additions & 61 deletions cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,6 @@ __global__ void computeSplitRegressionKernel(
auto* spred = alignPointer<DataT>(smem);
auto* scount = alignPointer<int>(spred + len);
auto* sbins = alignPointer<DataT>(scount + nbins);

// used only for MAE criterion
auto* spred2 = alignPointer<DataT>(sbins + nbins);
auto* spred2P = alignPointer<DataT>(spred2 + len);
auto* spredP = alignPointer<DataT>(spred2P + nbins);
Expand All @@ -372,8 +370,6 @@ __global__ void computeSplitRegressionKernel(
}
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
scount[i] = 0;
// printf("indexing from sbins: %p to %p, sizeof: %d (spred: %p)\n", sbins,
// &sbins[i], (int)sizeof(DataT*), spred);
sbins[i] = input.quantiles[col * nbins + i];
}
__syncthreads();
Expand Down Expand Up @@ -403,35 +399,37 @@ __global__ void computeSplitRegressionKernel(
}
__threadfence(); // for commit guarantee
__syncthreads();
// for MAE computation, we'd need a 2nd pass over data :(

/* Make a second pass over the data to compute gain */
// Wait until all blockIdx.x's are done
MLCommon::GridSync gs(workspace, MLCommon::SyncType::ACROSS_X, false);
gs.sync();
// now, compute the mean value to be used for metric update
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
scount[i] = count[gcOffset + i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the next revision, I would love to have these variable names be a bit more explicit. They all look similar and it was a bit hard to parse the formulas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I had to read the code very carefully to deduce their meaning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with John. When this was written initially, readability was not kept as a priority item! @hcho3 can you please file an issue so that we don't forget about this?

spred2P[i] = DataT(0.0);
}
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
spred[i] = pred[gOffset + i];
spred2[i] = DataT(0.0);
}
__syncthreads();
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
spredP[i] = spred[i] + spred[i + nbins];
}
__syncthreads();
auto invlen = DataT(1.0) / range_len;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto cnt_l = DataT(scount[i]);
auto cnt_r = DataT(range_len - scount[i]);
spred[i] /= cnt_l;
spred[i + nbins] /= cnt_r;
spredP[i] *= invlen;
}
__syncthreads();

// 2nd pass over data to compute partial metric across blockIdx.x's
if (splitType == CRITERION::MAE) {
// wait until all blockIdx.x's are done
MLCommon::GridSync gs(workspace, MLCommon::SyncType::ACROSS_X, false);
gs.sync();
// now, compute the mean value to be used for MAE update
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
scount[i] = count[gcOffset + i];
spred2P[i] = DataT(0.0);
}
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
spred[i] = pred[gOffset + i];
spred2[i] = DataT(0.0);
}
__syncthreads();
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
spredP[i] = spred[i] + spred[i + nbins];
}
__syncthreads();
auto invlen = DataT(1.0) / range_len;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto cnt_l = DataT(scount[i]);
auto cnt_r = DataT(range_len - scount[i]);
spred[i] /= cnt_l;
spred[i + nbins] /= cnt_r;
spredP[i] *= invlen;
}
__syncthreads();
// 2nd pass over data to compute partial MAE's across blockIdx.x's
for (auto i = range_start + tid; i < end; i += stride) {
auto row = input.rowids[i];
auto d = input.data[row + coloffset];
Expand All @@ -444,17 +442,31 @@ __global__ void computeSplitRegressionKernel(
atomicAdd(spred2P + b, raft::myAbs(label - spredP[b]));
}
}
__syncthreads();
// update the corresponding global location
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
atomicAdd(pred2P + gcOffset + i, spred2P[i]);
}
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
atomicAdd(pred2 + gOffset + i, spred2[i]);
} else {
for (auto i = range_start + tid; i < end; i += stride) {
auto row = input.rowids[i];
auto d = input.data[row + coloffset];
auto label = input.labels[row];
for (IdxT b = 0; b < nbins; ++b) {
auto isRight = d > sbins[b]; // no divergence
auto offset = isRight * nbins + b;
auto diff = label - (isRight ? spred[nbins + b] : spred[b]);
auto diff2 = label - spredP[b];
atomicAdd(spred2 + offset, (diff * diff));
atomicAdd(spred2P + b, (diff2 * diff2));
}
}
__threadfence(); // for commit guarantee
__syncthreads();
}
__syncthreads();
// update the corresponding global location
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
atomicAdd(pred2P + gcOffset + i, spred2P[i]);
}
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
atomicAdd(pred2 + gOffset + i, spred2[i]);
}
__threadfence(); // for commit guarantee
__syncthreads();
// last threadblock will go ahead and compute the best split
bool last = true;
if (gridDim.x > 1) {
Expand All @@ -466,27 +478,15 @@ __global__ void computeSplitRegressionKernel(
// last block computes the final gain
Split<DataT, IdxT> sp;
sp.init();
if (splitType == CRITERION::MSE) {
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
spred[i] = pred[gOffset + i];
}
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
scount[i] = count[gcOffset + i];
}
__syncthreads();
mseGain(spred, scount, sbins, sp, col, range_len, nbins, min_samples_leaf,
min_impurity_decrease);
} else {
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
spred2[i] = pred2[gOffset + i];
}
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
spred2P[i] = pred2P[gcOffset + i];
}
__syncthreads();
maeGain(spred2, spred2P, scount, sbins, sp, col, range_len, nbins,
min_samples_leaf, min_impurity_decrease);
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
spred2[i] = pred2[gOffset + i];
}
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
spred2P[i] = pred2P[gcOffset + i];
}
__syncthreads();
regressionMetricGain(spred2, spred2P, scount, sbins, sp, col, range_len,
nbins, min_samples_leaf, min_impurity_decrease);
__syncthreads();
sp.evalBestSplit(smem, splits + nid, mutex + nid);
}
Expand Down
68 changes: 6 additions & 62 deletions cpp/src/decisiontree/batched-levelalgo/metrics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -185,62 +185,7 @@ DI void entropyGain(int* shist, DataT* sbins, Split<DataT, IdxT>& sp, IdxT col,
}

/**
* @brief Compute gain based on MSE
*
* @param[in] spred left/right child mean prediction for all
* bins [dim = 2 x bins]
* @param[in] scount left child count for all bins
* [len = nbins]
* @param[in] sbins quantiles for the current column
* [len = nbins]
* @param[inout] sp will contain the per-thread best split
* so far
* @param[in] col current column
* @param[in] len total number of samples for the current
* node to be split
* @param[in] nbins number of bins
* @param[in] min_samples_leaf minimum number of samples per each leaf.
* Any splits that lead to a leaf node with
* samples fewer than min_samples_leaf will
* be ignored.
* @param[in] min_impurity_decrease minimum improvement in MSE metric. Any
* splits that do not improve (decrease)
* the MSE metric at least by this amount
* will be ignored.
*/
template <typename DataT, typename IdxT>
DI void mseGain(DataT* spred, IdxT* scount, DataT* sbins,
Split<DataT, IdxT>& sp, IdxT col, IdxT len, IdxT nbins,
IdxT min_samples_leaf, DataT min_impurity_decrease) {
auto invlen = DataT(1.0) / len;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto nLeft = scount[i];
auto nRight = len - nLeft;
DataT gain;
// if there aren't enough samples in this split, don't bother!
if (nLeft < min_samples_leaf || nRight < min_samples_leaf) {
gain = -NumericLimits<DataT>::kMax;
} else {
auto invLeft = DataT(1.0) / nLeft;
auto invRight = DataT(1.0) / nRight;
auto valL = spred[i];
auto valR = spred[nbins + i];
// parent sum is basically sum of its left and right children
auto valP = (valL + valR) * invlen;
gain = -valP * valP;
gain += valL * invlen * valL * invLeft;
gain += valR * invlen * valR * invRight;
}
// if the gain is not "enough", don't bother!
if (gain <= min_impurity_decrease) {
gain = -NumericLimits<DataT>::kMax;
}
sp.update({sbins[i], col, gain, nLeft});
}
}

/**
* @brief Compute gain based on MAE
* @brief Compute gain based on MSE or MAE
*
* @param[in] spred left/right child sum of abs diff of
* prediction for all bins [dim = 2 x bins]
Expand All @@ -266,9 +211,10 @@ DI void mseGain(DataT* spred, IdxT* scount, DataT* sbins,
* will be ignored.
*/
template <typename DataT, typename IdxT>
DI void maeGain(DataT* spred, DataT* spredP, IdxT* scount, DataT* sbins,
Split<DataT, IdxT>& sp, IdxT col, IdxT len, IdxT nbins,
IdxT min_samples_leaf, DataT min_impurity_decrease) {
DI void regressionMetricGain(DataT* spred, DataT* spredP, IdxT* scount,
DataT* sbins, Split<DataT, IdxT>& sp, IdxT col,
IdxT len, IdxT nbins, IdxT min_samples_leaf,
DataT min_impurity_decrease) {
auto invlen = DataT(1.0) / len;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto nLeft = scount[i];
Expand All @@ -278,9 +224,7 @@ DI void maeGain(DataT* spred, DataT* spredP, IdxT* scount, DataT* sbins,
if (nLeft < min_samples_leaf || nRight < min_samples_leaf) {
gain = -NumericLimits<DataT>::kMax;
} else {
gain = spredP[i];
gain -= spred[i];
gain -= spred[i + nbins];
gain = spredP[i] - spred[i] - spred[i + nbins];
gain *= invlen;
}
// if the gain is not "enough", don't bother!
Expand Down
Loading