Skip to content

Commit

Permalink
[QNN] Support CallNode inputs in qnn.concatenate (#5360)
Browse files Browse the repository at this point in the history
* [QNN] Support CallNode inputs in qnn.concatenate

Currently, qnn.concatenate assumes that its 1st arg
(data) is a TupleNode. This may not necessarily be true
if the input is a CallNode which returns a value of
tuple-type. This patch handles the CallNode case by
inserting TupleGetItemNodes.

* Fix lint

* Add test

Change-Id: I40b55517b8b1dabbeca89337f80c0c8e62e34981

* Use isinstance

Change-Id: I731a231113c5214528373ef52b603a9f05ec502a

* isinstance fix

Change-Id: Ib3495532f6e4feb5aae3d3096cedd4dc4676cdb4

* Use elif/else if

Change-Id: Id8123ea2dd9ce3d8267609de7b5602bb84b084fb

* Fix lint

Change-Id: Ib6899bb22260575aa3f5d8b51b5d2a0277ee2b10

* Lint fix

Change-Id: I56cf1930315344e42d956818a6c68e80836ae786

* Spaces

Change-Id: I3edab192e32bafa9ffdc915315791c63279d85dc
  • Loading branch information
mbaret authored May 5, 2020
1 parent 70a5902 commit 32a094c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
13 changes: 7 additions & 6 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""QNN dialect operators."""

from __future__ import absolute_import as _abs
from tvm.relay.expr import Tuple
from tvm.relay.expr import Tuple, TupleWrapper
from tvm.relay.op.nn.util import get_pad_tuple2d
from . import _make

Expand Down Expand Up @@ -156,7 +156,7 @@ def concatenate(data,
Parameters
----------
data : Union(List[relay.Expr], Tuple[relay.Expr])
data : Union(List[relay.Expr], Tuple[relay.Expr], TupleWrapper[relay.Expr])
The list of quantized tensors.
input_scales : List[relay.Expr]
Expand All @@ -180,15 +180,16 @@ def concatenate(data,
The concatenated quantized tensor.
"""

data = list(data)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if isinstance(data, (list, tuple)):
data = Tuple(data)
elif isinstance(data, TupleWrapper):
data = data.tuple_value
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
input_scales = list(input_scales)
input_zero_points = list(input_zero_points)

return _make.concatenate(Tuple(data),
return _make.concatenate(data,
Tuple(input_scales),
Tuple(input_zero_points),
output_scale,
Expand Down
14 changes: 11 additions & 3 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,16 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// If the output qnn params do not match the input qnn params, we can call requantize on the input
// expr first, followed by a concatenate on the requantized input exprs.

auto tuple_data = data.as<TupleNode>();
CHECK(tuple_data != nullptr);
Array<Expr> tuple_exprs;
if (data->IsInstance<TupleNode>()) {
tuple_exprs = data.as<TupleNode>()->fields;
} else if (data->IsInstance<CallNode>()) { // if the data is a CallNode, use TupleGetItems
auto call = Downcast<Call>(data);
for (size_t i = 0; i < tuple_type->fields.size(); i++) {
tuple_exprs.push_back(TupleGetItem(call, i));
}
}
CHECK(!tuple_exprs.empty());

auto tuple_input_scales = input_scales.as<TupleNode>();
CHECK(tuple_input_scales != nullptr);
Expand All @@ -160,7 +168,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,

int idx = 0;
Array<Expr> requantized_exprs;
for (auto quantized_expr : tuple_data->fields) {
for (auto quantized_expr : tuple_exprs) {
// Get the input scale for the idx quantized input tensor.
auto input_scale = tuple_input_scales->fields[idx];

Expand Down
25 changes: 25 additions & 0 deletions tests/python/relay/test_op_qnn_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,32 @@ def test_same_i_qnn_params():
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

def test_call_input():
# This tests the case where the input to concatenate is not explicitly a
# tuple node but is instead a call node.
x_data = np.ones(shape=(64,)).astype('uint8')

x = relay.var("x", shape=(64,), dtype='uint8')
x_scale = relay.const(1, 'float32')
y_scale = relay.const(1, 'float32')
x_zero_point = relay.const(0, 'int32')
y_zero_point = relay.const(0, 'int32')

tup = relay.split(x, 2, axis=0)
z = relay.qnn.op.concatenate(tup,
input_scales=(x_scale, y_scale),
input_zero_points=(x_zero_point, y_zero_point),
output_scale=y_scale,
output_zero_point=relay.const(0, 'int32'),
axis=0)
func = relay.Function([x], z)

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data)
np.testing.assert_equal(op_res.asnumpy(), x_data)

if __name__ == '__main__':
test_call_input()
test_same_io_qnn_params()
test_different_io_qnn_params()
test_few_same_io_qnn_params()
Expand Down

0 comments on commit 32a094c

Please sign in to comment.