Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't throw error if ForwardContext is not available #730

Merged
merged 3 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,15 @@ def mean(self, states: List[BottleneckState], weights: torch.Tensor) -> Bottlene
def compose_single(self, adapter_setup: str, state: BottleneckState, lvl: int = 0) -> BottleneckState:
adapter_layer = self.adapters[adapter_setup]
context = ForwardContext.get_context()
output_gating = context.output_adapter_gating_scores if context is not None else False
layer_output = adapter_layer(
state.hidden_states,
residual_input=state.adapter_residual,
output_gating=context.output_adapter_gating_scores,
output_gating=output_gating,
)
hidden_states, up = layer_output[0], layer_output[2]
self._store_gating_score(adapter_setup, layer_output[-1])
if output_gating:
self._store_gating_score(adapter_setup, layer_output[-1])

return state._replace(hidden_states=hidden_states, bottleneck_up=up, last=adapter_setup)

Expand Down Expand Up @@ -246,14 +248,15 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0
up_list = torch.stack([state.bottleneck_up for state in children_states])
up_list = up_list.permute(1, 2, 0, 3)

output_fusion_attns = context.output_adapter_fusion_attentions if context is not None else False
fusion_output = self.adapter_fusion_layer[adapter_setup.name](
query,
up_list,
up_list,
state.adapter_residual,
output_attentions=context.output_adapter_fusion_attentions,
output_attentions=output_fusion_attns,
)
if context.output_adapter_fusion_attentions:
if output_fusion_attns:
hidden_states = fusion_output[0]
self._store_fusion_attentions(adapter_setup.name, fusion_output[-1])
else:
Expand Down
4 changes: 3 additions & 1 deletion src/adapters/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,11 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -437,6 +438,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.attention_adapters(hidden_states, residual, None)
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
):
if model is not None:
model_quantized = getattr(model, "is_quantized", False)
model.is_quantized = False
super().__init__(
model,
args,
Expand All @@ -55,6 +58,8 @@ def __init__(
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if model is not None:
model.is_quantized = model_quantized

if adapter_names is not None:
self.model.set_active_adapters(adapter_names)
Expand Down
86 changes: 85 additions & 1 deletion tests/test_adapter_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@
from tempfile import TemporaryDirectory

import torch
from datasets import Dataset

import adapters
from adapters import AutoAdapterModel
from adapters.composition import Fuse, Stack
from adapters.trainer import AdapterTrainer, logger
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
BertConfig,
BertForSequenceClassification,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
Trainer,
TrainingArguments,
)
from transformers.testing_utils import require_ray, slow
from transformers.testing_utils import require_bitsandbytes, require_ray, slow, torch_device


class TestAdapterTrainer(unittest.TestCase):
Expand Down Expand Up @@ -536,6 +541,85 @@ def model_init(trail=None):

trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, backend="ray", n_trials=2)

@parameterized.expand(["lora", "seq_bn"])
@require_bitsandbytes
def test_quantized_training(self, config):
model_name = "HuggingFaceM4/tiny-random-LlamaForCausalLM"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

dataset = Dataset.from_dict({"text": ["Hello, I'm a single sentence!", "This is another sentence."]})

def tokenize(element):
return tokenizer(
element["text"],
truncation=True,
max_length=512, # can set to longer values such as 2048
add_special_tokens=False,
)

dataset_tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])

model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False

adapters.init(model)
model.add_adapter("task")
model.train_adapter("task")

model.adapter_to("task", device=torch_device)

for param in model.parameters():
if param.ndim == 1:
# cast the small parameters (e.g. layernorm) to fp32 for stability
param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x):
return super().forward(x).to(torch.float32)

model.lm_head = CastOutputToFloat(model.lm_head)

self.assertEqual(Stack("task"), model.active_adapters)
with TemporaryDirectory() as tempdir:
training_args = TrainingArguments(
output_dir=tempdir,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
evaluation_strategy="steps",
logging_steps=10,
max_steps=5,
lr_scheduler_type="constant",
optim="paged_adamw_32bit",
learning_rate=0.0002,
group_by_length=True,
bf16=True,
max_grad_norm=0.3,
)
trainer = AdapterTrainer(
model=model,
tokenizer=tokenizer,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
train_dataset=dataset_tokenized,
args=training_args,
)

trainer.train()


if __name__ == "__main__":
unittest.main()
Loading