From 4e211a735221a9b9d188422025e2d464e37b3c96 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 12 Feb 2021 21:14:56 -0700 Subject: [PATCH] [ONNX] Make the ONNX Importer More Static (#7429) * Construct static Ops if inputs are Constant * Expose FoldConstant as a function in addition to the pass * refactor onnx importer to do more static imports by constant folding fix pylint * fix test regressions * fix style, two bugs * pipe freeze_params through sub_graphs when importing loops and control flow --- python/tvm/relay/frontend/common.py | 6 + python/tvm/relay/frontend/onnx.py | 198 +++++++++++++--------- python/tvm/relay/op/image/image.py | 4 +- python/tvm/relay/op/nn/nn.py | 16 +- python/tvm/relay/op/tensor.py | 6 +- python/tvm/relay/op/transform.py | 18 +- python/tvm/relay/transform/transform.py | 17 ++ src/relay/transforms/fold_constant.cc | 2 + tests/python/relay/test_op_grad_level3.py | 2 +- 9 files changed, 180 insertions(+), 89 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 6323c63ab9b3..2db420a40992 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -491,6 +491,12 @@ def infer_type(node, mod=None): return ret +def fold_constant(node, mod=None): + if mod is None: + mod = IRModule.from_expr(node) + return _transform.FoldConstantExpr(node, mod) + + def infer_channels(inputs, transpose=False): """A hack for getting 'channels' or 'units' since caffe2 does not provide these attributes. We check the shape of weights provided to get the number. diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c9140d782a2d..fb3d1c923561 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -34,7 +34,7 @@ from .. import ty as _ty from .common import AttrCvt, Renamer -from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value +from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value, fold_constant from .common import infer_type, get_name @@ -364,7 +364,7 @@ def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", d ), dtype="int64", ) - shape = _op.strided_slice(_op.shape_of(data, dtype="int64"), [2], [ndim]) + shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) # get input shape # set up integer constants @@ -545,9 +545,9 @@ class MatMul(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) # Need to check input shape as batch matmul must be supported. - a_shape = _op.shape_of(inputs[0]) + a_shape = shape_of(inputs[0]) a_rank = infer_shape(a_shape)[0] - b_shape = _op.shape_of(inputs[1]) + b_shape = shape_of(inputs[1]) b_rank = infer_shape(b_shape)[0] # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: @@ -555,9 +555,13 @@ def _impl_v1(cls, inputs, attr, params): def flatten_to_3d(x, x_shape): ndims = infer_shape(x_shape)[0] newshape = _op.concatenate( - [_expr.const([-1]), _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0 + [ + _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), + _op.strided_slice(x_shape, [ndims - 2], [ndims]), + ], + 0, ) - out = _op.reshape(x, newshape) + out = _op.reshape(x, fold_constant(newshape)) return out # Convert a and b into 3 dimensional tensors. @@ -598,7 +602,7 @@ def flatten_to_3d(x, x_shape): ], 0, ) - return _op.reshape(output, final_shape) + return _op.reshape(output, fold_constant(final_shape)) # Otherwise a simple dense op will get the job done. input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t) @@ -646,7 +650,7 @@ def _impl_v11(cls, inputs, attr, params): multiplier = _op.concatenate( [_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0 ) - total_output_shape = multiplier * _op.shape_of(data, dtype="int64") + total_output_shape = multiplier * shape_of(data, dtype="int64") # Add extra dimensions from kernel size and stride mismatch total_output_shape += _op.concatenate( [_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0 @@ -792,11 +796,11 @@ def _impl_v2(cls, inputs, attr, params): def _impl_v11(cls, inputs, attr, params): pads = inputs[1] if len(inputs) == 3: - value = _op.take(inputs[2], _op.const(0)) + value = fold_constant(_op.take(inputs[2], _op.const(0))) else: value = 0 - pad_width_expr = _op.transpose(_op.reshape(pads, (2, -1))) + pad_width_expr = fold_constant(_op.transpose(_op.reshape(pads, (2, -1)))) pad_mode = attr.get("mode", b"constant").decode("utf-8") if not pad_mode in ["constant", "edge", "reflect"]: @@ -823,7 +827,7 @@ class Prelu(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs)) - input_shape = _op.shape_of(inputs[0]) + input_shape = shape_of(inputs[0]) alpha = _op.broadcast_to_like(inputs[1], inputs[0]) alpha = _op.reshape(alpha, [-1]) output = _op.nn.prelu(_op.reshape(inputs[0], [-1]), alpha, axis=0) @@ -875,7 +879,6 @@ class DepthToSpace(OnnxOpConverter): @classmethod def _impl_v11(cls, inputs, attr, params): - block_size = int(attr["blocksize"]) mode = attr.get("mode", b"DCR").decode("utf-8") return _op.nn.depth_to_space(inputs[0], block_size, mode=mode) @@ -1015,8 +1018,9 @@ def _impl_v9(cls, inputs, attr, params): scales = params[inputs[1].name_hint].asnumpy() else: scales = inputs[1] - - if not isinstance(scales, _expr.Call): + if isinstance(scales, _expr.Constant): + scales = list(scales.data.asnumpy()) + if not isinstance(scales, _expr.Expr): assert scales[0] == 1.0 and scales[1] == 1.0 mode = attr.get("mode") @@ -1067,12 +1071,20 @@ def _impl_v9(cls, inputs, attr, params): return out +def shape_of(x, dtype="int64"): + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + return _op.shape_of(x, dtype) + + class Shape(OnnxOpConverter): """Operator converter for Shape.""" @classmethod def _impl_v1(cls, inputs, attr, params): - return _op.shape_of(inputs[0], "int64") + return shape_of(inputs[0], "int64") class CumSum(OnnxOpConverter): @@ -1204,7 +1216,7 @@ def _impl_v10(cls, inputs, attr, params): # Update the starts and ends according to axes if required. if axes is not None: - data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) + data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) starts = _op.scatter( _op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype), axes, @@ -1223,7 +1235,9 @@ def _impl_v10(cls, inputs, attr, params): if steps is None: steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype) - return _op.strided_slice(inputs[0], starts, ends, steps) + return _op.strided_slice( + inputs[0], fold_constant(starts), fold_constant(ends), fold_constant(steps) + ) class Gather(OnnxOpConverter): @@ -1531,6 +1545,19 @@ def _impl_v9(cls, inputs, attr, params): return output +class Constant(OnnxOpConverter): + """Operator converter for ConstantOfShape.""" + + @classmethod + def _impl_v9(cls, inputs, attr, params): + if "value" not in attr: + raise "No Value in Constant" + np_value = get_numpy(attr.pop("value")) + dtype = np_value.dtype.name + value = _expr.const(np_value, dtype) + return value + + class Sign(OnnxOpConverter): """Operator converter for Sign.""" @@ -1591,12 +1618,14 @@ def _impl_v9(cls, inputs, attr, params): # to that shape. max_rank = max(ranks) max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank] - broadcast_shape = _op.shape_of(inputs[max_rank_idxs[0]]) + broadcast_shape = shape_of(inputs[max_rank_idxs[0]]) # If two or more inputs have the same rank, compute the broadcast # shape by taking the maximum value of each dimensions. if len(max_rank_idxs) > 1: for idx in max_rank_idxs: - broadcast_shape = _op.maximum(broadcast_shape, _op.shape_of(inputs[idx])) + broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx])) + + broadcast_shape = fold_constant(broadcast_shape) condition = _op.broadcast_to(inputs[0], broadcast_shape) x = _op.broadcast_to(inputs[1], broadcast_shape) @@ -1618,7 +1647,7 @@ class Expand(OnnxOpConverter): @classmethod def _impl_v8(cls, inputs, attr, params): dtype = infer_type(inputs[1]).checked_type.dtype - in_shape = _op.shape_of(inputs[0], dtype=dtype) + in_shape = shape_of(inputs[0], dtype=dtype) shape = inputs[1] # Currently 'op.broadcast_to' expect the rank of the given 'shape' @@ -1667,7 +1696,7 @@ def expand_shape(in_shape, shape): new_shape = _op.maximum(in_shape, shape) return new_shape - shape = expand_shape(in_shape, shape) + shape = fold_constant(expand_shape(in_shape, shape)) return _op.broadcast_to(inputs[0], shape=shape) @@ -1942,10 +1971,9 @@ def _impl_v10(cls, inputs, attr, params): ) scale = inputs[1] - size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale - + size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale layout = "NCHW" # ONNX assumes NCHW layout - out_size = _op.strided_slice(size, [2], [4]) + out_size = fold_constant(_op.strided_slice(size, [2], [4])) return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric") @classmethod @@ -1969,7 +1997,7 @@ def _impl_v11(cls, inputs, attr, params): size = inputs[3] else: assert len(scale_shape) != 0, "One of scale or size should be passed." - size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale coord_trans = attr.get("coordinate_transformation_mode") if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]: @@ -1983,7 +2011,7 @@ def _impl_v11(cls, inputs, attr, params): "Unsupported coordinate_transformation_mode: {}".format(coord_trans) ) layout = "NCHW" # ONNX assumes NCHW layout - out_size = _op.strided_slice(size, [2], [4]) + out_size = fold_constant(_op.strided_slice(size, [2], [4])) return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) @@ -2152,7 +2180,9 @@ def cond_fn(*loop_inputs): # Get the current graph proto and create a clone for the subgraph graph_scope = GraphProto.current - subgraph_scope = GraphProto(graph_scope._shape, graph_scope._dtype) + subgraph_scope = GraphProto( + graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params + ) # Load nodes from outer graph into inner graph. subgraph_scope._nodes = graph_scope._nodes.copy() @@ -2246,7 +2276,7 @@ def body_fn(*loop_inputs): expand_scan = _op.expand_dims(new_scan, axis=0) # For non scalar outputs we need to broadcast the initial value. if rank > 0: - new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype) + new_scan_shape = shape_of(new_scan, dtype=iter_dtype) scan_broadcast = _op.concatenate( [_op.reshape(loop_count, [1]), new_scan_shape], axis=0 ) @@ -2264,7 +2294,7 @@ def body_fn(*loop_inputs): return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs # Create the loop function. - loop = _loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn) + loop = fold_constant(_loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn)) # Now need to run initial values through the graph. init_count = _expr.const(0, dtype=iter_dtype) @@ -2287,6 +2317,7 @@ def body_fn(*loop_inputs): # Update outer graph with constants found in the subgraph. free_vars = analysis.free_vars(loop) graph_scope._params.update(subgraph_scope._params) + graph_scope._nodes.update(subgraph_scope._nodes) for var in free_vars: graph_scope._nodes.update({var.name_hint: var}) return outputs @@ -2307,9 +2338,9 @@ def _impl_v1(cls, inputs, attr, params): # Create graph converters for both branches. graph_scope = GraphProto.current - then_graph = GraphProto(graph_scope._shape, graph_scope._dtype) + then_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params) then_graph._nodes = graph_scope._nodes.copy() - else_graph = GraphProto(graph_scope._shape, graph_scope._dtype) + else_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params) else_graph._nodes = graph_scope._nodes.copy() # Convert each branch to a relay expression. @@ -2320,10 +2351,12 @@ def _impl_v1(cls, inputs, attr, params): # Add constants from both branches to parent graph. graph_scope._params.update(then_graph._params) + graph_scope._nodes.update(then_graph._nodes) then_free_vars = analysis.free_vars(then_expr) for var in then_free_vars: graph_scope._nodes.update({var.name_hint: var}) graph_scope._params.update(else_graph._params) + graph_scope._nodes.update(else_graph._nodes) else_free_vars = analysis.free_vars(else_expr) for var in else_free_vars: graph_scope._nodes.update({var.name_hint: var}) @@ -2468,9 +2501,9 @@ def _first_body( # partially prepare ONNX output format by labeling batch_num, class_id nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1) - batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64")) + batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64")) batch_num = _op.expand_dims(batch_num, -1, 1) - class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64")) + class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64")) new_onnx_out = _op.concatenate( [batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1 ) @@ -2570,7 +2603,7 @@ def _outer_body(i, B, C, onnx_out, nms_size_out, out): ) # Call the first loop, perform NMS - B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3) + B, C, S = _op.split(shape_of(scores, dtype="int64"), 3) init_count = _op.const(np.array([0]), dtype="int64") init_onnx_out = _op.const([1], dtype="int64") init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0)) @@ -2617,6 +2650,7 @@ def _get_convert_map(opset): "ThresholdedRelu": ThresholdedRelu.get_converter(opset), "ScaledTanh": ScaledTanh.get_converter(opset), "ParametricSoftplus": ParametricSoftPlus.get_converter(opset), + "Constant": Constant.get_converter(opset), "ConstantOfShape": ConstantOfShape.get_converter(opset), # 'GivenTensorFill' "FC": AttrCvt("dense", ignores=["axis", "axis_w"]), @@ -2776,11 +2810,19 @@ class GraphProto: dtype : str or dict of str to str The input types to the graph + + freeze_params: bool + If this parameter is true, the importer will take any provided + onnx input values (weights, shapes, etc) and embed them into the relay model + as Constants instead of variables. This allows more aggressive optimizations + at compile time and helps in making models static if certain inputs represent + attributes relay would traditionally consider compile-time constants. + """ current = None - def __init__(self, shape, dtype): + def __init__(self, shape, dtype, freeze_params=False): self._nodes = {} self._params = {} self._inputs = {} @@ -2790,6 +2832,7 @@ def __init__(self, shape, dtype): self._shape = shape if shape else {} self._dtype = dtype self.opset = None + self._freeze_params = freeze_params def __enter__(self): self._old_manager = GraphProto.current @@ -2808,7 +2851,7 @@ def freeze(self, func, params): fn = _function.Function(analysis.free_vars(body), body) return fn, {} - def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): + def from_onnx(self, graph, opset, get_output_expr=False): """Construct Relay expression from ONNX graph. Onnx graph is a python protobuf object. @@ -2825,13 +2868,6 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): opset : opset version - freeze_params: bool - If this parameter is true, the importer will take any provided - onnx input values (weights, shapes, etc) and embed them into the relay model - as Constants instead of variables. This allows more aggressive optimizations - at compile time and helps in making models static if certain inputs represent - attributes relay would traditionally consider compile-time constants. - get_output_expr: bool If set to true, this conversion will return each output expression rather than a packaged module. This can be useful when converting subgraphs to @@ -2850,12 +2886,16 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): for init_tensor in graph.initializer: if not init_tensor.name.strip(): raise ValueError("Tensor's name is required.") - self._params[init_tensor.name] = self._parse_array(init_tensor) - self._nodes[init_tensor.name] = new_var( - init_tensor.name, - shape=self._params[init_tensor.name].shape, - dtype=self._params[init_tensor.name].dtype, - ) + array = self._parse_array(init_tensor) + if self._freeze_params: + self._nodes[init_tensor.name] = _expr.const(array) + else: + self._params[init_tensor.name] = array + self._nodes[init_tensor.name] = new_var( + init_tensor.name, + shape=self._params[init_tensor.name].shape, + dtype=self._params[init_tensor.name].dtype, + ) for i in graph.input: # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' @@ -2867,6 +2907,8 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): self._nodes[i_name] = new_var( i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype ) + elif i_name in self._nodes: + continue else: self._num_input += 1 if i_name in self._shape: @@ -2909,37 +2951,28 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): for i in node.input: if i != "": inputs[i] = self._nodes[self._renames.get(i, i)] - if op_name == "Constant": - t_proto = self._parse_attr(node.attribute)["value"] - self._num_param += 1 - # We should convert scalar integers to int32, to normalize. - array = self._parse_array(t_proto) - self._params[node.output[0]] = array - self._nodes[node.output[0]] = new_var( - node.output[0], shape=list(t_proto.dims), dtype=array.dtype - ) + i_name = self._parse_value_proto(node) + node_output = self._fix_outputs(op_name, node.output) + attr["tvm_custom"] = {} + attr["tvm_custom"]["name"] = i_name + attr["tvm_custom"]["num_outputs"] = len(node_output) + + op = self._convert_operator(op_name, inputs, attr, opset) + if not isinstance(op, _expr.TupleWrapper): + outputs_num = 1 else: - i_name = self._parse_value_proto(node) - node_output = self._fix_outputs(op_name, node.output) - attr["tvm_custom"] = {} - attr["tvm_custom"]["name"] = i_name - attr["tvm_custom"]["num_outputs"] = len(node_output) - - op = self._convert_operator(op_name, inputs, attr, opset) - if not isinstance(op, _expr.TupleWrapper): - outputs_num = 1 - else: - outputs_num = len(op) - assert ( - len(node_output) == outputs_num - ), "Number of output mismatch {} vs {} in {}.".format( - len(node_output), outputs_num, op_name - ) - if outputs_num == 1: - self._nodes[node_output[0]] = op - else: - for k, i in zip(list(node_output), range(len(node_output))): - self._nodes[k] = op[i] + outputs_num = len(op) + assert ( + len(node_output) == outputs_num + ), "Number of output mismatch {} vs {} in {}.".format( + len(node_output), outputs_num, op_name + ) + if outputs_num == 1: + self._nodes[node_output[0]] = fold_constant(op) + else: + op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) + for k, i in zip(list(node_output), range(len(node_output))): + self._nodes[k] = op[i] # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] @@ -2957,9 +2990,6 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): self._inputs[i_name] = self._nodes[i_name] # Create a function from our output expression and all input variables. func = _function.Function([v for k, v in self._inputs.items()], outputs) - if freeze_params: - func, params = self.freeze(func, self._params) - return IRModule.from_expr(func), params return IRModule.from_expr(func), self._params def _parse_value_proto(self, value_proto): @@ -3100,7 +3130,7 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals warnings.warn(str(e)) except ImportError: pass - g = GraphProto(shape, dtype) + g = GraphProto(shape, dtype, freeze_params) graph = model.graph if opset is None: try: @@ -3109,5 +3139,5 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals opset = 1 # Use the graph proto as a scope so that ops can access other nodes if needed. with g: - mod, params = g.from_onnx(graph, opset, freeze_params) + mod, params = g.from_onnx(graph, opset) return mod, params diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index a3f3a3e8cb92..153439b1e20c 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -17,7 +17,7 @@ """Image operations.""" from . import _make from ..dyn.image import _make as _dyn_make -from ...expr import Expr +from ...expr import Expr, Constant def resize( @@ -66,6 +66,8 @@ def resize( result: relay.Expr The resized result. """ + if isinstance(size, Constant): + size = list(size.data.asnumpy().astype("int32")) if isinstance(size, Expr): return _dyn_make.resize( data, size, layout, method, coordinate_transformation_mode, out_dtype diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 0c233a6e3b53..5135ac74de25 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -21,7 +21,7 @@ from . import _make from ..dyn.nn import _make as _dyn_make from .utils import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d -from ...expr import const, Expr +from ...expr import const, Expr, Constant def conv1d( @@ -1279,6 +1279,10 @@ def upsampling( result : tvm.relay.Expr The computed result. """ + if isinstance(scale_h, Constant): + scale_h = scale_h.data.asnumpy().item() + if isinstance(scale_w, Constant): + scale_w = scale_w.data.asnumpy().item() if isinstance(scale_h, Expr) or isinstance(scale_w, Expr): if not isinstance(scale_h, Expr): scale_h = const(scale_h, "float64") @@ -1338,6 +1342,12 @@ def upsampling3d( result : tvm.relay.Expr The computed result. """ + if isinstance(scale_d, Constant): + scale_d = scale_d.data.asnumpy().item() + if isinstance(scale_h, Constant): + scale_h = scale_h.data.asnumpy().item() + if isinstance(scale_w, Constant): + scale_w = scale_w.data.asnumpy().item() if isinstance(scale_d, Expr) or isinstance(scale_h, Expr) or isinstance(scale_w, Expr): if not isinstance(scale_d, Expr): scale_d = const(scale_d, "float64") @@ -1596,6 +1606,10 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"): result : tvm.relay.Expr The computed result. """ + if isinstance(pad_value, Constant): + pad_value = pad_value.data.asnumpy().item() + if isinstance(pad_width, Constant): + pad_width = [list(i) for i in pad_width.data.asnumpy()] if isinstance(pad_width, Expr) or (isinstance(pad_value, Expr)): if not isinstance(pad_width, Expr): pad_width = const(list(pad_width)) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 75e298786ddd..5b011043f588 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -22,7 +22,7 @@ from . import _make from .dyn import _make as _dyn_make -from ..expr import Tuple, Expr +from ..expr import Tuple, Expr, Constant from . import op as reg @@ -960,6 +960,8 @@ def zeros(shape, dtype): result : relay.Expr The resulting tensor. """ + if isinstance(shape, Constant): + shape = list(shape.data.asnumpy()) if isinstance(shape, Expr): return _dyn_make.zeros(shape, dtype) if isinstance(shape, int): @@ -1001,6 +1003,8 @@ def ones(shape, dtype): result : relay.Expr The resulting tensor. """ + if isinstance(shape, Constant): + shape = list(shape.data.asnumpy()) if isinstance(shape, Expr): return _dyn_make.ones(shape, dtype) if isinstance(shape, int): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index d42ef477499f..cda417cad239 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -21,7 +21,7 @@ from . import _make from .dyn import _make as _dyn_make from .tensor import shape_of -from ..expr import TupleWrapper, const, Expr, Tuple +from ..expr import TupleWrapper, const, Constant, Expr, Tuple from ...tir import expr as _expr @@ -216,6 +216,8 @@ def reshape(data, newshape): result : relay.Expr The reshaped result. """ + if isinstance(newshape, Constant): + newshape = list(newshape.data.asnumpy()) if isinstance(newshape, Expr): return _dyn_make.reshape(data, newshape) if isinstance(newshape, int): @@ -431,6 +433,8 @@ def full(fill_value, shape=(), dtype=""): result : relay.Expr The resulting tensor. """ + if isinstance(shape, Constant): + shape = list(shape.data.asnumpy()) if isinstance(shape, Expr): return _dyn_make.full(fill_value, shape, dtype) if isinstance(shape, int): @@ -614,6 +618,8 @@ def tile(data, reps): data is promoted to be d-dimensional by prepending new axes. If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it. """ + if isinstance(reps, Constant): + reps = list(reps.data.asnumpy()) if isinstance(reps, Expr): return _dyn_make.tile(data, reps) return _make.tile(data, reps) @@ -753,6 +759,8 @@ def broadcast_to(data, shape): result : relay.Expr The resulting tensor. """ + if isinstance(shape, Constant): + shape = list(shape.data.asnumpy()) if isinstance(shape, Expr): return _dyn_make.broadcast_to(data, shape) if isinstance(shape, int): @@ -884,6 +892,12 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): The computed result. """ strides = strides or [1] + if isinstance(begin, Constant): + begin = list(begin.data.asnumpy()) + if isinstance(end, Constant): + end = list(end.data.asnumpy()) + if isinstance(strides, Constant): + strides = list(strides.data.asnumpy()) if isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr): if isinstance(begin, (tuple, list)): begin = const(list(begin)) @@ -1170,6 +1184,8 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): [0, 1, 0], [0, 0, 1]] """ + if isinstance(depth, Constant): + depth = depth.data.asnumpy().item() if isinstance(depth, Expr): return _dyn_make.one_hot(indices, on_value, off_value, depth, axis, dtype) return _make.one_hot(indices, on_value, off_value, depth, axis, dtype) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index c6df8c1e6ea2..f02f8352de9e 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -240,6 +240,23 @@ def LazyGradientInit(): return _ffi_api.LazyGradientInit() +def FoldConstantExpr(expr, mod): + """Fold the constant expressions in a Relay program. + Parameters + ---------- + expr: Expr + The expression to fold + mod: IRModule + The module the expr lives in (for global calls) + + Returns + ------- + new_expr: Expr + The expr after Constant Folding + """ + return _ffi_api.FoldConstantExpr(expr, mod) + + def FoldConstant(): """Fold the constant expressions in a Relay program. diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 4454c9c0459a..9416b0ec4580 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -374,6 +374,8 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) { return ConstantFolder(mod).Mutate(expr); } +TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstant); + namespace transform { Pass FoldConstant() { diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 904576a181f6..d43744b38e3e 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -146,7 +146,7 @@ def test_zeros_ones_grad_const_ints(): def test_zeros_ones_grad_const_expr(): # when shape is static (i.e. not an input), there is no gradient at all - shape_const = relay.const(np.array([2, 3, 4]), dtype="int32") + shape_const = relay.const(np.array([2, 3, 4]), dtype="int32") * relay.const(1, dtype="int32") static_ty = relay.TensorType([2, 3, 4], dtype="float32") dyn_ty = relay.TensorType([relay.Any(), relay.Any(), relay.Any()], dtype="float32") expected_ty_static = relay.TupleType([static_ty, relay.TupleType([])])