Skip to content

Commit

Permalink
Fix dynamic argmax/min when select_last_index is set
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinch-nv committed Apr 5, 2022
1 parent b4a7461 commit 7fddcaa
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions onnx2trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,25 @@ NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::
// We don't care about the TopK values, just the indices.
nvinfer1::ITensor* indices = layer->getOutput(1);
indices->setType(nvinfer1::DataType::kINT32);

// If selectLastIndex is true, the TopK operation was performed on reversed data on the provided axis.
// Convert reversed indices back to forward indices by calculating the following:
// indices = shape(tensor)[axis] - indices - 1
if (selectLastIndex)
{
nvinfer1::Dims dims = tensor.getDimensions();
int dimOnAxis = dims.d[axis];
nvinfer1::Dims resultShape(dims);
resultShape.d[axis] = 1;
ShapedWeights shapeWeights = ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, resultShape);
std::vector<int> tempData(shapeWeights.count(), dimOnAxis);
std::memcpy(shapeWeights.values, tempData.data(), shapeWeights.count() * sizeof(int));
// Use shapeTensor semantics to support dynamic shapes
auto const dims = shapeOf(tensor);
auto const indicesDims = shapeOf(*indices);
auto const axisTensor = shapeVector(axis);
auto const dimOnAxis = gather(ctx, dims, axisTensor);

// Create constant of shape indicesDims with values tensor.shape[axis]
auto const tensorDimOnAxis = constantOfShape(ctx, node, &dimOnAxis.tensor(ctx), &indicesDims.tensor(ctx));

ShapedWeights weightOfOnes = ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, resultShape);
std::vector<int> ones(shapeWeights.count(), 1);
std::memcpy(weightOfOnes.values, ones.data(), weightOfOnes.count() * sizeof(int));
// Create constant of shape indicesDims with values of 1
auto const ones = constantOfShape(ctx, node, &shapeVector(1).tensor(ctx), &indicesDims.tensor(ctx));

std::vector<TensorOrWeights> newInputs{shapeWeights, indices, weightOfOnes};
std::vector<TensorOrWeights> newInputs{tensorDimOnAxis, indices, ones};
indices = &elementwiseHelper(ctx, node, newInputs, nvinfer1::ElementWiseOperation::kSUB).value().at(0).tensor();
}
if (keepdims)
Expand Down

0 comments on commit 7fddcaa

Please sign in to comment.