Replies: 1 comment
-
Hi Ram, The CPU model is quite different from the massively parallel GPU. On GPU, 60% represents the proportion of time your CUDA cores are busy executing a JAX kernel. You say the calculations are taking a long time, do you know whether your program is fundamentally compute or memory bound ? Depending on that, you could improve on 60% by having some optimisations. The easiest things to try could be to For now It'd be good to have a minimum working example of the program you are running, as well as the output of Cheers, |
Beta Was this translation helpful? Give feedback.
-
Hey,
This is a newbie question about using JAX on GPUs because I'm new at using GPUs for anything. I noticed that when I run JAX experiments using GPUs (on a
g2-standard-8
on GCP), the GPU usage as shown bynvtop
is around 60%, even when the calculations are taking a long time. Why is that not 100%? I've only run computations on CPUs before, where they usually go to 100%. Is there some kind of optimization that I would need to do to get JAX to use 100% of the GPU?Thanks,
Ram.
Beta Was this translation helpful? Give feedback.
All reactions