You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, while executing: torchrun --nproc_per_node gpu -m sae meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layers 24 25 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2,
I found that the training process gets stuck at step=8.
I have not had issues with it getting stuck at an early step like step 8. It has sometimes gotten stuck at the very end of training.
I met the same problem here too. Tried pile-10k on gpt2, Gemma-2b, and Llama-3-7b, and all the training got stuck at the exact last step. Specifically, the program got stuck on Line 281 of https://github.com/EleutherAI/sae/blob/main/sae/trainer.py. At that point, the loss is a tensor(0.3034, device='cuda:0', grad_fn=).
Based on this issue, I want to ask:
If there's any fix to this issue.
If it's ok to directly skip the last step and harvest the SAE.
Hi, while executing:
torchrun --nproc_per_node gpu -m sae meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layers 24 25 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2
,I found that the training process gets stuck at step=8.
I debugged and traced the problem to:
sae/sae/trainer.py
Line 431 in f60c38d
The text was updated successfully, but these errors were encountered: