You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I saw the pallas concept in official latest jax docs, and follow up the pallas quickstart section.
I installed latest jaxlib and jax using github head.
I encountered the following error.
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/sh0416/research/scripts/pallas_quickstart.py", line 20, in <module>
print(add_vectors(jnp.arange(8), jnp.arange(8)))
File "/home/sh0416/research/scripts/pallas_quickstart.py", line 17, in add_vectors
return pl.pallas_call(add_vectors_kernel, out_shape=out_shape)(x, y)
File "/home/sh0416/research/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 353, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'pallas_call' not found for platform cuda
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/sh0416/research/scripts/pallas_quickstart.py", line 20, in <module>
print(add_vectors(jnp.arange(8), jnp.arange(8)))
NotImplementedError: MLIR translation rule for primitive 'pallas_call' not found for platform cuda
Do I have to install something different? or has it not fully upstreamed yet?
The text was updated successfully, but these errors were encountered:
I saw the pallas concept in official latest jax docs, and follow up the pallas quickstart section.
I installed latest jaxlib and jax using github head.
I encountered the following error.
Do I have to install something different? or has it not fully upstreamed yet?
The text was updated successfully, but these errors were encountered: