From de7d9fc858410065720d268e028b61843c5b0a62 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Mon, 11 Apr 2022 15:00:48 -0700 Subject: [PATCH] Fix GEMM importer Signed-off-by: Kevin Chen --- builtin_op_importers.cpp | 96 +++------------------------------------- trt_utils.hpp | 25 ----------- 2 files changed, 7 insertions(+), 114 deletions(-) diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index feea620a..69778047 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -1523,87 +1523,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm) bool transA = attrs.get("transA", false); bool transB = attrs.get("transB", false); nvinfer1::ITensor& inputA = convertToTensor(inputs.at(0), ctx); + nvinfer1::ITensor& inputB = convertToTensor(inputs.at(1), ctx); // Validate inputs - ASSERT(inputs.at(0).shape().nbDims == 2 && inputs.at(1).shape().nbDims == 2 && "GEMM must have 2D inputs!", ErrorCode::kINVALID_NODE); + ASSERT(inputA.getDimensions().nbDims == 2 && inputB.getDimensions().nbDims == 2 && "GEMM must have 2D inputs!", ErrorCode::kINVALID_NODE); // TRT does not support INT32 input types for this node - ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32() - && "TensorRT doesn't support INT32 inputs for GEMM!", ErrorCode::kUNSUPPORTED_NODE); - // Use FC if it is likely to be faster - which is usually when no Shuffles are required. - bool canUseFC = inputs.at(0).is_tensor() && inputs.at(1).is_weights() && alpha == 1.f - && beta == 1.f && inputs.at(0).tensor().getDimensions().nbDims == 2 && inputs.at(1).weights().shape.nbDims == 2; - canUseFC &= inputs.size() < 3 || (inputs.at(2).is_weights() && inputs.at(2).weights().shape.nbDims == 1); - if (canUseFC) - { - LOG_VERBOSE("GEMM: using FC layer instead of MM because all criteria were met."); - const std::vector axesInput{2, 3}; - nvinfer1::ITensor* inputAExtendDim = unsqueezeTensor(ctx, node, inputA, axesInput); - - ShapedWeights weights = inputs.at(1).weights(); - if (!transB) - { - auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape); - ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights, ctx), ErrorCode::kUNSUPPORTED_NODE); - weights = transposedWeights; - } - ShapedWeights biases{}; - if (inputs.size() > 2) - { - biases = inputs.at(2).weights(); - } - nvinfer1::IFullyConnectedLayer* fc = ctx->network()->addFullyConnected(*inputAExtendDim, biases.shape.d[0], weights, biases); - // Register layer, along with refittable kernel weights and bias weights (if any) - ctx->registerLayer(fc, getNodeName(node)); - ctx->network()->setWeightsName(weights, weights.getName()); - if (inputs.size() == 3) - { - ctx->network()->setWeightsName(biases, inputs.at(2).weights().getName()); - } - const std::vector axesOutput{2, 3}; - return {{squeezeTensor(ctx, node, *fc->getOutput(0), axesOutput)}}; - } - - nvinfer1::ITensor* inputB {nullptr}; - - // If input B is a constant, we transpose at parse time if necessary, - // because In some cases, A * Bt is much slower than A * B. - if (inputs.at(1).is_weights()) - { - ShapedWeights weights = inputs.at(1).weights(); - if (transB) - { - auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape); - ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights, ctx) && "Failed to transpose input tensor B.", ErrorCode::kUNSUPPORTED_NODE); - weights = transposedWeights; - // Since we've already transposed now, we can set transpose to false. - transB = false; - } - nvinfer1::IConstantLayer* weightsLayer - = ctx->network()->addConstant(weights.shape, static_cast(weights)); - // Map the constant layer to the weights name. - ctx->registerLayer(weightsLayer, node.input(1)); - ctx->network()->setWeightsName(weights, weights.getName()); - inputB = weightsLayer->getOutput(0); - } - else - { - inputB = &inputs.at(1).tensor(); - } - - nvinfer1::ITensor* inputASqueezed = &inputA; - nvinfer1::Dims newDims = squeeze_trailing_dims(inputA.getDimensions()); - // When A has more than 2 dimensions, it needs to be flattened. - if (newDims.nbDims > 2) - { - newDims = nvinfer1::Dims{1, {-1}}; - } - // Due to other TRT layers, inputA may sometimes have trailing 1s that need to be removed. - if (newDims.nbDims < inputA.getDimensions().nbDims) - { - nvinfer1::IShuffleLayer* squeeze = ctx->network()->addShuffle(inputA); - squeeze->setReshapeDimensions(newDims); - squeeze->setZeroIsPlaceholder(false); - inputASqueezed = squeeze->getOutput(0); - } + ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32() && "TensorRT doesn't support INT32 inputs for GEMM!", + ErrorCode::kUNSUPPORTED_NODE); const auto getMatrixOp = [](const nvinfer1::ITensor& input, bool transpose) { if (input.getDimensions().nbDims == 1) @@ -1617,13 +1542,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm) return nvinfer1::MatrixOperation::kNONE; }; - nvinfer1::MatrixOperation opA = getMatrixOp(*inputASqueezed, transA); - nvinfer1::MatrixOperation opB = getMatrixOp(*inputB, transB); + nvinfer1::MatrixOperation opA = getMatrixOp(inputA, transA); + nvinfer1::MatrixOperation opB = getMatrixOp(inputB, transB); LOG_VERBOSE("Using opA: " << static_cast(opA) << " opB: " << static_cast(opB)); - LOG_VERBOSE("GEMM: A, after squeezing: " << inputASqueezed->getDimensions()); - nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(*inputASqueezed, opA, *inputB, opB); + nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(inputA, opA, inputB, opB); ctx->registerLayer(matmul, getNodeName(node)); nvinfer1::ITensor* matmulTensor = matmul->getOutput(0); @@ -1655,12 +1579,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm) *betaConstantTensor, *biasTensor, nvinfer1::ElementWiseOperation::kPROD); biasTensor = scaledBias->getOutput(0); } - // A*B may be lower rank than C in TRT, so need to squeeze C. - if (ctx->getOpsetVersion() < 7 && !attrs.get("broadcast", false)) - { - nvinfer1::Dims squeezeDims = squeeze_leading_dims(biasTensor->getDimensions()); - biasTensor = reshapeTensor(ctx, *biasTensor, squeezeDims); - } CHECK(broadcastTensors(ctx, matmulTensor, biasTensor)); nvinfer1::IElementWiseLayer* biasAdd = ctx->network()->addElementWise(*matmulTensor, *biasTensor, nvinfer1::ElementWiseOperation::kSUM); diff --git a/trt_utils.hpp b/trt_utils.hpp index 25ad0a7c..1d3a301d 100644 --- a/trt_utils.hpp +++ b/trt_utils.hpp @@ -102,31 +102,6 @@ inline nvinfer1::Permutation remove_first_dim(nvinfer1::Permutation const& perm) return new_perm; } -inline nvinfer1::Dims squeeze_trailing_dims(nvinfer1::Dims const& dims) -{ - nvinfer1::Dims new_dims = dims; - // Note: TRT requires at least one dimension, so we don't squeeze [1]->[] - while (new_dims.nbDims > 1 && new_dims.d[new_dims.nbDims - 1] == 1) - { - --new_dims.nbDims; - } - return new_dims; -} - -inline nvinfer1::Dims squeeze_leading_dims(const nvinfer1::Dims& dims) -{ - nvinfer1::Dims newDims; - // Copy dims only if a non-1 has been seen already. - bool non1Seen{false}; - newDims.nbDims = std::copy_if(dims.d, dims.d + dims.nbDims, newDims.d, - [&non1Seen](int x) { - non1Seen = (x != 1) ? true : non1Seen; - return non1Seen; - }) - - newDims.d; - return newDims; -} - inline nvinfer1::DimsHW operator-(nvinfer1::DimsHW dims) { return nvinfer1::DimsHW(-dims.h(), -dims.w());