Skip to content

Commit

Permalink
Fix QNN type inference (apache#7074)
Browse files Browse the repository at this point in the history
* check for incomplete types in QNN Relation functions

* add regression test from apache#7067

* respond to review comments
  • Loading branch information
Matthew Brookhart authored and Tushar Dey committed Jan 20, 2021
1 parent 85c7304 commit faa95d3
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 15 deletions.
36 changes: 30 additions & 6 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,53 @@ namespace qnn {

bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, input_scales, input_zero_points, output_scale, output_zero_point,
// out_type
ICHECK_EQ(types.size(), 6);

if (types[0].as<IncompleteTypeNode>()) {
return false;
}
// Check the scale and zero point types
const auto* input_scales_tuple = types[1].as<TupleTypeNode>();
if (input_scales_tuple == nullptr) {
throw Error(ErrorBuilder()
<< "qnn concatenate requires a tuple of scales as the second argument, found "
<< PrettyPrint(types[1]));
if (types[1].as<IncompleteTypeNode>()) {
return false;
} else {
throw Error(ErrorBuilder()
<< "qnn concatenate requires a tuple of scales as the second argument, found "
<< PrettyPrint(types[1]));
}
}
for (const auto& input_scale : input_scales_tuple->fields) {
if (input_scale.as<IncompleteTypeNode>()) {
return false;
}
ICHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx]
}

const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
if (input_zero_points_tuple == nullptr) {
throw Error(ErrorBuilder()
<< "qnn concatenate requires a tuple of zero_points as the third argument, found "
<< PrettyPrint(types[2]));
if (types[2].as<IncompleteTypeNode>()) {
return false;
} else {
throw Error(ErrorBuilder()
<< "qnn concatenate requires a tuple of zero_points as the third argument, found "
<< PrettyPrint(types[2]));
}
}
for (const auto& input_zero_point : input_zero_points_tuple->fields) {
if (input_zero_point.as<IncompleteTypeNode>()) {
return false;
}
ICHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx]
}

for (size_t i = 3; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point

Expand Down
13 changes: 10 additions & 3 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ namespace qnn {

bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
// out_type
ICHECK_EQ(types.size(), 7);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
Expand All @@ -57,22 +59,27 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
for (size_t i = 2; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
if (param->groups == 1) {
size_t axis = param->kernel_layout.operator std::string().find('O');
ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // weight_scale
} else {
// Here, total number of output channels depend on depth multiplier.
size_t o_axis = param->kernel_layout.operator std::string().find('O');
size_t i_axis = param->kernel_layout.operator std::string().find('I');
ICHECK(o_axis != std::string::npos || i_axis != std::string::npos)
<< "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis],
reporter); // kernel scale
reporter); // weight_scale
}

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
Expand Down
11 changes: 9 additions & 2 deletions src/relay/qnn/op/convolution_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ Array<Array<Layout>> QnnConvTransposeInferCorrectLayout(

bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
// out_type
ICHECK_EQ(types.size(), 7);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
Expand All @@ -96,14 +98,19 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
for (size_t i = 2; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
if (param->groups == 1) {
size_t axis = param->kernel_layout.find('O');
ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // weight_scale
} else {
// Here, total number of output channels depend on depth multiplier.
size_t o_axis = param->kernel_layout.find('O');
Expand Down
15 changes: 11 additions & 4 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ namespace qnn {

bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
// out_type
ICHECK_EQ(types.size(), 7);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
Expand All @@ -53,10 +55,15 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;

// Check the types of scale and zero points.
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
AssignType(types[5], DataType::Float(32), param->units, reporter);
for (size_t i = 2; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale

ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

Expand Down
13 changes: 13 additions & 0 deletions src/relay/qnn/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,22 @@ inline Array<Array<Layout> > QnnBinaryBroadcastLayout(const Attrs& attrs,

static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale,
// output_zero_point, out_type
ICHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes);

// Check the lhs and rhs types
for (size_t i = 0; i < 2; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
// Check the scale and zero point types
for (size_t i = 2; i < 8; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale
ICHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale
Expand Down
7 changes: 7 additions & 0 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,20 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
*/
bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, input_scale, input_zero_point, output_scale, output_zero_point, output
ICHECK_EQ(types.size(), 6);
const auto* data = types[0].as<TensorTypeNode>();

if (data == nullptr) {
return false;
}

// Check the scale and zero point types
for (size_t i = 3; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
const auto in_dtype = data->dtype;
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(32))
Expand Down
41 changes: 41 additions & 0 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,58 @@
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
from tvm.contrib.download import download_testdata

from tvm.relay.dataflow_pattern import wildcard, is_op
from tvm.relay.op.contrib.register import register_pattern_table
from tvm.relay.op.contrib.register import get_pattern_table


def torch_version_check():
from packaging import version

return version.parse(torch.__version__) > version.parse("1.4.0")


def make_qnn_add_pattern():
lhs = wildcard()
rhs = wildcard()
lhs_scale = wildcard()
lhs_zero_point = wildcard()
rhs_scale = wildcard()
rhs_zero_point = wildcard()
output_scale = wildcard()
output_zero_point = wildcard()
qadd = is_op("qnn.add")(
lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point,
)
return qadd.optional(is_op("clip"))


@register_pattern_table("test_table")
def pattern_table():
return [
("qnn_add", make_qnn_add_pattern()),
]


def get_tvm_runtime(script_module, input_name, ishape):

input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
pattern_table = get_pattern_table("test_table")
with tvm.transform.PassContext(opt_level=3):
pass_list = [
tvm.relay.transform.SimplifyInference(),
tvm.relay.transform.MergeComposite(pattern_table),
]
composite_partition = tvm.transform.Sequential(pass_list)
partitioned = composite_partition(mod)

with tvm.transform.PassContext(opt_level=3):
# test on only cpu for now, torch cannot run quant models on cuda
Expand Down

0 comments on commit faa95d3

Please sign in to comment.