Skip to content

JAX GPU behaviour changes when using debug printing #22145

Answered by jakevdp
AdrienCorenflos asked this question in Q&A
Discussion options

You must be logged in to vote

This kind of issue is often related to compiler fusions. When you debug.print an intermediate value, it can prevent the compiler from fusing the operations that produce this intermediate value, because in the fused operation the intermediate value may not actually be computed.

To investigate, you could use the ahead of time lowering tools to see how the compiler is optimizing the outputs of your decomposition. That might give a clue as to which optimization pass is causing the final issue you're seeing.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by AdrienCorenflos
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants