-
In my case autograd (ex1.py)import autograd.numpy as np
from autograd import elementwise_grad as egrad
dx, dy = 0, 1
def nabla4(w):
def fn(x, y):
return (
egrad(egrad(egrad(egrad(w, dx), dx), dx), dx)(x, y)
+ 2 * egrad(egrad(egrad(egrad(w, dx), dx), dy), dy)(x, y)
+ egrad(egrad(egrad(egrad(w, dy), dy), dy), dy)(x, y)
)
return fn
def f(x, y):
return x**4 + 2 * x**2 * y**2 + y**4
x = np.arange(10_000, dtype=np.float64)
y = np.arange(10_000, dtype=np.float64)
w = [f] * 100 # In a real program, the elements of the list are various functions.
r = [nabla4(f)(x, y) for f in w]
jax (ex2.py)import jax
from jax import grad, vmap
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
dx, dy = 0, 1
def nabla4(w):
def fn(x, y):
return (
vmap(grad(grad(grad(grad(w, dx), dx), dx), dx))(x, y)
+ 2 * vmap(grad(grad(grad(grad(w, dx), dx), dy), dy))(x, y)
+ vmap(grad(grad(grad(grad(w, dy), dy), dy), dy))(x, y)
)
return fn
def f(x, y):
return x**4 + 2 * x**2 * y**2 + y**4
x = jnp.arange(10_000, dtype=jnp.float64)
y = jnp.arange(10_000, dtype=jnp.float64)
w = [f] * 100 # In a real program, the elements of the list are various functions.
r = [nabla4(f)(x, y) for f in w]
The program using jax is almost 9x slower than the version using autograd. In more complicated programs the differences are much greater. I would be grateful for suggestions on how to achieve similar (or better) performance using jax. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
This is expected, I think. If you replace "numpy with "autograd" in the answer here, the discussion iholds true, and is relevant to this question: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy The problem is that you're performing microbenchmarks of repeated short function calls, which is not a domain that JAX was designed for, and is a domain that NumPy & autograd are designed for: they don't have any ability to In your benchmark, I think you'll find that if you jit-compile your repeated function call, JAX will be far faster: jit_nabla4_f = jax.jit(nabla4(f))
r = [jit_nabla4_f(x, y) for f in w] |
Beta Was this translation helpful? Give feedback.
-
Thank you @jakevdp. In your snippet # create list of jit-compiled functions
jit_nabla4_w = [jax.jit(nabla4(f)) for f in w]
# calculate values using jit-compiled functions in (x, y) coordinates
r = [f(x, y) for f in jit_nabla4_w] But this is even slower than previous version:
This is probably also expected since these jit-compiled functions are only called once to get the results. |
Beta Was this translation helpful? Give feedback.
-
Assuming the def f_1(x, y):
return x**4 + 2 * x**2 * y**2 + y**4
def f_2(x, y):
return jnp.sin(x)**2 + jnp.cos(y)**2
def f_3(x, y):
return - jnp.sin(x) * jnp.sin(y)
def f_4(x, y):
return jnp.sin(y) - 1/9 * x**3 + 1/2
w = [f_1, f_2, f_3, f_4]
r = [nabla4(f)(x, y) for f in w] the script using jax executes in 2851 milliseconds without jit. Using jit as below: w = [f_1, f_2, f_3, f_4]
jit_nabla4_w = [jax.jit(nabla4(f)) for f in w]
r = [f(x, y) for f in jit_nabla4_w] the script executes in 1376 milliseconds. The same as jax version without jit, but using In my real program, functions These microbenchmarks appear to confirm the information contained in the jax documentation (FAQ). It looks like it will be difficult to take advantage of jax in my case. Given the execution time and CPU load, the current implementation is essentially impossible to practically use as a replacement for autograd. |
Beta Was this translation helpful? Give feedback.
Yes, this seems consistent with the discussion at the link in my first response above. These microbenchmarks are in a regime where it’s not surprising that autograd is faster than JAX: i.e. individually dispatched small array operations on CPU.