Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Added Ray Logging to MosaicTrainer #29620

Merged
merged 82 commits into from
Oct 27, 2022
Merged

Conversation

ilee300a
Copy link
Contributor

Added RayLogger to MosaicTrainer to relay all reported information.

RayLogger is a subclass of LoggerDestination, just like all other native composer loggers. The information to be logged is given via log_metrics call, which is saved in the RayLogger object. The logger reports the logged information every batch checkpoint and epoch checkpoint. All other composer loggers besides RayLogger loggers are removed from the trainer.

Note that because at the moment, the result metrics_dataframe will only include the keys that are reported in the very first report call, to have metrics that are not reported every batch in the final metrics dataframe, the keys should be passed in via 'log_keys' in the trainer_init_config.

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

ilee300a and others added 30 commits October 11, 2022 14:22
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: ilee300a <[email protected]>

Because ray's metric dataframe will not include new keys that is reported after the
very first report call, any logged information with the keys not included in the
first batch checkpoint would not be retrievable after training. In other words, if
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are users expected to know what these keys upfront? Looking at the Mosaic code, it seems that these keys are automatically added by Mosaic algorithms and callbacks, so I don't think users are aware of what these keys are in order to provide them here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should fix the underlying bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert "lr-DecoupledSGDW/group0" in metrics_columns
assert "grad_l2_norm/step" in metrics_columns


Copy link
Contributor

@amogkam amogkam Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make these newly added tests more robust?

  • do the number of rows in the dataframe match with what we expect?
  • we should add a dummy callback that reports to the logger, and then check to make sure the values in the dataframe match with what we expect.
  • are there any other edge cases you can think of?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests have been updated to check

  • the number of rows in the dataframe
  • value reported by a dummy callback
  • whether null value exists for the reported composer monitoring callbacks

python/ray/train/tests/test_mosaic_trainer.py Outdated Show resolved Hide resolved

def epoch_checkpoint(self, state: State, logger: Logger) -> None:
del logger # unused
session.report(self.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't report on at both the batch level and the epoch level. Each call to session.report should be 1 iteration, so if we log at both, we will be double counting.

For now, I would say let's just log only at every epoch. We can see in the future if we want to give users the ability to configure this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll definitely have users want to do it on either level - that was the case with HF, where we started with epochs only and had to add steps too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely agree @Yard1. I’m thinking we can default to epoch for now and then add batch support in a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been updated!
We are reporting each epoch now, but we also report after the fit call, in case the training ends before an epoch checkpoint call could be made. This adds an extra report call, in which an epoch checkpoint can be double counted. -- but we can also make it so that this last call is made only if there are extra batch runs after the last epoch run.

Copy link
Contributor Author

@ilee300a ilee300a Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mentioned change above has been applied.

@amogkam
Copy link
Contributor

amogkam commented Oct 26, 2022

Thanks @ilee300a! Left some comments on improving the UX and on the testing

Copy link
Member

@jiaodong jiaodong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update test plan of the pr with logged data sample and final charts ?

)
test_dataset = torch.utils.data.Subset(
datasets.CIFAR10(
data_directory, train=False, download=True, transform=cifar10_transforms
),
list(range(64)),
list(range(2048)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is batch size for training const and this inline number ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated so that it is BATCH_SIZE *10 just like the train dataset

Co-authored-by: Amog Kamsetty <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ilee300a! lgtm overall, just left some minor comments

python/ray/train/mosaic/_mosaic_utils.py Show resolved Hide resolved
python/ray/train/mosaic/_mosaic_utils.py Show resolved Hide resolved
@amogkam
Copy link
Contributor

amogkam commented Oct 27, 2022

Thanks @ilee300a! Please ping again once tests are passing for merge

@amogkam amogkam merged commit 28e84b8 into ray-project:master Oct 27, 2022
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
Added RayLogger to MosaicTrainer to relay all reported information.

RayLogger is a subclass of LoggerDestination, just like all other native composer loggers. The information to be logged is given via log_metrics call, which is saved in the RayLogger object. The logger reports the logged information every batch checkpoint and epoch checkpoint. All other composer loggers besides RayLogger loggers are removed from the trainer.

Note that because at the moment, the result metrics_dataframe will only include the keys that are reported in the very first report call, to have metrics that are not reported every batch in the final metrics dataframe, the keys should be passed in via 'log_keys' in the trainer_init_config.

Co-authored-by: Amog Kamsetty <[email protected]>
Signed-off-by: ilee300a <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants