Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split FIL infer_k into phases to speed up compilation (when a patch is applied) #4148

Merged
merged 8 commits into from
Aug 11, 2021
110 changes: 61 additions & 49 deletions cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <fil/internal.cuh>
#include "common.cuh"

#define NOINLINE_FAST_COMPILE __noinline__
levsnv marked this conversation as resolved.
Show resolved Hide resolved
levsnv marked this conversation as resolved.
Show resolved Hide resolved

namespace ML {
namespace fil {

Expand Down Expand Up @@ -278,12 +280,12 @@ struct tree_aggregator_t {
acc += single_tree_prediction;
}

__device__ __forceinline__ void finalize(float* block_out,
int block_num_rows,
int output_stride,
output_t transform,
int num_trees,
int log2_threads_per_tree)
__device__ NOINLINE_FAST_COMPILE void finalize(float* block_out,
int block_num_rows,
int output_stride,
output_t transform,
int num_trees,
int log2_threads_per_tree)
{
if (FIL_TPB != 1 << log2_threads_per_tree) { // anything to reduce?
// ensure input columns can be overwritten (no threads traversing trees)
Expand Down Expand Up @@ -460,12 +462,12 @@ struct tree_aggregator_t<NITEMS, GROVE_PER_CLASS_FEW_CLASSES> {
acc += single_tree_prediction;
}

__device__ __forceinline__ void finalize(float* out,
int num_rows,
int num_outputs,
output_t transform,
int num_trees,
int log2_threads_per_tree)
__device__ NOINLINE_FAST_COMPILE void finalize(float* out,
int num_rows,
int num_outputs,
output_t transform,
int num_trees,
int log2_threads_per_tree)
{
__syncthreads(); // free up input row in case it was in shared memory
// load margin into shared memory
Expand Down Expand Up @@ -532,12 +534,12 @@ struct tree_aggregator_t<NITEMS, GROVE_PER_CLASS_MANY_CLASSES> {
__syncthreads();
}

__device__ __forceinline__ void finalize(float* out,
int num_rows,
int num_outputs,
output_t transform,
int num_trees,
int log2_threads_per_tree)
__device__ NOINLINE_FAST_COMPILE void finalize(float* out,
int num_rows,
int num_outputs,
output_t transform,
int num_trees,
int log2_threads_per_tree)
{
class_margins_to_global_memory(per_class_margin,
per_class_margin + num_classes,
Expand Down Expand Up @@ -631,12 +633,12 @@ struct tree_aggregator_t<NITEMS, VECTOR_LEAF> {
}
}
}
__device__ __forceinline__ void finalize(float* out,
int num_rows,
int num_outputs,
output_t transform,
int num_trees,
int log2_threads_per_tree)
__device__ NOINLINE_FAST_COMPILE void finalize(float* out,
int num_rows,
int num_outputs,
output_t transform,
int num_trees,
int log2_threads_per_tree)
{
if (num_classes < blockDim.x) {
__syncthreads();
Expand Down Expand Up @@ -728,12 +730,12 @@ struct tree_aggregator_t<NITEMS, CATEGORICAL_LEAF> {
out[row] = best_class;
}
}
__device__ __forceinline__ void finalize(float* out,
int num_rows,
int num_outputs,
output_t transform,
int num_trees,
int log2_threads_per_tree)
__device__ NOINLINE_FAST_COMPILE void finalize(float* out,
int num_rows,
int num_outputs,
output_t transform,
int num_trees,
int log2_threads_per_tree)
{
if (num_outputs > 1) {
// only supporting num_outputs == num_classes
Expand All @@ -744,6 +746,33 @@ struct tree_aggregator_t<NITEMS, CATEGORICAL_LEAF> {
}
};

__device__ NOINLINE_FAST_COMPILE void load_data(float* sdata,
const float* block_input,
predict_params params,
int rows_per_block,
int block_num_rows)
{
int num_cols = params.num_cols;
int sdata_stride = params.sdata_stride();
// cache the row for all threads to reuse
// 2021: latest SMs still do not have >256KiB of shared memory/block required to
// exceed the uint16_t
#pragma unroll
for (uint16_t input_idx = threadIdx.x; input_idx < block_num_rows * num_cols;
input_idx += blockDim.x) {
// for even num_cols, we need to pad sdata_stride to reduce bank conflicts
// assuming here that sdata_stride == num_cols + 1
// then, idx / num_cols * sdata_stride + idx % num_cols == idx + idx / num_cols
uint16_t sdata_idx =
sdata_stride == num_cols ? input_idx : input_idx + input_idx / (uint16_t)num_cols;
sdata[sdata_idx] = block_input[input_idx];
}
#pragma unroll
for (int idx = block_num_rows * sdata_stride; idx < rows_per_block * sdata_stride;
idx += blockDim.x)
sdata[idx] = 0.0f;
}

template <int NITEMS, leaf_algo_t leaf_algo, bool cols_in_shmem, class storage_type>
__global__ void infer_k(storage_type forest, predict_params params)
{
Expand All @@ -758,25 +787,8 @@ __global__ void infer_k(storage_type forest, predict_params params)
int block_num_rows =
max(0, (int)min((int64_t)rows_per_block, (int64_t)params.num_rows - block_row0));
const float* block_input = params.data + block_row0 * num_cols;
if (cols_in_shmem) {
// cache the row for all threads to reuse
// 2021: latest SMs still do not have >256KiB of shared memory/block required to
// exceed the uint16_t
#pragma unroll
for (uint16_t input_idx = threadIdx.x; input_idx < block_num_rows * num_cols;
input_idx += blockDim.x) {
// for even num_cols, we need to pad sdata_stride to reduce bank conflicts
// assuming here that sdata_stride == num_cols + 1
// then, idx / num_cols * sdata_stride + idx % num_cols == idx + idx / num_cols
uint16_t sdata_idx =
sdata_stride == num_cols ? input_idx : input_idx + input_idx / (uint16_t)num_cols;
sdata[sdata_idx] = block_input[input_idx];
}
#pragma unroll
for (int idx = block_num_rows * sdata_stride; idx < rows_per_block * sdata_stride;
idx += blockDim.x)
sdata[idx] = 0.0f;
}
if constexpr (cols_in_shmem)
load_data(sdata, block_input, params, rows_per_block, block_num_rows);

tree_aggregator_t<NITEMS, leaf_algo> acc(
params, (char*)sdata + params.cols_shmem_size(), sdata, forest.vector_leaf_);
Expand Down