Skip to content

Commit

Permalink
fix(tensorrt): update tensorrt code of tensorrt_yolo
Browse files Browse the repository at this point in the history
Signed-off-by: M. Fatih Cırıt <[email protected]>
  • Loading branch information
xmfcx committed Nov 18, 2022
1 parent 25c1836 commit d756050
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
4 changes: 3 additions & 1 deletion perception/tensorrt_yolo/lib/include/trt_yolo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ class Net
cuda::unique_ptr<float[]> out_scores_d_ = nullptr;
cuda::unique_ptr<float[]> out_boxes_d_ = nullptr;
cuda::unique_ptr<float[]> out_classes_d_ = nullptr;
std::string name_tensor_in_;
std::string name_tensor_out_;

void load(const std::string & path);
bool prepare();
std::vector<float> preprocess(
const cv::Mat & in_img, const int c, const int h, const int w) const;
// Infer using pre-allocated GPU buffers {data, scores, boxes}
void infer(std::vector<void *> & buffers, const int batch_size);
void infer(const int batch_size);
};

bool set_cuda_device(int gpu_id)
Expand Down
26 changes: 15 additions & 11 deletions perception/tensorrt_yolo/lib/src/trt_yolo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ Net::Net(const std::string & path, bool verbose)
std::cout << "Fail to prepare engine" << std::endl;
return;
}
name_tensor_in_ = engine_->getIOTensorName(0);
name_tensor_out_ = engine_->getIOTensorName(engine_->getNbIOTensors() - 1);
}

Net::~Net()
Expand Down Expand Up @@ -267,6 +269,8 @@ Net::Net(
std::cout << "Fail to create engine" << std::endl;
return;
}
name_tensor_in_ = engine_->getIOTensorName(0);
name_tensor_out_ = engine_->getIOTensorName(engine_->getNbIOTensors() - 1);
}

void Net::save(const std::string & path) const
Expand All @@ -276,15 +280,16 @@ void Net::save(const std::string & path) const
file.write(reinterpret_cast<const char *>(plan_->data()), plan_->size());
}

void Net::infer(std::vector<void *> & buffers, const int batch_size)
void Net::infer(const int batch_size)
{
if (!context_) {
throw std::runtime_error("Fail to create context");
}
auto input_dims = engine_->getBindingDimensions(0);
context_->setBindingDimensions(
0, nvinfer1::Dims4(batch_size, input_dims.d[1], input_dims.d[2], input_dims.d[3]));
context_->enqueueV2(buffers.data(), stream_, nullptr);
const auto input_dims = engine_->getTensorShape(name_tensor_in_.c_str());
context_->setInputShape(
name_tensor_in_.c_str(),
nvinfer1::Dims4(batch_size, input_dims.d[1], input_dims.d[2], input_dims.d[3]));
context_->enqueueV3(stream_);
cudaStreamSynchronize(stream_);
}

Expand All @@ -294,10 +299,8 @@ bool Net::detect(const cv::Mat & in_img, float * out_scores, float * out_boxes,
const auto input = preprocess(in_img, input_dims.at(0), input_dims.at(2), input_dims.at(1));
CHECK_CUDA_ERROR(
cudaMemcpy(input_d_.get(), input.data(), input.size() * sizeof(float), cudaMemcpyHostToDevice));
std::vector<void *> buffers = {
input_d_.get(), out_scores_d_.get(), out_boxes_d_.get(), out_classes_d_.get()};
try {
infer(buffers, 1);
infer(1);
} catch (const std::runtime_error & e) {
return false;
}
Expand All @@ -316,13 +319,14 @@ bool Net::detect(const cv::Mat & in_img, float * out_scores, float * out_boxes,

std::vector<int> Net::getInputDims() const
{
auto dims = engine_->getBindingDimensions(0);
const auto dims = engine_->getTensorShape(name_tensor_in_.c_str());
return {dims.d[1], dims.d[2], dims.d[3]};
}

int Net::getMaxBatchSize() const
{
return engine_->getProfileDimensions(0, 0, nvinfer1::OptProfileSelector::kMAX).d[0];
return engine_->getProfileShape(name_tensor_in_.c_str(), 0, nvinfer1::OptProfileSelector::kMAX)
.d[0];
}

int Net::getInputSize() const
Expand All @@ -333,6 +337,6 @@ int Net::getInputSize() const
return input_size;
}

int Net::getMaxDetections() const { return engine_->getBindingDimensions(1).d[1]; }
int Net::getMaxDetections() const { return engine_->getTensorShape(name_tensor_out_.c_str()).d[1]; }

} // namespace yolo

0 comments on commit d756050

Please sign in to comment.