Issue with jax_getattr
inside jax.scan
when the PyTree has multiple leaves
#23782
Labels
bug
Something isn't working
Description
Using
jax_getattr
blows up if the following conditions are true:jax.lax.scan
'd over (and probably other loops)Reproducer (pass
--bug
to trigger the problem):--jit
and--no-jit
both fail, but slightly differently.The proximate issue seems to be a confusion in
loops.py
about whether the leaves of tracked PyTrees are flattened or not, but I haven't worked through the details.e.g. this code expects flattening, but perhaps this caller isn't aware of that?
System info (python version, jaxlib version, accelerator, etc.)
I can't easily test a newer nightly right now, but the relevant code looks unchanged from a quick inspection.
The text was updated successfully, but these errors were encountered: