Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM when loading 300B models with AutoModelForCausalLM.from_pretrained and BitsAndBytesConfig quantization. #31577

Open
2 of 4 tasks
Neo9061 opened this issue Jun 24, 2024 · 8 comments

Comments

@Neo9061
Copy link

Neo9061 commented Jun 24, 2024

System Info

My goal is to follow Distributed fine-tuning blogpost with FSDP to test with distributed fine-tuning on larger size of model like 300B Grok-1.

Context is that I have tried g5.48xlarge (8 GPUs with 192 GB and 768 GB CPU) and p4d.24xlarge (8 GPUs. with 320 GB and 1152 GB CPU). There are two issues listed as following.

Transformer version is: transformers==4.40.0


Issue 1
When I tried to load the model with 4 bits quantization with code below (WITHOUT FSDP and it is purely on a EC2 of g5.48xlarge), the total GPU memory required should be around 150GB (since model is ~300B Grok-1), which is smaller than 192GB GPU memory of g5.48xlarge, but I hit OOM. If I turn on low_cpu_mem_usage=True, then the model can be successfully loaded on CPU in the EC2 of g5.48xlarge. Same error happens at p4d.24xlarge where 4 bit quantization is failed at loading.

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    set_seed,
)
import torch
import os
   
torch_dtype = torch.bfloat16
quant_storage_dtype = torch.bfloat16

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_quant_storage=quant_storage_dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    "keyfan/grok-1-hf",
    quantization_config=quantization_config,
    torch_dtype=quant_storage_dtype,
    use_cache=(
        False
    ), 
    trust_remote_code=True,
)

Issue 2

Continue on point 1, i think I find a path forward to load the model into CPU by setting low_cpu_mem_usage=True. Follow the blogpost above, I start try SageMaker training job and I try to load this model using the default qlora_fsdp script, shown in the blog. Further, I disabled the quantization (as the quantization will load the model into GPUs but it failed in the point 1). Since when FSDP is enabled, it will by default use
low_cpu_mem_usage=True according to this line. However, I hit timeout issue even after I modified training argument ddp_timeout to be 10800.

The model checkpoints are loaded twice and failed at second time of loading.

  return self.fget.__get__(instance, owner)()
Loading checkpoint shards:   5%|▌         | 1/19 [00:00<00:03,  5.36it/s]
Loading checkpoint shards:  11%|█         | 2/19 [00:00<00:03,  5.28it/s]
Loading checkpoint shards:  16%|█▌        | 3/19 [00:00<00:03,  5.24it/s]
Loading checkpoint shards:  21%|██        | 4/19 [00:00<00:02,  5.23it/s]
Loading checkpoint shards:  26%|██▋       | 5/19 [00:00<00:02,  5.29it/s]
Loading checkpoint shards:  32%|███▏      | 6/19 [00:01<00:02,  5.27it/s]
Loading checkpoint shards:  37%|███▋      | 7/19 [00:01<00:02,  5.25it/s]
Loading checkpoint shards:  42%|████▏     | 8/19 [00:01<00:02,  5.25it/s]
Loading checkpoint shards:  47%|████▋     | 9/19 [00:01<00:01,  5.23it/s]
Loading checkpoint shards:  53%|█████▎    | 10/19 [00:01<00:01,  5.21it/s]
Loading checkpoint shards:  58%|█████▊    | 11/19 [00:02<00:01,  5.20it/s]
Loading checkpoint shards:  63%|██████▎   | 12/19 [00:02<00:01,  5.20it/s]
Loading checkpoint shards:  68%|██████▊   | 13/19 [00:02<00:01,  5.19it/s]
Loading checkpoint shards:  74%|███████▎  | 14/19 [00:02<00:00,  5.21it/s]
Loading checkpoint shards:  79%|███████▉  | 15/19 [00:02<00:00,  5.20it/s]
Loading checkpoint shards:  84%|████████▍ | 16/19 [00:03<00:00,  5.19it/s]
Loading checkpoint shards:  89%|████████▉ | 17/19 [00:03<00:00,  5.19it/s]
Loading checkpoint shards:  95%|█████████▍| 18/19 [00:03<00:00,  5.24it/s]
Loading checkpoint shards: 100%|██████████| 19/19 [00:03<00:00,  5.27it/s]
Loading checkpoint shards: 100%|██████████| 19/19 [00:03<00:00,  5.23it/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating train split: 1 examples [00:00,  8.21 examples/s]
Generating train split: 827 examples [00:00, 1985.28 examples/s]
Generating train split: 1641 examples [00:00, 3495.17 examples/s]
Generating train split: 2496 examples [00:00, 3041.11 examples/s]
Generating train split: 3324 examples [00:01, 3366.71 examples/s]
Generating train split: 4001 examples [00:01, 3996.93 examples/s]
Generating train split: 4797 examples [00:01, 4292.10 examples/s]
Generating train split: 5698 examples [00:01, 4238.81 examples/s]
Generating train split: 6060 examples [00:01, 3625.63 examples/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating train split: 324 examples [00:00, 4303.12 examples/s]
Loading checkpoint shards:   5%|▌         | 1/19 [01:51<33:35, 111.97s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:56<34:51, 116.18s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:55<34:45, 115.86s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:55<34:46, 115.93s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:55<34:46, 115.89s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:55<34:47, 115.98s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:57<35:21, 117.86s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:45<42:02, 148.38s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:49<42:21, 149.50s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:48<42:19, 149.37s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:48<42:20, 149.42s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:48<42:19, 149.39s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:50<42:32, 150.16s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:51<42:45, 150.92s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:27<40:58, 153.63s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:27<41:10, 154.41s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:28<41:01, 153.85s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:27<41:00, 153.78s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:27<41:00, 153.78s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:29<41:13, 154.57s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:31<41:20, 155.04s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:21<40:24, 161.67s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:21<40:24, 161.63s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:22<40:28, 161.92s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:22<40:36, 162.40s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:21<40:28, 161.87s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:23<40:28, 161.92s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:26<40:43, 162.87s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:01<37:33, 160.99s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:01<37:27, 160.56s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:02<37:37, 161.28s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:02<37:37, 161.22s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:02<37:37, 161.26s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:02<37:45, 161.83s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:06<37:46, 161.88s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:56<35:54, 165.76s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:56<35:53, 165.62s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:56<35:53, 165.67s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:57<36:00, 166.18s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:57<35:56, 165.90s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:57<36:00, 166.17s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [16:01<36:01, 166.23s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:36<32:47, 164.00s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:38<32:54, 164.56s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:38<32:55, 164.67s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:38<32:53, 164.44s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:39<32:55, 164.63s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:40<33:00, 165.05s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:45<33:05, 165.47s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:58, 168.96s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:55, 168.64s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:56, 168.80s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:56, 168.78s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:57, 168.91s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:38<31:01, 169.27s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:44<31:09, 169.91s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:19<27:49, 166.91s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:21<27:57, 167.72s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:21<27:57, 167.71s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:21<27:56, 167.63s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:21<27:56, 167.69s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:22<27:56, 167.63s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:27<27:55, 167.55s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:17<25:30, 170.07s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:17<25:30, 170.01s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:17<25:31, 170.16s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:17<25:31, 170.15s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:21<25:38, 170.90s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:21<25:41, 171.30s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:24<25:35, 170.64s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:54<22:10, 166.27s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:54<22:10, 166.29s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:54<22:10, 166.26s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:57<22:13, 166.71s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:57<22:16, 167.09s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [30:00<22:19, 167.44s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [30:03<22:16, 167.03s/it]
[E ProcessGroupNCCL.cpp:474] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.
/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:318: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
[E ProcessGroupNCCL.cpp:488] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:494] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:915] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.

Who can help?

@philschmid @SunMarc @lewtun @sgugger @ArthurZucker @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Same as above

Expected behavior

Should be no OOM

@Titus-von-Koeller
Copy link
Contributor

Hi @Neo9061, thanks for reporting this issue and providing the detailed context. I'm maintainer of bitsandbytes and have a few thoughts:

Regarding issue 1)

From the code you provided, you don't seem to be using any parallel training approach, like FSDP. Is that right? In that case it would be expected that you cannot use the full memory of all combined GPUs and therefore it would be expected that you run OOM.

You mentioned that it "should run". To qualify this with a concrete estimate, I did the following back of the envelope math:

Memory consists of gradient, parameters, and optimizer state (all parameter-dependent) and activation memory. Activation memory peaks at the memory for a transformer block when using gradient checkpointing, as that is the granularity of gradient checkpointing. With LoRA, gradient and optimizer state is approximately 1% of the parameter size, and parameters in 4-bit precision are just half a byte.

For parameter-related memory:
(gradient+optimizer)*0.01 + (parameters) = (p*4 bytes + p*8 bytes)*0.01 + p*0.5 bytes

Activation memory is more complex, but an approximation with a good implementation is:
seqlen*hidden_size*2bytes*batch_size*8

With a suboptimal implementation, the upper bound is approximately:
seqlen*hidden_size*2bytes*batch_size*16

Calculating the parameter-related memory for a 300 billion parameter model:
(300 billion * 4 bytes + 300 billion * 8 bytes) * 0.01 + 300 billion * 0.5 bytes = 186 GB or 173.2 GiB

However, 186 GB doesn't yet account for the activation: I'm not certain what values to plug in for the activation calculation, but with 186 GB used out of the total 192 GB combined capacity of your 8 GPU AWS instance (g5.48xlarge), this leaves relatively little wriggle room. Could you please find out the missing values in your training configuration and we can refine the estimate from there?

Optimizer states and gradients can also potentially have lower precisions, so it would be important to take that into account (i.e. confirm our calculation assumption or use less precision to save memory in one of your tests).


On a side-note, not necessarily related but I want to mention it anyways:

We have observed an issue with QLoRA and long seqlen leading to higher than expected memory usage when using BNB quantization and suspect a memory leak. However, I don't see the seqlen specified in your minimal reproducer code. Could you please confirm what sequence length you are using?

I investigated this briefly after we became aware that there might be a memory leak leading to excessive memory consumption for high seqlen during the FSDP + QLoRA release but couldn't easily reproduce the problem. Due to limited resources and other high-priority items, we temporarily deprioritized further investigation. Our assumption was that this actually only affects really few users, but would like to know in case it's perceived as a blocker for prioritization.

We have a new engineer, @matthew Douglas, joining the BNB team in July. Once he's on board, we plan to reassess the importance and urgency of this issue. It would be helpful if you would could look into this a bit and if you think it's a blocker, ping us again.


Regarding issue 2)

I think this is more for @philschmid and the others to answer, as nothing immediately catches my eye.

@Neo9061
Copy link
Author

Neo9061 commented Jun 26, 2024

Hi @Titus-von-Koeller, thanks for answering my questions!

To follow-up, my first issue is OOM during model loading stage, not the model fine-tuning stage. I followed the blogpost https://www.philschmid.de/sagemaker-train-deploy-llama3 which initializes FSDP and here is the training script.

Q1. By this line

new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
i feel that if the model is quantized, then it uses GPU to load the model rather than loading the model into CPU using low_cpu_mem_usage=True (is that True? and it seems uses single GPU - rank 0 - to load the entire quantized models).

On the other end, by using @philschmid 's blogpost and code, I am able to load and train two models:

  1. llama-3 70B on 4 instances of g5.24xlarge, and
  2. mixtral 8x22 on 4 instances of g5.8xlarge (failed at model merging stage).

Neither g5.12xlarge (4 GPUs with 96 GB in total, and 192 GB CPU) nor g5.16xlarge (1 GPU with 24 GB and 256 GB CPU) has enough GPU memory on single GPU to load the model, thus I suspect you are doing offloading to CPU memory rather than using single GPU - rank 0 - to load the quantized models. But then it does not explain why Grok-1 model is failed at loading stage with 4 bit quantization.

Q2. For the memory you computed, is it for model loading or model fine-tuning? my understanding is that the memory for activations, optimizers, and gradients are all not required at model loading stage by .from_pretrained method. Is my understanding correct?

Q3. For fomula of activations: seqlen*hidden_size*2bytes*batch_size*16 if I have seqlen of 1024, hidden_size of 8000, batch size of 1, then the total memory is 262144000 / (10^9) = 0.26 GB which is negligible if we don't use high batch size. Is my understanding correct?

@Titus-von-Koeller
Copy link
Contributor

cc @matthewdouglas who has taken over the lead on this task

@thepowerfuldeez
Copy link

Up! Are there any suggestions on this issue??

@matthewdouglas
Copy link
Member

matthewdouglas commented Jul 28, 2024

@Neo9061 @thepowerfuldeez See the PR on #32276. The observation here is that weights would be offloaded to CPU memory for all ranks instead of just one (e.g. 8x CPU memory requirement on the g5.48xlarge and p4d.24xlarge mentioned in the original issue). This usage goes back down after the model is loaded, so a temporary workaround could be to create additional swap space on local NVMe storage.

In addition to this, I'm testing out some further changes to enable the usage of prequantized checkpoints with FSDP+QLoRA.

@ArthurZucker
Copy link
Collaborator

Since the PR was merged but then reverted, @matthewdouglas is there another PR we can follow for this feature ?

@matthewdouglas
Copy link
Member

New PR: #33154

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants