Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MODEL_TO_LIGER_KERNEL_PATCHING_FUNC to minimize dependencies from external code #40

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/huggingface/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

@dataclass
class CustomArguments:
model_name: str = (
"meta-llama/Meta-Llama-3-8B"
)
model_name: str = "meta-llama/Meta-Llama-3-8B"
dataset: str = "tatsu-lab/alpaca"
max_seq_length: int = 512
use_liger: bool = False
Expand Down
3 changes: 2 additions & 1 deletion examples/lightning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import lightning.pytorch as pl
import torch
import transformers
from liger_kernel.transformers import apply_liger_kernel_to_llama
from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision
from torch.utils.data import DataLoader
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from trl import DataCollatorForCompletionOnlyLM

from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama(fused_linear_cross_entropy=True, cross_entropy=False)


Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from liger_kernel.transformers.monkey_patch import ( # noqa: F401
MODEL_TO_LIGER_KERNEL_PATCHING_FUNC,
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
Expand Down
12 changes: 12 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import OrderedDict

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.model.llama import lce_forward
Expand Down Expand Up @@ -128,3 +130,13 @@ def apply_liger_kernel_to_gemma(
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if geglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP


MODEL_TO_LIGER_KERNEL_PATCHING_FUNC = OrderedDict(
[
("LlamaForCausalLM", apply_liger_kernel_to_llama),
("MistralForCausalLM", apply_liger_kernel_to_mistral),
("MixtralForCausalLM", apply_liger_kernel_to_mixtral),
("GemmaForCausalLM", apply_liger_kernel_to_gemma),
]
)
Loading