-
Notifications
You must be signed in to change notification settings - Fork 386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA out of memory #7
Comments
See related: huggingface/transformers#22224 |
You can also convert the parameters to float16/bfloat16 as follows: # for fp16
pipeline.params = pipeline.model.to_fp16(pipeline.params)
# for bf16
pipeline.params = pipeline.model.to_bf16(pipeline.params) |
@sanchit-gandhi It is a bit concerning that it can take up to 30+ gbs of GPU memory during batch inference. How much batch size will be ideal to keep usage low? Like under 12gb VRAM |
I tried running the medium model on a T4 colab instance. Took 14mins to transcribe a 10min audio. Is this due to the memory constraints and the model paging out? Or is it running on the CPU altogether? |
I get this error after updating the video card drivers or kernel and forgetting to reboot afterwards. You can use GreenWithEnvy (gwe), available in most distro repos, to profile Nvidia cards and see what, if anything, is going on there. Update: gwe seems like a bloated version of nvidia-smi, which comes with the video drivers already, so just use that. |
Note that the phenomenon of JAX using 90% of your GPU memory just to load the model is due to JAX's GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html JAX doesn't actually require all of this memory, but blocks it out to prevent fragmentation. If you want to disable this, you can do so with the global var
A more reliable way of monitoring your JAX memory is Still working on figuring out how we can load the |
Trying to load medium or large model, I get out of memory errors. Loading small with float16 precision works but takes all my 24 GB VRAM. Is there any way to limit Jax memory usage? The OpenAI model is far more modest in its requirements. Reducing the model weights to float16 should be a good idea too.
The text was updated successfully, but these errors were encountered: