Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Add EETQ support in PEFT #1675

Merged
merged 23 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docker/peft-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ RUN apt-get update && \
apt-get clean && \
rm -rf /var/lib/apt/lists*

# Add eetq for quantization testing
RUN source activate peft && \
python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git

# Activate the conda env and install transformers + accelerate from source
RUN source activate peft && \
python3 -m pip install -U --no-cache-dir \
Expand Down
36 changes: 36 additions & 0 deletions docs/source/developer_guides/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,42 @@ quantized_model = get_peft_model(quantized_model, peft_config)

You can refer to the [Google Colab](https://colab.research.google.com/drive/12GTp1FCj5_0SnnNQH18h_2XFh9vS_guX?usp=sharing) example for an overview of AQLM+LoRA finetuning.

## EETQ quantization

You can also perform LoRA fine-tuning on EETQ quantized models. [EETQ](https://github.com/NetEase-FuXi/EETQ) package offers simple and efficient way to perform 8-bit quantization, which is claimed to be faster than the LLM.int8() algorithm.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

```py
import torch
from transformers import EetqConfig

config = EetqConfig("int8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably requires the latest transformers, right? Maybe worth adding the min version?

```

Pass the `config` to the [`~transformers.AutoModelForCausalLM.from_pretrained`] method.

```py
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", quantization_config=config)
```

and create a `LoraConfig` and pass it to `get_peft_model`:

```py
from peft import LoraConfig, get_peft_model

config = LoraConfig(
r=16,
lora_alpha=8,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
```

## Next steps

If you're interested in learning more about quantization, the following may be helpful:
Expand Down
5 changes: 5 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,8 @@ def is_aqlm_available():
@lru_cache
def is_auto_awq_available():
return importlib.util.find_spec("awq") is not None


@lru_cache
def is_eetq_available():
return importlib.util.find_spec("eetq") is not None
7 changes: 6 additions & 1 deletion src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available

from .config import LoftQConfig, LoraConfig
from .gptq import QuantLinear
Expand All @@ -34,4 +34,9 @@ def __getattr__(name):

return Linear4bit

if (name == "EetqLoraLinear") and is_eetq_available():
from .eetq import EetqLoraLinear

return EetqLoraLinear

raise AttributeError(f"module {__name__} has no attribute {name}")
104 changes: 104 additions & 0 deletions src/peft/tuners/lora/eetq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional

import torch

from peft.import_utils import is_eetq_available
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer


if is_eetq_available():
from eetq import EetqLinear
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have lazy import as for bnb or is it not necessary for EETQ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is should we indent all the code below to be inside of if is_eetq_available():? Or is it not necessary because, unlike bnb, EETQ does not initialize cuda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does, let's indent it to be on the safe zon


class EetqLoraLinear(torch.nn.Module, LoraLayer):
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer)

# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def forward(self, x: torch.Tensor):
result = self.quant_linear_module(x)

if self.disable_adapters:
return result

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling
result = result + output
return result

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
raise AttributeError("Merging LoRA layers is not supported for Eetq layers.")

def unmerge(self) -> None:
raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.")

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep


def dispatch_eetq(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if is_eetq_available() and isinstance(target_base_layer, EetqLinear):
new_module = EetqLoraLinear(target, adapter_name, **kwargs)
target.weight = target_base_layer.weight

if hasattr(target, "bias"):
target.bias = target_base_layer.bias

return new_module
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
# Awq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif base_layer.__class__.__name__ == "EetqLinear":
# Eetq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")

Expand Down
5 changes: 4 additions & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
from .config import LoraConfig
from .eetq import dispatch_eetq
from .gptq import dispatch_gptq
from .layer import Conv2d, LoraLayer, dispatch_default
from .tp_layer import dispatch_megatron
Expand Down Expand Up @@ -288,7 +289,9 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):

dispatchers.append(dispatch_bnb_4bit)

dispatchers.extend([dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default])
dispatchers.extend(
[dispatch_eetq, dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default]
)

new_module = None
for dispatcher in dispatchers:
Expand Down
6 changes: 4 additions & 2 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,24 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq"

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False

if not is_gptq_quantized and not is_aqlm_quantized:
if not is_gptq_quantized and not is_aqlm_quantized and not is_eetq_quantized:
# cast all non INT8 parameters to fp32
for param in model.parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)

if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized) and use_gradient_checkpointing:
if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized or is_eetq_quantized) and use_gradient_checkpointing:
# When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
# For backward compatibility
Expand Down
Loading
Loading