-
Hi all, I'm working on an implementation of a simple from jax import Array
from jax.typing import ArrayLike
class LinearOp:
def __mat__(self, x: ArrayLike) -> Array:
[...]
def __rmat__(self, x: ArrayLike) -> Array:
[...]
def transpose(self):
return LinearOpT(...)
T = property(transpose)
@propery
def shape(self) -> Array:
[...]
# other properties like ndim, size, etc...
class LinearOpT(LinearOp):
[...] It seems to be working out well for my needs so far using basic scaling and matrix multiplication, but I'd like to extend it to be compatible with calls to the my_data = LinearOp(...)
X, Y = # some jax arrays
X @ my_data # works
my_data @ Y # works
A, B = # more jax arrays
A @ my_data.T # works
my_data.T @ B # works
jnp.einsum("np,nk,kp->p", X, my_data, Y) # doesn't work What is not clear to me is what additional private methods need to be defined in order to work with Cheers! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Thanks for the question! That seems like a cool thing to build. Unfortunately JAX doesn't currently offer a good way to overload The best current solution I can think of is just to define your own NumPy wrapper API in a package, like We've talked about adding a NumPy-overriding API to JAX, but it's never risen to the top of the priority list. What do you think? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
That seems like a cool thing to build. Unfortunately JAX doesn't currently offer a good way to overload
jax.numpy
function application (and the NumPy's overriding mechanisms probably won't work here).The best current solution I can think of is just to define your own NumPy wrapper API in a package, like
import quattro.numpy as qnp
, and then write functions callingqnp
(or call itjnp
ornp
if you prefer).We've talked about adding a NumPy-overriding API to JAX, but it's never risen to the top of the priority list.
What do you think?