Skip to content

Commit

Permalink
Enable parallel sequence generation with adapters (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Oct 27, 2022
1 parent 1bc2da4 commit 7855c9d
Show file tree
Hide file tree
Showing 18 changed files with 122 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/transformers/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ForwardContext:
# thread-local storage that holds a stack of active contexts
storage = threading.local()

context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions"]
context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"]

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class MultiHeadOutput(ModelOutput):
head_outputs: List[ModelOutput] = None
loss: Optional[torch.FloatTensor] = None

@property
def logits(self):
return torch.vstack([outputs["logits"] for outputs in self.head_outputs])

def __getitem__(self, k):
# with number indices the head output at that position is accessed
# e.g output[1] is equivalent to output.head_outputs[1]
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,11 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
return

context.adapters_parallelized = False
# Check if already parallelized in encoder
adapter_input_parallelized = kwargs.pop("adapter_input_parallelized", None)
if adapter_input_parallelized:
if active_adapters.parallel_channels > 1:
context.adapters_parallelized = True
# Add the shared parameters for the active adapters to the context
context.shared_parameters = {
name: param
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def forward(
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
Expand Down Expand Up @@ -139,6 +140,7 @@ def prepare_inputs_for_generation(
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

# Copied from BartForConditionalGeneration
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/adapters/models/bert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand All @@ -94,6 +95,24 @@ def forward(
# in case no head is used just return the output of the base model (including pooler output)
return outputs

# Copied from BertLMHeadModel
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/models/deberta/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/adapters/models/distilbert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

outputs = self.forward_head(
Expand All @@ -102,6 +103,24 @@ def forward(

return outputs

# Copied from RobertaForCausalLM
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/gpt2/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

batch_size = outputs[0].shape[0]
Expand Down Expand Up @@ -139,6 +140,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/gptj/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

batch_size = outputs[0].shape[0]
Expand Down Expand Up @@ -135,6 +136,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/mbart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def forward(
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
Expand Down Expand Up @@ -139,6 +140,7 @@ def prepare_inputs_for_generation(
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

# Copied from MBartForConditionalGeneration
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/adapters/models/roberta/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand All @@ -99,6 +100,24 @@ def forward(
# in case no head is used just return the output of the base model (including pooler output)
return outputs

# Copied from RobertaForCausalLM
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/t5/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
sequence_output = model_output[0]
# ToDo move head to device for parallel forward pass
Expand Down Expand Up @@ -118,7 +119,6 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs
):

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
Expand All @@ -132,6 +132,7 @@ def prepare_inputs_for_generation(
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}

# Copied from T5ForConditionalGeneration
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/models/vit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def forward(
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<0.8.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.distributed as dist
from torch import nn

from .adapters.composition import adjust_tensors_for_parallel
from .adapters.context import ForwardContext
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
Expand Down Expand Up @@ -1198,6 +1199,17 @@ def generate(
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor

# Pre-replicate inputs for parallel adapters to avoid issues within generation code
if (
hasattr(self.config, "adapters")
and self.config.adapters.active_setup
and self.config.adapters.active_setup.parallel_channels > 1
):
input_ids = input_ids.repeat(self.config.adapters.active_setup.parallel_channels, 1)
model_kwargs["adapter_input_parallelized"] = True
(attention_mask,) = adjust_tensors_for_parallel(input_ids, model_kwargs["attention_mask"])
model_kwargs["attention_mask"] = attention_mask

# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
Expand Down
28 changes: 28 additions & 0 deletions tests_adapters/test_adapter_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Trainer,
TrainingArguments,
)
from transformers.adapters import ADAPTER_MODEL_MAPPING
from transformers.adapters.composition import BatchSplit, Fuse, Parallel, Split, Stack, parse_composition
from transformers.testing_utils import require_torch, torch_device

Expand Down Expand Up @@ -245,6 +246,33 @@ def test_batch_split_with_heads(self):
)
)

def test_parallel_generate(self):
if self.config_class not in ADAPTER_MODEL_MAPPING or (
not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_seq2seq_lm_head")
and not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_causal_lm_head")
):
self.skipTest("No seq2seq or causal language model head")

model1 = AutoAdapterModel.from_config(self.config())
model1.add_adapter("adapter1")
model1.add_adapter("adapter2")
if hasattr(model1, "add_seq2seq_lm_head"):
model1.add_seq2seq_lm_head("adapter1")
model1.add_seq2seq_lm_head("adapter2")
else:
model1.add_causal_lm_head("adapter1")
model1.add_causal_lm_head("adapter2")
model1.set_active_adapters(Parallel("adapter1", "adapter2"))
model1.to(torch_device)

seq_output_length = 32

# Finally, also check if generation works properly
input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"]
input_ids = input_ids.to(torch_device)
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertLessEqual(generated.shape, (2, seq_output_length))


class ParallelTrainingMixin:
def create_twin_adapters(self, model, name):
Expand Down

0 comments on commit 7855c9d

Please sign in to comment.