You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The following is a summary of the speed-up for the jitted function with increasing problem size (i.e. increasing number of nodes/gauss points numGPs in the domain):
numGPs=25, 110x
numGPs=250, 11x
numGPs=2500, 1.9x
numGPs=25000, 1.1x
I would like to know if the diminishing speed-up is the expected behaviour or if I am doing something silly. If this is the expected behaviour, would anyone have any suggestions for improving the speed-up? I think it would be great to use vmap(), however, as the temperature at time step $t^n$ is a function of the temperature at the previous time step $t^{n-1}$, I'm not sure how I could implement it.
I have opted for scan() over fori_loop() for example as I ultimately need the gradients. Thus far, I have only been running on CPU. I'm running python 3.11, jax 0.4.1, jaxlib 0.4.1.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi jax team,
I am trying to solve the time-dependent heat equation as fast as possible. Below is a simplified version of my program.
The following is a summary of the speed-up for the jitted function with increasing problem size (i.e. increasing number of nodes/gauss points
numGPs
in the domain):numGPs=25
, 110xnumGPs=250
, 11xnumGPs=2500
, 1.9xnumGPs=25000
, 1.1xI would like to know if the diminishing speed-up is the expected behaviour or if I am doing something silly. If this is the expected behaviour, would anyone have any suggestions for improving the speed-up? I think it would be great to use vmap(), however, as the temperature at time step$t^n$ is a function of the temperature at the previous time step $t^{n-1}$ , I'm not sure how I could implement it.
I have opted for scan() over fori_loop() for example as I ultimately need the gradients. Thus far, I have only been running on CPU. I'm running python 3.11, jax 0.4.1, jaxlib 0.4.1.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions