Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash running FSDP on BF16-prequantized models #1310

Closed
dmitrii-palisaderesearch opened this issue Aug 7, 2024 · 4 comments
Closed

Crash running FSDP on BF16-prequantized models #1310

dmitrii-palisaderesearch opened this issue Aug 7, 2024 · 4 comments
Assignees
Labels
huggingface-related A bug that is likely due to the interaction between bnb and HF libs (transformers, accelerate, peft) likely not a BNB issue

Comments

@dmitrii-palisaderesearch

System Info

$ python --version
Python 3.10.12

# pip install accelerate transformers bitsandbytes datasets trl peft setuptools
# using latest PyPI versions 
$ pip list
Package                  Version
------------------------ -----------
accelerate               0.33.0
aiohappyeyeballs         2.3.4
aiohttp                  3.10.1
aiosignal                1.3.1
async-timeout            4.0.3
attrs                    24.2.0
bitsandbytes             0.43.3
certifi                  2024.7.4
charset-normalizer       3.3.2
datasets                 2.20.0
dill                     0.3.8
docstring-parser         0.16
filelock                 3.15.4
frozenlist               1.4.1
fsspec                   2024.5.0
huggingface-hub          0.24.5
idna                     3.7
jinja2                   3.1.4
markdown-it-py           3.0.0
markupsafe               2.1.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.0.5
multiprocess             0.70.16
networkx                 3.3
numpy                    1.26.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.6.20
nvidia-nvtx-cu12         12.1.105
packaging                24.1
pandas                   2.2.2
peft                     0.12.0
psutil                   6.0.0
pyarrow                  17.0.0
pyarrow-hotfix           0.6
pygments                 2.18.0
python-dateutil          2.9.0.post0
pytz                     2024.1
pyyaml                   6.0.2
regex                    2024.7.24
requests                 2.32.3
rich                     13.7.1
safetensors              0.4.4
setuptools               72.1.0
shtab                    1.7.1
six                      1.16.0
sympy                    1.13.1
tokenizers               0.19.1
torch                    2.4.0
tqdm                     4.66.5
transformers             4.44.0
triton                   3.0.0
trl                      0.9.6
typing-extensions        4.12.2
tyro                     0.8.5
tzdata                   2024.1
urllib3                  2.2.2
xxhash                   3.4.1
yarl                     1.9.4

Reproduction

A DP run goes through fine:

accelerate launch --config-file dp.yml main.py
# OK!

A FSDP run crashes:

$ accelerate launch --config-file fsdp.yml main.py
FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/main.py", line 70, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 451, in train
[rank0]:     output = super().train(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 1948, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2289, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3328, in training_step
[rank0]:     loss = self.compute_loss(model, inputs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3373, in compute_loss
[rank0]:     outputs = model(**inputs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward
[rank0]:     return self.base_model(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
[rank0]:     outputs = self.model(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
[rank0]:     layer_outputs = decoder_layer(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
[rank0]:     hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 617, in forward
[rank0]:     query_states = self.q_proj(hidden_states)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/tuners/lora/bnb.py", line 467, in forward
[rank0]:     result = self.base_layer(x, *args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 477, in forward
[rank0]:     out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
[rank0]: AttributeError: 'Tensor' object has no attribute 'quant_state'

You can find the repro files in this gist.

Expected behavior

Post-#1295, running FSDP with models prequantized with BNB to NF4 stored in BF16 should work.

@Titus-von-Koeller
Copy link
Collaborator

Thanks @dmitrii-palisaderesearch for raising this and giving detailed error logs and repro instructions. We (the bitsandbytes team) are under very tight bandwidth at the moment, so I can't guarantee a prompt response. Please keep us updated if anything changes.

Mentioning this to @matthewdouglas, as he the one recently dealing with FSDP and prequantized weights.

@matthewdouglas
Copy link
Member

Hi @dmitrii-palisaderesearch, thank you for reporting!

This issue exists on the transformers side. We were not able to keep the required changes needed to support this ahead of the v4.40 release, but we should have it merged in soon. The PR to track for this is huggingface/transformers#32276.

@matthewdouglas matthewdouglas self-assigned this Aug 8, 2024
@matthewdouglas matthewdouglas added likely not a BNB issue huggingface-related A bug that is likely due to the interaction between bnb and HF libs (transformers, accelerate, peft) labels Aug 8, 2024
@dmitrii-palisaderesearch
Copy link
Author

Yes, I was tracking that, but then that was reverted at huggingface/transformers#32477 and I was confused. Thanks, I'll keep an eye on transformers.

@matthewdouglas
Copy link
Member

Since the PR has been merged on the transformers side, I'm going to go ahead and close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
huggingface-related A bug that is likely due to the interaction between bnb and HF libs (transformers, accelerate, peft) likely not a BNB issue
Projects
None yet
Development

No branches or pull requests

3 participants