JAX GPU behaviour changes when using debug printing #22145
-
Hi, I've noticed a regression during testing of some old code (during a review round for the related paper). This likely relates to cholesky decompositions/sym_pos solves being less stable than they used to be for float32. I was trying to isolate the problem and create an MRE, but then weirdly enough the issue disappears when I print the output of the cholesky solver! Is there any chance someone could explain to me why this is happening? It would likely then help me with creating the MRE. Thanks! (As a side quest, if anyone knows where regressions in cholesky and other linear algebra could come from send ideas my way!) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This kind of issue is often related to compiler fusions. When you To investigate, you could use the ahead of time lowering tools to see how the compiler is optimizing the outputs of your decomposition. That might give a clue as to which optimization pass is causing the final issue you're seeing. |
Beta Was this translation helpful? Give feedback.
This kind of issue is often related to compiler fusions. When you
debug.print
an intermediate value, it can prevent the compiler from fusing the operations that produce this intermediate value, because in the fused operation the intermediate value may not actually be computed.To investigate, you could use the ahead of time lowering tools to see how the compiler is optimizing the outputs of your decomposition. That might give a clue as to which optimization pass is causing the final issue you're seeing.