From 7fddcaa145904b664ade781ed38da6b0a0afc336 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Tue, 5 Apr 2022 12:04:24 -0700 Subject: [PATCH] Fix dynamic argmax/min when select_last_index is set --- onnx2trt_utils.cpp | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/onnx2trt_utils.cpp b/onnx2trt_utils.cpp index 9138a2c4..593c450a 100644 --- a/onnx2trt_utils.cpp +++ b/onnx2trt_utils.cpp @@ -108,21 +108,25 @@ NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE:: // We don't care about the TopK values, just the indices. nvinfer1::ITensor* indices = layer->getOutput(1); indices->setType(nvinfer1::DataType::kINT32); + + // If selectLastIndex is true, the TopK operation was performed on reversed data on the provided axis. + // Convert reversed indices back to forward indices by calculating the following: + // indices = shape(tensor)[axis] - indices - 1 if (selectLastIndex) { - nvinfer1::Dims dims = tensor.getDimensions(); - int dimOnAxis = dims.d[axis]; - nvinfer1::Dims resultShape(dims); - resultShape.d[axis] = 1; - ShapedWeights shapeWeights = ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, resultShape); - std::vector tempData(shapeWeights.count(), dimOnAxis); - std::memcpy(shapeWeights.values, tempData.data(), shapeWeights.count() * sizeof(int)); + // Use shapeTensor semantics to support dynamic shapes + auto const dims = shapeOf(tensor); + auto const indicesDims = shapeOf(*indices); + auto const axisTensor = shapeVector(axis); + auto const dimOnAxis = gather(ctx, dims, axisTensor); + + // Create constant of shape indicesDims with values tensor.shape[axis] + auto const tensorDimOnAxis = constantOfShape(ctx, node, &dimOnAxis.tensor(ctx), &indicesDims.tensor(ctx)); - ShapedWeights weightOfOnes = ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, resultShape); - std::vector ones(shapeWeights.count(), 1); - std::memcpy(weightOfOnes.values, ones.data(), weightOfOnes.count() * sizeof(int)); + // Create constant of shape indicesDims with values of 1 + auto const ones = constantOfShape(ctx, node, &shapeVector(1).tensor(ctx), &indicesDims.tensor(ctx)); - std::vector newInputs{shapeWeights, indices, weightOfOnes}; + std::vector newInputs{tensorDimOnAxis, indices, ones}; indices = &elementwiseHelper(ctx, node, newInputs, nvinfer1::ElementWiseOperation::kSUB).value().at(0).tensor(); } if (keepdims)