Skip to content

Commit

Permalink
Update gather elements implementation (onnx#675)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
kevinch-nv committed Jul 2, 2021
1 parent 5ab7d3e commit e5ee2b5
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 58 deletions.
104 changes: 46 additions & 58 deletions builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1206,85 +1206,73 @@ DEFINE_BUILTIN_OP_IMPORTER(Gather)

DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
{

// We can treat GatherElements as a regular Gather operation with transformed input and indices tensors.
// Consider a simple example of a 3D tensor with axis = 1.
// The regular forumla of out[i][j][k] = in[i][idx[i][j][k]][k] can be rewritten as out[i][j][k] = in'[idx'[i,j,k]]
// Where in' is a squeezed down 1D representation of the data and idx' is calculated from the following formula:
// idx' = idx[i,j,k] * pitch[1] + bias. The bias is calculated as i*pitch[0] + k*pitch[2].

// clang-format off
/* Example: Data is 3D tensor of shape [2,2,2] with values [[[1,2], [3,4]], [[5,6], [7,8]]]
Indices is a 3D tensor of shape [2,2,1] with values [[[0], [1]], [[0], [1]]]
From the original formula, the output is [[[1], [3]], [[5], [7]]],
Pitch vector of data is [4,2,1].
idx` calculation:
idx`[0, 0, 0] = [idx[0,0,0]](0) * [pitch[axis]](2) + [i(0)*pitch[0](4)](0) + [k(0)*pitch[2](1)](0) = 0
idx`[0, 1, 0] = [idx[0,1,0]](1) * [pitch[axis]](2) + [i(0)*pitch[0](4)](0) + [k(0)*pitch[2](1)](0) = 2
idx`[1, 0, 0] = [idx[1,0,0]](0) * [pitch[axis]](2) + [i(1)*pitch[0](4)](4) + [k(0)*pitch[2](1)](0) = 4
idx`[1, 1, 0] = [idx[1,1,0]](1) * [pitch[axis]](2) + [i(1)*pitch[0](4)](4) + [k(0)*pitch[2](1)](0) = 6
= [[[0], [2]], [[4], [6]]]
After linearizing data to 1D: [1,2,3,4,5,6,7,8], gathering on axis 0 with the new indices gives the same results.
*/
// clang-format on

nvinfer1::ITensor& data = convertToTensor(inputs.at(0), ctx);
nvinfer1::ITensor& index = convertToTensor(inputs.at(1), ctx);

const nvinfer1::Dims& idxDims = index.getDimensions();
const nvinfer1::Dims& dataDims = data.getDimensions();
const nvinfer1::Dims& daDims = data.getDimensions();

// Note the above tranformation requires dimensions to be known at parse time, so check for dynamic shapes
ASSERT(!isDynamic(daDims) && !isDynamic(idxDims)
&& "This version of TenosrRT does not support GatherElements on dynamic shapes!",
ErrorCode::kUNSUPPORTED_NODE);

OnnxAttrs attrs(node, ctx);
int32_t axis = attrs.get<int32_t>("axis", 0);
int32_t dataNbDims = dataDims.nbDims;
int32_t dataNbDims = daDims.nbDims;

TRT_CHECK(convertAxis(axis, dataNbDims));
LOG_VERBOSE("Using Gather axis: " << axis);

// Calculate how many indices
// Calculate data pitches vector, and create axisPitch vector
int64_t nIndx = volume(idxDims);
std::vector<int32_t> pitches = calculatePitches(daDims);
std::vector<int32_t> axisPitch(nIndx, pitches[axis]);

// Calculate pitches of input tensor
int32_t nDataElements = volume(dataDims), pitch = 1;
int32_t pitches[nvinfer1::Dims::MAX_DIMS] = {0};
pitches[dataDims.nbDims-1] = pitch;
for (int32_t i = dataDims.nbDims-2; i >= 0 ; i--)
{
pitch *= dataDims.d[i];
pitches[i] = pitch;
}

// Generate constants based on axis
std::vector<int32_t> sCoeff(nIndx, pitches[axis]);
std::vector<int32_t> aCoeff;

// Transform a 1-d index back to the nDims
for (int32_t i = 0; i < nIndx; i++)
{
std::vector<int32_t> nDimsIdx; //this can be an array
int32_t currI = i;

for (int32_t j = 0; j < dataDims.nbDims; j++)
{
int32_t currIdxVal = currI / pitches[j];
nDimsIdx.push_back(currIdxVal);
currI = currI % pitches[j];
}

int32_t bias = 0;
//calculate the aCoeff
for (size_t j = 0; j < nDimsIdx.size(); j++)
{

if (j == (size_t)axis)
{
continue;
}
bias += nDimsIdx[j] * pitches[j];
}
aCoeff.push_back(bias);
}

auto* sCoeffLayer = addConstant(ctx, sCoeff, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims);
auto* aCoeffLayer = addConstant(ctx, aCoeff, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims);

nvinfer1::ITensor* sCoeffTensor = sCoeffLayer->getOutput(0);
nvinfer1::ITensor* aCoeffTensor = aCoeffLayer->getOutput(0);
auto* mul = ctx->network()->addElementWise(index, *sCoeffTensor, nvinfer1::ElementWiseOperation::kPROD);

nvinfer1::ITensor* mulTensor = mul->getOutput(0);
auto* add = ctx->network()->addElementWise(*mulTensor, *aCoeffTensor, nvinfer1::ElementWiseOperation::kSUM);
// Calculate bias vector
std::vector<int32_t> biasVector = calculateBias(daDims, idxDims, pitches, axis);

nvinfer1::ITensor* addTensor = add->getOutput(0);
// Perform idx` = idx * pitch[axis] + bias calculation.
auto* axisPitchTensor = addConstant(ctx, axisPitch, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims)->getOutput(0);
auto* biasTensor = addConstant(ctx, biasVector, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims)->getOutput(0);

nvinfer1::Dims flattenDataDims{1};
auto* mul
= ctx->network()->addElementWise(index, *axisPitchTensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
auto* newIndices
= ctx->network()->addElementWise(*mul, *biasTensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);

flattenDataDims.nbDims = 1;
flattenDataDims.d[0] = nDataElements;
nvinfer1::Dims flattenDataDims{1, {static_cast<int32_t>(volume(daDims))}};
auto* reshape = ctx->network()->addShuffle(data);
reshape->setReshapeDimensions(flattenDataDims);
reshape->setZeroIsPlaceholder(false);

nvinfer1::ITensor* flattenData = reshape->getOutput(0);
auto* layer = ctx->network()->addGather(*flattenData, *addTensor, 0);
auto* layer = ctx->network()->addGather(*flattenData, *newIndices, 0);
ctx->registerLayer(layer, getNodeName(node));
RETURN_FIRST_OUTPUT(layer);
}
Expand Down
62 changes: 62 additions & 0 deletions onnx2trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,68 @@ Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1:
return Status::success();
}

// Helper functions for calculateBias:
int32_t getBias(const std::vector<int32_t>& dimension_count, const std::vector<int32_t>& pitches, int32_t axis)
{
int32_t result{0};
for (int32_t i = 0; i < static_cast<int32_t>(dimension_count.size()); i++)
{
if (i != axis)
{
result += dimension_count[i] * pitches[i];
}
}
return result;
}

void incrementOuterDimension(std::vector<int32_t>& dimensionCount, nvinfer1::Dims idxDims)
{
// Start at [x,x,0]. Increment starting from the outer dimension.
int32_t rank = dimensionCount.size();

for (int32_t i = rank - 1; i >= 0; i--)
{
int dimLimit = idxDims.d[i];
// If we're not at the limit, increment current axis and return
if (++dimensionCount[i] != dimLimit)
{
break;
}
// Else, we increment on the next dimension and reset current one
dimensionCount[i] = 0;
}
}

std::vector<int32_t> calculateBias(
const nvinfer1::Dims& daDims, const nvinfer1::Dims& idxDims, const std::vector<int32_t>& pitches, int32_t axis)
{
std::vector<int32_t> biasVector;
std::vector<int32_t> dimensionCount(daDims.nbDims, 0);
int64_t total = volume(idxDims);

for (int64_t i = 0; i < total; i++)
{
int32_t bias = getBias(dimensionCount, pitches, axis);
biasVector.push_back(bias);
incrementOuterDimension(dimensionCount, idxDims);
}
return biasVector;
}

std::vector<int32_t> calculatePitches(const nvinfer1::Dims& inputDims)
{
int32_t pitch = 1;
int32_t nbDims = inputDims.nbDims;
std::vector<int32_t> pitches(nbDims);
pitches[nbDims - 1] = pitch;
for (int32_t i = nbDims - 2; i >= 0; i--)
{
pitch *= inputDims.d[i + 1];
pitches[i] = pitch;
}
return pitches;
}

bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors)
{
// Linear resize supports up to 3D resize on the outermost dimensions.
Expand Down
7 changes: 7 additions & 0 deletions onnx2trt_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1:
// Helper function to broadcast three tensors to the largest one's shape
Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2, nvinfer1::ITensor*& t3);

// Helper function to calculate the bias tensor for GatherElements.
std::vector<int32_t> calculateBias(
const nvinfer1::Dims& daDims, const nvinfer1::Dims& idxDims, const std::vector<int32_t>& pitches, int32_t axis);

// Helper function to calculate and return a vector representation of the pitches of a given shape
std::vector<int32_t> calculatePitches(const nvinfer1::Dims& inputDims);

// Helper function to check that linear resize can be used
bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors);

Expand Down

0 comments on commit e5ee2b5

Please sign in to comment.