Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxify numpy function #2774

Closed
sursu opened this issue Apr 20, 2020 · 4 comments
Closed

Jaxify numpy function #2774

sursu opened this issue Apr 20, 2020 · 4 comments
Labels
question Questions for the JAX team

Comments

@sursu
Copy link
Contributor

sursu commented Apr 20, 2020

I have not used JAX enough.

In my understanding to be able to use JAX, for instance in order to get the gradient of a function, one needs to define a function using jax.numpy functions instead of numpy.

How do I avoid double declarations?

Some would suggest to use only JAX, but, aside performance and correctness concerns, I have concerns about behavior: do jax.numpy function behave identically to numpy functions? What happens when numpy functions are updated to change some behaviors?

Wouldn't it be possible to create an instance of the function for JAX purposes with a function which would inspect the code and replace the numpy or scipy methods with JAX analogs?

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 20, 2020

That's a really interesting idea, and it could be quite useful, although it would be a challenge to implement it well.

I was thinking about whether this could be done with no change to the jax source; this is a terrible hack that is quite brittle, but something like this will work in simple cases:

import inspect
import numpy as np
from functools import wraps

def jaxify(func):
  import jax.numpy
  namespace = func.__globals__.copy()
  namespace['np'] = namespace['numpy'] = jax.numpy
  namespace['jaxify'] = lambda func: func
  source = inspect.getsource(func)
  exec(source, namespace)
  return wraps(func)(namespace[func.__name__])

@jaxify
def my_func(N):
  return np.arange(N).sum()

my_func(10)
# DeviceArray(45, dtype=int32)

@jekbradbury
Copy link
Contributor

An alternative approach to overloading NumPy code in-place was explored in #1565 and prototyped in #611. If you're interested in that functionality, please chime in on #1565, since a major reason it wasn't merged was a relative lack of interested users (compared to the added complexity).

@jekbradbury jekbradbury added the question Questions for the JAX team label Apr 21, 2020
@shoyer
Copy link
Collaborator

shoyer commented Apr 21, 2020

Numba does tricks like this inside numba.jit and it seems to work pretty well for their users.

That said, I think this would be very hard to do in JAX because we occasionally see people using original NumPy inside JAX functions. It's also not very explicit or composable.

I do like the idea of trying to support overrides of NumPy's API via NumPy's own protocols (#1565), which would at least solve most of the composability issues.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 24, 2022

I'm going to close this as out of scope for the project.

@jakevdp jakevdp closed this as completed Jun 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants