Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Parallelize Treelite to FIL conversion over trees #3396

Merged
merged 12 commits into from
Jan 26, 2021
43 changes: 29 additions & 14 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

/** @file fil.cu implements forest inference */

#include <omp.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
Expand Down Expand Up @@ -433,7 +434,10 @@ inline int max_depth(const tl::Tree<T, L>& tree) {
template <typename T, typename L>
int max_depth(const tl::ModelImpl<T, L>& model) {
int depth = 0;
for (const auto& tree : model.trees) {
const auto& trees = model.trees;
#pragma omp parallel for reduction(max : depth)
for (size_t i = 0; i < trees.size(); ++i) {
const auto& tree = trees[i];
depth = std::max(depth, max_depth(tree));
}
return depth;
Expand Down Expand Up @@ -543,12 +547,12 @@ void tree2fil_dense(std::vector<dense_node>* pnodes, int root,
}

template <typename fil_node_t, typename T, typename L>
int tree2fil_sparse(std::vector<fil_node_t>* pnodes, const tl::Tree<T, L>& tree,
int tree2fil_sparse(std::vector<fil_node_t>& nodes, int root,
const tl::Tree<T, L>& tree,
const forest_params_t& forest_params) {
typedef std::pair<int, int> pair_t;
std::stack<pair_t> stack;
int root = pnodes->size();
pnodes->push_back(fil_node_t());
int built_index = root + 1;
stack.push(pair_t(tree_root(tree), 0));
while (!stack.empty()) {
const pair_t& top = stack.top();
Expand All @@ -572,10 +576,9 @@ int tree2fil_sparse(std::vector<fil_node_t>* pnodes, const tl::Tree<T, L>& tree,
// reserve space for child nodes
// left is the offset of the left child node relative to the tree root
// in the array of all nodes of the FIL sparse forest
int left = pnodes->size() - root;
pnodes->push_back(fil_node_t());
pnodes->push_back(fil_node_t());
(*pnodes)[root + cur] =
int left = built_index - root;
built_index += 2;
nodes[root + cur] =
fil_node_t(val_t{.f = 0}, threshold, tree.SplitIndex(node_id),
default_left, false, left);

Expand All @@ -587,8 +590,8 @@ int tree2fil_sparse(std::vector<fil_node_t>* pnodes, const tl::Tree<T, L>& tree,
}

// leaf node
(*pnodes)[root + cur] = fil_node_t(val_t{.f = NAN}, NAN, 0, false, true, 0);
tl2fil_leaf_payload(&(*pnodes)[root + cur], tree, node_id, forest_params);
nodes[root + cur] = fil_node_t(val_t{.f = NAN}, NAN, 0, false, true, 0);
tl2fil_leaf_payload(&nodes[root + cur], tree, node_id, forest_params);
}

return root;
Expand Down Expand Up @@ -751,11 +754,23 @@ void tl2fil_sparse(std::vector<int>* ptrees, std::vector<fil_node_t>* pnodes,
tl2fil_common(params, model, tl_params);
tl2fil_sparse_check_t<fil_node_t>::check(model);

size_t num_trees = model.trees.size();

ptrees->reserve(num_trees);
wphicks marked this conversation as resolved.
Show resolved Hide resolved
ptrees->push_back(0);
for (size_t i = 0; i < num_trees - 1; ++i) {
ptrees->push_back(model.trees[i].num_nodes + ptrees->back());
}
size_t total_nodes = ptrees->back() + model.trees.back().num_nodes;

pnodes->resize(total_nodes);

// convert the nodes
for (int i = 0; i < model.trees.size(); ++i) {
int root = tree2fil_sparse(pnodes, model.trees[i], *params);
ptrees->push_back(root);
#pragma omp parallel for
for (int i = 0; i < num_trees; ++i) {
tree2fil_sparse(*pnodes, (*ptrees)[i], model.trees[i], *params);
}

params->num_nodes = pnodes->size();
}

Expand Down