Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternLang] Lift constant nodes to partitioned function arguments #5662

Closed
comaniac opened this issue May 23, 2020 · 3 comments
Closed

[PatternLang] Lift constant nodes to partitioned function arguments #5662

comaniac opened this issue May 23, 2020 · 3 comments

Comments

@comaniac
Copy link
Contributor

comaniac commented May 23, 2020

In #5656, we found that pattern.partition will not lift the bind constant nodes to the partitioned function arguments. This results in argument mismatch and could be a potential problem when applying to op fusion.

Here is an illustrative example:

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
from tvm.relay.build_module import bind_params_by_name
import numpy as np

x = relay.var('x', shape=(1, 3, 224, 224))
w = relay.var('w', shape=(3, 3, 3, 3))
b = relay.var('b', shape=(3,))

conv2d = relay.op.nn.conv2d(x, w)
out = relay.op.nn.bias_add(conv2d, b)
func = relay.Function([x, w, b], out)
mod = tvm.IRModule.from_expr(func)

mod["main"] = bind_params_by_name(mod["main"],
                                  {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
print('=== Fuse ====')
print(relay.transform.FuseOps()(mod)['main'].body)

conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
print('=== Partition ===')
print(pattern.partition(mod['main'].body, {'Composite': 'aa'}))

Output:

=== Fuse ====
free_var %x: Tensor[(1, 3, 224, 224), float32]
free_var %b: Tensor[(3), float32]
%1 = fn (%p0: Tensor[(1, 3, 224, 224), float32], %p1: Tensor[(3, 3, 3, 3), float64], %p2: Tensor[(3), float32], Primitive=1) -> Tensor[(1, 3, 222, 222), float32] {
  %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 222, 222), float32] */;
  nn.bias_add(%0, %p2) /* ty=Tensor[(1, 3, 222, 222), float32] */
};
%1(%x, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), float64] */ /* ty=Tensor[(3, 3, 3, 3), float64] */, %b) /* ty=Tensor[(1, 3, 222, 222), float32] */
// meta data omitted. you can use show_meta_data=True to include meta data

=== Partition ===
free_var %x: Tensor[(1, 3, 224, 224), float32]
free_var %b: Tensor[(3), float32]
%1 = fn (%FunctionVar_0_0, %FunctionVar_0_1, Composite="aa", PartitionedFromPattern="nn.conv2d_nn.bias_add_") {
  %0 = nn.conv2d(%FunctionVar_0_0, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), float64] */ /* ty=Tensor[(3, 3, 3, 3), float64] */, padding=[0, 0, 0, 0]);
  nn.bias_add(%0, %FunctionVar_0_1)
};
%1(%x, %b)
// meta data omitted. you can use show_meta_data=True to include meta data

We can see that the function generated by the op fusion keeps the original arguments and refers to the constant node in the function call. However, the partitioned function directly accesses the constant node from inside of the function body. Ideally, the partitioned should be same as the fused function.

cc @mbrookhart @masahi @zhiics

@mbrookhart
Copy link
Contributor

mbrookhart commented May 23, 2020

Hey Cody.

This might just be a design misalignment. When writing the partition function, I explicitly dropped the constants from the inputs because I assumed we'd want to propagate constants through the body of the function.

If that was an incorrect assumption, this change:

diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 980935c34..f89116e34 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -544,7 +544,7 @@ class PatternGrouper : protected MixedModeVisitor {
           auto matches = node_map[node->ref_];
           for (auto match : matches) {
             if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
-                match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
+                match.as<FunctionNode>() == nullptr) {
               inputs[match] = Var(
                   "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
                   NullValue<Type>());

Gives me this result:

=== Fuse ====
free_var %x: Tensor[(1, 3, 224, 224), float32]
free_var %b: Tensor[(3), float32]
%1 = fn (%p0: Tensor[(1, 3, 224, 224), float32], %p1: Tensor[(3, 3, 3, 3), float64], %p2: Tensor[(3), float32], Primitive=1) -> Tensor[(1, 3, 222, 222), float32] {
  %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 222, 222), float32] */;
  nn.bias_add(%0, %p2) /* ty=Tensor[(1, 3, 222, 222), float32] */
};
%1(%x, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), float64] */ /* ty=Tensor[(3, 3, 3, 3), float64] */, %b) /* ty=Tensor[(1, 3, 222, 222), float32] */
// meta data omitted. you can use show_meta_data=True to include meta data
=== Partition ===
free_var %x: Tensor[(1, 3, 224, 224), float32]
free_var %b: Tensor[(3), float32]
%1 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, Composite="aa", PartitionedFromPattern="nn.conv2d_nn.bias_add_") {
  %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
  nn.bias_add(%0, %FunctionVar_0_2)
};
%1(%x, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), float64] */ /* ty=Tensor[(3, 3, 3, 3), float64] */, %b)
// meta data omitted. you can use show_meta_data=True to include meta data

Where the only differences I see is the naming of function variables and the fact that partition doesn't do type inference by default.

Do you think that's the correct behavior? It breaks a couple of other unit tests, I'll see if I can fix them.

@comaniac
Copy link
Contributor Author

comaniac commented May 23, 2020

Thanks for the advise!
This major difference affects how the constant values will be used in the runtime, and I consider Relay functions should be used as same as Relay ops. For example, we may have

%0 = nn.conv2d(%data, meta[relay.Constant][0];

If we treat nn.conv2d as a Relay function, we can see that even one if its arguments is a constant, we still maintain its function signature and pass the constant node via the call. Since this concept is applicable to fused and composite functions, I think it would be better for pattern language to have the same behavior.

Meanwhile, I would like to ask @tqchen @zhiics @yzhliu for double confirm :)

@comaniac
Copy link
Contributor Author

#5663 merged. Thanks @mbrookhart :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants