-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for MPS backend #90
Comments
@opcode81 @MischaPanch what do you think? |
It would be useful for mac users in the short term, but I feel like we should rather focus on simplifying the use of custom trainers and custom models. In the end, sensai is neither a model nor a trainer library, but rather a wrapper library allowing rapid prototyping with any models. Maintaining all kinds of things like improvements to how models are created (torch.compile), parallelized across devices (lightning, jax, other things), exported to different hardware (GPU, TPU, M1/2/3), and generally trained would be near impossible and unnecessary. Instead, we could generalize the current |
@schroedk this could presumably be handled by specifying a torch device, right? While the use of And we should make sure that old models can still be loaded/remain compatible. @schroedk, what do you think? As far as Mischa's statements are concerned: I think it makes sense to maintain the current torch wrapper, as it's very convenient to use and provides what is needed in most cases. We could think about adding further, more lightweight wrappers that could enable all sorts of things though. |
The current class TorchModel has the following init:
and is responsible for putting the inputs of the torch model the corresponding device (here):
I would like to suggest to include support for different torch backends, in particular the MPS-backend for Apple machines.
My first impression is, that this could be a breaking change, so let's discuss here.
The text was updated successfully, but these errors were encountered: