From 18998f2c2709150a28f08d8df34f12825b73542a Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Tue, 5 May 2020 22:38:40 +0100 Subject: [PATCH] [QNN] Support CallNode inputs in qnn.concatenate (#5360) * [QNN] Support CallNode inputs in qnn.concatenate Currently, qnn.concatenate assumes that its 1st arg (data) is a TupleNode. This may not necessarily be true if the input is a CallNode which returns a value of tuple-type. This patch handles the CallNode case by inserting TupleGetItemNodes. * Fix lint * Add test Change-Id: I40b55517b8b1dabbeca89337f80c0c8e62e34981 * Use isinstance Change-Id: I731a231113c5214528373ef52b603a9f05ec502a * isinstance fix Change-Id: Ib3495532f6e4feb5aae3d3096cedd4dc4676cdb4 * Use elif/else if Change-Id: Id8123ea2dd9ce3d8267609de7b5602bb84b084fb * Fix lint Change-Id: Ib6899bb22260575aa3f5d8b51b5d2a0277ee2b10 * Lint fix Change-Id: I56cf1930315344e42d956818a6c68e80836ae786 * Spaces Change-Id: I3edab192e32bafa9ffdc915315791c63279d85dc --- python/tvm/relay/qnn/op/qnn.py | 13 +++++----- src/relay/qnn/op/concatenate.cc | 14 ++++++++--- tests/python/relay/test_op_qnn_concatenate.py | 25 +++++++++++++++++++ 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 5c1baef4db94..5a3106d1e787 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,7 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm.relay.expr import Tuple +from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.util import get_pad_tuple2d from . import _make @@ -156,7 +156,7 @@ def concatenate(data, Parameters ---------- - data : Union(List[relay.Expr], Tuple[relay.Expr]) + data : Union(List[relay.Expr], Tuple[relay.Expr], TupleWrapper[relay.Expr]) The list of quantized tensors. input_scales : List[relay.Expr] @@ -180,15 +180,16 @@ def concatenate(data, The concatenated quantized tensor. """ - data = list(data) - if not data: - raise ValueError("relay.concatenate requires data to be non-empty.") + if isinstance(data, (list, tuple)): + data = Tuple(data) + elif isinstance(data, TupleWrapper): + data = data.tuple_value if not isinstance(axis, int): raise ValueError("For now, we only support integer axis") input_scales = list(input_scales) input_zero_points = list(input_zero_points) - return _make.concatenate(Tuple(data), + return _make.concatenate(data, Tuple(input_scales), Tuple(input_zero_points), output_scale, diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 650dcb962d44..338e7a1ff6ad 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -149,8 +149,16 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, // If the output qnn params do not match the input qnn params, we can call requantize on the input // expr first, followed by a concatenate on the requantized input exprs. - auto tuple_data = data.as(); - CHECK(tuple_data != nullptr); + Array tuple_exprs; + if (data->IsInstance()) { + tuple_exprs = data.as()->fields; + } else if (data->IsInstance()) { // if the data is a CallNode, use TupleGetItems + auto call = Downcast(data); + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + tuple_exprs.push_back(TupleGetItem(call, i)); + } + } + CHECK(!tuple_exprs.empty()); auto tuple_input_scales = input_scales.as(); CHECK(tuple_input_scales != nullptr); @@ -160,7 +168,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, int idx = 0; Array requantized_exprs; - for (auto quantized_expr : tuple_data->fields) { + for (auto quantized_expr : tuple_exprs) { // Get the input scale for the idx quantized input tensor. auto input_scale = tuple_input_scales->fields[idx]; diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index 03ab9eeb1321..fb60e9805206 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -144,7 +144,32 @@ def test_same_i_qnn_params(): op_res = intrp.evaluate(func)(x_data, y_data) np.testing.assert_equal(op_res.asnumpy(), golden_output) +def test_call_input(): + # This tests the case where the input to concatenate is not explicitly a + # tuple node but is instead a call node. + x_data = np.ones(shape=(64,)).astype('uint8') + + x = relay.var("x", shape=(64,), dtype='uint8') + x_scale = relay.const(1, 'float32') + y_scale = relay.const(1, 'float32') + x_zero_point = relay.const(0, 'int32') + y_zero_point = relay.const(0, 'int32') + + tup = relay.split(x, 2, axis=0) + z = relay.qnn.op.concatenate(tup, + input_scales=(x_scale, y_scale), + input_zero_points=(x_zero_point, y_zero_point), + output_scale=y_scale, + output_zero_point=relay.const(0, 'int32'), + axis=0) + func = relay.Function([x], z) + + intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_equal(op_res.asnumpy(), x_data) + if __name__ == '__main__': + test_call_input() test_same_io_qnn_params() test_different_io_qnn_params() test_few_same_io_qnn_params()