Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA support #512

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft

XLA support #512

wants to merge 3 commits into from

Conversation

pcuenca
Copy link

@pcuenca pcuenca commented Feb 23, 2022

I'm working on this for my own use, so I can train in Google Cloud TPU instances. So far I've done the minimum adjustments necessary to make it work for me, but I can try to improve it if you are interested in this feature. Some of the topics remaining to be addressed are:

  • State restoration in load_resume_state.
  • Net and weights loading optimization. We should load the network and any pre-trained weights just once in the main process, then copy to the rest of the devices using xmp.MpModelWrapper().
  • Maybe use torch_xla.distributed.parallel_loader instead of EnlargedSampler. I made some preliminary tests and both seem to work fine, so I kept EnlargedSampler for simplicity.
  • Adapt tests / inference code. So far I only need the training loop for my purposes.
  • Documentation updates.

XLA is enabled by using a new accelerator: xla option (default: cuda) in the configuration file. Setting num_gpu to auto would make the training script use all the available TPU devices. Note that XLA supports either 1 device or all of them (8, typically). It is not possible to use more than 1 and less than the number of installed TPU cores, but I haven't added a test for that: the code will crash if you use something less than the maximum.

You need to launch the training process without any parallel launcher – the code in this PR forks the process automatically, following the recommended approach I've seen in all PyTorch-XLA tutorials. So, you launch as usual:

python basicsr/train.py -opt <your_configuration.yml>

and the script will parallelize across devices nonetheless.

One limitation is that this only works to parallelize training across TPU nodes in the same computer; it does not work with distributed systems.

I tried to minimize the impact on model architectures for them to support XLA training. So far I have only adapted sr_model and the required change was minimal, but this could vary in other architectures.

Overall, I feel a bit unhappy that the parallelization works by special-casing a few places in the existing code. This is caused in part by the peculiarities of PyTorch-XLA, which clash with the previous assumptions that distributed training would require a common paradigm that no longer holds: a distributed "launcher", and the use of the torch.distributed interface. However, it is the simplest way I could think of, without having to perform a major refactor of the existing code, or completely duplicating the training script. If you can think of a better alternative, by all means please let me know.

Thanks a lot for making this project available!

Note: the call to the new `optimizer_step` method has only been added to
`SRModel`.
This is still hackish and being tested.
@pcuenca pcuenca marked this pull request as draft February 23, 2022 17:33
@lgtm-com
Copy link

lgtm-com bot commented Feb 23, 2022

This pull request introduces 5 alerts when merging 36dca17 into 6697f41 - view on LGTM.com

new alerts:

  • 3 for Except block handles 'BaseException'
  • 2 for Unused import

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant