Skip to content

Commit

Permalink
Merge branch 'mbrookhart/dynamic_onnx' into electriclilies/dynamic_onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
Lily Orth-Smith authored Aug 26, 2020
2 parents 27fd728 + 6a027a3 commit df62e85
Show file tree
Hide file tree
Showing 23 changed files with 600 additions and 349 deletions.
34 changes: 34 additions & 0 deletions include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr.h>

#include <bitset>
#include <string>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -124,6 +125,11 @@ class FeatureSet {
*/
bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); }

/*!
* \brief return a string representation.
*/
std::string ToString() const;

private:
std::bitset<feature_count> bs_;
FeatureSet() = default;
Expand Down Expand Up @@ -160,6 +166,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param fs The feature set of the program.
*/
void CheckFeature(const RelayExpr& expr, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param mod The module.
* \param fs The feature set of the program.
*/
void CheckFeature(const IRModule& mod, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param mod The module.
* \param fs The feature set of the program.
*/
inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) {
CheckFeature(expr, fs);
CheckFeature(mod, fs);
}

} // namespace relay
} // namespace tvm

Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ TVM_DLL Pass ToBasicBlockNormalForm();
*/
TVM_DLL Pass ToANormalForm();

/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
*
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,10 +645,10 @@ def _impl_v11(cls, inputs, attr, params):
value = _op.take(inputs[2], _op.const(0))
else:
value = 0
attr["pad_value"] = value
pads_shape = infer_shape(pads)
dims = int(pads_shape[0] / 2)
pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims)))

pad_mode = attr.get('mode', b'constant').decode('utf-8')
if pad_mode in ['constant', 'edge', 'reflect']:
attr['pad_mode'] = pad_mode
Expand Down Expand Up @@ -891,7 +891,7 @@ def _impl_v9(cls, inputs, attr, params):
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))

attr = {'method': method}
# in 3d case, we use the purely static op

if dims == 5:
if isinstance(scales, Call):
scale_h = _op.take(scales, _op.const(-2))
Expand All @@ -918,6 +918,7 @@ def _impl_v9(cls, inputs, attr, params):
return _op.nn.upsampling(inputs[0], scale_h, scale_w, layout=layout, method=method, align_corners=True)



class Shape(OnnxOpConverter):
""" Operator converter for Shape.
"""
Expand Down
34 changes: 14 additions & 20 deletions python/tvm/relay/op/dyn/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,31 @@ def compute_resize(attrs, inputs, out_type):

reg.register_injective_schedule("dyn.image.resize")


@script
def _NCHW_resize_shape_func(dshape, size, ndim):
def _resize_shape_func(dshape, size, ndim, height_axis, width_axis):
out = output_tensor((ndim, ), "int64")
for i in const_range(ndim):
out[i] = int64(dshape[i])
out[2] = int64(size[0])
out[3] = int64(size[1])
out[height_axis] = int64(size[0])
out[width_axis] = int64(size[1])
return out


@script
def _NHWC_resize_shape_func(dshape, size, ndim):
out = output_tensor((ndim, ), "int64")
for i in const_range(ndim):
out[i] = int64(dshape[i])
out[1] = int64(size[0])
out[2] = int64(size[1])
return out


@reg.register_shape_func("dyn.image.resize", True)
def resize_shape_func(attrs, inputs, _):
"""
Shape function for dyn.image.resize op.
"""
layout = attrs.layout
if layout == 'NHWC':
out = [_NHWC_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))]
elif (layout == 'NCHW') or nchw_pack_layout(layout) or nchw_xc_layout(layout):
out = [_NCHW_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))]
if nchw_pack_layout(layout) or nchw_xc_layout(layout):
out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)),
convert(2), convert(3))]
else:
raise ValueError("Resize Unsupported Layout", layout)
height_axis = width_axis = 1
for i, letter in enumerate(layout):
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)),
convert(height_axis), convert(width_axis))]
return out
16 changes: 9 additions & 7 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule, TypeCall
from tvm.relay.transform import ToANormalFormExpr

from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
Expand Down Expand Up @@ -204,7 +205,6 @@ def define_tensor_concatenate(self):
self.prelude.mod[concat_var] = \
Function([x, y], Match(x, [case], False), tensor_type_var(), [])


def define_tensor_expand_dims(self):
"""Defines a function to grow a tensor_t's rank by adding one dimension in front
of the original tensor_t.
Expand Down Expand Up @@ -511,8 +511,9 @@ def define_tensor_array_stack(self):
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims))
output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
self.prelude.mod[stack_var] = Function([tensor_array], tensors,
output_tensor_type_var(), [])
self.prelude.mod[stack_var] = \
Function([tensor_array], tensors,
output_tensor_type_var(), [])

def define_tensor_array_gather(self):
"""Defines a function to return the selected values in a tensor array as tensor_t.
Expand Down Expand Up @@ -809,7 +810,7 @@ def define_tensor_concat(self):
tensor4_var(op.concatenate([t41, t42], axis=0)))],
False))
# op.concatenate does not support tensor with rank higher than 4
self.prelude.mod[concat_var] =\
self.prelude.mod[concat_var] = \
Function([x, y], Match(x, [tensor1_case,
tensor2_case,
tensor3_case,
Expand Down Expand Up @@ -1167,7 +1168,7 @@ def define_tensor_array_gather(self):
current = Var("current", scalar_type('int32'))
limit = Var("limit", scalar_type('int32'))
indices_ = Var('indices_', TensorType([Any()], 'int32'))
helper_body =\
helper_body = \
If(equal(current, const(0)),
stack_var(accu),
helper_var(
Expand All @@ -1187,7 +1188,7 @@ def define_tensor_array_gather(self):
indices_shape = op.shape_of(indices)
limit = op.take(indices_shape, const(0))
body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
self.prelude.mod[gather_var] =\
self.prelude.mod[gather_var] = \
Function([tensor_array, indices], body, tensor_type_var(), [])

def define_tensor_array_stack(self):
Expand All @@ -1205,7 +1206,8 @@ def define_tensor_array_stack(self):
tensors = self.prelude.foldl(concat_var,
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims))
self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), [])
self.prelude.mod[stack_var] = \
Function([tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), [])

def register(self):
"""Register all tensor array ops in Prelude"""
Expand Down
17 changes: 16 additions & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,26 @@ def ToANormalForm():
Returns
-------
ret: Union[tvm.transform.Pass, tvm.relay.Expr]
ret : Union[tvm.transform.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form.
"""
return _ffi_api.ToANormalForm()

def ToANormalFormExpr(e):
"""ToANormalForm, but on expression level.
Parameters
----------
e : Expr
The graph expression.
Returns
-------
ret : Expr
The transformed expresion.
"""
return _ffi_api.ToANormalFormExpr(e)

def ToBasicBlockNormalForm():
"""Turn an expression to Basic Block Normal Form.
We define a block as a group of expressions implied by the scope structure.
Expand Down
64 changes: 49 additions & 15 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import tvm
from tvm import te
from tvm.topi import nn
from ..util import get_const_tuple
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
from ..util import get_const_tuple, get_const_int
from ..nn.util import get_pad_tuple
from .tensor_intrin import gemv_quantized, gemv_quantized_impl
from .tensor_intrin import gemm_quantized, gemm_quantized_impl

def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
Expand All @@ -38,15 +39,15 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
executing GEMM and transforming the output back"""
batches, IH, IW, IC = get_const_tuple(data.shape)

KH, KW = kernel_size
OC = output_channels
KH, KW = get_const_tuple(kernel_size)
OC = get_const_int(output_channels)

K_AREA = KH * KW

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
dilation_h, dilation_w = get_const_tuple(dilation)

dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
Expand Down Expand Up @@ -126,6 +127,28 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z),
name='conv2d_gemm_output')


# Configuration space
x, y = cfg.axis(M_padded // 4), cfg.axis(K_padded // 16)
cfg.define_reorder('reorder_gemm',
[x, y],
policy='candidate',
candidate=[[x, y],
[y, x]])

outer_loop, inner_loop = cfg.axis(4), cfg.axis(16)
cfg.define_annotate("A_interleaved_unroll_vec",
[outer_loop, inner_loop],
policy="try_unroll_vec")
cfg.define_knob('gemm_quantized_unroll', [True, False])
cfg.define_knob('gemm_quantized_interleave', [True, False])

# Fallback configuration
if cfg.is_fallback:
cfg['reorder_gemm'] = ReorderEntity([0, 1])
cfg['A_interleaved_unroll_vec'] = AnnotateEntity(["unroll", "vec"])
cfg['gemm_quantized_unroll'] = OtherOptionEntity(False)
cfg['gemm_quantized_interleave'] = OtherOptionEntity(True)
return out

# Schedules
Expand All @@ -150,33 +173,44 @@ def schedule_conv2d_gemm(cfg, s, out, final_out):
n_outer, n_inner = s[data_im2col].split(n, 16)
s[data_im2col].unroll(n_outer)
s[data_im2col].vectorize(n_inner)
b_m_fused = s[data_im2col].fuse(b, m)
s[data_im2col].parallel(b_m_fused)
else:
s[data_im2col].compute_inline()

# Computation(through tensorize)
b, xo, yo, xi, yi = C_interleaved.op.axis
s[C_interleaved].reorder(xo, yo, yi, xi)
s[C_interleaved].parallel(xo)
s[A_interleaved].compute_at(s[C_interleaved], xo)
s[A_interleaved].vectorize(A_interleaved.op.axis[4])
outer_gemm, inner_gemm = cfg['reorder_gemm'].apply(s, C_interleaved, [xo, yo])
s[C_interleaved].reorder(yi, xi)
b_outer_gemm_fused = s[C_interleaved].fuse(b, outer_gemm)
s[C_interleaved].parallel(b_outer_gemm_fused)
s[A_interleaved].compute_at(s[C_interleaved], b_outer_gemm_fused)
_, _, _, outer_A_interleaved, inner_A_interleaved = A_interleaved.op.axis
cfg['A_interleaved_unroll_vec'].apply(s,
A_interleaved,
[outer_A_interleaved, inner_A_interleaved])

in_type = A_interleaved.dtype
out_type = C.dtype
if is_aarch64_arm() and out_type == 'int32':
K = A_interleaved_input.shape[2]
_, M, N = C.shape
assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported"

gem_v_dotprod = gemv_quantized(M, N, K, in_type, out_type)
s[C_interleaved].pragma(xo, "import_llvm", gemv_quantized_impl(M, N, in_type))
s[C_interleaved].tensorize(yi, gem_v_dotprod)
unroll = cfg['gemm_quantized_unroll'].val
interleave = cfg['gemm_quantized_interleave'].val
gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
s[C_interleaved].pragma(b_outer_gemm_fused, "import_llvm", gemm_quantized_impl(M,
N,
K,
unroll,
interleave,
in_type))
s[C_interleaved].tensorize(yi, gemm)

# Output transform
if out != final_out:
n, h, w, c = out.op.axis
_, inner = s[out].split(c, 4)
s[C].compute_at(s[out], inner)
s[out].vectorize(inner)


return s
2 changes: 2 additions & 0 deletions python/tvm/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,10 @@ def schedule_conv2d_NHWC_quantized(cfg, outs):
# Vectorize the output and then inline all the rest
out = outs[0]
n, h, w, c = out.op.axis
n_h_fused = s[out].fuse(n, h)
outer, inner = s[out].split(c, 4)
s[out].vectorize(inner)
s[out].parallel(n_h_fused)

def _callback(op):
"""Traverse operators from computation graph"""
Expand Down
Loading

0 comments on commit df62e85

Please sign in to comment.