-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jaxify numpy function #2774
Comments
That's a really interesting idea, and it could be quite useful, although it would be a challenge to implement it well. I was thinking about whether this could be done with no change to the jax source; this is a terrible hack that is quite brittle, but something like this will work in simple cases: import inspect
import numpy as np
from functools import wraps
def jaxify(func):
import jax.numpy
namespace = func.__globals__.copy()
namespace['np'] = namespace['numpy'] = jax.numpy
namespace['jaxify'] = lambda func: func
source = inspect.getsource(func)
exec(source, namespace)
return wraps(func)(namespace[func.__name__])
@jaxify
def my_func(N):
return np.arange(N).sum()
my_func(10)
# DeviceArray(45, dtype=int32) |
Numba does tricks like this inside That said, I think this would be very hard to do in JAX because we occasionally see people using original NumPy inside JAX functions. It's also not very explicit or composable. I do like the idea of trying to support overrides of NumPy's API via NumPy's own protocols (#1565), which would at least solve most of the composability issues. |
I'm going to close this as out of scope for the project. |
I have not used JAX enough.
In my understanding to be able to use JAX, for instance in order to get the gradient of a function, one needs to define a function using
jax.numpy
functions instead ofnumpy
.How do I avoid double declarations?
Some would suggest to use only
JAX
, but, aside performance and correctness concerns, I have concerns about behavior: dojax.numpy
function behave identically tonumpy
functions? What happens whennumpy
functions are updated to change some behaviors?Wouldn't it be possible to create an instance of the function for JAX purposes with a function which would inspect the code and replace the
numpy
orscipy
methods with JAX analogs?The text was updated successfully, but these errors were encountered: