Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA out of memory #7

Open
Tronic opened this issue Apr 20, 2023 · 6 comments
Open

CUDA out of memory #7

Tronic opened this issue Apr 20, 2023 · 6 comments

Comments

@Tronic
Copy link

Tronic commented Apr 20, 2023

Trying to load medium or large model, I get out of memory errors. Loading small with float16 precision works but takes all my 24 GB VRAM. Is there any way to limit Jax memory usage? The OpenAI model is far more modest in its requirements. Reducing the model weights to float16 should be a good idea too.

@sanchit-gandhi
Copy link
Owner

See related: huggingface/transformers#22224

@sanchit-gandhi
Copy link
Owner

sanchit-gandhi commented Apr 21, 2023

You can also convert the parameters to float16/bfloat16 as follows:

# for fp16
pipeline.params = pipeline.model.to_fp16(pipeline.params)
# for bf16
pipeline.params = pipeline.model.to_bf16(pipeline.params)

@arnavmehta7
Copy link

arnavmehta7 commented Apr 21, 2023

@sanchit-gandhi It is a bit concerning that it can take up to 30+ gbs of GPU memory during batch inference. How much batch size will be ideal to keep usage low? Like under 12gb VRAM

@seboslaw
Copy link

I tried running the medium model on a T4 colab instance. Took 14mins to transcribe a 10min audio. Is this due to the memory constraints and the model paging out? Or is it running on the CPU altogether?

@themanyone
Copy link

themanyone commented Apr 27, 2023

I get this error after updating the video card drivers or kernel and forgetting to reboot afterwards. You can use GreenWithEnvy (gwe), available in most distro repos, to profile Nvidia cards and see what, if anything, is going on there. Update: gwe seems like a bloated version of nvidia-smi, which comes with the video drivers already, so just use that.

@sanchit-gandhi
Copy link
Owner

Note that the phenomenon of JAX using 90% of your GPU memory just to load the model is due to JAX's GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

JAX doesn't actually require all of this memory, but blocks it out to prevent fragmentation.

If you want to disable this, you can do so with the global var XLA_PYTHON_CLIENT_PREALLOCATE:

XLA_PYTHON_CLIENT_PREALLOCATE=false python run_benchmark.py

A more reliable way of monitoring your JAX memory is jax-smi: https://github.com/ayaka14732/jax-smi

Still working on figuring out how we can load the large-v2 checkpoint on a 16 GB T4 GPU!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants