Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][TensorRT] Add TensorRT own int8 calibration support to TensorRT BYOC integration #8808

Merged
merged 60 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
2375c0b
update trt
Aug 20, 2021
66c5fe2
Merge https://github.com/apache/tvm into pr_trt_int8
Aug 23, 2021
9113659
clean codes
Aug 24, 2021
6671366
tetsing running trt
Aug 24, 2021
99c0a57
clean data
Aug 24, 2021
525af93
clean codes?
Aug 26, 2021
0eda372
remove env func
Aug 26, 2021
7ec0586
fix num_bings
Aug 26, 2021
a39a5a1
add buildfromjson func
Aug 26, 2021
eada412
change condition
Aug 27, 2021
0a22eff
reset input and output func
Aug 28, 2021
3f9fec2
re-config func
Aug 30, 2021
7f7343c
re-added trt version check
Aug 30, 2021
c77b433
checking sanity
Aug 30, 2021
8566cc6
try to fix sanity issue
Aug 30, 2021
7a1f3ff
checking sainity
Aug 30, 2021
8914349
fixing sanity issue
Aug 30, 2021
cb8fe8f
fixing sainity issue
Aug 30, 2021
6aa6051
fixing sanity
Aug 30, 2021
f51ba11
clang format fixed
Aug 30, 2021
19e151b
clang format fixing
Aug 30, 2021
3b6ef10
clean trt cali
Aug 30, 2021
ecb43e0
try to fix clang format
Aug 30, 2021
17bb566
fixed some comments
Sep 1, 2021
dbd1594
remove double destroy engine codes
Sep 2, 2021
411504f
modify comments
Sep 2, 2021
9ec455e
add checking function
Sep 2, 2021
55ead8b
add trt int8 test
Sep 7, 2021
c613c45
update trt int8 test file
Sep 7, 2021
0afbcc7
Update test_tensorrt_int8_exp.py
tiandiao123 Sep 7, 2021
2640ab3
update trt int8 fikle
Sep 8, 2021
06b9f77
update trt int8 test file
Sep 8, 2021
2e4293a
change a little
Sep 8, 2021
7282762
upate trt int8 file
Sep 8, 2021
5645dad
upate trt int8 file
Sep 8, 2021
9756f7b
fixing ci
Sep 8, 2021
fddbd43
fixing ci
Sep 8, 2021
4b56ac8
fixing ci
Sep 8, 2021
06797a0
fixing ci
Sep 8, 2021
947d22d
fixing ci
Sep 8, 2021
d77dda0
fixing ci issue
Sep 8, 2021
9a515f8
fixing ci issue
Sep 8, 2021
53802b5
fixing ci
Sep 8, 2021
ccb74c7
fixing ci issue
Sep 8, 2021
c1f0faf
fixing ci
Sep 8, 2021
81761b4
fixing ci problem
Sep 8, 2021
bf30e8e
fixing ci
Sep 8, 2021
63512ad
upate trt python int8 test file
Sep 8, 2021
d3ac8c9
fixed ci
Sep 8, 2021
ac58979
fixed ci
Sep 8, 2021
90eabe1
fix gpu build
Sep 8, 2021
89c8eeb
fixed ci
Sep 8, 2021
30648d6
update trt int8 test file
Sep 8, 2021
3c85b9a
fix bug
Sep 8, 2021
41be39a
fix bug
Sep 8, 2021
c7d2bcc
update trtint8 file
Sep 8, 2021
2d8ce42
reformat
Sep 8, 2021
63cf965
update trt int8 file
Sep 8, 2021
5685800
update
Sep 8, 2021
c0e931d
modify
Sep 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/runtime/contrib/tensorrt/tensorrt_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ namespace contrib {
TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
const std::vector<const DLTensor*>& data_entry,
size_t max_workspace_size, bool use_implicit_batch, bool use_fp16,
int batch_size)
int batch_size, nvinfer1::IInt8Calibrator* calibrator)
: data_entry_(data_entry),
max_workspace_size_(max_workspace_size),
use_implicit_batch_(use_implicit_batch),
use_fp16_(use_fp16),
batch_size_(batch_size) {
// Create TRT builder and network.
builder_ = nvinfer1::createInferBuilder(*logger);
use_int8_ = false;
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

#if TRT_VERSION_GE(6, 0, 1)
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
// Use INetworkV2.
auto flags =
Expand All @@ -56,9 +58,15 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
flags = 0U;
builder_->setMaxBatchSize(batch_size_);
}
this->calibrator_ = calibrator;
if (calibrator != nullptr) {
use_int8_ = true;
builder_->setFp16Mode(true);
builder_->setInt8Mode(true);
builder_->setInt8Calibrator(calibrator);
}
network_ = builder_->createNetworkV2(flags);
#else
// Use INetwork with implicit batch.
builder_->setMaxBatchSize(batch_size_);
builder_->setMaxWorkspaceSize(max_workspace_size_);
builder_->setFp16Mode(use_fp16_);
Expand Down Expand Up @@ -158,6 +166,13 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
if (use_fp16_) {
config_->setFlag(nvinfer1::BuilderFlag::kFP16);
}

if (use_int8_) {
config_->setFlag(nvinfer1::BuilderFlag::kINT8);
config_->setInt8Calibrator(calibrator_);
LOG(INFO) << "config finishes setting up calibrator as INT8 mode ... ";
}

// Add profiles.
if (!use_implicit_batch_) {
auto profile = builder_->createOptimizationProfile();
Expand Down
11 changes: 9 additions & 2 deletions src/runtime/contrib/tensorrt/tensorrt_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class TensorRTBuilder {
* \param batch_size If use_implicit_batch,
*/
TensorRTBuilder(TensorRTLogger* logger, const std::vector<const DLTensor*>& data_entry,
size_t max_workspace_size, bool use_implicit_batch, bool use_fp16,
int batch_size);
size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, int batch_size,
nvinfer1::IInt8Calibrator* calibrator = nullptr);

/*!
* \brief Add TensorRT input(s) for input node in network definition.
Expand Down Expand Up @@ -153,6 +153,9 @@ class TensorRTBuilder {
/*! \brief Whether to automatically convert model to 16-bit floating point precision. */
bool use_fp16_;

/*! \brief whether to automatically convert model to int8 precision */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, use_int8_ is exclusive with use_fp16_? If so, we should combine them to be a single variable like target_dtype.

bool use_int8_;
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

/*! \brief Batch size to optimize for. */
int batch_size_;

Expand All @@ -161,6 +164,10 @@ class TensorRTBuilder {

/*! \brief Output names. */
std::vector<std::string> network_output_names_;

/*! \brief calibrator pointer to add batch data when using int8 mode */
/*! \brief pointer will be nullptr when it is fp16 or fp32 precision */
nvinfer1::IInt8Calibrator* calibrator_;
};

} // namespace contrib
Expand Down
130 changes: 130 additions & 0 deletions src/runtime/contrib/tensorrt/tensorrt_calibrator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/* * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.

* file runtime/contrib/tensorrt/tensorrt_builder.h
* brief Contains TensorRTBuilder class which can be used to convert a relay
* program into a TRT engine which can be used for inference.
*/

#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_
#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_

#include <string>
#include <vector>

#include "../../cuda/cuda_common.h"
#include "NvInfer.h"

namespace tvm {
namespace runtime {

class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 {
public:
TensorRTCalibrator(int batch_size, const std::vector<std::string>& input_names)
: batch_size_(batch_size), num_batches_calibrated_(0), input_names_(input_names) {}

~TensorRTCalibrator() {
// Free calibration data
for (auto& inputs : data_) {
for (size_t i = 0; i < inputs.size(); ++i) {
delete[] inputs[i];
}
}
// Free buffers
for (size_t i = 0; i < buffers_.size(); ++i) {
CUDA_CALL(cudaFree(buffers_[i]));
}
}

void AddBatchData(const std::vector<void*>& bindings, const std::vector<size_t>& binding_sizes) {
// Copy data from GPU
std::vector<float*> data_host(bindings.size(), nullptr);
for (size_t i = 0; i < bindings.size(); ++i) {
data_host[i] = new float[batch_size_ * binding_sizes[i]];
CUDA_CALL(cudaMemcpy(static_cast<void*>(data_host[i]), bindings[i],
batch_size_ * binding_sizes[i] * sizeof(float), cudaMemcpyDeviceToHost));
}
data_.push_back(data_host);
data_sizes_.push_back(binding_sizes);
}

int getBatchSize() const override { return batch_size_; }

/*!
* \brief TensorRT will call this method to get next batch of data to
* calibrate with.
*/
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
AllocateBuffersIfNotAllocated();
CHECK_EQ(input_names_.size(), nbBindings);
for (size_t i = 0; i < input_names_.size(); ++i) {
CHECK_EQ(input_names_[i], names[i]);
CUDA_CALL(cudaMemcpy(buffers_[i], data_[num_batches_calibrated_][i],
batch_size_ * data_sizes_[num_batches_calibrated_][i] * sizeof(float),
cudaMemcpyHostToDevice));
bindings[i] = buffers_[i];
}
num_batches_calibrated_++;
// TODO(trevmorr): Free data from previous batch?
return (num_batches_calibrated_ < data_.size());
}

const void* readCalibrationCache(size_t& length) override {
if (calibration_cache_.empty()) return nullptr;
length = calibration_cache_.size();
return calibration_cache_.data();
}

void writeCalibrationCache(const void* cache, size_t length) override {
calibration_cache_.assign(static_cast<const char*>(cache), length);
}

private:
/*! \brief Batch size. */
int batch_size_;
/*! \brief Number of batches already fed to calibrator. */
int num_batches_calibrated_;
/*! \brief Storage for calibration cache. */
std::string calibration_cache_;

/*! \brief Data to be used for calibration. */
std::vector<std::vector<float*>> data_;
/*! \brief Number of elements for data to be used for calibration. */
std::vector<std::vector<size_t>> data_sizes_;

/*! \brief Device buffers to be used for calibration. */
std::vector<void*> buffers_;

/*! \brief Names of inputs */
const std::vector<std::string> input_names_;

/*! \brief Allocate device memory buffers. data_sizes_ must already have one
* entry. */
void AllocateBuffersIfNotAllocated() {
if (!buffers_.empty()) return;
CHECK_GE(data_sizes_.size(), 1);
const int num_inputs = data_sizes_[0].size();
buffers_.assign(num_inputs, nullptr);
for (int i = 0; i < num_inputs; ++i) {
CUDA_CALL(cudaMalloc(&buffers_[i], data_sizes_[0][i] * sizeof(float)));
}
}
};

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_
Loading