Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate experimental FIL in the FIL backend #366

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ else()
rapids_cuda_init_architectures(RAPIDS_TRITON_BACKEND)
project(RAPIDS_TRITON_BACKEND VERSION 22.10.00 LANGUAGES CXX CUDA)
else()
project(RAPIDS_TRITON_BACKEND VERSION 22.10.00 LANGUAGES CXX)
project(RAPIDS_TRITON_BACKEND VERSION 22.10.00 LANGUAGES CXX CUDA)
endif()

##############################################################################
Expand Down Expand Up @@ -166,11 +166,7 @@ else()
set(RAPIDS_TRITON_MIN_VERSION_rapids_projects "${RAPIDS_DEPENDENCIES_VERSION}.00")
set(RAPIDS_TRITON_BRANCH_VERSION_rapids_projects "${RAPIDS_DEPENDENCIES_VERSION}")

if(TRITON_ENABLE_GPU)
include(cmake/thirdparty/get_cuml.cmake)
else()
include(cmake/thirdparty/get_treelite.cmake)
endif()
include(cmake/thirdparty/get_cuml.cmake)
include(cmake/thirdparty/get_rapids-triton.cmake)

if(BUILD_TESTS)
Expand All @@ -197,6 +193,7 @@ else()
POSITION_INDEPENDENT_CODE ON
INTERFACE_POSITION_INDEPENDENT_CODE ON
)
target_sources(${BACKEND_TARGET} PRIVATE src/detail/postprocess_gpu.cu)
else()
set_target_properties(${BACKEND_TARGET}
PROPERTIES BUILD_RPATH "\$ORIGIN"
Expand Down Expand Up @@ -229,7 +226,7 @@ else()

target_link_libraries(${BACKEND_TARGET}
PRIVATE
$<$<BOOL:${TRITON_ENABLE_GPU}>:cuml++>
cuml++
${TREELITE_LIBS}
rapids_triton::rapids_triton
triton-core-serverstub
Expand All @@ -239,9 +236,7 @@ else()
OpenMP::OpenMP_CXX
)

if(TRITON_ENABLE_GPU)
list(APPEND BACKEND_TARGET "cuml++")
endif()
list(APPEND BACKEND_TARGET "cuml++")

if(NOT TRITON_FIL_USE_TREELITE_STATIC)
list(APPEND BACKEND_TARGET ${TREELITE_LIBS_NO_PREFIX})
Expand Down
9 changes: 5 additions & 4 deletions cmake/thirdparty/get_cuml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

function(find_and_configure_cuml)

set(oneValueArgs VERSION FORK PINNED_TAG USE_TREELITE_STATIC)
set(oneValueArgs VERSION FORK PINNED_TAG USE_TREELITE_STATIC TRITON_ENABLE_GPU)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )

Expand All @@ -43,6 +43,7 @@ function(find_and_configure_cuml)
"BUILD_CUML_STD_COMMS OFF"
"BUILD_SHARED_LIBS ON"
"CUML_USE_TREELITE_STATIC ${PKG_USE_TREELITE_STATIC}"
"CUML_ENABLE_GPU ${PKG_TRITON_ENABLE_GPU}"
"USE_CCACHE ON"
"RAFT_COMPILE_LIBRARIES OFF"
"RAFT_ENABLE_NN_DEPENDENCIES OFF"
Expand All @@ -56,7 +57,7 @@ endfunction()
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_cuml(VERSION ${RAPIDS_TRITON_MIN_VERSION_rapids_projects}
FORK rapidsai
PINNED_TAG branch-23.08
FORK hcho3
PINNED_TAG fix_cpu_fil
USE_TREELITE_STATIC ${TRITON_FIL_USE_TREELITE_STATIC}
)
TRITON_ENABLE_GPU ${TRITON_ENABLE_GPU})
90 changes: 88 additions & 2 deletions src/cpu_forest_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,118 @@

#pragma once

#include <detail/postprocess_cpu.h>
#include <forest_model.h>
#include <names.h>
#include <tl_model.h>

#include <algorithm>
#include <cstddef>
#include <cuml/experimental/fil/constants.hpp>
#include <cuml/experimental/fil/detail/index_type.hpp>
#include <cuml/experimental/fil/detail/raft_proto/device_type.hpp>
#include <cuml/experimental/fil/detail/raft_proto/handle.hpp>
#include <cuml/experimental/fil/forest_model.hpp>
#include <cuml/experimental/fil/treelite_importer.hpp>
#include <herring/omp_helpers.hpp>
#include <memory>
#include <optional>
#include <rapids_triton/memory/buffer.hpp>
#include <rapids_triton/memory/types.hpp>

namespace triton { namespace backend { namespace NAMESPACE {

namespace filex = ML::experimental::fil;

template <>
struct ForestModel<rapids::HostMemory> {
ForestModel() = default;
ForestModel(std::shared_ptr<TreeliteModel> tl_model) : tl_model_{tl_model} {}
using device_id_t = int;
ForestModel(
device_id_t device_id, cudaStream_t stream,
std::shared_ptr<TreeliteModel> tl_model, bool use_new_fil)
: device_id_{device_id}, tl_model_{tl_model},
new_fil_model_{[this, use_new_fil]() {
auto result = std::optional<filex::forest_model>{};
if (use_new_fil) {
try {
result = filex::import_from_treelite_model(
*tl_model_->base_tl_model(), filex::preferred_tree_layout,
filex::index_type{}, std::nullopt,
raft_proto::device_type::cpu);
rapids::log_info(__FILE__, __LINE__)
<< "Loaded model to new FIL format";
}
catch (filex::model_import_error const& ex) {
result = std::nullopt;
auto log_stream = rapids::log_info(__FILE__, __LINE__);
log_stream << "Experimental FIL load failed with error \"";
log_stream << ex.what();
log_stream << "\"; falling back to current FIL";
}
}
return result;
}()},
class_encoder_{int(thread_count(tl_model_->config().cpu_nthread))}
{
}

void predict(
rapids::Buffer<float>& output, rapids::Buffer<float const> const& input,
std::size_t samples, bool predict_proba) const
{
tl_model_->predict(output, input, samples, predict_proba);
if (new_fil_model_) {
// Create non-owning Buffer to same memory as `output`
auto output_buffer = rapids::Buffer<float>{
output.data(), output.size(), output.mem_type(), output.device(),
output.stream()};
auto output_size = output.size();
// New FIL expects buffer of size samples * num_classes for multi-class
// classifiers, but output buffer may be smaller, so we need a temporary
// buffer
auto const num_classes = tl_model_->num_classes();
if (!predict_proba && tl_model_->config().output_class &&
num_classes > 1) {
output_size = samples * num_classes;
if (output_size != output.size()) {
// If expected output size is not the same as the size of `output`,
// create a temporary buffer of the correct size
output_buffer =
rapids::Buffer<float>{output_size, rapids::HostMemory};
}
}

// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
// input buffer
new_fil_model_->predict(
raft_proto::handle_t{}, output_buffer.data(),
const_cast<float*>(input.data()), samples,
get_raft_proto_device_type(output.mem_type()),
get_raft_proto_device_type(input.mem_type()),
filex::infer_kind::default_kind);

if (!predict_proba && tl_model_->config().output_class &&
num_classes > 1) {
class_encoder_.argmax_for_multiclass(
output, output_buffer, samples, num_classes);
} else if (
!predict_proba && tl_model_->config().output_class &&
num_classes == 1) {
class_encoder_.threshold_inplace(
output, samples, tl_model_->config().threshold);
}
} else {
tl_model_->predict(output, input, samples, predict_proba);
}
}


private:
std::shared_ptr<TreeliteModel> tl_model_;
device_id_t device_id_;
// TODO(hcho3): Make filex::forest_model::predict() a const method
mutable std::optional<filex::forest_model> new_fil_model_;
ClassEncoder<rapids::HostMemory> class_encoder_;
};

}}} // namespace triton::backend::NAMESPACE
49 changes: 49 additions & 0 deletions src/detail/postprocess.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <names.h>

#include <cstddef>
#include <rapids_triton/memory/buffer.hpp>
#include <rapids_triton/memory/types.hpp>

namespace triton { namespace backend { namespace NAMESPACE {

/* This struct defines a unified interface for converting probabilities
* to integer class outputs, on CPU and GPU targets. */
template <rapids::MemoryType M>
struct ClassEncoder {
ClassEncoder() = default;
void argmax_for_multiclass(
rapids::Buffer<float>& output, rapids::Buffer<float>& input,
std::size_t samples, std::size_t num_classes) const
{
throw rapids::TritonException(
rapids::Error::Unsupported,
"ClassEncoder invoked with a memory type unsupported by this build");
}
void threshold_inplace(
rapids::Buffer<float>& output, std::size_t samples, float threshold) const
{
throw rapids::TritonException(
rapids::Error::Unsupported,
"ClassEncoder invoked with a memory type unsupported by this build");
}
};

}}} // namespace triton::backend::NAMESPACE
68 changes: 68 additions & 0 deletions src/detail/postprocess_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <detail/postprocess.h>
#include <names.h>

#include <cstddef>
#include <rapids_triton/memory/buffer.hpp>
#include <rapids_triton/memory/types.hpp>

namespace triton { namespace backend { namespace NAMESPACE {

template <>
struct ClassEncoder<rapids::HostMemory> {
ClassEncoder() : nthread_(1) {}
ClassEncoder(int nthread) : nthread_(nthread) {}

void argmax_for_multiclass(
rapids::Buffer<float>& output, rapids::Buffer<float>& input,
std::size_t samples, std::size_t num_classes) const
{
// Perform argmax for multi-class classification
auto* dest = output.data();
auto* src = input.data();
#pragma omp parallel for num_threads(nthread_)
for (std::size_t i = 0; i < samples; ++i) {
float max_prob = 0.0f;
int max_class = 0;
for (std::size_t j = 0; j < num_classes; ++j) {
if (src[i * num_classes + j] > max_prob) {
max_prob = src[i * num_classes + j];
max_class = j;
}
}
dest[i] = max_class;
}
}
void threshold_inplace(
rapids::Buffer<float>& output, std::size_t samples, float threshold) const
{
// Perform thresholding in-place for binary classification
auto* out = output.data();
#pragma omp parallel for num_threads(nthread_)
for (std::size_t i = 0; i < samples; ++i) {
out[i] = (out[i] > threshold) ? 1.0f : 0.0f;
}
}

private:
int nthread_;
};

}}} // namespace triton::backend::NAMESPACE
65 changes: 65 additions & 0 deletions src/detail/postprocess_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <detail/postprocess_gpu.h>
#include <names.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h>

#include <cstddef>
#include <cstdint>

namespace triton { namespace backend { namespace NAMESPACE {

void
ClassEncoder<rapids::DeviceMemory>::argmax_for_multiclass(
rapids::Buffer<float>& output, rapids::Buffer<float>& input,
std::size_t samples, std::size_t num_classes) const
{
// Perform argmax for multi-class classification
thrust::counting_iterator<std::size_t> cnt_iter =
thrust::make_counting_iterator<std::size_t>(0);
thrust::for_each(
thrust::device, cnt_iter, cnt_iter + samples,
[dest = output.data(), src = input.data(),
num_classes] __device__(std::size_t i) {
float max_prob = 0.0f;
int max_class = 0;
for (std::size_t j = 0; j < num_classes; ++j) {
if (src[i * num_classes + j] > max_prob) {
max_prob = src[i * num_classes + j];
max_class = j;
}
}
dest[i] = max_class;
});
}

void
ClassEncoder<rapids::DeviceMemory>::threshold_inplace(
rapids::Buffer<float>& output, std::size_t samples, float threshold) const
{
// Perform thresholding in-place for binary classification
thrust::for_each(
thrust::device, output.data(), output.data() + samples,
[threshold] __device__(float& e) {
return (e > threshold) ? 1.0f : 0.0f;
});
}


}}} // namespace triton::backend::NAMESPACE
Loading