From faa95d3261b3a5efbc1c948d68294967a75c87b4 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 10 Dec 2020 01:26:05 -0700 Subject: [PATCH] Fix QNN type inference (#7074) * check for incomplete types in QNN Relation functions * add regression test from #7067 * respond to review comments --- src/relay/qnn/op/concatenate.cc | 36 ++++++++++++++++---- src/relay/qnn/op/convolution.cc | 13 +++++-- src/relay/qnn/op/convolution_transpose.cc | 11 ++++-- src/relay/qnn/op/dense.cc | 15 ++++++--- src/relay/qnn/op/op_common.h | 13 +++++++ src/relay/qnn/op/requantize.cc | 7 ++++ tests/python/frontend/pytorch/qnn_test.py | 41 +++++++++++++++++++++++ 7 files changed, 121 insertions(+), 15 deletions(-) diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 7a716a1ec498..59a519d66436 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -38,29 +38,53 @@ namespace qnn { bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // Expected Types: data, input_scales, input_zero_points, output_scale, output_zero_point, + // out_type ICHECK_EQ(types.size(), 6); + if (types[0].as()) { + return false; + } // Check the scale and zero point types const auto* input_scales_tuple = types[1].as(); if (input_scales_tuple == nullptr) { - throw Error(ErrorBuilder() - << "qnn concatenate requires a tuple of scales as the second argument, found " - << PrettyPrint(types[1])); + if (types[1].as()) { + return false; + } else { + throw Error(ErrorBuilder() + << "qnn concatenate requires a tuple of scales as the second argument, found " + << PrettyPrint(types[1])); + } } for (const auto& input_scale : input_scales_tuple->fields) { + if (input_scale.as()) { + return false; + } ICHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx] } const auto* input_zero_points_tuple = types[2].as(); if (input_zero_points_tuple == nullptr) { - throw Error(ErrorBuilder() - << "qnn concatenate requires a tuple of zero_points as the third argument, found " - << PrettyPrint(types[2])); + if (types[2].as()) { + return false; + } else { + throw Error(ErrorBuilder() + << "qnn concatenate requires a tuple of zero_points as the third argument, found " + << PrettyPrint(types[2])); + } } for (const auto& input_zero_point : input_zero_points_tuple->fields) { + if (input_zero_point.as()) { + return false; + } ICHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx] } + for (size_t i = 3; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index a9f2f361f2b3..21335ec2fb34 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -42,6 +42,8 @@ namespace qnn { bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale, + // out_type ICHECK_EQ(types.size(), 7); const auto* data = types[0].as(); const auto* weight = types[1].as(); @@ -57,14 +59,19 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; // Check the types of scale and zero points. + for (size_t i = 2; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point + ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. if (param->groups == 1) { size_t axis = param->kernel_layout.operator std::string().find('O'); ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; - AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale + AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // weight_scale } else { // Here, total number of output channels depend on depth multiplier. size_t o_axis = param->kernel_layout.operator std::string().find('O'); @@ -72,7 +79,7 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(o_axis != std::string::npos || i_axis != std::string::npos) << "Kernel layout attribute is not defined"; AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis], - reporter); // kernel scale + reporter); // weight_scale } // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay diff --git a/src/relay/qnn/op/convolution_transpose.cc b/src/relay/qnn/op/convolution_transpose.cc index c7515b5904f1..bde398df5e33 100644 --- a/src/relay/qnn/op/convolution_transpose.cc +++ b/src/relay/qnn/op/convolution_transpose.cc @@ -81,6 +81,8 @@ Array> QnnConvTransposeInferCorrectLayout( bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale, + // out_type ICHECK_EQ(types.size(), 7); const auto* data = types[0].as(); const auto* weight = types[1].as(); @@ -96,14 +98,19 @@ bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; // Check the types of scale and zero points. + for (size_t i = 2; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point + ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. if (param->groups == 1) { size_t axis = param->kernel_layout.find('O'); ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; - AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale + AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // weight_scale } else { // Here, total number of output channels depend on depth multiplier. size_t o_axis = param->kernel_layout.find('O'); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 3602995b8f16..6284524bff27 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -39,6 +39,8 @@ namespace qnn { bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale, + // out_type ICHECK_EQ(types.size(), 7); const auto* data = types[0].as(); const auto* weight = types[1].as(); @@ -53,10 +55,15 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, << "Expected quantized dense type(int32) for output but was " << param->out_dtype; // Check the types of scale and zero points. - ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point - ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale - AssignType(types[5], DataType::Float(32), param->units, reporter); + for (size_t i = 2; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } + ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point + ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point + ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale + AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index 330802c4c9b1..0f77db4f501a 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -168,9 +168,22 @@ inline Array > QnnBinaryBroadcastLayout(const Attrs& attrs, static inline bool QnnBroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // Expected Types: lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, + // output_zero_point, out_type ICHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes); + // Check the lhs and rhs types + for (size_t i = 0; i < 2; ++i) { + if (types[i].as()) { + return false; + } + } // Check the scale and zero point types + for (size_t i = 2; i < 8; ++i) { + if (types[i].as()) { + return false; + } + } ICHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale ICHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 8e9b31e6fc39..2ae879595659 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -256,6 +256,7 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, */ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // Expected Types: data, input_scale, input_zero_point, output_scale, output_zero_point, output ICHECK_EQ(types.size(), 6); const auto* data = types[0].as(); @@ -263,6 +264,12 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } + // Check the scale and zero point types + for (size_t i = 3; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } const auto in_dtype = data->dtype; ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) || in_dtype == DataType::Int(32)) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 9781eb5d57c4..4b7395922efb 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -32,6 +32,10 @@ from tvm.relay.frontend.pytorch_utils import is_version_greater_than from tvm.contrib.download import download_testdata +from tvm.relay.dataflow_pattern import wildcard, is_op +from tvm.relay.op.contrib.register import register_pattern_table +from tvm.relay.op.contrib.register import get_pattern_table + def torch_version_check(): from packaging import version @@ -39,10 +43,47 @@ def torch_version_check(): return version.parse(torch.__version__) > version.parse("1.4.0") +def make_qnn_add_pattern(): + lhs = wildcard() + rhs = wildcard() + lhs_scale = wildcard() + lhs_zero_point = wildcard() + rhs_scale = wildcard() + rhs_zero_point = wildcard() + output_scale = wildcard() + output_zero_point = wildcard() + qadd = is_op("qnn.add")( + lhs, + rhs, + lhs_scale, + lhs_zero_point, + rhs_scale, + rhs_zero_point, + output_scale, + output_zero_point, + ) + return qadd.optional(is_op("clip")) + + +@register_pattern_table("test_table") +def pattern_table(): + return [ + ("qnn_add", make_qnn_add_pattern()), + ] + + def get_tvm_runtime(script_module, input_name, ishape): input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + pattern_table = get_pattern_table("test_table") + with tvm.transform.PassContext(opt_level=3): + pass_list = [ + tvm.relay.transform.SimplifyInference(), + tvm.relay.transform.MergeComposite(pattern_table), + ] + composite_partition = tvm.transform.Sequential(pass_list) + partitioned = composite_partition(mod) with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda