Skip to content

Commit

Permalink
FEAT: Add EETQ support in PEFT (#1675)
Browse files Browse the repository at this point in the history
* v1

* fix tests'

* fix unneeded change

* fix unneeded change

* fix unneeded change

* fix

* fix CI

* fix docker image

* fix docker image

* add docs

* lazy import

* raise when merge

* raise when merge

* Update eetq.py

* merge

* style

* add unmerge

* indent

* Update docs/source/developer_guides/quantization.md

Co-authored-by: Benjamin Bossan <[email protected]>

* add details about transformers

---------

Co-authored-by: Benjamin Bossan <[email protected]>
  • Loading branch information
younesbelkada and BenjaminBossan committed Apr 26, 2024
1 parent b1d6c77 commit d0fa70a
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docker/peft-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ RUN apt-get update && \
apt-get clean && \
rm -rf /var/lib/apt/lists*

# Add eetq for quantization testing
RUN source activate peft && \
python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git

# Activate the conda env and install transformers + accelerate from source
RUN source activate peft && \
python3 -m pip install -U --no-cache-dir \
Expand Down
36 changes: 36 additions & 0 deletions docs/source/developer_guides/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,42 @@ quantized_model = get_peft_model(quantized_model, peft_config)

You can refer to the [Google Colab](https://colab.research.google.com/drive/12GTp1FCj5_0SnnNQH18h_2XFh9vS_guX?usp=sharing) example for an overview of AQLM+LoRA finetuning.

## EETQ quantization

You can also perform LoRA fine-tuning on EETQ quantized models. [EETQ](https://github.com/NetEase-FuXi/EETQ) package offers simple and efficient way to perform 8-bit quantization, which is claimed to be faster than the `LLM.int8()` algorithm. First, make sure that you have a transformers version that is compatible with EETQ (e.g. by installing it from latest pypi or from source).

```py
import torch
from transformers import EetqConfig

config = EetqConfig("int8")
```

Pass the `config` to the [`~transformers.AutoModelForCausalLM.from_pretrained`] method.

```py
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", quantization_config=config)
```

and create a `LoraConfig` and pass it to `get_peft_model`:

```py
from peft import LoraConfig, get_peft_model

config = LoraConfig(
r=16,
lora_alpha=8,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
```

## Next steps

If you're interested in learning more about quantization, the following may be helpful:
Expand Down
5 changes: 5 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,8 @@ def is_aqlm_available():
@lru_cache
def is_auto_awq_available():
return importlib.util.find_spec("awq") is not None


@lru_cache
def is_eetq_available():
return importlib.util.find_spec("eetq") is not None
7 changes: 6 additions & 1 deletion src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available

from .config import LoftQConfig, LoraConfig
from .gptq import QuantLinear
Expand All @@ -34,4 +34,9 @@ def __getattr__(name):

return Linear4bit

if (name == "EetqLoraLinear") and is_eetq_available():
from .eetq import EetqLoraLinear

return EetqLoraLinear

raise AttributeError(f"module {__name__} has no attribute {name}")
104 changes: 104 additions & 0 deletions src/peft/tuners/lora/eetq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional

import torch

from peft.import_utils import is_eetq_available
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer


if is_eetq_available():
from eetq import EetqLinear

class EetqLoraLinear(torch.nn.Module, LoraLayer):
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer)

# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def forward(self, x: torch.Tensor):
result = self.quant_linear_module(x)

if self.disable_adapters:
return result

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling
result = result + output
return result

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
raise AttributeError("Merging LoRA layers is not supported for Eetq layers.")

def unmerge(self) -> None:
raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.")

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep


def dispatch_eetq(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if is_eetq_available() and isinstance(target_base_layer, EetqLinear):
new_module = EetqLoraLinear(target, adapter_name, **kwargs)
target.weight = target_base_layer.weight

if hasattr(target, "bias"):
target.bias = target_base_layer.bias

return new_module
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
# Awq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif base_layer.__class__.__name__ == "EetqLinear":
# Eetq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")

Expand Down
5 changes: 4 additions & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
from .config import LoraConfig
from .eetq import dispatch_eetq
from .gptq import dispatch_gptq
from .layer import Conv2d, LoraLayer, dispatch_default
from .tp_layer import dispatch_megatron
Expand Down Expand Up @@ -288,7 +289,9 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):

dispatchers.append(dispatch_bnb_4bit)

dispatchers.extend([dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default])
dispatchers.extend(
[dispatch_eetq, dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default]
)

new_module = None
for dispatcher in dispatchers:
Expand Down
6 changes: 4 additions & 2 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,24 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq"

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False

if not is_gptq_quantized and not is_aqlm_quantized:
if not is_gptq_quantized and not is_aqlm_quantized and not is_eetq_quantized:
# cast all non INT8 parameters to fp32
for param in model.parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)

if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized) and use_gradient_checkpointing:
if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized or is_eetq_quantized) and use_gradient_checkpointing:
# When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
# For backward compatibility
Expand Down
Loading

0 comments on commit d0fa70a

Please sign in to comment.