Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.pure_callback crashes on TPU VM #12260

Closed
nalzok opened this issue Sep 7, 2022 · 5 comments
Closed

jax.pure_callback crashes on TPU VM #12260

nalzok opened this issue Sep 7, 2022 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@nalzok
Copy link
Contributor

nalzok commented Sep 7, 2022

Description

This is a follow-up issue of discussion #12245. The solution suggested by a project collaborator crashed the Python interpreter on my TPU VM, but it should work cross-platform.

Python 3.8.10 (default, Jun 22 2022, 20:18:18)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import jax.numpy as jnp
>>> import scipy.linalg
>>>
>>> def schur(x):
...   return jax.pure_callback(scipy.linalg.schur, (x, x), x)
...
>>> @jax.jit
... def f(x):
...   return schur(x)
...
>>> print(f(jnp.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]], jnp.float32)))
F0907 14:29:33.631891  939012 host_command_dispatcher.cc:83] Check failed: !handlers_by_run_ids_[queue_id].empty() Host command 50331929 triggered but no handler was registered, run id: 1
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/api.py", line 528, in cache_miss
    out_flat = xla.xla_call(
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 1963, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 1979, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/core.py", line 689, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl
    return compiled_fun(*args)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 841, in _execute_compiled
    out_bufs = token_handler(out_bufs, runtime_token)
  File "/home/qys/Research/label-shift/.venv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 808, in _remove_tokens
    output_token_buf, *token_bufs = token_bufs
jax._src.traceback_util.UnfilteredStackTrace: ValueError: not enough values to unpack (expected at least 1, got 0)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: not enough values to unpack (expected at least 1, got 0)
>>> *** Check failure stack trace: ***
    @     0x7f07a33b98e4  (unknown)
    @     0x7f07a33b93ca  (unknown)
    @     0x7f07a33b9c49  (unknown)
    @     0x7f079ff89e79  (unknown)
    @     0x7f079ff70f66  (unknown)
    @     0x7f079dfa17c3  (unknown)
    @     0x7f07a3297669  (unknown)
    @     0x7f07a32930fe  (unknown)
    @     0x7f0e8753b609  start_thread
https://symbolize.stripped_domain/r/?trace=7f07a33b98e4,7f07a33b93c9,7f07a33b9c48,7f079ff89e78,7f079ff70f65,7f079dfa17c2,7f07a3297668,7f07a32930fd,7f0e8753b608&map=068bf80b76f830987166dd8847d0248f:7f078dddc000-7f07a370ede0
https://symbolize.stripped_domain/r/?trace=7f0e8759900b,7f0e8759908f,7f07a33b9943,7f07a33b93c9,7f07a33b9c48,7f079ff89e78,7f079ff70f65,7f079dfa17c2,7f07a3297668,7f07a32930fd,7f0e8753b608&map=068bf80b76f830987166dd8847d0248f:7f078dddc000-7f07a370ede0
*** SIGABRT received by PID 937881 (TID 939012) on cpu 81 from PID 937881; ***
E0907 14:29:33.636138  939012 coredump_hook.cc:370] RAW: Remote crash data gathering hook invoked.
E0907 14:29:33.636153  939012 coredump_hook.cc:416] RAW: Skipping coredump since rlimit was 0 at process start.
E0907 14:29:33.636161  939012 client.cc:242] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0907 14:29:33.636166  939012 coredump_hook.cc:477] RAW: Sending fingerprint to remote end.
E0907 14:29:33.636174  939012 coredump_socket.cc:118] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0907 14:29:33.636184  939012 coredump_hook.cc:481] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0907 14:29:33.636190  939012 coredump_hook.cc:555] RAW: Discarding core.
F0907 14:29:33.631891  939012 host_command_dispatcher.cc:83] Check failed: !handlers_by_run_ids_[queue_id].empty() Host command 50331929 triggered but no handler was registered, run id: 1
E0907 14:29:33.898010  939012 process_state.cc:774] RAW: Raising signal 6 with default behavior
fish: “pipenv run python3terminated by signal SIGABRT (Abort)

What jax/jaxlib version are you using?

jax v0.3.17, jaxlab v0.3.15

Which accelerator(s) are you using?

TPU v3-8 with libtpu v1.3.0

Additional System Info

Python 3.8.10, TPU VM on GCP running Ubuntu 20.04 (Linux t1v-n-e307e167-w-0 5.13.0-1023-gcp #28~20.04.1-Ubuntu SMP Wed Mar 30 03:51:07 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux)

@nalzok nalzok added the bug Something isn't working label Sep 7, 2022
@sharadmv sharadmv self-assigned this Sep 8, 2022
@sharadmv
Copy link
Collaborator

sharadmv commented Sep 8, 2022

Seems like callbacks do not work right now on Cloud TPU VM because they are using the older stream_executor runtime but will soon switch to a newer runtime that does support callbacks. I'll monitor and update the issue when callbacks work.

@dlwh
Copy link
Contributor

dlwh commented Apr 17, 2023

@sharadmv any movement on this by any chance? Is there maybe an experimental vm image we can use? Thanks!

@sharadmv
Copy link
Collaborator

I think they should work now. Cc: @skye

@skye
Copy link
Collaborator

skye commented Apr 17, 2023

Ah yeah, this should work on Cloud TPU as of jax 0.4.8. I'm gonna close this issue, but please comment or reopen if you find things still aren't working!

@skye skye closed this as completed Apr 17, 2023
@dlwh
Copy link
Contributor

dlwh commented Apr 17, 2023

oops, yeah! i was running an older version of jax still. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants