Skip to content

Commit

Permalink
Updated Trainer's liger-kernel integration to call correct patching A…
Browse files Browse the repository at this point in the history
…PI (#33502)

* Updated liger-kernel integration in Trainer to call correct patching API

* Fixed styling
  • Loading branch information
shimizust committed Sep 17, 2024
1 parent 4ba531c commit ba1f1dc
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
13 changes: 6 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,19 +468,18 @@ def __init__(

if self.args.use_liger_kernel:
if is_liger_kernel_available():
from liger_kernel.transformers.trainer_integration import _apply_liger_kernel
from liger_kernel.transformers import _apply_liger_kernel_to_instance

model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
if model_type:
# Monkey patch the model with liger kernels. Use the default kernel configurations.
_apply_liger_kernel(model_type=model_type)
if isinstance(model, PreTrainedModel):
# Patch the model with liger kernels. Use the default kernel configurations.
_apply_liger_kernel_to_instance(model=model)
else:
logger.warning(
"The model does not have a valid `model_type` specified. No liger kernels will be applied."
"The model is not an instance of PreTrainedModel. No liger kernels will be applied."
)
else:
raise ImportError(
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. "
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
"Please install it with `pip install liger-kernel`"
)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ def is_liger_kernel_available():
if not _liger_kernel_available:
return False

return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.1.0")
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0")


# docstyle-ignore
Expand Down
30 changes: 18 additions & 12 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,22 +1344,28 @@ def test_get_eval_dataloader_with_persistent_workers(self):

@require_liger_kernel
def test_use_liger_kernel_patching(self):
# Test that the model code actually gets patched with Liger kernel
from liger_kernel.transformers.rms_norm import LigerRMSNorm
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.llama.modeling_llama"):
from liger_kernel.transformers import LigerRMSNorm, liger_rotary_pos_emb

from transformers.models.llama import modeling_llama
from transformers.models.llama import modeling_llama

config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)

args = TrainingArguments(
"./test",
use_liger_kernel=True,
)
Trainer(tiny_llama, args)
# Spot check that modeling code and model instance variables are not yet patched
self.assertNotEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm))

args = TrainingArguments(
"./test",
use_liger_kernel=True,
)
Trainer(tiny_llama, args)

# Check that one of the Llama model layers has been correctly patched with Liger kernel
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
# Spot check that modeling code and model instance variables are patched
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))

@require_liger_kernel
@require_torch_gpu
Expand Down

0 comments on commit ba1f1dc

Please sign in to comment.