Skip to content

Commit

Permalink
fix: load_best_model_at_end error when load_in_8bit is True
Browse files Browse the repository at this point in the history
    Ref: huggingface/peft#394
    Loading a quantized checkpoint into non-quantized Linear8bitLt is not supported.
    call module.cuda() before module.load_state_dict()
  • Loading branch information
dkqkxx committed May 18, 2023
1 parent a8732e0 commit d1f13f6
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,8 @@ def _load_best_model(self):
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
if model._is_int8_training_enabled or model.is_8bit_serializable:
model.cuda()
load_result = model.load_state_dict(state_dict, False)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)
Expand Down

0 comments on commit d1f13f6

Please sign in to comment.