From c8ed222555607399858e92e806d22d78267f39b0 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Thu, 26 Sep 2024 10:18:46 +0800 Subject: [PATCH] add liger_kernel support --- gptqmodel/models/auto.py | 2 ++ gptqmodel/models/base.py | 10 ++++++++++ requirements.txt | 1 + 3 files changed, 13 insertions(+) diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index db01bc1e..9a350028 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -102,6 +102,7 @@ def from_pretrained( pretrained_model_name_or_path: str, quantize_config: QuantizeConfig, trust_remote_code: bool = False, + use_liger_kernel: bool = False, **model_init_kwargs, ) -> BaseGPTQModel: model_type = check_and_get_model_type(pretrained_model_name_or_path, trust_remote_code) @@ -109,6 +110,7 @@ def from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, quantize_config=quantize_config, trust_remote_code=trust_remote_code, + use_liger_kernel=use_liger_kernel, **model_init_kwargs, ) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 50ed928d..a4bae076 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -993,6 +993,7 @@ def from_pretrained( pretrained_model_name_or_path: str, quantize_config: QuantizeConfig, trust_remote_code: bool = False, + use_liger_kernel: bool = False, torch_dtype: [str | torch.dtype] = "auto", **model_init_kwargs, ): @@ -1031,6 +1032,15 @@ def skip(*args, **kwargs): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs) + if use_liger_kernel: + from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN + + apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(config.model_type, None) + if apply_fn is None: + raise ValueError(f"apply_fn is not defined for model type {config.model_type}") + + apply_fn() + if torch_dtype == "auto": torch_dtype = auto_dtype_from_config(config) elif not isinstance(torch_dtype, torch.dtype): diff --git a/requirements.txt b/requirements.txt index 835f702a..bac956cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ intel_extension_for_transformers>=1.4.2 auto-round==0.3 huggingface-hub>=0.24.2 lm_eval==0.4.3 +liger-kernel>=0.3.0 \ No newline at end of file