You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX-Triton currently doesn't work. On trying the add example I get the following error which I suspect is due to some CUDA specific things in the triton_lib.py
I can run other tests etc. that are requested here to help make progress on this.
There is now support for flash-attention2 on AMD GPUs with PyTorch. They use the triton kernels for the same.
https://github.com/ROCmSoftwarePlatform/flash-attention
JAX-Triton currently doesn't work. On trying the add example I get the following error which I suspect is due to some CUDA specific things in the triton_lib.py
I can run other tests etc. that are requested here to help make progress on this.
The text was updated successfully, but these errors were encountered: