-
Notifications
You must be signed in to change notification settings - Fork 290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpoint tfds data iterator #954
base: main
Are you sure you want to change the base?
Conversation
Decided to add a flag because checkpointing the iterator creates large checkpoints and may not be preferred by all users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for using and contributing to our repo!
I would like to understand the use case of this feature. You must have checked out our doc and know about the Grain pipeline which is optimized for ckpt support of data iterators. It's efficient and will save very small data iterator ckpt containing only indices, therefore it's the recommended way for use cases that need check-pointing data iterators. We would like to hear if there's any difficulty in adopting Grain for your use case?
And I tested the branch with this command on v4-8 python3 MaxText/train.py MaxText/configs/base.yml steps=20 per_device_batch_size=8.0 learning_rate=3e-4 enable_checkpointing=true base_output_directory=gs://aireenmei-multipod/tfds_ckpt dataset_path=gs://maxtext_dataset tfds_iter_checkpointing=True run_name=$(date +%m%d-$H%M) checkpoint_period=10
and got this error (no error if tfds_iter_checkpointing=False): https://gist.github.com/aireenmei/42a4c4e0dd8caed0b7ce8182f5ca8292
@@ -43,6 +43,8 @@ async_checkpointing: True | |||
checkpoint_period: 10_000 | |||
# enables one replica to read the ckpt then broadcast to the rest | |||
enable_single_replica_ckpt_restoring: False | |||
# enable checkpointing of tfds data iterator, for fully deterministic training. saves large checkpoints. | |||
tfds_iter_checkpointing: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it saves large checkpoint, it would be better to set the default to False
``` | ||
tfds_iter_checkpointing: True | ||
``` | ||
Note that deteriminism with preemption requires checkpointing the data iterator, and the checkpoints will be larger in size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add more details on what contributes to the size of the checkpoint? Would be good to provide a way to estimate the size to help the users make plans
Enable deterministic training with preemption when using tfds pipeline by checkpointing data iterator.
Creates a checkpoint handler for data iterator that implements orbax.checkpoint.CheckpointHandler, similar to https://github.com/google/grain/blob/main/grain/_src/python/checkpoint_handlers.py. Handler utilizes tf.train.Checkpoint to save and restore iterator.
Makes checkpointing the data iterator optional, since this method will save large checkpoints. Adds a bool flag to base.yml
Async checkpointing is handled at the level of the orbax checkpoint manager.
Updates input pipeline description to reflect option to checkpoint tfds iterator.