-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[air] Add horovod trainer. #23437
[air] Add horovod trainer. #23437
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I think we should add an example as well as show how to use with with the predictors as well.
scaling_config=scaling_config, | ||
) | ||
trainer.fit() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also e2e test with TorchPredictor as well, like what we have in test_torch_trainer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it very heard to use the predictor interface for this image classification problem.
For the sake of verifying the training process, I just use the native pytorch DataLoader and Tensor stuff (and not the predictor).
I may need to just put a linear training if we want to cover the predictor part in an e2e fashion for horovod.
Although in terms of test coverage, I think the corresponding test_pytorch/tensorflow_trainer got it covered already.
I find it very heard to use the predictor interface for this image classification problem. For the sake of verifying the model, I just use the native pytorch DataLoader and Tensor stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments, but overall looks good to me!
num_epochs = config.get("num_epochs", 10) | ||
log_interval = config.get("log_interval", 10) | ||
use_cuda = config.get("use_cuda", False) | ||
save_model_as_dict = config.get("save_model_as_dict", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this is always going to be False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, I have a test, where save_model_as_dict
is True so that we can test that path as well. So it should be taking effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah got it
|
||
.. code-block:: python | ||
|
||
class Net(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use a simpler example for the docstring 🙂
I think we can just copy the one for TorchTrainer, except add in the hvd.init(), hvd.DistributedOptimizer, etc. lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point.
I updated with a simple linear example instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just a few minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - I think we can do some extra work down the line to clean up the example, but doesn't affect the implementation of the Trainer.
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.