Skip to content

Commit

Permalink
[ONNX] Make the ONNX Importer More Static (#7429)
Browse files Browse the repository at this point in the history
* Construct static Ops if inputs are Constant

* Expose FoldConstant as a function in addition to the pass

* refactor onnx importer to do more static imports by constant folding

fix pylint

* fix test regressions

* fix style, two bugs

* pipe freeze_params through sub_graphs when importing loops and control flow
  • Loading branch information
Matthew Brookhart committed Feb 13, 2021
1 parent b8a8340 commit 4e211a7
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 89 deletions.
6 changes: 6 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,12 @@ def infer_type(node, mod=None):
return ret


def fold_constant(node, mod=None):
if mod is None:
mod = IRModule.from_expr(node)
return _transform.FoldConstantExpr(node, mod)


def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
Expand Down
198 changes: 114 additions & 84 deletions python/tvm/relay/frontend/onnx.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion python/tvm/relay/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Image operations."""
from . import _make
from ..dyn.image import _make as _dyn_make
from ...expr import Expr
from ...expr import Expr, Constant


def resize(
Expand Down Expand Up @@ -66,6 +66,8 @@ def resize(
result: relay.Expr
The resized result.
"""
if isinstance(size, Constant):
size = list(size.data.asnumpy().astype("int32"))
if isinstance(size, Expr):
return _dyn_make.resize(
data, size, layout, method, coordinate_transformation_mode, out_dtype
Expand Down
16 changes: 15 additions & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _make
from ..dyn.nn import _make as _dyn_make
from .utils import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d
from ...expr import const, Expr
from ...expr import const, Expr, Constant


def conv1d(
Expand Down Expand Up @@ -1279,6 +1279,10 @@ def upsampling(
result : tvm.relay.Expr
The computed result.
"""
if isinstance(scale_h, Constant):
scale_h = scale_h.data.asnumpy().item()
if isinstance(scale_w, Constant):
scale_w = scale_w.data.asnumpy().item()
if isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
if not isinstance(scale_h, Expr):
scale_h = const(scale_h, "float64")
Expand Down Expand Up @@ -1338,6 +1342,12 @@ def upsampling3d(
result : tvm.relay.Expr
The computed result.
"""
if isinstance(scale_d, Constant):
scale_d = scale_d.data.asnumpy().item()
if isinstance(scale_h, Constant):
scale_h = scale_h.data.asnumpy().item()
if isinstance(scale_w, Constant):
scale_w = scale_w.data.asnumpy().item()
if isinstance(scale_d, Expr) or isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
if not isinstance(scale_d, Expr):
scale_d = const(scale_d, "float64")
Expand Down Expand Up @@ -1596,6 +1606,10 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
result : tvm.relay.Expr
The computed result.
"""
if isinstance(pad_value, Constant):
pad_value = pad_value.data.asnumpy().item()
if isinstance(pad_width, Constant):
pad_width = [list(i) for i in pad_width.data.asnumpy()]
if isinstance(pad_width, Expr) or (isinstance(pad_value, Expr)):
if not isinstance(pad_width, Expr):
pad_width = const(list(pad_width))
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from . import _make
from .dyn import _make as _dyn_make
from ..expr import Tuple, Expr
from ..expr import Tuple, Expr, Constant
from . import op as reg


Expand Down Expand Up @@ -960,6 +960,8 @@ def zeros(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Constant):
shape = list(shape.data.asnumpy())
if isinstance(shape, Expr):
return _dyn_make.zeros(shape, dtype)
if isinstance(shape, int):
Expand Down Expand Up @@ -1001,6 +1003,8 @@ def ones(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Constant):
shape = list(shape.data.asnumpy())
if isinstance(shape, Expr):
return _dyn_make.ones(shape, dtype)
if isinstance(shape, int):
Expand Down
18 changes: 17 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _make
from .dyn import _make as _dyn_make
from .tensor import shape_of
from ..expr import TupleWrapper, const, Expr, Tuple
from ..expr import TupleWrapper, const, Constant, Expr, Tuple
from ...tir import expr as _expr


Expand Down Expand Up @@ -216,6 +216,8 @@ def reshape(data, newshape):
result : relay.Expr
The reshaped result.
"""
if isinstance(newshape, Constant):
newshape = list(newshape.data.asnumpy())
if isinstance(newshape, Expr):
return _dyn_make.reshape(data, newshape)
if isinstance(newshape, int):
Expand Down Expand Up @@ -431,6 +433,8 @@ def full(fill_value, shape=(), dtype=""):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Constant):
shape = list(shape.data.asnumpy())
if isinstance(shape, Expr):
return _dyn_make.full(fill_value, shape, dtype)
if isinstance(shape, int):
Expand Down Expand Up @@ -614,6 +618,8 @@ def tile(data, reps):
data is promoted to be d-dimensional by prepending new axes.
If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it.
"""
if isinstance(reps, Constant):
reps = list(reps.data.asnumpy())
if isinstance(reps, Expr):
return _dyn_make.tile(data, reps)
return _make.tile(data, reps)
Expand Down Expand Up @@ -753,6 +759,8 @@ def broadcast_to(data, shape):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Constant):
shape = list(shape.data.asnumpy())
if isinstance(shape, Expr):
return _dyn_make.broadcast_to(data, shape)
if isinstance(shape, int):
Expand Down Expand Up @@ -884,6 +892,12 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
The computed result.
"""
strides = strides or [1]
if isinstance(begin, Constant):
begin = list(begin.data.asnumpy())
if isinstance(end, Constant):
end = list(end.data.asnumpy())
if isinstance(strides, Constant):
strides = list(strides.data.asnumpy())
if isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr):
if isinstance(begin, (tuple, list)):
begin = const(list(begin))
Expand Down Expand Up @@ -1170,6 +1184,8 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 1, 0],
[0, 0, 1]]
"""
if isinstance(depth, Constant):
depth = depth.data.asnumpy().item()
if isinstance(depth, Expr):
return _dyn_make.one_hot(indices, on_value, off_value, depth, axis, dtype)
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,23 @@ def LazyGradientInit():
return _ffi_api.LazyGradientInit()


def FoldConstantExpr(expr, mod):
"""Fold the constant expressions in a Relay program.
Parameters
----------
expr: Expr
The expression to fold
mod: IRModule
The module the expr lives in (for global calls)
Returns
-------
new_expr: Expr
The expr after Constant Folding
"""
return _ffi_api.FoldConstantExpr(expr, mod)


def FoldConstant():
"""Fold the constant expressions in a Relay program.
Expand Down
2 changes: 2 additions & 0 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) {
return ConstantFolder(mod).Mutate(expr);
}

TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstant);

namespace transform {

Pass FoldConstant() {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_zeros_ones_grad_const_ints():

def test_zeros_ones_grad_const_expr():
# when shape is static (i.e. not an input), there is no gradient at all
shape_const = relay.const(np.array([2, 3, 4]), dtype="int32")
shape_const = relay.const(np.array([2, 3, 4]), dtype="int32") * relay.const(1, dtype="int32")
static_ty = relay.TensorType([2, 3, 4], dtype="float32")
dyn_ty = relay.TensorType([relay.Any(), relay.Any(), relay.Any()], dtype="float32")
expected_ty_static = relay.TupleType([static_ty, relay.TupleType([])])
Expand Down

0 comments on commit 4e211a7

Please sign in to comment.