-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Add support for metrics aggregation #22099
Conversation
@amogkam @matthewdeng gentle ping 😄 |
cc @Yard1 |
Can we implement this as an Also, there would have to be some way for users to specify which keys to average, and also how many samples per key each worker is processing. There might be some workers that are processing more batches than the others, so we can't just do an even average for all the workers. |
@amogkam wouldn't implementing this as a result preprocessor require users to create their own callback subclasses? furthermore, each callback would have to rerun the same code, unless we implement some sort of caching for result preprocessors. |
@Yard1 right we would also have to provide a way for users to easily configure the preprocessors that are used for the callbacks |
@jwyyy I would recommend the below as the implementation:
Let me know if this makes sense! |
Hi @amogkam @Yard1 @matthewdeng , I revised the PR based on the previous discussion. Average metrics are not appended to the Please let me know your comments and feedback. Thank you very much! (Failed tests seem unrelated to this PR.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation makes sense to me for the averaging use-case, tagging @Yard1 for some input on the usability/extensibility aspect!
I updated the PR with the support for a default list of aggregated metrics and customized aggregation methods. Looking forward to more discussion on where to save aggregated metrics. |
Hi @matthewdeng @Yard1, thank you for your suggestions! I have revised the implementation based on this. One modification: to get the average weight before calling Currently, aggregated results are still added to the existing result Dicts. Do you have any agreement on how to store them? I assume we want to implement a customized dict/list class (link) (but the integration with callbacks wasn't completely sorted out). Can we add a new argument to callbacks and allow them to use aggregated metrics directly just like Please let me know your comments when you have time to review it again. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great - thanks a ton for the many iterations!
Regarding where to store these, after thinking about it some more it seems like the current options worth discussing are:
- Add aggregate metrics to all workers.
- Add aggregate metrics to 0th worker.
- Change the entire pattern of passing reported metrics to use a more complex dictionary as opposed to the current list.
In the long run I do think we'll move towards (3), but it's out of scope for this PR. There are some additional considerations we'll need to think more about such as the general usability of Callbacks and integration with Tune (currently we only pass worker 0 results and don't support Train Callbacks, but we'll eventually need to support passing aggregate metrics to Tune).
For (1) vs (2), the tradeoff is verbosity (e.g. if the user wants to log all results to a JSON file) vs. configurability (e.g. if the user wants to log only worker 2 results + aggregate results to TensorBoard). My initial hunch is to go with (2) and see if users request this type of customizability. @jwyyy @Yard1 thoughts?
Also cc @amogkam since this is worth considering for the API redesign.
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
@matthewdeng Thank you very much for your comments! I will address all issues by tomorrow.
I think in Also, callbacks have |
I think for now we should go with option (1), due to what @jwyyy has pointed out. That being said, it would be great if we could get a followup PR quickly to move to a more complex data structure, as that would be the best. I still like my list subclass idea 😂 |
If there is an agreement, I can follow up with another PR to implement the data structure idea 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionally this looks good to me! Could you add tests and docs for these new changes?
@amogkam can you take a pass?
Sure! I will add some tests and update doc soon. Thanks a lot! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the work on this @jwyyy! Left some comments.
Can we also add tests for this in test_results_preprocessors
?
One more point I want to clarify, what is the behavior if some workers report valid results for a key, but some workers do not? Should we ignore the key in it's entirety or still do the aggregation, but only for the workers with valid values? also cc @matthewdeng @Yard1 for thoughts on this.
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for the updates @jwyyy! I think this should be our final round, left some minor comments!
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_utils.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_utils.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_utils.py
Outdated
Show resolved
Hide resolved
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_utils.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jwyyy! Just one minor nit, but other than that this lgtm!
[(AverageResultsPreprocessor, 2.0), (MaxResultsPreprocessor, 3.0)], | ||
) | ||
def test_warning_in_aggregate_results_preprocessors( | ||
caplog, results_preprocessor, expected_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I never knew about the caplog fixture :)
Why are these changes needed?
This PR allows users to aggregate metrics returned from all workers.
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.