-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RAY AIR][DOC][TorchTrainer] Rewrote the TorchTrainer code snippet as a working example #30492
Conversation
…nd use testcode, and ignore long output from train Signed-off-by: Jules Damji <[email protected]>
Signed-off-by: Jules Damji <[email protected]>
@@ -22,13 +22,14 @@ class TorchTrainer(DataParallelTrainer): | |||
The ``train_loop_per_worker`` function is expected to take in either 0 or 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the paragraph above this, it says "already" twice in the sentence -- it would be great to also fix this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. Fixed
|
||
from typing import Dict | ||
def train_loop_per_worker(config: Dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally this would have a bit more typing like Dict[str, Any]
(not sure what exactly the format here is) and also link to the format of the dict if possible :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we can add some typing.
@@ -45,32 +46,33 @@ def train_loop_per_worker(config: Dict): | |||
Inside the ``train_loop_per_worker`` function, you can use any of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally there would also be an example for the above paragraph somewhere, we can feel free to do that in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(You can discard this, I saw the usage is already shown in the example below -- maybe add (see example below)
.
|
||
def train_loop_per_worker(): | ||
# Report intermediate results for callbacks or logging and | ||
# checkpoint data. | ||
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like it was better without this line but if you prefer feel free to keep it :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the code is incomplete session.report(...)
and session.get_checkpoint()
, nice to explain it with a comment
session.report(...) | ||
|
||
# Returns dict of last saved checkpoint. | ||
# Session returns dict of last saved checkpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Say Get dict of last saved checkpoint.
here (same below)? Session returns
is a little confusing I think, since technically session is a python module here and it doesn't return anything :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, Get x makes sense then returns, since it's an explicit method call to sessin.get_xxx
self.layer2 = nn.Linear(layer_size, output_size) | ||
|
||
def forward(self, input): | ||
return self.layer2(self.relu(self.layer1(input))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would either keep the ReLU layer here or have only one linear layer -- composing two linear layers doesn't do anything and it would likely be confusing to users :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping ReLU does not make sense. Why add non-linearity to a linear data relationship With ReLU the model's does not converge, it goes on like a seesaw. Having two linear layers is not uncommon. We can put in a comment, you can also use one layer if you relationship between your data and outcome (target) is linear.
|
||
# Report and record metrics, checkpoint model at end of each | ||
# epoch | ||
session.report({"loss": loss.item(), "epoch": epoch}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing since epoch
is both here and below, @amogkam can you recommend how to do this? Most users will follow the example, so we should make sure we do this well :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One is reporting the loss per epoch as metrics, the other is there for checkpoint per epoch. Nice to have that metrics per epoch. If @amogkam feels strongly that we should not include "epoch" in the metrics to report, then I can remove that entity.
result = trainer.fit() | ||
|
||
# Get the loss metric from TorchCheckpoint tuple data dictionary | ||
best_checkpoint_loss = result.metrics['loss'] | ||
# print(f"best loss: {best_checkpoint_loss:.4f}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should you remove the #
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah a bit redundant since the code is self-explanatory.
train_loop_per_worker: The training function to execute. | ||
This can either take in no arguments or a ``config`` dict. | ||
train_loop_config: Configurations to pass into | ||
train_loop_config: Configurations to pass into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indentation should be kept here, right? Otherwise it won't render correctly :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Args <parameter_name>:
should be indented on the same level. That is:
Args:
arg_1: ...
arg_2: ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this, this is great! There are a few small comments you should address before merging :)
Signed-off-by: Jules Damji <[email protected]>
Signed-off-by: Jules Damji <[email protected]>
Signed-off-by: Jules Damji <[email protected]>
…ppress output, all tests pass, incorporated most feedback Signed-off-by: Jules Damji <[email protected]>
Signed-off-by: Philipp Moritz <[email protected]>
Signed-off-by: Philipp Moritz <[email protected]>
Signed-off-by: Philipp Moritz <[email protected]>
Looks like some of the hugging face servers are down, which is independent of this PR, we can merge it after the tests ran. |
Signed-off-by: Jules Damji [email protected]
|
… a working example (ray-project#30492) Signed-off-by: Jules Damji [email protected] - Rewrote the code snippet as it was not working - Removed python-code block directives; instead use testcode and testoutput. This will test code if it runs in the CI - Ignore the output since we get loads of output from the three workers - Assert that the loss converges with the training data within specified epochs - Tested code end-to-end Signed-off-by: Weichen Xu <[email protected]>
… a working example (ray-project#30492) Signed-off-by: Jules Damji [email protected] - Rewrote the code snippet as it was not working - Removed python-code block directives; instead use testcode and testoutput. This will test code if it runs in the CI - Ignore the output since we get loads of output from the three workers - Assert that the loss converges with the training data within specified epochs - Tested code end-to-end Signed-off-by: tmynn <[email protected]>
Signed-off-by: Jules Damji [email protected]
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.