diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 5c1baef4db94..5a3106d1e787 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,7 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm.relay.expr import Tuple +from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.util import get_pad_tuple2d from . import _make @@ -156,7 +156,7 @@ def concatenate(data, Parameters ---------- - data : Union(List[relay.Expr], Tuple[relay.Expr]) + data : Union(List[relay.Expr], Tuple[relay.Expr], TupleWrapper[relay.Expr]) The list of quantized tensors. input_scales : List[relay.Expr] @@ -180,15 +180,16 @@ def concatenate(data, The concatenated quantized tensor. """ - data = list(data) - if not data: - raise ValueError("relay.concatenate requires data to be non-empty.") + if isinstance(data, (list, tuple)): + data = Tuple(data) + elif isinstance(data, TupleWrapper): + data = data.tuple_value if not isinstance(axis, int): raise ValueError("For now, we only support integer axis") input_scales = list(input_scales) input_zero_points = list(input_zero_points) - return _make.concatenate(Tuple(data), + return _make.concatenate(data, Tuple(input_scales), Tuple(input_zero_points), output_scale, diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 650dcb962d44..338e7a1ff6ad 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -149,8 +149,16 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, // If the output qnn params do not match the input qnn params, we can call requantize on the input // expr first, followed by a concatenate on the requantized input exprs. - auto tuple_data = data.as(); - CHECK(tuple_data != nullptr); + Array tuple_exprs; + if (data->IsInstance()) { + tuple_exprs = data.as()->fields; + } else if (data->IsInstance()) { // if the data is a CallNode, use TupleGetItems + auto call = Downcast(data); + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + tuple_exprs.push_back(TupleGetItem(call, i)); + } + } + CHECK(!tuple_exprs.empty()); auto tuple_input_scales = input_scales.as(); CHECK(tuple_input_scales != nullptr); @@ -160,7 +168,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, int idx = 0; Array requantized_exprs; - for (auto quantized_expr : tuple_data->fields) { + for (auto quantized_expr : tuple_exprs) { // Get the input scale for the idx quantized input tensor. auto input_scale = tuple_input_scales->fields[idx]; diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index 03ab9eeb1321..fb60e9805206 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -144,7 +144,32 @@ def test_same_i_qnn_params(): op_res = intrp.evaluate(func)(x_data, y_data) np.testing.assert_equal(op_res.asnumpy(), golden_output) +def test_call_input(): + # This tests the case where the input to concatenate is not explicitly a + # tuple node but is instead a call node. + x_data = np.ones(shape=(64,)).astype('uint8') + + x = relay.var("x", shape=(64,), dtype='uint8') + x_scale = relay.const(1, 'float32') + y_scale = relay.const(1, 'float32') + x_zero_point = relay.const(0, 'int32') + y_zero_point = relay.const(0, 'int32') + + tup = relay.split(x, 2, axis=0) + z = relay.qnn.op.concatenate(tup, + input_scales=(x_scale, y_scale), + input_zero_points=(x_zero_point, y_zero_point), + output_scale=y_scale, + output_zero_point=relay.const(0, 'int32'), + axis=0) + func = relay.Function([x], z) + + intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_equal(op_res.asnumpy(), x_data) + if __name__ == '__main__': + test_call_input() test_same_io_qnn_params() test_different_io_qnn_params() test_few_same_io_qnn_params()