Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast BOOL concats to INT32 #620

Merged
merged 1 commit into from
Jan 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,17 @@ DEFINE_BUILTIN_OP_IMPORTER(Clip)
DEFINE_BUILTIN_OP_IMPORTER(Concat)
{
std::vector<nvinfer1::ITensor*> tensors;
// Cast boolean inputs to INT32
bool isBool = false;
for (auto& input : inputs)
{
// TRT does not support BOOL input types for this node
ASSERT(!input.isBool(), ErrorCode::kUNSUPPORTED_NODE);
tensors.push_back(&convertToTensor(input, ctx));
auto* tensorPtr = &convertToTensor(input, ctx);
if (tensorPtr->getType() == nvinfer1::DataType::kBOOL)
{
tensorPtr = castHelper(ctx, tensorPtr, nvinfer1::DataType::kINT32);
isBool = true;
}
tensors.push_back(tensorPtr);
}
OnnxAttrs attrs(node, ctx);
int axis = attrs.get<int>("axis");
Expand All @@ -368,6 +374,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Concat)
ctx->registerLayer(layer, node.name());
ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE);
layer->setAxis(axis);

if (isBool)
{
return {{castHelper(ctx, layer->getOutput(0), nvinfer1::DataType::kBOOL)}};
}

RETURN_FIRST_OUTPUT(layer);
}

Expand Down
7 changes: 7 additions & 0 deletions onnx2trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors)
return true;
}

nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype)
{
nvinfer1::IIdentityLayer* cast = ctx->network()->addIdentity(*input);
cast->setOutputType(0, dtype);
return cast->getOutput(0);
}

nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape)
{
int rank = shape->getDimensions().d[0];
Expand Down
3 changes: 3 additions & 0 deletions onnx2trt_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1:
// Helper function to check that linear resize can be used
bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors);

// Helper function to add a Cast layer in the network
nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype);

// Helper function for constantOfShape operator. Input shape must be a shape tensor
nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape);

Expand Down