Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with jax_getattr inside jax.scan when the PyTree has multiple leaves #23782

Open
nelhage opened this issue Sep 19, 2024 · 1 comment
Open
Assignees
Labels
bug Something isn't working

Comments

@nelhage
Copy link

nelhage commented Sep 19, 2024

Description

Using jax_getattr blows up if the following conditions are true:

  • It is invoked inside of a function that is jax.lax.scan'd over (and probably other loops)
  • The retrieved attribute is a PyTree with a number of leaves not equal to 1

Reproducer (pass --bug to trigger the problem):
--jit and --no-jit both fail, but slightly differently.

import jax
from jax import lax
from jax.experimental.attrs import jax_getattr

import argparse


class C:
    def __init__(self):
        self.vals = dict(x=0)


state = C()


def f(y):
    v = jax_getattr(state, "vals")
    return v["x"] + y


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--jit", default=False, action="store_true")
    parser.add_argument("--no-jit", action="store_false", dest="jit")

    parser.add_argument("--bug", default=False, action="store_true")

    args = parser.parse_args()

    def f_iter(i, v):
        return f(v)

    def do_loop():
        return lax.fori_loop(0, 5, f_iter, 0)

    if args.jit:
        do_loop = jax.jit(do_loop)

    if args.bug:
        state.vals = dict(x=1, z=2)
    else:
        state.vals = dict(x=1)

    print(f"running with {state.vals=}")
    print(f"{do_loop()=}")


if __name__ == "__main__":
    main()

The proximate issue seems to be a confusion in loops.py about whether the leaves of tracked PyTrees are flattened or not, but I haven't worked through the details.
e.g. this code expects flattening, but perhaps this caller isn't aware of that?

System info (python version, jaxlib version, accelerator, etc.)

❯ python -c 'import jax; jax.print_environment_info()'
jax:    0.4.31.dev20240722
jaxlib: 0.4.31.dev20240722
numpy:  1.24.4
python: 3.11.6 | packaged by conda-forge | (main, Oct  3 2023, 10:37:07) [Clang 15.0.7 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Nelson-Elhage-MacBook', release='23.6.0', version='Darwin Kernel Version 23.6.0: Wed Jul 31 20:49:46 PDT 2024; root:xnu-10063.141.1.700.5~1/RELEASE_ARM64_T8103', machine='arm64')

I can't easily test a newer nightly right now, but the relevant code looks unchanged from a quick inspection.

@nelhage nelhage added the bug Something isn't working label Sep 19, 2024
@nelhage
Copy link
Author

nelhage commented Sep 20, 2024

Attaching output and stack traces from my machine for easier skimming and in case it doesn't reproduce somehow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants