Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I'm working on this for my own use, so I can train in Google Cloud TPU instances. So far I've done the minimum adjustments necessary to make it work for me, but I can try to improve it if you are interested in this feature. Some of the topics remaining to be addressed are:
load_resume_state
.xmp.MpModelWrapper()
.torch_xla.distributed.parallel_loader
instead ofEnlargedSampler
. I made some preliminary tests and both seem to work fine, so I keptEnlargedSampler
for simplicity.XLA is enabled by using a new
accelerator: xla
option (default:cuda
) in the configuration file. Settingnum_gpu
toauto
would make the training script use all the available TPU devices. Note that XLA supports either1
device or all of them (8
, typically). It is not possible to use more than 1 and less than the number of installed TPU cores, but I haven't added a test for that: the code will crash if you use something less than the maximum.You need to launch the training process without any parallel launcher – the code in this PR forks the process automatically, following the recommended approach I've seen in all PyTorch-XLA tutorials. So, you launch as usual:
and the script will parallelize across devices nonetheless.
One limitation is that this only works to parallelize training across TPU nodes in the same computer; it does not work with distributed systems.
I tried to minimize the impact on model architectures for them to support XLA training. So far I have only adapted
sr_model
and the required change was minimal, but this could vary in other architectures.Overall, I feel a bit unhappy that the parallelization works by special-casing a few places in the existing code. This is caused in part by the peculiarities of PyTorch-XLA, which clash with the previous assumptions that distributed training would require a common paradigm that no longer holds: a distributed "launcher", and the use of the
torch.distributed
interface. However, it is the simplest way I could think of, without having to perform a major refactor of the existing code, or completely duplicating the training script. If you can think of a better alternative, by all means please let me know.Thanks a lot for making this project available!