Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable parallel sequence generation with adapters #436

Merged
merged 3 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ForwardContext:
# thread-local storage that holds a stack of active contexts
storage = threading.local()

context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions"]
context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"]

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class MultiHeadOutput(ModelOutput):
head_outputs: List[ModelOutput] = None
loss: Optional[torch.FloatTensor] = None

@property
def logits(self):
return torch.vstack([outputs["logits"] for outputs in self.head_outputs])

def __getitem__(self, k):
# with number indices the head output at that position is accessed
# e.g output[1] is equivalent to output.head_outputs[1]
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,11 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
return

context.adapters_parallelized = False
# Check if already parallelized in encoder
adapter_input_parallelized = kwargs.pop("adapter_input_parallelized", None)
if adapter_input_parallelized:
if active_adapters.parallel_channels > 1:
context.adapters_parallelized = True
# Add the shared parameters for the active adapters to the context
context.shared_parameters = {
name: param for name, param in self.shared_parameters.items() if name in active_adapters.flatten()
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def forward(
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
Expand Down Expand Up @@ -139,6 +140,7 @@ def prepare_inputs_for_generation(
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

# Copied from BartForConditionalGeneration
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/adapters/models/bert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand All @@ -94,6 +95,24 @@ def forward(
# in case no head is used just return the output of the base model (including pooler output)
return outputs

# Copied from BertLMHeadModel
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/models/deberta/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/adapters/models/distilbert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

outputs = self.forward_head(
Expand All @@ -102,6 +103,24 @@ def forward(

return outputs

# Copied from RobertaForCausalLM
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/gpt2/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

batch_size = outputs[0].shape[0]
Expand Down Expand Up @@ -139,6 +140,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/gptj/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

batch_size = outputs[0].shape[0]
Expand Down Expand Up @@ -135,6 +136,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/mbart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def forward(
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
Expand Down Expand Up @@ -139,6 +140,7 @@ def prepare_inputs_for_generation(
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

# Copied from MBartForConditionalGeneration
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/adapters/models/roberta/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand All @@ -99,6 +100,24 @@ def forward(
# in case no head is used just return the output of the base model (including pooler output)
return outputs

# Copied from RobertaForCausalLM
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/t5/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
sequence_output = model_output[0]
# ToDo move head to device for parallel forward pass
Expand Down Expand Up @@ -118,7 +119,6 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs
):

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
Expand All @@ -132,6 +132,7 @@ def prepare_inputs_for_generation(
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

# Copied from T5ForConditionalGeneration
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/models/vit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<0.8.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.distributed as dist
from torch import nn

from .adapters.composition import adjust_tensors_for_parallel
from .adapters.context import ForwardContext
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
Expand Down Expand Up @@ -1198,6 +1199,17 @@ def generate(
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor

# Pre-replicate inputs for parallel adapters to avoid issues within generation code
if (
hasattr(self.config, "adapters")
and self.config.adapters.active_setup
and self.config.adapters.active_setup.parallel_channels > 1
):
input_ids = input_ids.repeat(self.config.adapters.active_setup.parallel_channels, 1)
model_kwargs["adapter_input_parallelized"] = True
(attention_mask,) = adjust_tensors_for_parallel(input_ids, model_kwargs["attention_mask"])
model_kwargs["attention_mask"] = attention_mask

# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
Expand Down
28 changes: 28 additions & 0 deletions tests_adapters/test_adapter_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Trainer,
TrainingArguments,
)
from transformers.adapters import ADAPTER_MODEL_MAPPING
from transformers.adapters.composition import BatchSplit, Fuse, Parallel, Split, Stack, parse_composition
from transformers.testing_utils import require_torch, torch_device

Expand Down Expand Up @@ -245,6 +246,33 @@ def test_batch_split_with_heads(self):
)
)

def test_parallel_generate(self):
if self.config_class not in ADAPTER_MODEL_MAPPING or (
not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_seq2seq_lm_head")
and not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_causal_lm_head")
):
self.skipTest("No seq2seq or causal language model head")

model1 = AutoAdapterModel.from_config(self.config())
model1.add_adapter("adapter1")
model1.add_adapter("adapter2")
if hasattr(model1, "add_seq2seq_lm_head"):
model1.add_seq2seq_lm_head("adapter1")
model1.add_seq2seq_lm_head("adapter2")
else:
model1.add_causal_lm_head("adapter1")
model1.add_causal_lm_head("adapter2")
lenglaender marked this conversation as resolved.
Show resolved Hide resolved
model1.set_active_adapters(Parallel("adapter1", "adapter2"))
model1.to(torch_device)

seq_output_length = 32

# Finally, also check if generation works properly
input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"]
input_ids = input_ids.to(torch_device)
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertLessEqual(generated.shape, (2, seq_output_length))


class ParallelTrainingMixin:
def create_twin_adapters(self, model, name):
Expand Down