From 2e2f3024685b6d761c8bf8ff087a426e035b7acd Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 7 Jan 2024 23:18:02 +0100 Subject: [PATCH] [Bart] Move CLS rep extraction from EOS tokens to head classes (#624) Fixes #494 and fixes #563. `BartAdapterModel` currently tries to extract CLS token representations from EOS tokens independent from the used heads (if at all). This causes weird issues for models not using CLS token based heads. This PR moves CLS token rep extraction into the specific head classes to only compute it when needed. Two (re-usable) new args to `forward_head()` are introduced: - `get_cls_from_eos_tokens`: to indicate a model uses CLS reps from EOS tokens - `eos_mask`: token mask to extract CLS reps (required when passing `get_cls_from_eos_tokens=True`) --- src/adapters/heads/base.py | 48 +++++++++++++++-------- src/adapters/models/bart/adapter_model.py | 15 ++----- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index d45897df10..fc4f6990b7 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -18,7 +18,13 @@ ) from transformers.utils import ModelOutput -from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition +from ..composition import ( + AdapterCompositionBlock, + BatchSplit, + Parallel, + adjust_tensors_for_parallel, + parse_heads_from_composition, +) from ..context import AdapterSetup, ForwardContext from ..loading import PredictionHeadLoader from ..methods.modeling import Activation_Function_Class @@ -105,6 +111,21 @@ def get_output_embeddings(self): def get_label_names(self): return ["labels"] + def _get_cls_output(self, outputs, **kwargs): + if self.config["use_pooler"]: + cls_output = kwargs.pop("pooled_output") + elif kwargs.get("get_cls_from_eos_tokens", False): + x = outputs[0] # last hidden state + eos_mask = kwargs.get("eos_mask") + (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + cls_output = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] + else: + cls_output = outputs[0][:, 0] + + return cls_output + class ClassificationHead(PredictionHead): def __init__( @@ -134,10 +155,7 @@ def __init__( def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs): if cls_output is None: - if self.config["use_pooler"]: - cls_output = kwargs.pop("pooled_output") - else: - cls_output = outputs[0][:, 0] + cls_output = self._get_cls_output(outputs, **kwargs) logits = super().forward(cls_output) loss = None labels = kwargs.pop("labels", None) @@ -205,10 +223,7 @@ def __init__( def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs): if cls_output is None: - if self.config["use_pooler"]: - cls_output = kwargs.pop("pooled_output") - else: - cls_output = outputs[0][:, 0] + cls_output = self._get_cls_output(outputs, **kwargs) logits = super().forward(cls_output) loss = None labels = kwargs.pop("labels", None) @@ -271,10 +286,7 @@ def __init__( def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=None, **kwargs): if cls_output is None: - if self.config["use_pooler"]: - cls_output = kwargs.pop("pooled_output") - else: - cls_output = outputs[0][:, 0] + cls_output = self._get_cls_output(outputs, **kwargs) logits = super().forward(cls_output) logits = logits.view(-1, self.config["num_choices"]) loss = None @@ -476,10 +488,7 @@ def __init__( def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs): if cls_output is None: - if self.config["use_pooler"]: - cls_output = kwargs.pop("pooled_output") - else: - cls_output = outputs[0][:, 0] + cls_output = self._get_cls_output(outputs, **kwargs) logits = super().forward(cls_output) loss = None labels = kwargs.pop("labels", None) @@ -800,6 +809,9 @@ def forward_head( cls_output (torch.Tensor, optional): The classification output of the model. attention_mask (torch.Tensor, optional): The attention mask of the model. return_dict (bool): Whether or not to return a ``ModelOutput`` instead of a plain tuple. + get_cls_from_eos_tokens (bool): + If set to True, retrieve classifier token representations from the last token in the sequence. + Setting to True requires `eos_mask` to be passed as well. **kwargs: Additional keyword arguments passed to the forward pass of the head. """ used_head_modules = self._get_used_heads(head_name) @@ -846,10 +858,12 @@ def _get_head_input(outputs, cls_out, batch): ) head_outputs = [] labels = kwargs.pop("labels", None) + eos_mask = kwargs.pop("eos_mask", None) for i, head in enumerate(self.active_head): head_module = self.heads[head] batch_idx = range(sum(self.active_head.batch_sizes[:i]), sum(self.active_head.batch_sizes[: i + 1])) kwargs["labels"] = labels[batch_idx] if labels is not None else None + kwargs["eos_mask"] = eos_mask[batch_idx] if eos_mask is not None else None head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx) # head_attention = attention_mask[batch_idx] if attention_mask is not None else None head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs) diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index ddb94e6fe9..ad75324fd1 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -10,7 +10,6 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...composition import adjust_tensors_for_parallel from ...heads import ( ClassificationHead, ModelWithFlexibleHeadsAdaptersMixin, @@ -102,23 +101,15 @@ def forward( ) # required e.g. for prompt tuning in all models kwargs["context"] = context - # sequence classification based on last token in sequence - x = outputs[0] # last hidden state - if input_ids is not None and x.shape[1] == input_ids.shape[1]: - eos_mask = input_ids.eq(self.config.eos_token_id) - (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) - if len(torch.unique(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] - else: - cls_representation = x head_outputs = self.forward_head( outputs, head_name=head, - cls_output=cls_representation, attention_mask=attention_mask, return_dict=return_dict, + get_cls_from_eos_tokens=True, + # `get_cls_from_eos_tokens` requires passing eos mask + eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None, **kwargs, )