-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jax 0.4.33] XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Unsupported number of sorted inputs: 17 #23727
[jax 0.4.33] XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Unsupported number of sorted inputs: 17 #23727
Comments
Try I think I forgot to add it to the release notes for 0.4.32, but there was a major upgrade to the CPU backend. The flag above temporarily switches back to the old version. |
I have a more reduced MWE import jax.numpy as jnp
jnp.lexsort(jnp.zeros((16,2))) |
Fixes jax-ml/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676420937
Fixes google/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676420937
Fixes google/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676420937
Fixes google/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676420937
Fixes jax-ml/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676420937
Fixes google/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676475983
Thank you for reporting the issue! Yes, the SortThunk needs a fallback kernel. We have filed a bug tracking this work. In the meanwhile, I've added the specialization for 17 inputs in openxla/xla@f237cc3. This should reflect in JAX nightly once JAX updates their XLA commit. |
Fixes google/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676475983
Thanks @penpornk may first this was confusing because we were seeing spurious failures only for some input sizes. |
Fixes jax-ml/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676762330
Fixes jax-ml/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676762330
Fixes jax-ml/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676762330
Fixes jax-ml/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676762330
Fixes jax-ml/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676789960
Fixes jax-ml/jax#23727 This is a temporary fix. We will add a fallback sort kernel soon. PiperOrigin-RevId: 676789960
@PhilipVinc Sounds good. I've added support for up to 25 inputs in openxla/xla@82deceb |
Thank you! |
Description
The following (very confusing) bug appeared in jax/lib 0.4.33 (jax/lib 0.4.31 works fine)
It appears for specific shapes.
I'm running it on Mac, but saw it on linux originally for shape (576, 16) as well.
System info (python version, jaxlib version, accelerator, etc.)
EDIT:
It seems that sort_thunk supports only some hardcoded values
https://github.com/openxla/xla/blame/edd8e7f610ba15f4b0b3ae87c8e7b7d8f5c3dc9f/xla/backends/cpu/runtime/sort_thunk.cc#L478
But I can't find the commit at which the implementation was switched to this.
There should be some sort of fallback?
The text was updated successfully, but these errors were encountered: