Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jax 0.4.33] XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Unsupported number of sorted inputs: 17 #23727

Open
PhilipVinc opened this issue Sep 18, 2024 · 6 comments · Fixed by openxla/xla#17360, tensorflow/tensorflow#76055, openxla/xla#17417 or tensorflow/tensorflow#76135
Assignees
Labels
bug Something isn't working

Comments

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Sep 18, 2024

Description

The following (very confusing) bug appeared in jax/lib 0.4.33 (jax/lib 0.4.31 works fine)

It appears for specific shapes.

I'm running it on Mac, but saw it on linux originally for shape (576, 16) as well.

import jax
import jax.numpy as jnp
S1 = (1,16)
a=jax.random.randint(jax.random.key(1), S1, 0, 1)
jnp.lexsort(list(a.T)[::-1])
# Array([0], dtype=int32)
S2 = (2, 15)
a=jax.random.randint(jax.random.key(1), S2, 0, 1)
jnp.lexsort(list(a.T)[::-1])
# Array([0, 1], dtype=int32)

S3 = (2, 16)
a=jax.random.randint(jax.random.key(1), S3, 0, 1)
jnp.lexsort(list(a.T)[::-1])

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 2782, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 443, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 949, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
    return xc._xla.pjit(
           ^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 1675, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1286, in __call__
    results = self.xla_executable.execute_sharded(input_bufs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unsupported number of sorted inputs: 17

System info (python version, jaxlib version, accelerator, etc.)

In [23]: import jax; jax.print_environment_info()
jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.0.2
python: 3.11.2 (main, Apr  7 2023, 16:35:55) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='mba-10834270.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:30 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6000', machine='arm64')

EDIT:
It seems that sort_thunk supports only some hardcoded values

https://github.com/openxla/xla/blame/edd8e7f610ba15f4b0b3ae87c8e7b7d8f5c3dc9f/xla/backends/cpu/runtime/sort_thunk.cc#L478

But I can't find the commit at which the implementation was switched to this.

There should be some sort of fallback?

@PhilipVinc PhilipVinc added the bug Something isn't working label Sep 18, 2024
@hawkinsp
Copy link
Collaborator

Try XLA_FLAGS=--xla_cpu_use_thunk_runtime=false as a workaround.

@ezhulenev @penpornk

I think I forgot to add it to the release notes for 0.4.32, but there was a major upgrade to the CPU backend. The flag above temporarily switches back to the old version.

@PhilipVinc
Copy link
Contributor Author

I have a more reduced MWE

import jax.numpy as jnp
jnp.lexsort(jnp.zeros((16,2)))

copybara-service bot pushed a commit to openxla/xla that referenced this issue Sep 19, 2024
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676420937
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Sep 19, 2024
Fixes google/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676420937
copybara-service bot pushed a commit to openxla/xla that referenced this issue Sep 19, 2024
Fixes google/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676420937
copybara-service bot pushed a commit to openxla/xla that referenced this issue Sep 19, 2024
Fixes google/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676420937
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Sep 19, 2024
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676420937
copybara-service bot pushed a commit to openxla/xla that referenced this issue Sep 19, 2024
Fixes google/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676475983
@penpornk
Copy link
Collaborator

Thank you for reporting the issue! Yes, the SortThunk needs a fallback kernel. We have filed a bug tracking this work.

In the meanwhile, I've added the specialization for 17 inputs in openxla/xla@f237cc3. This should reflect in JAX nightly once JAX updates their XLA commit.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Sep 19, 2024
Fixes google/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676475983
@PhilipVinc
Copy link
Contributor Author

Thanks @penpornk
Though technically 17 was just a MWE for a reproducer.
Technically for our application we reasonably need all values until 25.
As you already added 25 would it be reasonable to add all values until 25?

may first this was confusing because we were seeing spurious failures only for some input sizes.

copybara-service bot pushed a commit to openxla/xla that referenced this issue Sep 20, 2024
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676762330
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Sep 20, 2024
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676762330
copybara-service bot pushed a commit to openxla/xla that referenced this issue Sep 20, 2024
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676762330
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Sep 20, 2024
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676762330
copybara-service bot pushed a commit to openxla/xla that referenced this issue Sep 20, 2024
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676789960
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Sep 20, 2024
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676789960
@penpornk
Copy link
Collaborator

As you already added 25 would it be reasonable to add all values until 25?

@PhilipVinc Sounds good. I've added support for up to 25 inputs in openxla/xla@82deceb

@PhilipVinc
Copy link
Contributor Author

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment