-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PatternLang] Lift constant nodes to partitioned function arguments #5662
Comments
Hey Cody. This might just be a design misalignment. When writing the partition function, I explicitly dropped the constants from the inputs because I assumed we'd want to propagate constants through the body of the function. If that was an incorrect assumption, this change:
Gives me this result:
Where the only differences I see is the naming of function variables and the fact that partition doesn't do type inference by default. Do you think that's the correct behavior? It breaks a couple of other unit tests, I'll see if I can fix them. |
Thanks for the advise!
If we treat Meanwhile, I would like to ask @tqchen @zhiics @yzhliu for double confirm :) |
#5663 merged. Thanks @mbrookhart :) |
In #5656, we found that
pattern.partition
will not lift the bind constant nodes to the partitioned function arguments. This results in argument mismatch and could be a potential problem when applying to op fusion.Here is an illustrative example:
Output:
We can see that the function generated by the op fusion keeps the original arguments and refers to the constant node in the function call. However, the partitioned function directly accesses the constant node from inside of the function body. Ideally, the partitioned should be same as the fused function.
cc @mbrookhart @masahi @zhiics
The text was updated successfully, but these errors were encountered: