Skip to content

Commit

Permalink
Frontend: add onnx GlobalLpPool op (apache#8845)
Browse files Browse the repository at this point in the history
* Frontend: add onnx GlobalLpPool op

* update

* fix for test

Co-authored-by: xp224797 <[email protected]>
  • Loading branch information
2 people authored and ylc committed Sep 29, 2021
1 parent 26b23a0 commit 272b58e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,18 @@ def _impl_v1(cls, inputs, attr, params):
return _op.power(out, reci_p)


class GlobalLpPool(OnnxOpConverter):
"""Operator converter for GlobalLpPool."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
# TODO: GlobalLpPool does not yet support dynamic shapes
in_shape = infer_shape(inputs[0])
attr["kernel_shape"] = in_shape[2:]

return LpPool._impl_v1(inputs, attr, params)


class Mul(Elemwise):
"""Operator converter for Multiply."""

Expand Down Expand Up @@ -4083,6 +4095,7 @@ def _get_convert_map(opset):
# defs/nn
"AveragePool": AveragePool.get_converter(opset),
"LpPool": LpPool.get_converter(opset),
"GlobalLpPool": GlobalLpPool.get_converter(opset),
"MaxPool": MaxPool.get_converter(opset),
"MaxUnpool": MaxUnpool.get_converter(opset),
"Conv": Conv.get_converter(opset),
Expand Down
44 changes: 44 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3470,6 +3470,49 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="
)


def verify_global_lppool(x_shape, p, out_shape, target, dev):
pool_node = helper.make_node(
"GlobalLpPool",
inputs=["x"],
outputs=["y"],
p=p,
)

graph = helper.make_graph(
[pool_node],
"global_lppool_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))],
)

model = helper.make_model(graph, producer_name="global_lppool_test")
verify_with_ort(
model, [x_shape], out_shape, use_vm=True, convert_to_static=True, target=target, dev=dev
)


@tvm.testing.parametrize_targets
def test_global_lppool(target, dev):

# LpPool1D
verify_global_lppool(x_shape=[1, 15, 16], p=2, out_shape=[1, 15, 1], target=target, dev=dev)

# LpPool2D
verify_global_lppool(
x_shape=[1, 15, 32, 32], p=2, out_shape=[1, 15, 1, 1], target=target, dev=dev
)

# LpPool2D
verify_global_lppool(
x_shape=[1, 15, 32, 32], p=3, out_shape=[1, 15, 1, 1], target=target, dev=dev
)

# LpPool3D
verify_global_lppool(
x_shape=[1, 15, 3, 32, 32], p=2, out_shape=[1, 15, 1, 1, 1], target=target, dev=dev
)


def verify_rnn(
seq_length,
batch_size,
Expand Down Expand Up @@ -5826,3 +5869,4 @@ def repeat(N, D):
test_random_uniform()
test_convinteger()
test_batch_matmul()
test_global_lppool()

0 comments on commit 272b58e

Please sign in to comment.