Skip to content

Commit

Permalink
Add MXnNet parser for box_decode
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris committed Jun 30, 2020
1 parent 9efe119 commit 1bac7ed
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
37 changes: 37 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,42 @@ def _mx_box_nms(inputs, attrs):
return nms_out


def _mx_box_decode(inputs, attrs):
std0 = relay.const(attrs.get_float('std0', 1), "float32")
std1 = relay.const(attrs.get_float('std1', 1), "float32")
std2 = relay.const(attrs.get_float('std2', 1), "float32")
std3 = relay.const(attrs.get_float('std3', 1), "float32")
clip = attrs.get_float('clip', -1)
in_format = attrs.get_str('format', 'corner')

anchors = inputs[1] # (1, N, 4) encoded in corner or center
a = _op.split(anchors, indices_or_sections=4, axis=-1)
# Convert to format "center".
if in_format == "corner":
a_width = a[2] - a[0]
a_height = a[3] - a[1]
a_x = a[0] + a_width * relay.const(0.5, "float32")
a_y = a[1] + a_height * relay.const(0.5, "float32")
else:
a_x, a_y, a_width, a_height = a
data = inputs[0] # (B, N, 4) predicted bbox offset
p = _op.split(data, indices_or_sections=4, axis=-1)
ox = p[0] * std0 * a_width + a_x
oy = p[1] * std1 * a_height + a_y
dw = p[2] * std2
dh = p[3] * std3
if clip > 0:
clip = relay.const(clip, "float32")
dw = _op.minimum(dw, clip)
dh = _op.minimum(dh, clip)
dw = _op.exp(dw)
dh = _op.exp(dh)
ow = dw * a_width * relay.const(0.5, "float32")
oh = dh * a_height * relay.const(0.5, "float32")
out = _op.concatenate([ox - ow, oy - oh, ox + ow, oy + oh], axis=-1)
return out


def _mx_l2_normalize(inputs, attrs):
new_attrs = {}
mode = attrs.get_str('mode', 'instance')
Expand Down Expand Up @@ -2220,6 +2256,7 @@ def impl(inputs, input_types):
"_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
"_contrib_box_decode" : _mx_box_decode,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
"_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
"GridGenerator" : _mx_grid_generator,
Expand Down
23 changes: 23 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,28 @@ def verify(batch, seq_length, num_heads, head_dim):
verify(3, 10, 6, 8)


def test_forward_box_decode():
def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corner"):
dtype = "float32"
data = np.random.uniform(low=-2, high=2, size=data_shape).astype(dtype)
anchors = np.random.uniform(low=-2, high=2, size=anchor_shape).astype(dtype)
ref_res = mx.nd.contrib.box_decode(mx.nd.array(data), mx.nd.array(anchors), stds[0], stds[1], stds[2], stds[3], clip, in_format)
mx_sym = mx.sym.contrib.box_decode(mx.sym.var("data"), mx.sym.var("anchors"), stds[0], stds[1], stds[2], stds[3], clip, in_format)
shape_dict = {"data": data_shape, "anchors": anchor_shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data, anchors)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)

verify((1, 10, 4), (1, 10, 4))
verify((4, 10, 4), (1, 10, 4))
verify((1, 10, 4), (1, 10, 4), stds=[2, 3, 0.5, 1.5])
verify((1, 10, 4), (1, 10, 4), clip=1)
verify((1, 10, 4), (1, 10, 4), in_format="center")


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -1379,3 +1401,4 @@ def verify(batch, seq_length, num_heads, head_dim):
test_forward_arange_like()
test_forward_interleaved_matmul_selfatt_qk()
test_forward_interleaved_matmul_selfatt_valatt()
test_forward_box_decode()

0 comments on commit 1bac7ed

Please sign in to comment.