-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Added Ray Logging to MosaicTrainer #29620
Conversation
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
…/ray into init_mosaic_trainer_api Signed-off-by: ilee300a <[email protected]>
…/ray into init_mosaic_trainer_api
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
|
||
Because ray's metric dataframe will not include new keys that is reported after the | ||
very first report call, any logged information with the keys not included in the | ||
first batch checkpoint would not be retrievable after training. In other words, if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are users expected to know what these keys upfront? Looking at the Mosaic code, it seems that these keys are automatically added by Mosaic algorithms and callbacks, so I don't think users are aware of what these keys are in order to provide them here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should fix the underlying bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert "lr-DecoupledSGDW/group0" in metrics_columns | ||
assert "grad_l2_norm/step" in metrics_columns | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make these newly added tests more robust?
- do the number of rows in the dataframe match with what we expect?
- we should add a dummy callback that reports to the logger, and then check to make sure the values in the dataframe match with what we expect.
- are there any other edge cases you can think of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests have been updated to check
- the number of rows in the dataframe
- value reported by a dummy callback
- whether null value exists for the reported composer monitoring callbacks
|
||
def epoch_checkpoint(self, state: State, logger: Logger) -> None: | ||
del logger # unused | ||
session.report(self.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't report on at both the batch level and the epoch level. Each call to session.report
should be 1 iteration, so if we log at both, we will be double counting.
For now, I would say let's just log only at every epoch. We can see in the future if we want to give users the ability to configure this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll definitely have users want to do it on either level - that was the case with HF, where we started with epochs only and had to add steps too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely agree @Yard1. I’m thinking we can default to epoch for now and then add batch support in a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been updated!
We are reporting each epoch now, but we also report after the fit call, in case the training ends before an epoch checkpoint call could be made. This adds an extra report call, in which an epoch checkpoint can be double counted. -- but we can also make it so that this last call is made only if there are extra batch runs after the last epoch run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mentioned change above has been applied.
Thanks @ilee300a! Left some comments on improving the UX and on the testing |
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we update test plan of the pr with logged data sample and final charts ?
) | ||
test_dataset = torch.utils.data.Subset( | ||
datasets.CIFAR10( | ||
data_directory, train=False, download=True, transform=cifar10_transforms | ||
), | ||
list(range(64)), | ||
list(range(2048)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is batch size for training const and this inline number ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated so that it is BATCH_SIZE *10
just like the train dataset
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Co-authored-by: Amog Kamsetty <[email protected]> Signed-off-by: ilee300a <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ilee300a! lgtm overall, just left some minor comments
Signed-off-by: ilee300a <[email protected]>
Thanks @ilee300a! Please ping again once tests are passing for merge |
Added RayLogger to MosaicTrainer to relay all reported information. RayLogger is a subclass of LoggerDestination, just like all other native composer loggers. The information to be logged is given via log_metrics call, which is saved in the RayLogger object. The logger reports the logged information every batch checkpoint and epoch checkpoint. All other composer loggers besides RayLogger loggers are removed from the trainer. Note that because at the moment, the result metrics_dataframe will only include the keys that are reported in the very first report call, to have metrics that are not reported every batch in the final metrics dataframe, the keys should be passed in via 'log_keys' in the trainer_init_config. Co-authored-by: Amog Kamsetty <[email protected]> Signed-off-by: ilee300a <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
Added RayLogger to MosaicTrainer to relay all reported information.
RayLogger
is a subclass ofLoggerDestination
, just like all other native composer loggers. The information to be logged is given vialog_metrics
call, which is saved in theRayLogger
object. The logger reports the logged information every batch checkpoint and epoch checkpoint. All other composer loggers besidesRayLogger
loggers are removed from the trainer.Note that because at the moment, the result
metrics_dataframe
will only include the keys that are reported in the very first report call, to have metrics that are not reported every batch in the final metrics dataframe, the keys should be passed in via'log_keys'
in thetrainer_init_config
.Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.