Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GEMM importer #828

Merged
merged 1 commit into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 7 additions & 89 deletions builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1523,87 +1523,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
bool transA = attrs.get("transA", false);
bool transB = attrs.get("transB", false);
nvinfer1::ITensor& inputA = convertToTensor(inputs.at(0), ctx);
nvinfer1::ITensor& inputB = convertToTensor(inputs.at(1), ctx);
// Validate inputs
ASSERT(inputs.at(0).shape().nbDims == 2 && inputs.at(1).shape().nbDims == 2 && "GEMM must have 2D inputs!", ErrorCode::kINVALID_NODE);
ASSERT(inputA.getDimensions().nbDims == 2 && inputB.getDimensions().nbDims == 2 && "GEMM must have 2D inputs!", ErrorCode::kINVALID_NODE);
// TRT does not support INT32 input types for this node
ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32()
&& "TensorRT doesn't support INT32 inputs for GEMM!", ErrorCode::kUNSUPPORTED_NODE);
// Use FC if it is likely to be faster - which is usually when no Shuffles are required.
bool canUseFC = inputs.at(0).is_tensor() && inputs.at(1).is_weights() && alpha == 1.f
&& beta == 1.f && inputs.at(0).tensor().getDimensions().nbDims == 2 && inputs.at(1).weights().shape.nbDims == 2;
canUseFC &= inputs.size() < 3 || (inputs.at(2).is_weights() && inputs.at(2).weights().shape.nbDims == 1);
if (canUseFC)
{
LOG_VERBOSE("GEMM: using FC layer instead of MM because all criteria were met.");
const std::vector<int> axesInput{2, 3};
nvinfer1::ITensor* inputAExtendDim = unsqueezeTensor(ctx, node, inputA, axesInput);

ShapedWeights weights = inputs.at(1).weights();
if (!transB)
{
auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape);
ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights, ctx), ErrorCode::kUNSUPPORTED_NODE);
weights = transposedWeights;
}
ShapedWeights biases{};
if (inputs.size() > 2)
{
biases = inputs.at(2).weights();
}
nvinfer1::IFullyConnectedLayer* fc = ctx->network()->addFullyConnected(*inputAExtendDim, biases.shape.d[0], weights, biases);
// Register layer, along with refittable kernel weights and bias weights (if any)
ctx->registerLayer(fc, getNodeName(node));
ctx->network()->setWeightsName(weights, weights.getName());
if (inputs.size() == 3)
{
ctx->network()->setWeightsName(biases, inputs.at(2).weights().getName());
}
const std::vector<int> axesOutput{2, 3};
return {{squeezeTensor(ctx, node, *fc->getOutput(0), axesOutput)}};
}

nvinfer1::ITensor* inputB {nullptr};

// If input B is a constant, we transpose at parse time if necessary,
// because In some cases, A * Bt is much slower than A * B.
if (inputs.at(1).is_weights())
{
ShapedWeights weights = inputs.at(1).weights();
if (transB)
{
auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape);
ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights, ctx) && "Failed to transpose input tensor B.", ErrorCode::kUNSUPPORTED_NODE);
weights = transposedWeights;
// Since we've already transposed now, we can set transpose to false.
transB = false;
}
nvinfer1::IConstantLayer* weightsLayer
= ctx->network()->addConstant(weights.shape, static_cast<nvinfer1::Weights>(weights));
// Map the constant layer to the weights name.
ctx->registerLayer(weightsLayer, node.input(1));
ctx->network()->setWeightsName(weights, weights.getName());
inputB = weightsLayer->getOutput(0);
}
else
{
inputB = &inputs.at(1).tensor();
}

nvinfer1::ITensor* inputASqueezed = &inputA;
nvinfer1::Dims newDims = squeeze_trailing_dims(inputA.getDimensions());
// When A has more than 2 dimensions, it needs to be flattened.
if (newDims.nbDims > 2)
{
newDims = nvinfer1::Dims{1, {-1}};
}
// Due to other TRT layers, inputA may sometimes have trailing 1s that need to be removed.
if (newDims.nbDims < inputA.getDimensions().nbDims)
{
nvinfer1::IShuffleLayer* squeeze = ctx->network()->addShuffle(inputA);
squeeze->setReshapeDimensions(newDims);
squeeze->setZeroIsPlaceholder(false);
inputASqueezed = squeeze->getOutput(0);
}
ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32() && "TensorRT doesn't support INT32 inputs for GEMM!",
ErrorCode::kUNSUPPORTED_NODE);

const auto getMatrixOp = [](const nvinfer1::ITensor& input, bool transpose) {
if (input.getDimensions().nbDims == 1)
Expand All @@ -1617,13 +1542,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
return nvinfer1::MatrixOperation::kNONE;
};

nvinfer1::MatrixOperation opA = getMatrixOp(*inputASqueezed, transA);
nvinfer1::MatrixOperation opB = getMatrixOp(*inputB, transB);
nvinfer1::MatrixOperation opA = getMatrixOp(inputA, transA);
nvinfer1::MatrixOperation opB = getMatrixOp(inputB, transB);

LOG_VERBOSE("Using opA: " << static_cast<int>(opA) << " opB: " << static_cast<int>(opB));
LOG_VERBOSE("GEMM: A, after squeezing: " << inputASqueezed->getDimensions());

nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(*inputASqueezed, opA, *inputB, opB);
nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(inputA, opA, inputB, opB);
ctx->registerLayer(matmul, getNodeName(node));
nvinfer1::ITensor* matmulTensor = matmul->getOutput(0);

Expand Down Expand Up @@ -1655,12 +1579,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
*betaConstantTensor, *biasTensor, nvinfer1::ElementWiseOperation::kPROD);
biasTensor = scaledBias->getOutput(0);
}
// A*B may be lower rank than C in TRT, so need to squeeze C.
if (ctx->getOpsetVersion() < 7 && !attrs.get("broadcast", false))
{
nvinfer1::Dims squeezeDims = squeeze_leading_dims(biasTensor->getDimensions());
biasTensor = reshapeTensor(ctx, *biasTensor, squeezeDims);
}
CHECK(broadcastTensors(ctx, matmulTensor, biasTensor));
nvinfer1::IElementWiseLayer* biasAdd
= ctx->network()->addElementWise(*matmulTensor, *biasTensor, nvinfer1::ElementWiseOperation::kSUM);
Expand Down
25 changes: 0 additions & 25 deletions trt_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,6 @@ inline nvinfer1::Permutation remove_first_dim(nvinfer1::Permutation const& perm)
return new_perm;
}

inline nvinfer1::Dims squeeze_trailing_dims(nvinfer1::Dims const& dims)
{
nvinfer1::Dims new_dims = dims;
// Note: TRT requires at least one dimension, so we don't squeeze [1]->[]
while (new_dims.nbDims > 1 && new_dims.d[new_dims.nbDims - 1] == 1)
{
--new_dims.nbDims;
}
return new_dims;
}

inline nvinfer1::Dims squeeze_leading_dims(const nvinfer1::Dims& dims)
{
nvinfer1::Dims newDims;
// Copy dims only if a non-1 has been seen already.
bool non1Seen{false};
newDims.nbDims = std::copy_if(dims.d, dims.d + dims.nbDims, newDims.d,
[&non1Seen](int x) {
non1Seen = (x != 1) ? true : non1Seen;
return non1Seen;
})
- newDims.d;
return newDims;
}

inline nvinfer1::DimsHW operator-(nvinfer1::DimsHW dims)
{
return nvinfer1::DimsHW(-dims.h(), -dims.w());
Expand Down