Replies: 1 comment 4 replies
-
You could use import jax
import jax.numpy as jnp
import scipy.linalg
def schur(x):
return jax.pure_callback(scipy.linalg.schur, (x, x), x)
@jax.jit
def f(x):
return schur(x)
print(f(jnp.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]], jnp.float32))) |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
It is known that when using a GPU or TPU backend, calling
jax.scipy.linalg.schur
raisesNotImplementedError: Schur decomposition is only implemented on the CPU backend
. I wonder if it is possible to bypass the issue by falling back to the CPU implementation in such cases.I tried to transfer the array back to the CPU with
jax.device_get(X)
, hoping that would trigger the LAPACK backend, but it still gives the sameNotImplementedError
.To clarify, I only want this particular function to be executed on CPU, since there isn't a GPU/TPU implementation. The idea is that I will do a lot of computation on a GPU/TPU, transfer the array to the CPU, perform Schur decomposition, transfer the result back to GPT/TPU, and continue execution there.
Beta Was this translation helpful? Give feedback.
All reactions