You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder
My own task or dataset (give details below)
Reproduction
Hi TRL team,
I am hitting OOM errors when fine-tuning a Llama-3.1-70B model on my modified RL trainer.
It looks like the error happens on unwrapping the model for generation (I have an on policy algorithm and each training step I will generate some sequences)
My machine has 8 H100 80GB GPUs and I used lora. But it looks like unwrap_model_for_generation will load the entire model into memory and causing OOM. Any suggestions?
[rank7]: Traceback (most recent call last):
[rank7]: File "/export/scripts/training.py", line 243, in <module>
[rank7]: trainer.train(resume_from_checkpoint=config.checkpoint_path)
[rank7]: File "/export/trainer/trainer_simple_rloo.py", line 195, in train
[rank7]: with torch.no_grad(), unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
[rank7]: File "/opt/conda/lib/python3.12/contextlib.py", line 137, in __enter__
[rank7]: return next(self.gen)
[rank7]: ^^^^^^^^^^^^^^
[rank7]: File "/export/venv/lib/python3.12/site-packages/trl/models/utils.py", line 162, in unwrap_model_for_generation
[rank7]: with deepspeed.zero.GatheredParameters(model.parameters()):
[rank7]: File "/export/venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 2224, in __enter__
[rank7]: self.params[0].all_gather(param_list=self.params)
[rank7]: File "/export/venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1143, in all_gather
[rank7]: return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/export/venv/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank7]: ret_val = func(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/export/venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1511, in _all_gather
[rank7]: self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False)
[rank7]: File "/export/venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1799, in _allgather_params_coalesced
[rank7]: flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype,
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 448.00 MiB. GPU 7 has a total capacity of 79.11 GiB of which 256.56 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 74.68 GiB is allocated by PyTorch, and 7.22 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Expected behavior
OOM issue resolved.
The text was updated successfully, but these errors were encountered:
System Info
torch==2.4.0
transformers==4.43.4
trl==0.9.6
tokenizers==0.19.1
accelerate==0.32.0
peft==0.12.0
datasets==2.20.0
deepspeed==0.15.0
bitsandbytes==0.43.3
sentencepiece==0.2.0
flash-attn==2.6.3
gcc version 11.4.0 (Ubuntu 11.4.0-1ubuntu1~22.04)
Information
Tasks
examples
folderReproduction
Hi TRL team,
I am hitting OOM errors when fine-tuning a Llama-3.1-70B model on my modified RL trainer.
It looks like the error happens on unwrapping the model for generation (I have an on policy algorithm and each training step I will generate some sequences)
My machine has 8 H100 80GB GPUs and I used lora. But it looks like
unwrap_model_for_generation
will load the entire model into memory and causing OOM. Any suggestions?Expected behavior
OOM issue resolved.
The text was updated successfully, but these errors were encountered: