Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MPS backend #90

Open
schroedk opened this issue Mar 26, 2024 · 3 comments
Open

Add support for MPS backend #90

schroedk opened this issue Mar 26, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@schroedk
Copy link
Collaborator

schroedk commented Mar 26, 2024

The current class TorchModel has the following init:

class TorchModel(ABC, ToStringMixin):
    """
    sensAI abstraction for torch models, which supports one-line training, allows for convenient model application,
    has basic mechanisms for data scaling, and soundly handles persistence (via pickle).
    An instance wraps a torch.nn.Module, which is constructed on demand during training via the factory method
    createTorchModule.
    """
    log: logging.Logger = log.getChild(__qualname__)

    def __init__(self, cuda=True) -> None:
        self.cuda: bool = cuda
        self.module: Optional[torch.nn.Module] = None
        self.outputScaler: Optional[TensorScaler] = None
        self.inputScaler: Optional[TensorScaler] = None
        self.trainingInfo: Optional[TrainingInfo] = None
        self._gpu: Optional[int] = None
        self._normalisationCheckThreshold: Optional[int] = 5

and is responsible for putting the inputs of the torch model the corresponding device (here):

if self._is_cuda_enabled():
            torch.cuda.set_device(self._gpu)
            inputs = [t.cuda() for t in inputs]

I would like to suggest to include support for different torch backends, in particular the MPS-backend for Apple machines.

My first impression is, that this could be a breaking change, so let's discuss here.

@schroedk schroedk added the enhancement New feature or request label Mar 26, 2024
@schroedk
Copy link
Collaborator Author

@opcode81 @MischaPanch what do you think?

@MischaPanch
Copy link
Collaborator

MischaPanch commented Apr 2, 2024

It would be useful for mac users in the short term, but I feel like we should rather focus on simplifying the use of custom trainers and custom models. In the end, sensai is neither a model nor a trainer library, but rather a wrapper library allowing rapid prototyping with any models.

Maintaining all kinds of things like improvements to how models are created (torch.compile), parallelized across devices (lightning, jax, other things), exported to different hardware (GPU, TPU, M1/2/3), and generally trained would be near impossible and unnecessary.

Instead, we could generalize the current NNOptimizer and TorchModel and provide support for things that fulfil basic interfaces

@opcode81
Copy link
Owner

opcode81 commented Apr 3, 2024

@schroedk this could presumably be handled by specifying a torch device, right?
sensai's device handling (using the combination of cuda flag and gpu index) dates back to torch 0.4 days. I think we should switch to the more modern use of device strings/objects in order to address this and generalise this issue.

While the use of gpu is local to the optimiser, uses of the cuda flag are scattered throughout (over 100 occurrences), but I think it's ultimately quick to change. I believe it could even be unnecessary for a few places to be aware of the device, but I'm not sure.

And we should make sure that old models can still be loaded/remain compatible.

@schroedk, what do you think?

As far as Mischa's statements are concerned: I think it makes sense to maintain the current torch wrapper, as it's very convenient to use and provides what is needed in most cases. We could think about adding further, more lightweight wrappers that could enable all sorts of things though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants