Skip to content

Commit

Permalink
Cast BOOL concats to INT32 (#620)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
kevinch-nv authored Jan 11, 2021
1 parent 21f62e9 commit 17c6d89
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
18 changes: 15 additions & 3 deletions builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,17 @@ DEFINE_BUILTIN_OP_IMPORTER(Clip)
DEFINE_BUILTIN_OP_IMPORTER(Concat)
{
std::vector<nvinfer1::ITensor*> tensors;
// Cast boolean inputs to INT32
bool isBool = false;
for (auto& input : inputs)
{
// TRT does not support BOOL input types for this node
ASSERT(!input.isBool(), ErrorCode::kUNSUPPORTED_NODE);
tensors.push_back(&convertToTensor(input, ctx));
auto* tensorPtr = &convertToTensor(input, ctx);
if (tensorPtr->getType() == nvinfer1::DataType::kBOOL)
{
tensorPtr = castHelper(ctx, tensorPtr, nvinfer1::DataType::kINT32);
isBool = true;
}
tensors.push_back(tensorPtr);
}
OnnxAttrs attrs(node, ctx);
int axis = attrs.get<int>("axis");
Expand All @@ -368,6 +374,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Concat)
ctx->registerLayer(layer, node.name());
ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE);
layer->setAxis(axis);

if (isBool)
{
return {{castHelper(ctx, layer->getOutput(0), nvinfer1::DataType::kBOOL)}};
}

RETURN_FIRST_OUTPUT(layer);
}

Expand Down
7 changes: 7 additions & 0 deletions onnx2trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors)
return true;
}

nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype)
{
nvinfer1::IIdentityLayer* cast = ctx->network()->addIdentity(*input);
cast->setOutputType(0, dtype);
return cast->getOutput(0);
}

nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape)
{
int rank = shape->getDimensions().d[0];
Expand Down
3 changes: 3 additions & 0 deletions onnx2trt_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1:
// Helper function to check that linear resize can be used
bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors);

// Helper function to add a Cast layer in the network
nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype);

// Helper function for constantOfShape operator. Input shape must be a shape tensor
nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape);

Expand Down

0 comments on commit 17c6d89

Please sign in to comment.