-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Large GPU memory allocation for newer versions #17124
Comments
Could you also share the Also, what GPU are you running it on? |
Locally I have an NVIDIA GeForce RTX 4050 card. Also seeing the issue running on AWS on g4dn instances with 1 or more NVIDIA T4 GPUs. These outputs are with an nvidia driver with cuda12.3, but upgrading to the latest jax version (0.4.31) with cuda12.6 appears to show the same behavior. before_optimizations.txtjax==0.4.14, xla_client._version=174,module_0000.jit_run.before_optimizations.txt
jax==0.4.16, xla_client._version=194,module_0000.jit_run.before_optimizations.txt
|
Oh that's quite old versions, apparently it changed between July and November 2023 -- lots has changed since then. Not sure whether bisection is doable or useful.
|
There are instructions that output f32[3540,71,75,71] tensor, which is what is taking 5G of VRAM. Older version emits everything as a single kernel (see the graph below), while the newer one splits it into three (causing the materialization of that tensor). Generally, when partitioning graph into kernels, the compiler doesn't try to optimize for a memory usage, only for a runtime. However, I doubt that materializing 5G twice would make it faster. |
Is there a way to recreate this as an XLA test locally? I'm happy to try myself. This could then be run against a few versions/commits to try narrow down/verify what changed caused it. A flag to disable whatever optimization is causing this would also work - is XLA able to take in account availability of memory when optimizing? Typically individual GPUs don't have that much memory - I also experimented with trying to pmap it over multiple GPUs but got the same sized memory error I seem to remember. Obviously in this particular case I'd prefer it to be slightly less time-optimized in order to run/fit on a given GPU. |
It won't help to know the commit that changed the behavior. The code that builds fusions was completely rewritten since then anyway, so we'd need to look whether it's possible to fix the current code so that it does it. I would also be curious to look into the |
Ok that makes sense, thanks! I generated an XLA dump for the latest version of Jax (v0.4.33) and both before and after optimization files were the same for I looked at the buffer assignment files and both are allocating close to 5GB, but the one that works is 4.49GB and the one that fails is 4.5GB, so I assume the 1594/1595 threshold is just the limit of what can be allocated locally, but no change in actual fusion behaviour. The jax docs (https://jax.readthedocs.io/en/latest/device_memory_profiling.html) indicate that Not sure if I'm misinterpreting the I've attached the memory profile visualize with pprof for the working |
This is a more XLA-specific version of jax-ml/jax#23548, encountered using an NVIDIA GPU.
Basically when using a value of
dls=jnp.ones(shape=(1590, 3))
the program ran successfully and pprof reported ~500kB of memory usage, but increasing todls=jnp.ones(shape=(1600, 3))
fails trying to allocate ~5GB. So it seems like there might be a difference between request GPU memory and actual usage.The Jax script described in that ticket works on Jax v0.4.14 but not on versions v0.4.16 and above. The python xla_client version for v0.4.14 is 174 and for v0.4.16 is 194, so I'm thinking the reason for this change in behaviour must have occurred between
5ca49a9
and326f72f
. I see there were some changes/additions to cudnn fusion logic between those commits, and the last line ofmodule_0000.jit_run.sm_8.9_gpu_after_optimizations.txt
seems to reference "fusion" rather than "reduce" as before, not sure if that is relevant.I've provided content of the
module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt
files below. Let me know if there's any other information I can provide.xla_dump output
jax==0.4.14, xla_client._version=174, module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt
jax==0.4.16, xla_client._version=194, module_0000.jit_run.sm_8.9_gpu_after_optimizations.txt
The text was updated successfully, but these errors were encountered: