Skip to content

Commit

Permalink
[MXNET]broadcast and logical op support (apache#5461)
Browse files Browse the repository at this point in the history
* [MXNET]broadcast and logical op support

* Review comment fixed
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent ae43022 commit cbb1607
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 11 deletions.
37 changes: 36 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,6 +1712,33 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
res = _op.nn.relu(res)
return res


def _mx_broadcast_to(inputs, attrs):
data = inputs[0]
tgt_shape = attrs.get_int_tuple("shape", [])

return _op.broadcast_to(data, tgt_shape)


def _mx_logical_not(inputs, input_types):
data = inputs[0]
dtype = _infer_type(data).checked_type.dtype
data = _op.cast(data, "bool") if dtype != "bool" else data

return _op.cast(_op.logical_not(data), dtype)


def _mx_broadcast_logical(logical_op):
def impl(inputs, input_types):
lhs_type = _infer_type(inputs[0]).checked_type.dtype
rhs_type = _infer_type(inputs[1]).checked_type.dtype
lhs = _op.cast(inputs[0], "bool") if lhs_type != "bool" else inputs[0]
rhs = _op.cast(inputs[1], "bool") if rhs_type != "bool" else inputs[1]

return _op.cast(logical_op(lhs, rhs), lhs_type)
return impl


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand All @@ -1738,19 +1765,27 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
"_copy" : _rename(_op.copy),
"relu" : _rename(_op.nn.relu),
"broadcast_add" : _rename(_op.add),
"broadcast_plus" : _rename(_op.add),
"broadcast_sub" : _rename(_op.subtract),
"broadcast_minus" : _rename(_op.subtract),
"broadcast_mul" : _rename(_op.multiply),
"broadcast_div" : _rename(_op.divide),
"broadcast_mod" : _rename(_op.mod),
"broadcast_maximum" : _rename(_op.maximum),
"broadcast_minimum" : _rename(_op.minimum),
"broadcast_power" : _rename(_op.power),
"arctan" : _rename(_op.atan),
"broadcast_equal" : _mx_compare(_op.equal, _rename),
"broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
"broadcast_greater" : _mx_compare(_op.greater, _rename),
"broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
"broadcast_lesser" : _mx_compare(_op.less, _rename),
"broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
"broadcast_logical_or" : _mx_broadcast_logical(_op.logical_or),
"broadcast_logical_and" : _mx_broadcast_logical(_op.logical_and),
"broadcast_logical_xor" : _mx_broadcast_logical(_op.logical_xor),
"broadcast_to" : _mx_broadcast_to,
"logical_not" : _mx_logical_not,
"_equal" : _mx_compare(_op.equal, _rename),
"_not_equal" : _mx_compare(_op.not_equal, _rename),
"_greater" : _mx_compare(_op.greater, _rename),
Expand Down Expand Up @@ -1860,6 +1895,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
"broadcast_axes": _mx_broadcast_axis,
"BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
Expand Down Expand Up @@ -1897,7 +1933,6 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
# "contrib_fifo_buffer": _mx_contrib_fifo_buffer,
"ring_buffer": _mx_contrib_fifo_buffer,
# Qnn ops
Expand Down
71 changes: 61 additions & 10 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,25 @@ def _mx_symbol(F, op_name, inputs):
return op(*inputs)

def test_forward_broadcast_ops():
for op in ["broadcast_add", "broadcast_sub", "broadcast_mul",
"broadcast_div", "broadcast_mod", "broadcast_maximum",
"broadcast_minimum", "broadcast_equal", "broadcast_not_equal",
"broadcast_greater", "broadcast_greater_equal",
"broadcast_lesser", "broadcast_lesser_equal"]:
for op in ["broadcast_add",
"broadcast_plus",
"broadcast_sub",
"broadcast_minus",
"broadcast_mul",
"broadcast_div",
"broadcast_mod",
"broadcast_maximum",
"broadcast_minimum",
"broadcast_equal",
"broadcast_not_equal",
"broadcast_greater",
"broadcast_greater_equal",
"broadcast_lesser",
"broadcast_lesser_equal",
"broadcast_power",
"broadcast_logical_or",
"broadcast_logical_and",
"broadcast_logical_xor"]:
a_shape = (3, 4, 5)
b_shape = (4, 5)
if op == "broadcast_mod":
Expand Down Expand Up @@ -462,16 +476,51 @@ def verify(shape, axis):
def test_forward_broadcast_axis():
def verify(shape, axis, size):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size)
mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for op in ["broadcast_axis",
"broadcast_axes"]:
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('x'),axis,size])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(x_np),axis,size])
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

verify((1, 2, 1), 2, 3)
verify((1, 2, 1), (0, 2), (2, 3))


def test_forward_broadcast_to():
def verify(input_shape, shape):
x_np = np.random.uniform(size=input_shape).astype("float32")
ref_res = mx.nd.broadcast_to(mx.nd.array(x_np), shape=shape)
mx_sym = mx.sym.broadcast_to(mx.sym.var("x"), shape=shape)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 2, 1), 2, 3)
verify((1, 2, 1), (0, 2), (2, 3))

verify((1, 2, 3), (3, 2, 3))
verify((4, 1, 32, 32), (4, 8, 32, 32))


def test_forward_logical_not():
a_shape = (3, 4, 5)
dtype = 'float32'
a_np = np.random.uniform(size=a_shape).astype(dtype)
mx_sym = mx.sym.logical_not(mx.sym.var('a'))
ref_res = mx.nd.logical_not(mx.nd.array(a_np))
shapes = {'a': a_shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())


def test_forward_full():
def verify(val, shape, dtype):
Expand Down Expand Up @@ -1061,6 +1110,8 @@ def verify(shape, blocksize=2):
test_forward_where()
test_forward_arange()
test_forward_broadcast_ops()
test_forward_broadcast_to()
test_forward_logical_not()
test_forward_elemwise_ops()
test_forward_scalar_ops()
test_forward_slice_like()
Expand Down

0 comments on commit cbb1607

Please sign in to comment.