Skip to content

Commit

Permalink
[Bart] Move CLS rep extraction from EOS tokens to head classes (#624)
Browse files Browse the repository at this point in the history
Fixes #494 and fixes #563.

`BartAdapterModel` currently tries to extract CLS token representations
from EOS tokens independent from the used heads (if at all). This causes
weird issues for models not using CLS token based heads.

This PR moves CLS token rep extraction into the specific head classes to
only compute it when needed. Two (re-usable) new args to
`forward_head()` are introduced:
- `get_cls_from_eos_tokens`: to indicate a model uses CLS reps from EOS
tokens
- `eos_mask`: token mask to extract CLS reps (required when passing
`get_cls_from_eos_tokens=True`)
  • Loading branch information
calpt authored Jan 7, 2024
1 parent fea1684 commit 2e2f302
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 29 deletions.
48 changes: 31 additions & 17 deletions src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
)
from transformers.utils import ModelOutput

from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition
from ..composition import (
AdapterCompositionBlock,
BatchSplit,
Parallel,
adjust_tensors_for_parallel,
parse_heads_from_composition,
)
from ..context import AdapterSetup, ForwardContext
from ..loading import PredictionHeadLoader
from ..methods.modeling import Activation_Function_Class
Expand Down Expand Up @@ -105,6 +111,21 @@ def get_output_embeddings(self):
def get_label_names(self):
return ["labels"]

def _get_cls_output(self, outputs, **kwargs):
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
elif kwargs.get("get_cls_from_eos_tokens", False):
x = outputs[0] # last hidden state
eos_mask = kwargs.get("eos_mask")
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_output = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_output = outputs[0][:, 0]

return cls_output


class ClassificationHead(PredictionHead):
def __init__(
Expand Down Expand Up @@ -134,10 +155,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -205,10 +223,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -271,10 +286,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=None, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
logits = logits.view(-1, self.config["num_choices"])
loss = None
Expand Down Expand Up @@ -476,10 +488,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -800,6 +809,9 @@ def forward_head(
cls_output (torch.Tensor, optional): The classification output of the model.
attention_mask (torch.Tensor, optional): The attention mask of the model.
return_dict (bool): Whether or not to return a ``ModelOutput`` instead of a plain tuple.
get_cls_from_eos_tokens (bool):
If set to True, retrieve classifier token representations from the last <eos> token in the sequence.
Setting to True requires `eos_mask` to be passed as well.
**kwargs: Additional keyword arguments passed to the forward pass of the head.
"""
used_head_modules = self._get_used_heads(head_name)
Expand Down Expand Up @@ -846,10 +858,12 @@ def _get_head_input(outputs, cls_out, batch):
)
head_outputs = []
labels = kwargs.pop("labels", None)
eos_mask = kwargs.pop("eos_mask", None)
for i, head in enumerate(self.active_head):
head_module = self.heads[head]
batch_idx = range(sum(self.active_head.batch_sizes[:i]), sum(self.active_head.batch_sizes[: i + 1]))
kwargs["labels"] = labels[batch_idx] if labels is not None else None
kwargs["eos_mask"] = eos_mask[batch_idx] if eos_mask is not None else None
head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx)
# head_attention = attention_mask[batch_idx] if attention_mask is not None else None
head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs)
Expand Down
15 changes: 3 additions & 12 deletions src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...composition import adjust_tensors_for_parallel
from ...heads import (
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
Expand Down Expand Up @@ -102,23 +101,15 @@ def forward(
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_representation = x

head_outputs = self.forward_head(
outputs,
head_name=head,
cls_output=cls_representation,
attention_mask=attention_mask,
return_dict=return_dict,
get_cls_from_eos_tokens=True,
# `get_cls_from_eos_tokens` requires passing eos mask
eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None,
**kwargs,
)

Expand Down

0 comments on commit 2e2f302

Please sign in to comment.