Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make operator TreeEnsemble 5x faster for batches of size 100.000 #5965

Merged
merged 11 commits into from
Dec 3, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ template <typename T>
TreeEnsembleClassifier<T>::TreeEnsembleClassifier(const OpKernelInfo& info)
: OpKernel(info),
tree_ensemble_(
100,
80,
50,
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
info.GetAttrsOrDefault<float>("base_values"),
Expand Down
183 changes: 116 additions & 67 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include "core/platform/ort_mutex.h"
#include "core/platform/threadpool.h"

#define TREEENSEMBLE_BATCHSIZE 256
#define TREEENSEMBLE_MAXSIZE 8589934592
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved

namespace onnxruntime {
namespace ml {
namespace detail {
Expand Down Expand Up @@ -266,11 +269,11 @@ void TreeEnsembleCommon<ITYPE, OTYPE>::ComputeAgg(concurrency::ThreadPool* ttp,
if (n_targets_or_classes_ == 1) {
if (N == 1) {
ScoreValue<OTYPE> score = {0, 0};
if (n_trees_ <= parallel_tree_) {
if (n_trees_ <= parallel_tree_) { /* section A */
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
for (int64_t j = 0; j < n_trees_; ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data));
}
} else {
} else { /* section B */
std::vector<ScoreValue<OTYPE>> scores_t(n_trees_, {0, 0});
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
Expand All @@ -284,46 +287,86 @@ void TreeEnsembleCommon<ITYPE, OTYPE>::ComputeAgg(concurrency::ThreadPool* ttp,
agg.MergePrediction1(score, *it);
}
}

agg.FinalizeScores1(z_data, score, label_data);
} else {
if (N <= parallel_N_) {
ScoreValue<OTYPE> score;
size_t j;

for (int64_t i = 0; i < N; ++i) {
score = {0, 0};
for (j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores1(z_data + i * n_targets_or_classes_, score,
label_data == nullptr ? nullptr : (label_data + i));
} else if (N <= parallel_N_) { /* section C */
ScoreValue<OTYPE> score;
size_t j;

for (int64_t i = 0; i < N; ++i) {
score = {0, 0};
for (j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
} else {
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(N),
[this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) {
ScoreValue<OTYPE> score = {0, 0};
for (size_t j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores1(z_data + i * n_targets_or_classes_, score,
label_data == nullptr ? nullptr : (label_data + i));
},
0);
agg.FinalizeScores1(z_data + i, score,
label_data == nullptr ? nullptr : (label_data + i));
}
} else { /* section D */
/*
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
// Parallelization by trees.
// This could use an array N * nth where nth is the number of threads.
// It would requires function omp_get_thread_num and omp_get_max_threads.
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
std::vector<ScoreValue<OTYPE>> scores_t(n_trees_ * N, {0, 0});
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(n_trees_),
[this, &scores_t, &agg, x_data, N, stride](ptrdiff_t j) {
for (int64_t i = 0; i < N; ++i) {
agg.ProcessTreeNodePrediction1(scores_t[j * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
},
0);

concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(N),
[this, &scores_t, &agg, z_data, label_data, N](ptrdiff_t i) {
for (int64_t j = 1; j < this->n_trees_; ++j) {
agg.MergePrediction1(scores_t[i], scores_t[j * N + i]);
}
agg.FinalizeScores1(z_data + i, scores_t[i], label_data == nullptr ? nullptr : (label_data + i));
},
0);
*/
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(n_trees_));
std::vector<ScoreValue<OTYPE>> scores(num_threads * N);
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, x_data, N, stride](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, this->n_trees_);
for (int64_t i = 0; i < N; ++i) {
scores[batch_num * N + i] = {0, 0};
}
for (auto j = work.start; j < work.end; ++j) {
for (int64_t i = 0; i < N; ++i) {
agg.ProcessTreeNodePrediction1(scores[batch_num * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
}
});

concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[&agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);
for (auto i = work.start; i < work.end; ++i) {
for (int64_t j = 1; j < num_threads; ++j) {
agg.MergePrediction1(scores[i], scores[j * N + i]);
}
agg.FinalizeScores1(z_data + i, scores[i],
label_data == nullptr ? nullptr : (label_data + i));
}
});
}
} else {
if (N == 1) {
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_, {0, 0});
if (n_trees_ <= parallel_tree_) {
if (n_trees_ <= parallel_tree_) { /* section A2 */
for (int64_t j = 0; j < n_trees_; ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data));
}
} else {
} else { /* section B2 */
// split the work into one block per thread so we can re-use the 'private_scores' vector as much as possible
// TODO: Refine the number of threads used
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(n_trees_));
Expand All @@ -344,44 +387,50 @@ void TreeEnsembleCommon<ITYPE, OTYPE>::ComputeAgg(concurrency::ThreadPool* ttp,
}

agg.FinalizeScores(scores, z_data, -1, label_data);
} else {
if (N <= parallel_N_) {
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
size_t j;

for (int64_t i = 0; i < N; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
} else if (N <= parallel_N_) { /* section C2 */
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
size_t j;

for (int64_t i = 0; i < N; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
} else {
// split the work into one block per thread so we can re-use the 'scores' vector as much as possible
// TODO: Refine the number of threads used.
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(N));
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) {
size_t j;
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);

for (auto i = work.start; i < work.end; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores(scores,
z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
});

agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
} else {
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(n_trees_));
std::vector<std::vector<ScoreValue<OTYPE>>> scores(num_threads * N);
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, x_data, N, stride](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, this->n_trees_);
for (int64_t i = 0; i < N; ++i) {
scores[batch_num * N + i].resize(n_targets_or_classes_, {0, 0});
}
for (auto j = work.start; j < work.end; ++j) {
for (int64_t i = 0; i < N; ++i) {
agg.ProcessTreeNodePrediction(scores[batch_num * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
}
});

concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);
for (auto i = work.start; i < work.end; ++i) {
for (int64_t j = 1; j < num_threads; ++j) {
agg.MergePrediction(scores[i], scores[j * N + i]);
}
agg.FinalizeScores(scores[i], z_data + i * this->n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
});
}
}
} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/ml/treeregressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <typename T>
TreeEnsembleRegressor<T>::TreeEnsembleRegressor(const OpKernelInfo& info)
: OpKernel(info),
tree_ensemble_(
100,
80,
50,
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
info.GetAttrsOrDefault<float>("base_values"),
Expand Down
Loading