Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Support CallNode inputs in qnn.concatenate #5360

Merged
merged 9 commits into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""QNN dialect operators."""

from __future__ import absolute_import as _abs
from tvm.relay.expr import Tuple
from tvm.relay.expr import Tuple, TupleWrapper
from tvm.relay.op.nn.util import get_pad_tuple2d
from . import _make

Expand Down Expand Up @@ -156,7 +156,7 @@ def concatenate(data,

Parameters
----------
data : Union(List[relay.Expr], Tuple[relay.Expr])
data : Union(List[relay.Expr], Tuple[relay.Expr], TupleWrapper[relay.Expr])
The list of quantized tensors.

input_scales : List[relay.Expr]
Expand All @@ -180,15 +180,16 @@ def concatenate(data,
The concatenated quantized tensor.
"""

data = list(data)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if isinstance(data, (list, tuple)):
data = Tuple(data)
elif isinstance(data, TupleWrapper):
data = data.tuple_value
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
input_scales = list(input_scales)
input_zero_points = list(input_zero_points)

return _make.concatenate(Tuple(data),
return _make.concatenate(data,
Tuple(input_scales),
Tuple(input_zero_points),
output_scale,
Expand Down
14 changes: 11 additions & 3 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,16 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// If the output qnn params do not match the input qnn params, we can call requantize on the input
// expr first, followed by a concatenate on the requantized input exprs.

auto tuple_data = data.as<TupleNode>();
CHECK(tuple_data != nullptr);
Array<Expr> tuple_exprs;
if (data->IsInstance<TupleNode>()) {
tuple_exprs = data.as<TupleNode>()->fields;
} else if (data->IsInstance<CallNode>()) { // if the data is a CallNode, use TupleGetItems
auto call = Downcast<Call>(data);
for (size_t i=0; i < tuple_type->fields.size(); i++) {
mbaret marked this conversation as resolved.
Show resolved Hide resolved
tuple_exprs.push_back(TupleGetItem(call, i));
}
}
CHECK(!tuple_exprs.empty());

auto tuple_input_scales = input_scales.as<TupleNode>();
CHECK(tuple_input_scales != nullptr);
Expand All @@ -160,7 +168,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,

int idx = 0;
Array<Expr> requantized_exprs;
for (auto quantized_expr : tuple_data->fields) {
for (auto quantized_expr : tuple_exprs) {
// Get the input scale for the idx quantized input tensor.
auto input_scale = tuple_input_scales->fields[idx];

Expand Down
25 changes: 25 additions & 0 deletions tests/python/relay/test_op_qnn_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,32 @@ def test_same_i_qnn_params():
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

def test_call_input():
# This tests the case where the input to concatenate is not explicitly a
# tuple node but is instead a call node.
x_data = np.ones(shape=(64,)).astype('uint8')

x = relay.var("x", shape=(64,), dtype='uint8')
x_scale = relay.const(1, 'float32')
y_scale = relay.const(1, 'float32')
x_zero_point = relay.const(0, 'int32')
y_zero_point = relay.const(0, 'int32')

tup = relay.split(x, 2, axis=0)
z = relay.qnn.op.concatenate(tup,
input_scales=(x_scale, y_scale),
input_zero_points=(x_zero_point, y_zero_point),
output_scale=y_scale,
output_zero_point=relay.const(0, 'int32'),
axis=0)
func = relay.Function([x], z)

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data)
np.testing.assert_equal(op_res.asnumpy(), x_data)

if __name__ == '__main__':
test_call_input()
test_same_io_qnn_params()
test_different_io_qnn_params()
test_few_same_io_qnn_params()
Expand Down