Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make operator TreeEnsemble 5x faster for batches of size 100.000 #5965

Merged
merged 11 commits into from
Dec 3, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ template <typename T>
TreeEnsembleClassifier<T>::TreeEnsembleClassifier(const OpKernelInfo& info)
: OpKernel(info),
tree_ensemble_(
100,
80,
50,
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
info.GetAttrsOrDefault<float>("base_values"),
Expand Down
224 changes: 139 additions & 85 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,126 +262,180 @@ void TreeEnsembleCommon<ITYPE, OTYPE>::ComputeAgg(concurrency::ThreadPool* ttp,
const ITYPE* x_data = X->template Data<ITYPE>();
OTYPE* z_data = Z->template MutableData<OTYPE>();
int64_t* label_data = label == nullptr ? nullptr : label->template MutableData<int64_t>();
auto max_num_threads = concurrency::ThreadPool::DegreeOfParallelism(ttp);

if (n_targets_or_classes_ == 1) {
if (N == 1) {
ScoreValue<OTYPE> score = {0, 0};
if (n_trees_ <= parallel_tree_) {
if (n_trees_ <= parallel_tree_) { /* section A: 1 output, 1 row and not enough trees to parallelize */
for (int64_t j = 0; j < n_trees_; ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data));
}
} else {
std::vector<ScoreValue<OTYPE>> scores_t(n_trees_, {0, 0});
} else { /* section B: 1 output, 1 row and enough trees to parallelize */
std::vector<ScoreValue<OTYPE>> scores(n_trees_, {0, 0});
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(n_trees_),
[this, &scores_t, &agg, x_data](ptrdiff_t j) {
agg.ProcessTreeNodePrediction1(scores_t[j], *ProcessTreeNodeLeave(roots_[j], x_data));
[this, &scores, &agg, x_data](ptrdiff_t j) {
agg.ProcessTreeNodePrediction1(scores[j], *ProcessTreeNodeLeave(roots_[j], x_data));
},
0);

for (auto it = scores_t.cbegin(); it != scores_t.cend(); ++it) {
for (auto it = scores.cbegin(); it != scores.cend(); ++it) {
agg.MergePrediction1(score, *it);
}
}

agg.FinalizeScores1(z_data, score, label_data);
} else {
if (N <= parallel_N_) {
ScoreValue<OTYPE> score;
size_t j;

for (int64_t i = 0; i < N; ++i) {
score = {0, 0};
for (j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores1(z_data + i * n_targets_or_classes_, score,
label_data == nullptr ? nullptr : (label_data + i));
} else if (N <= parallel_N_) { /* section C: 1 output, 2+ rows but not enough rows to parallelize */
ScoreValue<OTYPE> score;
size_t j;

for (int64_t i = 0; i < N; ++i) {
score = {0, 0};
for (j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
} else {
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(N),
[this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) {
ScoreValue<OTYPE> score = {0, 0};
for (size_t j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores1(z_data + i * n_targets_or_classes_, score,
label_data == nullptr ? nullptr : (label_data + i));
},
0);
agg.FinalizeScores1(z_data + i, score,
label_data == nullptr ? nullptr : (label_data + i));
}
} else if (n_trees_ > max_num_threads) { /* section D: 1 output, 2+ rows and enough trees to parallelize */
auto num_threads = std::min<int32_t>(max_num_threads, SafeInt<int32_t>(n_trees_));
std::vector<ScoreValue<OTYPE>> scores(num_threads * N);
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, x_data, N, stride](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, this->n_trees_);
for (int64_t i = 0; i < N; ++i) {
scores[batch_num * N + i] = {0, 0};
}
for (auto j = work.start; j < work.end; ++j) {
for (int64_t i = 0; i < N; ++i) {
agg.ProcessTreeNodePrediction1(scores[batch_num * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
}
});

concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[&agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);
for (auto i = work.start; i < work.end; ++i) {
for (int64_t j = 1; j < num_threads; ++j) {
agg.MergePrediction1(scores[i], scores[j * N + i]);
}
agg.FinalizeScores1(z_data + i, scores[i],
label_data == nullptr ? nullptr : (label_data + i));
}
});
} else { /* section E: 1 output, 2+ rows, parallelization by rows */
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(N),
[this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) {
ScoreValue<OTYPE> score = {0, 0};
for (size_t j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores1(z_data + i, score,
label_data == nullptr ? nullptr : (label_data + i));
},
0);
}
} else {
if (N == 1) {
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_, {0, 0});
if (n_trees_ <= parallel_tree_) {
if (N == 1) { /* section A2: 2+ outputs, 1 row, not enough trees to parallelize */
if (n_trees_ <= parallel_tree_) { /* section A2 */
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_, {0, 0});
for (int64_t j = 0; j < n_trees_; ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data));
}
} else {
// split the work into one block per thread so we can re-use the 'private_scores' vector as much as possible
// TODO: Refine the number of threads used
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(n_trees_));
OrtMutex merge_mutex;
agg.FinalizeScores(scores, z_data, -1, label_data);
} else { /* section B2: 2+ outputs, 1 row, enough trees to parallelize */
auto num_threads = std::min<int32_t>(max_num_threads, SafeInt<int32_t>(n_trees_));
std::vector<std::vector<ScoreValue<OTYPE>>> scores(num_threads);
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, &merge_mutex, num_threads, x_data](ptrdiff_t batch_num) {
std::vector<ScoreValue<OTYPE>> private_scores(n_targets_or_classes_, {0, 0});
[this, &agg, &scores, num_threads, x_data](ptrdiff_t batch_num) {
scores[batch_num].resize(n_targets_or_classes_, {0, 0});
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, n_trees_);
for (auto j = work.start; j < work.end; ++j) {
agg.ProcessTreeNodePrediction(private_scores, *ProcessTreeNodeLeave(roots_[j], x_data));
agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data));
}

std::lock_guard<OrtMutex> lock(merge_mutex);
agg.MergePrediction(scores, private_scores);
});
for (size_t i = 1; i < scores.size(); ++i) {
agg.MergePrediction(scores[0], scores[i]);
}
agg.FinalizeScores(scores[0], z_data, -1, label_data);
}

agg.FinalizeScores(scores, z_data, -1, label_data);
} else {
if (N <= parallel_N_) {
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
size_t j;

for (int64_t i = 0; i < N; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
} else if (N <= parallel_N_) { /* section C2: 2+ outputs, 2+ rows, not enough rows to parallelize */
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
size_t j;

for (int64_t i = 0; i < N; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
} else {
// split the work into one block per thread so we can re-use the 'scores' vector as much as possible
// TODO: Refine the number of threads used.
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(N));
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) {
size_t j;
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);

for (auto i = work.start; i < work.end; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores(scores,
z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
});

agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
} else if (n_trees_ >= max_num_threads) { /* section: D2: 2+ outputs, 2+ rows, enough trees to parallelize*/
auto num_threads = std::min<int32_t>(max_num_threads, SafeInt<int32_t>(n_trees_));
std::vector<std::vector<ScoreValue<OTYPE>>> scores(num_threads * N);
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, x_data, N, stride](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, this->n_trees_);
for (int64_t i = 0; i < N; ++i) {
scores[batch_num * N + i].resize(n_targets_or_classes_, {0, 0});
}
for (auto j = work.start; j < work.end; ++j) {
for (int64_t i = 0; i < N; ++i) {
agg.ProcessTreeNodePrediction(scores[batch_num * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
}
});

concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);
for (auto i = work.start; i < work.end; ++i) {
for (int64_t j = 1; j < num_threads; ++j) {
agg.MergePrediction(scores[i], scores[j * N + i]);
}
agg.FinalizeScores(scores[i], z_data + i * this->n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
});
} else { /* section E2: 2+ outputs, 2+ rows, parallelization by rows */
auto num_threads = std::min<int32_t>(max_num_threads, SafeInt<int32_t>(N));
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) {
size_t j;
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);

for (auto i = work.start; i < work.end; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores(scores,
z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
});
}
}
} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/ml/treeregressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <typename T>
TreeEnsembleRegressor<T>::TreeEnsembleRegressor(const OpKernelInfo& info)
: OpKernel(info),
tree_ensemble_(
100,
80,
50,
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
info.GetAttrsOrDefault<float>("base_values"),
Expand Down
Loading