Skip to content

Commit

Permalink
[ONNX] Add Einsum converter (#8985)
Browse files Browse the repository at this point in the history
* einsum

* address review

* move files around

* use generic topi op

* TODO comment

* jostle ci

* jostle ci
  • Loading branch information
anwang2009 committed Sep 15, 2021
1 parent 2aebd33 commit e44f6c0
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 6 deletions.
11 changes: 10 additions & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
.describe("The first element is not included")
.set_default(Bool(false));
}
};
}; // struct ScanopAttrs

/*! \brief Attributes used in unique operator */
struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
Expand All @@ -489,6 +489,15 @@ struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
}
}; // struct UniqueAttrs

/*! \brief Attributes used in einsum operator */
struct EinsumAttrs : public tvm::AttrsNode<EinsumAttrs> {
String equation;

TVM_DECLARE_ATTRS(EinsumAttrs, "relay.attrs.EinsumAttrs") {
TVM_ATTR_FIELD(equation).describe("The einsum expression string");
}
}; // struct EinsumAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3501,6 +3501,15 @@ def _impl_v11(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4)


class Einsum(OnnxOpConverter):
"""Operator converter for Einsum"""

@classmethod
def _impl_v12(cls, inputs, attr, params):
equation = attr["equation"].decode("utf-8")
return _op.einsum(inputs, equation)


class RandomUniform(OnnxOpConverter):
"""Operator converter for random_uniform"""

Expand Down Expand Up @@ -3864,6 +3873,7 @@ def _get_convert_map(opset):
"Range": Range.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from . import _transform
from . import _reduce
from . import _algorithm
from . import _math


def _register_op_make():
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/op/_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
from . import op as _reg
from . import strategy

# einsum
_reg.register_strategy("einsum", strategy.einsum_strategy)
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def compute_unique(attrs, inputs, output_type):
_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy)
_reg.register_shape_func("invert_permutation", False, elemwise_shape_func)


#####################
# Shape functions #
#####################
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,3 +1215,16 @@ def invert_permutation_strategy_cuda(attrs, inputs, out_type, target):
name="invert_permutation.cuda",
)
return strategy


@einsum_strategy.register(["cuda", "gpu"])
def einsum_strategy_cuda(attrs, inputs, out_type, target):
"""einsum cuda strategy"""
strategy = _op.OpStrategy()
# TODO: Add cuda-specific op implementation for einsum
strategy.add_implementation(
wrap_compute_einsum(topi.einsum),
wrap_topi_schedule(topi.generic.schedule_extern),
name="einsum.cuda",
)
return strategy
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,3 +1669,24 @@ def invert_permutation_strategy(attrs, inputs, out_type, target):
name="invert_permutation.generic",
)
return strategy


def wrap_compute_einsum(topi_compute):
"""Wrap einsum topi compute"""

def _compute_einsum(attrs, inputs, _):
return [topi_compute(attrs.equation, *inputs)]

return _compute_einsum


@override_native_generic_func("einsum_strategy")
def einsum_strategy(attrs, inputs, out_type, target):
"""einsum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_einsum(topi.einsum),
wrap_topi_schedule(topi.generic.schedule_einsum),
name="einsum.generic",
)
return strategy
23 changes: 23 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,29 @@ def concatenate(data, axis):
return _make.concatenate(Tuple(data), axis)


def einsum(data, equation):
"""Evaluates the Einstein summation convention on data
Parameters
----------
data : Union(List[relay.Expr], Tuple[relay.Expr])
A list of tensors.
equation : str
The einsum expression string.
Returns
-------
result : relay.Expr
The output tensor from the einsum op.
"""
data = list(data)
if not data:
raise ValueError("relay.einsum requires data to be non-empty.")
if not isinstance(equation, str):
raise ValueError("einsum `equation` must be a str")
return _make.einsum(Tuple(data), equation)


def stack(data, axis):
"""Join a sequence of arrays along a new axis.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/generic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
from .sort import *
from .search import *
from .image import *
from .math import *
34 changes: 34 additions & 0 deletions python/tvm/topi/generic/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Generic math operators"""
from .default import default_schedule as _default_schedule


def schedule_einsum(outs):
"""Schedule for einsum operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of einsum.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
115 changes: 115 additions & 0 deletions src/relay/op/tensor/math.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file math.cc
* \brief Math operators.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/topi/einsum.h>

#include "../make_op.h"
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
namespace relay {

// relay.einsum
TVM_REGISTER_NODE_TYPE(EinsumAttrs);

bool EinsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Check attrs
const EinsumAttrs* param = attrs.as<EinsumAttrs>();
if (param == nullptr) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "the call attributes are not defined");
return false;
}

// types: [data, result]
ICHECK_EQ(types.size(), 2) << "the arity of einsum is 2, not " << types.size();

// Check input type is a tuple.
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "einsum requires a tuple of tensors as the first argument, found "
<< PrettyPrint(types[0]));
return false;
}

// Check the input tuple consists of tensors with consistent dtype.
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
const DataType dtype = first->dtype;
std::vector<Array<PrimExpr>> input_shapes;
for (const Type& ele : tensor_tuple->fields) {
if (ele.as<IncompleteTypeNode>()) {
return false;
}

const auto& e = Downcast<TensorType>(ele);

const DataType& e_dtype = e->dtype;
if (e_dtype != dtype) {
throw Error("relay.einsum requires all tensors have the same dtype");
}
input_shapes.push_back(e->shape);
}

// Calculate output shape
Array<IndexExpr> oshape = topi::NumpyEinsumShape(param->equation, input_shapes);

auto rtype = TensorType(oshape, dtype);
reporter->Assign(types[1], rtype);
return true;
}

Array<te::Tensor> EinsumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const EinsumAttrs* param = attrs.as<EinsumAttrs>();
ICHECK(param != nullptr);
return Array<te::Tensor>{topi::einsum(param->equation, inputs)};
}

Expr MakeEinsum(Expr data, String equation) {
auto attrs = make_object<EinsumAttrs>();
attrs->equation = std::move(equation);
static const Op& op = Op::Get("einsum");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.einsum").set_body_typed(MakeEinsum);

RELAY_REGISTER_OP("einsum")
.describe(R"doc(Evaluates the Einstein summation convention
on the operands)doc" TVM_ADD_FILELINE)
.set_attrs_type<EinsumAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tuple of Tensors", "The input list of tensors.")
.set_support_level(11)
.add_type_rel("Einsum", EinsumRel)
.set_attr<FTVMCompute>("FTVMCompute", EinsumCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace relay
} // namespace tvm
5 changes: 0 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4735,11 +4735,6 @@ def verify_eyelike(indata):
"test_dropout_default_mask",
"test_dropout_default_mask_ratio",
"test_dropout_default_ratio",
"test_einsum_batch_diagonal",
"test_einsum_batch_matmul",
"test_einsum_inner_prod",
"test_einsum_sum",
"test_einsum_transpose",
"test_greater_equal",
"test_greater_equal_bcast",
"test_if_seq",
Expand Down

0 comments on commit e44f6c0

Please sign in to comment.