You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Creating this issue to refer for implementing auto accelerator for JAX users.
The thing is that we need to be opinionated about how other frameworks decide to support devices and what devices they support. So the semantics of `auto` will need to be framework-specific.
For instance with JAX you can call device_put on GPU, but you don't have mps. So we can add JAX with a different implementation for _choose_gpu_accelerator_backend.
We should actually have framework-specific _choose_gpu_accelerator_backend functions, like _choose_gpu_accelerator_torch etc.
Creating this issue to refer for implementing auto accelerator for JAX users.
For instance with JAX you can call
device_put
on GPU, but you don't havemps
. So we can add JAX with a different implementation for_choose_gpu_accelerator_backend
.We should actually have framework-specific
_choose_gpu_accelerator_backend
functions, like_choose_gpu_accelerator_torch
etc.Originally posted by @lantiga in #44 (comment)
The text was updated successfully, but these errors were encountered: