Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync with torchtitan #2

Closed
wants to merge 263 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
263 commits
Select commit Hold shift + click to select a range
3d1e9ea
update readme (#74)
wanchaol Feb 24, 2024
98a0f79
move config folder to root and adjust options (#83)
wanchaol Feb 24, 2024
629652b
add iter time tracking via cuda events, add data loading times, add c…
lessw2020 Feb 26, 2024
c866a64
Fill missing options in toml file wih argparse defaults (#91)
gnadathur Feb 26, 2024
78a1643
support infinite loop over alpaca dataset
tianyu-l Feb 26, 2024
6d9e4e6
Add color to console output if local logging, auto avoid color loggin…
lessw2020 Feb 27, 2024
e987ac3
update GPU metrics logging to GiB (gibibytes) (#95)
lessw2020 Feb 27, 2024
62ff09d
improve TensorBoard instructions in README
tianyu-l Feb 27, 2024
60f6b0d
Enable libUV for torchtrain (#98)
gnadathur Feb 28, 2024
7acab70
use warmup steps for lr scheduler, ban steps == -1 (#99)
wanchaol Feb 29, 2024
d5c27a9
Add llama 7B config (#100)
wanchaol Feb 29, 2024
2c8cec2
add selective activation checkpointing
tianyu-l Feb 29, 2024
452baee
Add job description field in toml (#101)
gnadathur Mar 1, 2024
eb3fdd0
fix 2D parallel crash caused by all-reduce on 2D world_mesh
tianyu-l Mar 2, 2024
2682144
Load missing keys default from argparse (#111)
gnadathur Mar 5, 2024
afbf62a
Add meta_init, enable it as default init process (#84)
lessw2020 Mar 5, 2024
f91f97a
Fix feedback from PR 111 (#113)
gnadathur Mar 5, 2024
1a180ee
fix SP minor issues
tianyu-l Mar 5, 2024
ed04380
enable loss parallel in SP
tianyu-l Mar 6, 2024
41f5172
Float8_experimental option for training (#102)
drisspg Mar 6, 2024
680f1aa
add miniPile dataset for pretraining, 1M entries (solves the 'out of …
lessw2020 Mar 7, 2024
85263f7
add data loading option to load from local file system
tianyu-l Mar 7, 2024
3c51744
add llama 13B configs
wanchaol Mar 9, 2024
649cf0b
add llama 70B toml
wanchaol Mar 9, 2024
ab05f66
set betas and weight decay for optimizers
wanchaol Mar 9, 2024
66c196b
Add c4 dataset (177M, streaming), update multi-node support for lates…
lessw2020 Mar 9, 2024
10229d6
Add openwebtext dataset for larger scale training without shuffling (…
lessw2020 Mar 12, 2024
7fee3cf
[TorchTrain][Checkpoint] Fix TrainState state_dict to unblock loading…
wz337 Mar 12, 2024
7cd2725
improve logging
tianyu-l Mar 13, 2024
3161ffb
use SequenceParallel style in tp/sp (#133)
wanchaol Mar 13, 2024
e39ee7e
support TP-only parallelism
tianyu-l Mar 13, 2024
5d18bf0
disable verbose print from profiling
tianyu-l Mar 13, 2024
0d415d7
add Selective layer activation checkpointing, single control for tur…
lessw2020 Mar 14, 2024
cc2061a
remove per iter syncronize
tianyu-l Mar 14, 2024
3b3362b
Shorten nccl comm timeout and enable flight recorder dumping (#103)
wconstab Mar 15, 2024
9f5a56d
fix up gpu memory monitoring and logging
tianyu-l Mar 15, 2024
9eb6a21
Separate timeout during init and training (#149)
wconstab Mar 15, 2024
6485be9
Update activation check with updates to config manager (#152)
drisspg Mar 20, 2024
fd4c75b
Refactor to clean up parallelisms/__init__.py
wconstab Mar 20, 2024
93c2b7d
enable gc control scheduling to help avoid stragglers (#148)
lessw2020 Mar 20, 2024
9e7920f
Add float8 specific parallel strategies (#153)
drisspg Mar 20, 2024
e5d1b89
add MFU to metrics
tianyu-l Mar 20, 2024
ceebd53
disable buffer reuse for compile for now (#156)
wanchaol Mar 21, 2024
32aa083
refactor config manager and support cmd overrides (#157)
wanchaol Mar 22, 2024
a21645e
Add support for generating debug traces on failure
chauhang Mar 24, 2024
e28832e
rename sequence_parallel to tensor_parallel (#162)
wanchaol Mar 25, 2024
6722657
add basic AC configs for 13B and 70B (#169)
wanchaol Mar 27, 2024
c49cc9e
[TorchTrain][Checkpoint] Update train state to include global_avg_los…
wz337 Mar 27, 2024
2b017fd
Basic integration test infra (#170)
gnadathur Mar 27, 2024
ab5d918
Add 2D integration test (FSDP + TP) (#171)
gnadathur Mar 27, 2024
83c879f
Used per-parameter FSDP (#165)
awgu Mar 28, 2024
f6d9de7
plot losses in loaded TrainState to TensorBoard
tianyu-l Mar 28, 2024
1150944
Removed setting global flag for `swap_tensors` since not needed anymore
awgu Mar 29, 2024
25ee32f
Add integration test with compile enabled (#183)
gnadathur Apr 2, 2024
25f9bff
remove folding and unfolding of sequence dim in model.py
tianyu-l Apr 3, 2024
c233ecd
bump comm.train_timeout_seconds (#189)
wanchaol Apr 4, 2024
bb3919d
fix checkpoint parser
wz337 Apr 5, 2024
4d593d4
support sequence of tests and add checkpoint test
wz337 Apr 5, 2024
5a0995a
Make freqs_cis a persistent buffer for pp init
wconstab Apr 5, 2024
db204f9
Delete grad scaler, which is unsupported/unused
wconstab Apr 5, 2024
859963d
Factor out loss_fn to share code with pipeline par
wconstab Apr 5, 2024
5d2c148
[TorchTrain] Minor fix for #197 (#204)
wz337 Apr 5, 2024
3471165
Add FusedRMSNorm (Triton kernel, +15% eager), Add NPLayerNorm, Enable…
lessw2020 Apr 5, 2024
5b2bb52
remove .item() per iter
tianyu-l Apr 5, 2024
7146841
Removed cache_k and cache_v comments
awgu Apr 10, 2024
c18d760
Some more cleanups
awgu Apr 10, 2024
e62573d
avoid record streams and make color printing a config
tianyu-l Apr 10, 2024
7419d71
fix SAC to use the correct reduce_scatter op (#215)
wanchaol Apr 10, 2024
cfdd4af
Test runner raises exception on failures (#216)
gnadathur Apr 10, 2024
144b229
Revert "Separate TransformerEmbedding layer (#33)"
wconstab Apr 10, 2024
05c181d
Fix 2DParallel test (#219)
gnadathur Apr 10, 2024
b6414aa
Added initial FSDP readme
awgu Apr 10, 2024
07a3ec8
[TorchTrain][Checkpoint] Add model_weights_only option to train_confi…
wz337 Apr 11, 2024
c22d1a8
Rename to torchtitan (#221)
wanchaol Apr 11, 2024
55a0187
[TorchTitan] Add destory process group at the end of training (#223)
wz337 Apr 12, 2024
2373509
Add 1 sec delay to rank 0 cleanup (#224)
gnadathur Apr 12, 2024
fd5ad5a
[Torchtrain][Checkpoint] Add support to allow dtype conversion (#222)
wz337 Apr 12, 2024
009b14f
[TorchTitan] Remove checkpoint folder at the end in test_runner.py (#…
wz337 Apr 12, 2024
c7d5865
codebase cleanup
tianyu-l Apr 15, 2024
f86bfb2
Update README to reflect positioning (#229)
wanchaol Apr 16, 2024
a10262a
First release readme (#227)
lessw2020 Apr 16, 2024
a0a7ff7
Update licenses and headers (#231)
wanchaol Apr 16, 2024
d8b7c7f
use permalink for logo image (#232)
lessw2020 Apr 16, 2024
6596219
[TorchTitan][Checkpoint] Move checkpoint folder under dump_folder and…
wz337 Apr 16, 2024
1601d35
use combo of html and local file src for logo (#234)
lessw2020 Apr 16, 2024
63d752b
add performance -- infra metrics and loss curves (#237) (#238)
lessw2020 Apr 16, 2024
10b572d
add license section in readme (#239)
wanchaol Apr 16, 2024
7781fd7
[TorchTitan][Checkpoint] Add a step-by-step instruction for checkpoin…
wz337 Apr 16, 2024
441b33f
more license headers (#240)
wanchaol Apr 16, 2024
53dc5eb
Update README (#242)
wanchaol Apr 16, 2024
16701c3
Add torchtune checkpoint link, modify product position statement loca…
lessw2020 Apr 16, 2024
b889f3d
Add pyproject and upgrade version (#236)
wanchaol Apr 16, 2024
b60c6bd
minor doc updates - remove asynch checkpt ref, grammar on prod positi…
lessw2020 Apr 16, 2024
09d0047
Fix multi-line string usage (#244)
gnadathur Apr 16, 2024
c9454d3
polish toml files
tianyu-l Apr 16, 2024
9537825
[torchtitan][checkpoint][doc] Minor fix checkpoint doc (#246)
wz337 Apr 16, 2024
7af51cf
fix default max_seq_len for freq_cis init (#248)
wanchaol Apr 17, 2024
0c655b8
set max_seq_len before training to make it align with input data (#249)
wanchaol Apr 17, 2024
9949284
fix pypi docs
tianyu-l Apr 17, 2024
bfe9998
update dataset to use c4
tianyu-l Apr 18, 2024
f80223b
Add c4_mini, a local 45K dataset (subset of c4) (#253)
lessw2020 Apr 18, 2024
6926922
remove logo, update pre-release date to 4/18 (#254)
lessw2020 Apr 18, 2024
d6f72e2
add intro video (#233)
lessw2020 Apr 18, 2024
395a526
add performance file to show convergence with 64 a100s (#255)
lessw2020 Apr 18, 2024
df2dcc7
Support Llama3 8b/70b (#256)
wanchaol Apr 20, 2024
2db26cf
polish llama 3 setup
tianyu-l Apr 22, 2024
4b60829
reenable integration tests with a test tokenizer (#259)
wanchaol Apr 23, 2024
b2ee158
warn supported dataset checks instead of throw (#260)
wanchaol Apr 24, 2024
3b51460
De-dup repeated `freqs_cis` computation code
awgu Apr 24, 2024
1ea476e
update readme.md and performance.md
tianyu-l Apr 24, 2024
f8863bd
followup changes to allow unsupported datasets
tianyu-l Apr 24, 2024
157a12c
fix ac 'checkpointing' spelling, minor spacing tweaks (#265)
lessw2020 Apr 24, 2024
0891fa3
Update legal terms (#269)
lessw2020 Apr 25, 2024
aea510d
apply less heavy profiling
tianyu-l Apr 25, 2024
e6d0d08
Showcase where the product positioning lies more clearly (#272)
soumith Apr 25, 2024
15057dd
Doc Fixes (#273)
msaroufim Apr 25, 2024
fd01061
fix lr scheduling by checkpointing scheduler
tianyu-l Apr 26, 2024
4333aca
insert barrier to profiler to resolve collectives timeout
tianyu-l Apr 25, 2024
a3b529a
some misc changes (#278)
wanchaol Apr 26, 2024
b898545
inherit stateful protocol where appropriate
tianyu-l Apr 26, 2024
935b572
Fixed docs on HSDP sharding/replication dims
awgu Apr 29, 2024
f61e0ba
Add more Float8 description (#284)
drisspg Apr 29, 2024
8697234
Remove unneeded torchvision/audio deps
wconstab Apr 29, 2024
a6d2625
fix 3d mesh order (#288)
wanchaol Apr 30, 2024
258f608
unify data loading from HF and from disk
tianyu-l Apr 30, 2024
10ef7a6
Add periodic integration test with signal (#289)
gnadathur May 1, 2024
0c6ca90
exclude embedding in MFU computation
tianyu-l Apr 26, 2024
e34d2ac
Add support for seed checkpoint creation for meta-init flow
wconstab May 2, 2024
1480766
remove unnecessary install of torchtitan
tianyu-l May 2, 2024
add0261
Remove unnecessary .to() inside model forward
wconstab May 2, 2024
3e2fa85
Fix the incorrect step log for profiler after resuming from a checkpo…
fegin May 3, 2024
5e84866
turn off dynamic shape for torch.compile (#297)
wanchaol May 3, 2024
8996249
Renamed `bsz` to `bs` for consistency; removed dead code
awgu May 3, 2024
5d63fff
Implement async_checkpoint
fegin May 7, 2024
26ff44f
simplify embedding + first transformer block TP (#314)
wanchaol May 8, 2024
ad46097
Only include checkpoints that have .metadata written (#315)
liangluofb May 10, 2024
99729e9
Refactor freqs_cis slice to be safer for PP
wconstab May 11, 2024
14d422f
Make Transformer tolerate missing layers for PP
wconstab May 11, 2024
ac94484
Use torch generic workflow for CI
wconstab May 15, 2024
41d69d2
[checkpointing] import async checkpoint with pinned memory only when …
tianyu-l May 15, 2024
6ed5237
Add a workflow to build torchtitan-ubuntu-20.04-clang12 Docker image …
huydhn May 16, 2024
2dca85e
Make pip install torch quiet
wconstab May 17, 2024
3baba7b
Make test_runner.py warn on non-empty output dir
wconstab May 17, 2024
5c69c02
Expose mixed_precision dtype arguments
wconstab May 21, 2024
8cc0b38
Use stateful dataloader to checkpoint data iteration order and token …
gokulavasan May 21, 2024
aafe0e8
Add Pipeline Parallel (and 2D PP+FSDP) support
wconstab May 21, 2024
60f58b9
fix i periodic integration test and add helper message on torchdata i…
tianyu-l May 22, 2024
9954e19
torch.compile each TransformerBlock instead of the whole model (#268)
wanchaol May 22, 2024
f47f442
Make test_runner use separate logger with default INFO
wconstab May 22, 2024
93a8053
Fix llama_13b.toml -> llama2_13b.toml in multinode_trainer.slurm (#350)
pbelevich May 22, 2024
0afb276
Fix bug in PP output layer shape
wconstab May 22, 2024
c73a59d
Update pipelining import after change on pytorch
wconstab May 23, 2024
c161119
update .gitignore to screen out slew of new temp files (#359)
lessw2020 May 24, 2024
e593e7d
Add test for PP tracer frontend
wconstab May 24, 2024
0779207
only produce tensorboard logs on rank 0 by default
tianyu-l May 29, 2024
f6ea139
replace old torch dependency in requirements.txt
tianyu-l May 29, 2024
0fff2d2
Add --test option to specify test to run (#368)
kwen2501 May 30, 2024
1877738
use integration test as the badge shown on the homepage
tianyu-l May 29, 2024
c48ae39
keep only latest k checkpoints (#366)
liangluofb May 31, 2024
3227d50
Make seed checkpoint creation work on CPU
wconstab Jun 3, 2024
fbc4aa0
Fix start/stop layer parsing
wconstab Jun 3, 2024
ff3c6e2
Use general way to access and update submodules
kwen2501 Jun 3, 2024
a1f9edb
Make metrics logging work for pipeline parallelism
wconstab Jun 4, 2024
9d25778
[RFC] Allow ModelWrapper and OptimizerWrapper to accept multiple models
fegin Jun 5, 2024
4eb4bfc
Add 3D support
wconstab Jun 4, 2024
40f8fd0
[torchtitan][optim] Add fused as an option in train config (#355)
wz337 Jun 6, 2024
3bbe3d9
[torchtitan] Fix test runner fused optim tests (#384)
wz337 Jun 6, 2024
d953107
Abstract out out optimizer params and update foreach calling conventi…
drisspg Jun 7, 2024
cf37b61
DeviceMesh BC fix (#387)
wanchaol Jun 9, 2024
9acdc6f
BC fix for ManualPipelineStage import (#388)
wanchaol Jun 9, 2024
3e5c0aa
fix missing tb logs
tianyu-l Jun 10, 2024
032b9d1
add the 8-gpu test badge and use correct links for the integration te…
tianyu-l Jun 10, 2024
91937ef
Fix 1D PP tracer test
kwen2501 Jun 10, 2024
e29b6b4
del logits=(bs, seq_len, vocab_size) to save 3.9G memory (#391)
weifengpy Jun 12, 2024
d0b4092
Update contributing.md (#385)
H-Huang Jun 12, 2024
000d43f
update all toml files to use experimental section (#392)
wanchaol Jun 12, 2024
7fcf70d
enable TP fp8 allgather with PrepareFloat8ModuleInput (#393)
wanchaol Jun 13, 2024
a6b585f
Update unit_test_cpu.yaml with cpu nightly (#396)
wanchaol Jun 13, 2024
0bf344c
Fix SAC BC breaking and renaming to ac_freq (#397)
wanchaol Jun 13, 2024
230300b
SAC API follow ups to restore old behavior (#401)
wanchaol Jun 13, 2024
38496a3
enable TritonFusedRMSNorm with local_map annotation (#404)
XilunWu Jun 14, 2024
e99f237
Cosmetic changes to train.py
kwen2501 Jun 14, 2024
a96fb82
Break down parallelize_llama for inference cases
kwen2501 Jun 14, 2024
ae3d2a9
Change debugmodel to have 8 layers
wconstab Jun 17, 2024
f8e17f1
Prepare train.py for model chunks for pipelining
wconstab Jun 17, 2024
71b70b5
dump memory snapshot to analyze OOMs (#395)
weifengpy Jun 19, 2024
6117759
whole_model for fp8 (#414)
weifengpy Jun 20, 2024
04661a6
Add train loop support for looped PP schedules
wconstab Jun 21, 2024
b1340a1
Set `record_shapes=True` for profiler
awgu Jun 24, 2024
be126a6
Improved `repeat_kv` eager perf
awgu Jun 24, 2024
342a07e
Adding FSDP Memory Tracking and Estimation
sanketpurandare Jun 25, 2024
134addd
Adding integration test for FSDP Memory Tracking and Estimation
sanketpurandare Jun 25, 2024
f5171cb
by default disable heavy memory profiling
tianyu-l Jun 26, 2024
1ec2ece
Add the option to turn on async-TP
yifuwang Jun 26, 2024
64d47fd
Modifying memory estimation options and minor changes
sanketpurandare Jul 1, 2024
6655204
add comment pointing to Sequence Parallel optimization example
tianyu-l Jul 4, 2024
8a1aa06
switch float8 logic from Float8DynamicLinear to Float8Linear (#436)
vkuzo Jul 8, 2024
28762c8
Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`
awgu Jul 10, 2024
064730a
Reordered TP parallel plan to follow execution order
awgu Jul 10, 2024
3e3a913
Made some stylistic changes to `apply_dp`
awgu Jul 10, 2024
347ddc0
Refactored activation checkpointing
awgu Jul 10, 2024
3ff7fbb
compiled RMSNorm
tianyu-l Jul 10, 2024
562d7e2
Renamed parallel styles for transformer block weights
awgu Jul 10, 2024
0ddf49b
Added type annotations and more stylistic changes
awgu Jul 10, 2024
535acf6
[Cleanup] Remove libuv from run_llama_train.sh
wconstab Jul 15, 2024
ac72078
[Cleanup] Organize run_llama_train.sh options
wconstab Jul 15, 2024
4b6cdc1
[Cleanup] Split run_llama_train.sh and run_memory_estimation.sh
wconstab Jul 15, 2024
8fa11f0
[Cleanup] Remove unused TRAINER_DIR
wconstab Jul 15, 2024
174c44a
Add educational code pointers to top level README
wconstab Jul 15, 2024
a4b2ee3
enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (#413)
weifengpy Jul 16, 2024
ae8181b
import float8_experimental only when fp8 is enabled and install it in…
weifengpy Jul 17, 2024
3760bcf
skip fp8 CI on non-H100 GPUs (#465)
weifengpy Jul 17, 2024
69fe8de
clean up float8 configs in torchtitan (#466)
vkuzo Jul 17, 2024
2f989b9
Add support of DDP and experimental CompiledAutograd
fegin Jul 18, 2024
71b8eae
add torch.compile + FSDP2 float8 all-gather in CI (#468)
weifengpy Jul 19, 2024
0c6f9a2
[float8] keep model.output as `nn.Linear` (high precision, not fp8) (…
weifengpy Jul 19, 2024
0a17c26
remove CI for FSDP2 + fp8 all-gather (#470)
weifengpy Jul 20, 2024
0ee573c
dynamically update torch.compile cache config to ensure async tp supp…
lessw2020 Jul 21, 2024
69c9bb2
Fix 8gpu PP failure due to 2D DCP disablement
wconstab Jul 15, 2024
90e2070
update float8 integration after UX changes (#484)
vkuzo Jul 26, 2024
42f4ff5
Re-enable FSDP2 Mem Tracker integration tests
Jul 26, 2024
a48de09
Used `partial` instead of global vars for LR scheduling
awgu Jul 29, 2024
b63e209
[EZ] Add logs for some basic training params so that we can verify in…
fduwjj Jul 30, 2024
91f075a
make float8 scaling type configurable (#489)
vkuzo Jul 30, 2024
9358d70
[PP] add flexible interleaved 1f1b schedule #490 (#493)
H-Huang Jul 30, 2024
239d56f
move float8 callsites to torchao.float8 (#492)
vkuzo Jul 30, 2024
3c77e9f
[BE][1/n] simplify train.py
tianyu-l Jul 31, 2024
bf90710
[BE][2/n] use proper method signatures in parallelize_llama
tianyu-l Jul 31, 2024
40f79d7
[BE][3/n] wrap fp8 logic using Float8Handler
tianyu-l Jul 31, 2024
4871358
Bring LLaMa 3.1 405B to TorchTitan family (#481)
fduwjj Aug 1, 2024
d41d604
[TP] Infer local n_heads instead of ad-hoc model changes
kwen2501 Aug 2, 2024
24aef32
some compile-related updates
tianyu-l Aug 2, 2024
c44cca0
[EZ][405B] Use scientific notation for 405B model lr (#504)
fduwjj Aug 5, 2024
8849580
[BE][4/n] split pipeline_llama into a separate file
tianyu-l Aug 4, 2024
a4d88d1
[fix] float8 should be applied on all model_parts
tianyu-l Aug 5, 2024
1a303b3
Add warning to compile rmsnorm (#505)
wanchaol Aug 6, 2024
b99bc5e
add float8 to README (#509)
weifengpy Aug 7, 2024
fa8cdd4
address TODOs as 2D recompiles is fixed
tianyu-l Aug 7, 2024
d6e3f77
[BE][5/n] simply pp vs. non-pp set up
tianyu-l Aug 8, 2024
34fa017
[BE][6/n] replace large c4_mini datasets by c4_test with the first 2K…
tianyu-l Aug 8, 2024
9de54a5
Create composability.md (#511)
wconstab Aug 9, 2024
b41b41b
depend on torchdata 0.8.0 instead of nightly
tianyu-l Aug 9, 2024
a4bc948
[PP] Bypass seed checkpoint my init-ing model parts separately (#516)
H-Huang Aug 12, 2024
a47a5a9
[small] format composability.md (#517)
H-Huang Aug 12, 2024
36a0057
Throw warning if users are using old pytorch version that not includi…
XilunWu Aug 13, 2024
1c96a01
Update fsdp.md (#519)
crcrpar Aug 14, 2024
6c16807
remove old torch dependency in requirements.txt
tianyu-l Aug 15, 2024
f339363
Fail when using tracer made without seed checkpoint (#522)
H-Huang Aug 16, 2024
81c555f
uniformly use skip for both (map-style) Dataset and IterableDataset
tianyu-l Aug 15, 2024
57c3400
Merge branch 'main' into sync_with_torchtitan
philippguevorguian Aug 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
torch >= 2.3.0
torchdata >= 0.8.0
datasets >= 2.19.0
datasets >= 2.21.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
Expand Down
27 changes: 8 additions & 19 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,13 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset

try:
from torchdata.stateful_dataloader import StatefulDataLoader
except ImportError as e:
raise ImportError(
"Please install the latest torchdata nightly to use StatefulDataloader via:"
"pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly"
) from e
from torchdata.stateful_dataloader import StatefulDataLoader


from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import load_dataset
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
Expand Down Expand Up @@ -102,7 +97,7 @@ def __init__(
else:
ds = load_dataset(dataset_path, split="train")

# TODO: support shuffling and checkpointing
# TODO: support shuffling
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
Expand Down Expand Up @@ -143,17 +138,11 @@ def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# Skip samples
if isinstance(self._data, IterableDataset):
it = iter(self._data)
# Naively iterate through the samples as skip may not be supported
for _ in range(self._sample_idx):
next(it)
return it

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if self._sample_idx == len(self._data):
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])


return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
Expand All @@ -179,7 +168,7 @@ def state_dict(self) -> Dict[str, Any]:
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid, don't log a warning
# State being empty is valid
if not state_dict:
return

Expand Down
5 changes: 5 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def loss_fn(pred, labels):
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
if job_config.experimental.pipeline_parallel_split_mode == "tracer":
raise RuntimeError(
"Pipeline parallelism with tracer mode is not supported without a seed checkpoint."
)

# TODO: fix this by allowing each rank to set their own seed
logger.warning(
"Pipeline Parallelism is being used without a seed checkpoint. "
Expand Down
Loading