Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint tfds data iterator #954

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

mattdonati
Copy link

Enable deterministic training with preemption when using tfds pipeline by checkpointing data iterator.

Creates a checkpoint handler for data iterator that implements orbax.checkpoint.CheckpointHandler, similar to https://github.com/google/grain/blob/main/grain/_src/python/checkpoint_handlers.py. Handler utilizes tf.train.Checkpoint to save and restore iterator.

Makes checkpointing the data iterator optional, since this method will save large checkpoints. Adds a bool flag to base.yml

Async checkpointing is handled at the level of the orbax checkpoint manager.

Updates input pipeline description to reflect option to checkpoint tfds iterator.

Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for using and contributing to our repo!
I would like to understand the use case of this feature. You must have checked out our doc and know about the Grain pipeline which is optimized for ckpt support of data iterators. It's efficient and will save very small data iterator ckpt containing only indices, therefore it's the recommended way for use cases that need check-pointing data iterators. We would like to hear if there's any difficulty in adopting Grain for your use case?

And I tested the branch with this command on v4-8 python3 MaxText/train.py MaxText/configs/base.yml steps=20 per_device_batch_size=8.0 learning_rate=3e-4 enable_checkpointing=true base_output_directory=gs://aireenmei-multipod/tfds_ckpt dataset_path=gs://maxtext_dataset tfds_iter_checkpointing=True run_name=$(date +%m%d-$H%M) checkpoint_period=10 and got this error (no error if tfds_iter_checkpointing=False): https://gist.github.com/aireenmei/42a4c4e0dd8caed0b7ce8182f5ca8292

@@ -43,6 +43,8 @@ async_checkpointing: True
checkpoint_period: 10_000
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False
# enable checkpointing of tfds data iterator, for fully deterministic training. saves large checkpoints.
tfds_iter_checkpointing: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it saves large checkpoint, it would be better to set the default to False

```
tfds_iter_checkpointing: True
```
Note that deteriminism with preemption requires checkpointing the data iterator, and the checkpoints will be larger in size.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more details on what contributes to the size of the checkpoint? Would be good to provide a way to estimate the size to help the users make plans

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants