-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: remove steps from pytorch callbacks [DET-3526] #831
Changes from 3 commits
81a68c3
aac12ab
681f9c8
6498fda
b16ead1
13b3b30
41051ca
85d45c2
49e8133
662c189
85b8d7a
efd3bff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,28 +16,32 @@ class PyTorchCallback: | |
|
||
.. warning:: | ||
If distributed training is enabled, every GPU will execute a copy of this callback | ||
(except for :meth:`on_validation_step_end` and :meth:`on_checkpoint_end`). To | ||
(except for :meth:`on_validation_end` and :meth:`on_checkpoint_end`). To | ||
configure a callback implementation to execute on a subset of GPUs, please condition | ||
your implementation on ``trial.context.distributed.get_rank()``. | ||
""" | ||
|
||
def on_train_step_start(self, step_id: int) -> None: | ||
def on_batch_start(self, batch_idx: int) -> None: | ||
""" | ||
Run before every training step begins. | ||
Run before every batch is trained. | ||
""" | ||
pass | ||
|
||
def on_train_step_end(self, step_id: int, metrics: Dict[str, Any]) -> None: | ||
def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: | ||
""" | ||
Run after every training step ends. | ||
Run after every batch is trained. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking: you need to add a warning about metrics here. Additionally, we need to decide if we want to average metrics here every batch if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with that, the tricky part here is we may not always have the metrics for the entire epoch (if the training was resumed mid-epoch). I would propose that we just don't provide averaged metrics in the training callbacks for now, and if/when the need for them arises, we can decide on the proper mechanism to do so. |
||
""" | ||
pass | ||
|
||
.. warning:: | ||
If distributed training is enabled, every GPU will execute a copy of | ||
this callback at the end of every training step. If | ||
``optimizations.average_training_metrics`` is enabled, then the | ||
``metrics`` will be averaged across all GPUs before the callback is | ||
executed. If ``optimizations.average_training_metrics`` is | ||
disabled, then the ``metrics`` will be local to the GPU. | ||
def on_epoch_start(self, epoch_idx: int) -> None: | ||
""" | ||
Run before every epoch begins. | ||
""" | ||
pass | ||
|
||
def on_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking: need to be more descriptive about what these metrics are. |
||
""" | ||
Run after every epoch ends. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking: We should be more descriptive here about the timing. Often There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hadn't thought that would be the expected behavior. thanks. |
||
""" | ||
pass | ||
|
||
|
@@ -50,15 +54,15 @@ def on_before_optimizer_step(self, parameters: Iterator) -> None: | |
# TODO(DET-3267): deprecate this when releasing pytorch flexible primitives. | ||
pass | ||
|
||
def on_validation_step_start(self) -> None: | ||
def on_validation_start(self) -> None: | ||
""" | ||
Run before every validation step begins. | ||
Run before every validation begins. | ||
""" | ||
pass | ||
|
||
def on_validation_step_end(self, metrics: Dict[str, Any]) -> None: | ||
def on_validation_end(self, metrics: Dict[str, Any]) -> None: | ||
""" | ||
Run after every validation step ends. | ||
Run after every validation ends. | ||
|
||
.. warning:: | ||
This callback only executes on the chief GPU when doing distributed training. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really different behavior than what we had before, and it doesn't seem like it would be terribly useful to me, since we already call train_batch on every batch.
I'm not sure what use cases we were trying to solve before though? Are those use cases still valid?
Git blame says @yoavz wrote these hooks, maybe @shiyuann or @aaron276h know the answer though?
And if those use cases are still valid, how would we address them after removing steps from the UX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was done so that users can do adjustments to the optimizer and model before training. Don't remember the exact use cases. I do think it makes sense to leave this callback in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose these changes because it's exactly what is in the ERD. I assumed they were discussed/decided already. I understand
on_batch_start
/on_batch_end
seem pretty not useful though. I think most of the context on the original decision is somewhere in #ml-ag slack.