-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: remove steps from pytorch callbacks [DET-3526] #831
feat: remove steps from pytorch callbacks [DET-3526] #831
Conversation
configure a callback implementation to execute on a subset of GPUs, please condition | ||
your implementation on ``trial.context.distributed.get_rank()``. | ||
""" | ||
|
||
def on_train_step_start(self, step_id: int) -> None: | ||
def on_batch_start(self, batch_idx: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really different behavior than what we had before, and it doesn't seem like it would be terribly useful to me, since we already call train_batch on every batch.
I'm not sure what use cases we were trying to solve before though? Are those use cases still valid?
Git blame says @yoavz wrote these hooks, maybe @shiyuann or @aaron276h know the answer though?
And if those use cases are still valid, how would we address them after removing steps from the UX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was done so that users can do adjustments to the optimizer and model before training. Don't remember the exact use cases. I do think it makes sense to leave this callback in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose these changes because it's exactly what is in the ERD. I assumed they were discussed/decided already. I understand on_batch_start
/on_batch_end
seem pretty not useful though. I think most of the context on the original decision is somewhere in #ml-ag slack.
|
||
def on_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: | ||
""" | ||
Run after every epoch ends. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blocking: We should be more descriptive here about the timing. Often epoch_end
is considered to be the end of training and evaluating on a full dataset, which is not the case for us. Might be even worth renaming this callback to on_train_epoch_end()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hadn't thought that would be the expected behavior. thanks.
""" | ||
Run after every training step ends. | ||
Run after every batch is trained. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blocking: you need to add a warning about metrics here. Additionally, we need to decide if we want to average metrics here every batch if optimizations.average_training_metrics
is enabled and if on_batch_end
is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like on_batch_end
we shouldn't but on_train_epoch_end
we should?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with that, the tricky part here is we may not always have the metrics for the entire epoch (if the training was resumed mid-epoch). I would propose that we just don't provide averaged metrics in the training callbacks for now, and if/when the need for them arises, we can decide on the proper mechanism to do so.
""" | ||
pass | ||
|
||
def on_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blocking: need to be more descriptive about what these metrics are.
configure a callback implementation to execute on a subset of GPUs, please condition | ||
your implementation on ``trial.context.distributed.get_rank()``. | ||
""" | ||
|
||
def on_train_step_start(self, step_id: int) -> None: | ||
def on_batch_start(self, batch_idx: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was done so that users can do adjustments to the optimizer and model before training. Don't remember the exact use cases. I do think it makes sense to leave this callback in.
@shiyuann after our discussion, it seems like there are no concerns with breaking these interfaces. the only question is, do you want to try to break it in sync with your pytorch changes? |
@stoksc This break in Pytorch callbacks should be orthogonal from my change. |
@shiyuann Yeah, @aaron276h was just thinking maybe if we're making breaking changes, we should ship them together. |
Doesn't really make sense to me to have |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks even better the second time around
""" | ||
# TODO(DET-3267): deprecate this when releasing pytorch flexible primitives. | ||
pass | ||
|
||
def on_validation_step_start(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-blocking: might be worth adding a comment that this should be removed in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Description
As part of #remove-steps, we also need to remove the concept of steps from the
PytorchCallback
s API.Test Plan
Commentary (optional)
Currently, this is looking to be the only breaking change of #remove-steps. Everything else has been done by allowing new and old interfaces to coexist and deprecating the old ones. Open to suggestions on when/how to land this, or to leave the old interfaces intact and mark them as deprecated for at least a while.