From 6c7943d9cb4d4625603994f3c69bb9c271cebb13 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 22 May 2020 23:35:30 -0700 Subject: [PATCH 1/2] remove constants from partitioned functions --- src/relay/ir/dataflow_matcher.cc | 6 ++--- tests/python/relay/test_dataflow_pattern.py | 29 ++++++++++++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 2f25733b6cb9..dd9d80648d90 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -557,7 +557,7 @@ class PatternGrouper { auto matches = node_map[node->ref_]; for (auto match : matches) { if (fuzzy_matches.count(match) == 0 && match.as() == nullptr && - match.as() == nullptr && match.as() == nullptr) { + match.as() == nullptr) { inputs[match] = Var( "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), NullValue()); @@ -577,8 +577,8 @@ class PatternGrouper { auto extractor = MatchExtractor(inputs); auto body = extractor.Mutate(expr); - // Verify the pattern still holds - CHECK(DFPatternMatcher(body).Match(pattern_, body)); + // Verify the pattern still holds, no longer valid if we're not embedding constants in the + // graph, keep here for future debug CHECK(DFPatternMatcher(body).Match(pattern_, body)); group.function = Function(params, body, NullValue(), Array()); group.name = extractor.GetName(); // Check to make sure we aren't overlapping with another group diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 17c8df45db4c..5d69521ec479 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -878,8 +878,8 @@ def nested_diamond(inp, weight): ) assert tvm.ir.structural_equal(partitioned, reference) -def get_BN(x, var, mean, beta, gamma, eps = 1e-5): - return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta +def get_BN(x, var, mean, beta, gamma, eps): + return gamma * (x - mean)/relay.op.sqrt(var + eps) + beta def test_partition_batchnorm(): x = relay.var('x') @@ -887,7 +887,8 @@ def test_partition_batchnorm(): mean = relay.var('mean') beta = relay.var('beta') gamma = relay.var('gamma') - BN = get_BN(x, var, mean, beta, gamma) + eps = relay.const(1e-5) + BN = get_BN(x, var, mean, beta, gamma, eps) xf = relay.var('xf') @@ -895,11 +896,14 @@ def test_partition_batchnorm(): meanf = relay.var('meanf') betaf = relay.var('betaf') gammaf = relay.var('gammaf') + epsf = relay.var('epsf') # Put the arguments in toplogological order for the reference - f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") + f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN) - assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta)) + print(partitioned) + print(f(gamma, x, mean, var, beta)) + assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta)) def test_partition_double_batchnorm(): x = relay.var('x') @@ -907,26 +911,31 @@ def test_partition_double_batchnorm(): mean = relay.var('mean') beta = relay.var('beta') gamma = relay.var('gamma') + eps = relay.const(1e-5) - BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta - BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + BN = gamma * (x - mean)/relay.op.sqrt(var + eps) + beta + BN2 = gamma * (BN - mean)/relay.op.sqrt(var + eps) + beta xf = relay.var('xf') varf = relay.var('varf') meanf = relay.var('meanf') betaf = relay.var('betaf') gammaf = relay.var('gammaf') - f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") + epsf = relay.var('epsf') + f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") # The partitioner doesn't replace duplicates, so we use two copies of the function xf2 = relay.var('xf2') varf2 = relay.var('varf2') meanf2 = relay.var('meanf2') betaf2 = relay.var('betaf2') gammaf2 = relay.var('gammaf2') - f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") + epsf2 = relay.var('epsf2') + f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2, epsf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN2) - reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta) + reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta) + print(partitioned) + print(reference) assert tvm.ir.structural_equal(partitioned, reference) def test_partition_check(): From 547849231c463b5c3338da1cc3dbeaae2f40e251 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 22 May 2020 23:56:55 -0700 Subject: [PATCH 2/2] remove print statements --- tests/python/relay/test_dataflow_pattern.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 5d69521ec479..ed90873421f3 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -901,8 +901,6 @@ def test_partition_batchnorm(): f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN) - print(partitioned) - print(f(gamma, x, mean, var, beta)) assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta)) def test_partition_double_batchnorm(): @@ -934,8 +932,6 @@ def test_partition_double_batchnorm(): partitioned = BatchnormCallback().pattern.partition(BN2) reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta) - print(partitioned) - print(reference) assert tvm.ir.structural_equal(partitioned, reference) def test_partition_check():