-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Add support for tuple node in operator fusion #2187
Conversation
This looks right, @masahi We also need to fix for opt_level=0, which will break every op into a single function. But we dont want to do that for a single Tuple node and isolate that into a function. A solution can either be:
|
I would recommend handling the Tuple as root case, so that this can be used for future purposes |
@tqchen I handled the cases for opt level = 0 and when tuple is the root of the group. But I'm not sure if I handled them correctly. When tuple is a isolated, it is not put into a function. But if it is fused with other ops before it, I make a function that returns a tuple. Please have a look at the new test case. |
@masahi your implementation is correct, However, it is likely we don't need to maintain a node count.What you can do instead is to quickly check if all the fields in the tuple are params of that GroupInfo and if they are, it means it is an isolated function. Good job in making this happen |
@tqchen let me know if there is a better way to test equality of two |
👍 looks like a good change, this might of been one of the reasons fusion was being blocked for me and @MarisaKirisame on one of our examples for the paper. |
src/relay/pass/fuse_ops.cc
Outdated
if (ret_group == gmap_.at(tuple)) { | ||
bool isolated = true; | ||
for (size_t i = 0; i < new_fields.size(); ++i) { | ||
isolated &= (new_fields[i] == ginfo_[ret_group].params[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to new_fields[i].same_as(ginfo_[ret_group].params[i])
, in case == get overloaded in the future
Fixes #2183. It enables fusing a concat node with other ops before it.
@tqchen please review.
Example: Upsampling + Concat