diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index c82fc14239..37e5b76fc9 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -52,6 +52,10 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists* +# Add eetq for quantization testing +RUN source activate peft && \ + python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git + # Activate the conda env and install transformers + accelerate from source RUN source activate peft && \ python3 -m pip install -U --no-cache-dir \ diff --git a/docs/source/developer_guides/quantization.md b/docs/source/developer_guides/quantization.md index 93a6aaacbc..702dee963a 100644 --- a/docs/source/developer_guides/quantization.md +++ b/docs/source/developer_guides/quantization.md @@ -128,6 +128,42 @@ quantized_model = get_peft_model(quantized_model, peft_config) You can refer to the [Google Colab](https://colab.research.google.com/drive/12GTp1FCj5_0SnnNQH18h_2XFh9vS_guX?usp=sharing) example for an overview of AQLM+LoRA finetuning. +## EETQ quantization + +You can also perform LoRA fine-tuning on EETQ quantized models. [EETQ](https://github.com/NetEase-FuXi/EETQ) package offers simple and efficient way to perform 8-bit quantization, which is claimed to be faster than the `LLM.int8()` algorithm. First, make sure that you have a transformers version that is compatible with EETQ (e.g. by installing it from latest pypi or from source). + +```py +import torch +from transformers import EetqConfig + +config = EetqConfig("int8") +``` + +Pass the `config` to the [`~transformers.AutoModelForCausalLM.from_pretrained`] method. + +```py +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", quantization_config=config) +``` + +and create a `LoraConfig` and pass it to `get_peft_model`: + +```py +from peft import LoraConfig, get_peft_model + +config = LoraConfig( + r=16, + lora_alpha=8, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" +) + +model = get_peft_model(model, config) +``` + ## Next steps If you're interested in learning more about quantization, the following may be helpful: diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 6799058e0c..a1acf484ab 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -77,3 +77,8 @@ def is_aqlm_available(): @lru_cache def is_auto_awq_available(): return importlib.util.find_spec("awq") is not None + + +@lru_cache +def is_eetq_available(): + return importlib.util.find_spec("eetq") is not None diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index 3115fff724..2a0bce2a5f 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available from .config import LoftQConfig, LoraConfig from .gptq import QuantLinear @@ -34,4 +34,9 @@ def __getattr__(name): return Linear4bit + if (name == "EetqLoraLinear") and is_eetq_available(): + from .eetq import EetqLoraLinear + + return EetqLoraLinear + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/lora/eetq.py b/src/peft/tuners/lora/eetq.py new file mode 100644 index 0000000000..6bf42c6814 --- /dev/null +++ b/src/peft/tuners/lora/eetq.py @@ -0,0 +1,104 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional + +import torch + +from peft.import_utils import is_eetq_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer + + +if is_eetq_available(): + from eetq import EetqLinear + + class EetqLoraLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor): + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + return result + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + raise AttributeError("Merging LoRA layers is not supported for Eetq layers.") + + def unmerge(self) -> None: + raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.") + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_eetq( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_eetq_available() and isinstance(target_base_layer, EetqLinear): + new_module = EetqLoraLinear(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + + return new_module diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 024257d182..689b921371 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -77,6 +77,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": # Awq layers in_features, out_features = base_layer.in_features, base_layer.out_features + elif base_layer.__class__.__name__ == "EetqLinear": + # Eetq layers + in_features, out_features = base_layer.in_features, base_layer.out_features else: raise ValueError(f"Unsupported layer type {type(base_layer)}") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 4ed41012d8..6b3fbc6a69 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -48,6 +48,7 @@ from .aqlm import dispatch_aqlm from .awq import dispatch_awq from .config import LoraConfig +from .eetq import dispatch_eetq from .gptq import dispatch_gptq from .layer import Conv2d, LoraLayer, dispatch_default from .tp_layer import dispatch_megatron @@ -288,7 +289,9 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): dispatchers.append(dispatch_bnb_4bit) - dispatchers.extend([dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default]) + dispatchers.extend( + [dispatch_eetq, dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default] + ) new_module = None for dispatcher in dispatchers: diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 503db7d9d3..03f6507975 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -95,6 +95,8 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm" + is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq" + if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {} @@ -102,7 +104,7 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad # freeze base model's layers param.requires_grad = False - if not is_gptq_quantized and not is_aqlm_quantized: + if not is_gptq_quantized and not is_aqlm_quantized and not is_eetq_quantized: # cast all non INT8 parameters to fp32 for param in model.parameters(): if ( @@ -110,7 +112,7 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad ) and param.__class__.__name__ != "Params4bit": param.data = param.data.to(torch.float32) - if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized) and use_gradient_checkpointing: + if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized or is_eetq_quantized) and use_gradient_checkpointing: # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]: # For backward compatibility diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index a77b916220..91b09a0fbd 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -60,6 +60,7 @@ require_auto_awq, require_auto_gptq, require_bitsandbytes, + require_eetq, require_optimum, require_torch_gpu, require_torch_multi_gpu, @@ -2072,6 +2073,153 @@ def test_causal_lm_training_multi_gpu(self): assert trainer.state.log_history[-1]["train_loss"] is not None +@require_torch_gpu +@require_eetq +class PeftEetqGPUTests(unittest.TestCase): + r""" + EETQ + peft tests + """ + + def setUp(self): + self.causal_lm_model_id = "facebook/opt-125m" + self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + + def tearDown(self): + r""" + Efficient mechanism to free GPU memory after each test. Based on + https://github.com/huggingface/transformers/issues/21094 + """ + gc.collect() + torch.cuda.empty_cache() + + def _check_inference_finite(self, model, batch): + # try inference without Trainer class + training = model.training + model.eval() + output = model(**batch.to(model.device)) + assert torch.isfinite(output.logits).all() + model.train(training) + + @pytest.mark.single_gpu_tests + def test_causal_lm_training_eetq(self): + r""" + Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set + correctly. + """ + from transformers import EetqConfig + + with tempfile.TemporaryDirectory() as tmp_dir: + quantization_config = EetqConfig("int8") + + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, device_map="auto", quantization_config=quantization_config + ) + + model = prepare_model_for_kbit_training(model) + + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + @require_torch_multi_gpu + def test_causal_lm_training_multi_gpu_eetq(self): + r""" + Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set + correctly. + """ + from transformers import EetqConfig + + with tempfile.TemporaryDirectory() as tmp_dir: + quantization_config = EetqConfig("int8") + + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=quantization_config, + ) + + assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count())) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)] LORA_PARAMS = { diff --git a/tests/testing_utils.py b/tests/testing_utils.py index eee33b5a67..28ba3a2b32 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -18,7 +18,13 @@ import pytest import torch -from peft.import_utils import is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, is_optimum_available +from peft.import_utils import ( + is_aqlm_available, + is_auto_awq_available, + is_auto_gptq_available, + is_eetq_available, + is_optimum_available, +) def require_torch_gpu(test_case): @@ -75,6 +81,13 @@ def require_auto_awq(test_case): return unittest.skipUnless(is_auto_awq_available(), "test requires auto-awq")(test_case) +def require_eetq(test_case): + """ + Decorator marking a test that requires eetq. These tests are skipped when eetq isn't installed. + """ + return unittest.skipUnless(is_eetq_available(), "test requires eetq")(test_case) + + def require_optimum(test_case): """ Decorator marking a test that requires optimum. These tests are skipped when optimum isn't installed.