-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unused vmap GPU memory allocation causes RESOURCE_EXHAUSTED for versions >0.4.14 #23548
Comments
I checked the HLO when using
After commenting out the two lines containing exp these large tensors are not materialized:
I'm not sure why thus code runs on Jax <0.4.14... it's possible there's some optimizations being done differently. You can inspect the compiled code yourself using: |
Thanks for the response. I'm starting to think it is some change in openxla or lower that is responsible rather than jax itself. A few questions:
Does this seem like a bug or just an old edge case not working anymore do you think? When using |
Description
Overview
The script below works when using an NVIDIA GPU with Jax version 0.4.14, but after upgrading to 0.4.31 (and trying a few other versions in between) it is triggering the following error:
E0910 20:24:00.097739 38257 pjrt_stream_executor_client.cc:3067] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate X bytes
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate X bytes.
where the value of
X
ranges from ~5GB (e.g. 4843897104) to 20GB+ depending on the shape of thedls
variable (set to 3540 in the script below).jax<=0.4.14 - no error
jax>0.4.14 - error
Not sure if this is a bug or if there is some code/syntax in the example below that is no longer supported in versions > 0.4.14 that is responsible for this behavior.
Allocation vs. pprof usage
The GPU has 6GB of memory and after some trial and error it appears that setting the
dls
variable to a shape of 1590 succeeds and uses only ~500kB of memory according to pprof (following https://jax.readthedocs.io/en/latest/device_memory_profiling.html), but a shape of 1600 gives an error trying to allocate ~5GB. If pprof is in fact showing GPU memory usage this could suggest memory is being allocated but not used.jnp.exp removal
Trial and error also showed that removing the
jnp.exp
calls inside the functionm
seem to resolve the issue. For example, the script below withdls
shape set to 10000 fails trying to allocate 30GB, but removing thejnp.exp
calls succeeds and shows as using only ~2MB by pprof.Script
System info (python version, jaxlib version, accelerator, etc.)
Pip versions:
Output of
jax.print_environment_info()
, it is running inside a container based onnvidia/cuda:12.3.2-base-ubuntu22.04
:Pip versions of latest version that does not show the error (v0.4.14):
The text was updated successfully, but these errors were encountered: