From 6b7bac20ea10b82df3810e43e3a227a678943d4e Mon Sep 17 00:00:00 2001 From: corehalt Date: Thu, 10 Dec 2020 23:26:59 +0900 Subject: [PATCH] Add test for MergeComposite on a QNN graph --- tests/python/frontend/pytorch/qnn_test.py | 104 +++++++++++++--------- 1 file changed, 63 insertions(+), 41 deletions(-) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 4b7395922efb..07e52b7079e8 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -31,10 +31,7 @@ from tvm import relay from tvm.relay.frontend.pytorch_utils import is_version_greater_than from tvm.contrib.download import download_testdata - -from tvm.relay.dataflow_pattern import wildcard, is_op -from tvm.relay.op.contrib.register import register_pattern_table -from tvm.relay.op.contrib.register import get_pattern_table +from tvm.relay.op.contrib.register import register_pattern_table, get_pattern_table def torch_version_check(): @@ -43,47 +40,10 @@ def torch_version_check(): return version.parse(torch.__version__) > version.parse("1.4.0") -def make_qnn_add_pattern(): - lhs = wildcard() - rhs = wildcard() - lhs_scale = wildcard() - lhs_zero_point = wildcard() - rhs_scale = wildcard() - rhs_zero_point = wildcard() - output_scale = wildcard() - output_zero_point = wildcard() - qadd = is_op("qnn.add")( - lhs, - rhs, - lhs_scale, - lhs_zero_point, - rhs_scale, - rhs_zero_point, - output_scale, - output_zero_point, - ) - return qadd.optional(is_op("clip")) - - -@register_pattern_table("test_table") -def pattern_table(): - return [ - ("qnn_add", make_qnn_add_pattern()), - ] - - def get_tvm_runtime(script_module, input_name, ishape): input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) - pattern_table = get_pattern_table("test_table") - with tvm.transform.PassContext(opt_level=3): - pass_list = [ - tvm.relay.transform.SimplifyInference(), - tvm.relay.transform.MergeComposite(pattern_table), - ] - composite_partition = tvm.transform.Sequential(pass_list) - partitioned = composite_partition(mod) with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda @@ -587,3 +547,65 @@ def forward(self, inp): # Outputs from v1.6 seem reliable. TVM's outputs are always the same if is_version_greater_than("1.5.1"): tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4) + + +def make_qnn_add_pattern(): + from tvm.relay.dataflow_pattern import wildcard, is_op + + lhs = wildcard() + rhs = wildcard() + lhs_scale = wildcard() + lhs_zero_point = wildcard() + rhs_scale = wildcard() + rhs_zero_point = wildcard() + output_scale = wildcard() + output_zero_point = wildcard() + qadd = is_op("qnn.add")( + lhs, + rhs, + lhs_scale, + lhs_zero_point, + rhs_scale, + rhs_zero_point, + output_scale, + output_zero_point, + ) + return qadd.optional(is_op("clip")) + + +@register_pattern_table("test_table") +def pattern_table(): + return [ + ("qnn_add", make_qnn_add_pattern()), + ] + + +def run_qnn_mergecomposite(script_module, input_name, ishape): + input_shapes = [(input_name, ishape)] + mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + pattern_table = get_pattern_table("test_table") + with tvm.transform.PassContext(opt_level=3): + pass_list = [ + tvm.relay.transform.SimplifyInference(), + tvm.relay.transform.MergeComposite(pattern_table), + ] + composite_partition = tvm.transform.Sequential(pass_list) + partitioned = composite_partition(mod) + + +def test_qnn_mergecomposite(): + from torchvision.models.quantization import resnet as qresnet + + model = qresnet.resnet18(pretrained=True) + model.eval() + + inp = torch.zeros((1, 3, 224, 224)) + model.fuse_model() + model.qconfig = torch.quantization.get_default_qconfig("fbgemm") + torch.quantization.prepare(model, inplace=True) + model(inp) + torch.quantization.convert(model, inplace=True) + script_module = torch.jit.trace(model, inp).eval() + + input_name = "image" + run_qnn_mergecomposite(script_module, input_name, inp.shape)