From dd9f6fc880c6dc4acc9eaf99b7b1d2f4b094f816 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Mon, 11 Jan 2021 11:17:54 -0800 Subject: [PATCH] Cast BOOL concats to INT32 Signed-off-by: Kevin Chen --- builtin_op_importers.cpp | 18 +++++++++++++++--- onnx2trt_utils.cpp | 7 +++++++ onnx2trt_utils.hpp | 3 +++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index 401db0a1..899d3fc3 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -354,11 +354,17 @@ DEFINE_BUILTIN_OP_IMPORTER(Clip) DEFINE_BUILTIN_OP_IMPORTER(Concat) { std::vector tensors; + // Cast boolean inputs to INT32 + bool isBool = false; for (auto& input : inputs) { - // TRT does not support BOOL input types for this node - ASSERT(!input.isBool(), ErrorCode::kUNSUPPORTED_NODE); - tensors.push_back(&convertToTensor(input, ctx)); + auto* tensorPtr = &convertToTensor(input, ctx); + if (tensorPtr->getType() == nvinfer1::DataType::kBOOL) + { + tensorPtr = castHelper(ctx, tensorPtr, nvinfer1::DataType::kINT32); + isBool = true; + } + tensors.push_back(tensorPtr); } OnnxAttrs attrs(node, ctx); int axis = attrs.get("axis"); @@ -368,6 +374,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Concat) ctx->registerLayer(layer, node.name()); ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE); layer->setAxis(axis); + + if (isBool) + { + return {{castHelper(ctx, layer->getOutput(0), nvinfer1::DataType::kBOOL)}}; + } + RETURN_FIRST_OUTPUT(layer); } diff --git a/onnx2trt_utils.cpp b/onnx2trt_utils.cpp index 69cd90d8..491ee95e 100644 --- a/onnx2trt_utils.cpp +++ b/onnx2trt_utils.cpp @@ -152,6 +152,13 @@ bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors) return true; } +nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype) +{ + nvinfer1::IIdentityLayer* cast = ctx->network()->addIdentity(*input); + cast->setOutputType(0, dtype); + return cast->getOutput(0); +} + nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape) { int rank = shape->getDimensions().d[0]; diff --git a/onnx2trt_utils.hpp b/onnx2trt_utils.hpp index c189ace7..a1fc7692 100644 --- a/onnx2trt_utils.hpp +++ b/onnx2trt_utils.hpp @@ -154,6 +154,9 @@ Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1: // Helper function to check that linear resize can be used bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors); +// Helper function to add a Cast layer in the network +nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype); + // Helper function for constantOfShape operator. Input shape must be a shape tensor nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape);