Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a standalone regression test for running MergeComposite on a QNN graph #7080

Merged
merged 1 commit into from
Dec 11, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 63 additions & 41 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
from tvm import relay
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
from tvm.contrib.download import download_testdata

from tvm.relay.dataflow_pattern import wildcard, is_op
from tvm.relay.op.contrib.register import register_pattern_table
from tvm.relay.op.contrib.register import get_pattern_table
from tvm.relay.op.contrib.register import register_pattern_table, get_pattern_table


def torch_version_check():
Expand All @@ -43,47 +40,10 @@ def torch_version_check():
return version.parse(torch.__version__) > version.parse("1.4.0")


def make_qnn_add_pattern():
lhs = wildcard()
rhs = wildcard()
lhs_scale = wildcard()
lhs_zero_point = wildcard()
rhs_scale = wildcard()
rhs_zero_point = wildcard()
output_scale = wildcard()
output_zero_point = wildcard()
qadd = is_op("qnn.add")(
lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point,
)
return qadd.optional(is_op("clip"))


@register_pattern_table("test_table")
def pattern_table():
return [
("qnn_add", make_qnn_add_pattern()),
]


def get_tvm_runtime(script_module, input_name, ishape):

input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
pattern_table = get_pattern_table("test_table")
with tvm.transform.PassContext(opt_level=3):
pass_list = [
tvm.relay.transform.SimplifyInference(),
tvm.relay.transform.MergeComposite(pattern_table),
]
composite_partition = tvm.transform.Sequential(pass_list)
partitioned = composite_partition(mod)

with tvm.transform.PassContext(opt_level=3):
# test on only cpu for now, torch cannot run quant models on cuda
Expand Down Expand Up @@ -587,3 +547,65 @@ def forward(self, inp):
# Outputs from v1.6 seem reliable. TVM's outputs are always the same
if is_version_greater_than("1.5.1"):
tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)


def make_qnn_add_pattern():
from tvm.relay.dataflow_pattern import wildcard, is_op

lhs = wildcard()
rhs = wildcard()
lhs_scale = wildcard()
lhs_zero_point = wildcard()
rhs_scale = wildcard()
rhs_zero_point = wildcard()
output_scale = wildcard()
output_zero_point = wildcard()
qadd = is_op("qnn.add")(
lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point,
)
return qadd.optional(is_op("clip"))


@register_pattern_table("test_table")
def pattern_table():
return [
("qnn_add", make_qnn_add_pattern()),
]


def run_qnn_mergecomposite(script_module, input_name, ishape):
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
pattern_table = get_pattern_table("test_table")
with tvm.transform.PassContext(opt_level=3):
pass_list = [
tvm.relay.transform.SimplifyInference(),
tvm.relay.transform.MergeComposite(pattern_table),
]
composite_partition = tvm.transform.Sequential(pass_list)
partitioned = composite_partition(mod)


def test_qnn_mergecomposite():
from torchvision.models.quantization import resnet as qresnet

model = qresnet.resnet18(pretrained=True)
model.eval()

inp = torch.zeros((1, 3, 224, 224))
model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
torch.quantization.prepare(model, inplace=True)
model(inp)
torch.quantization.convert(model, inplace=True)
script_module = torch.jit.trace(model, inp).eval()

input_name = "image"
run_qnn_mergecomposite(script_module, input_name, inp.shape)