Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bart] Move CLS rep extraction from EOS tokens to head classes #624

Merged
merged 1 commit into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
)
from transformers.utils import ModelOutput

from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition
from ..composition import (
AdapterCompositionBlock,
BatchSplit,
Parallel,
adjust_tensors_for_parallel,
parse_heads_from_composition,
)
from ..context import AdapterSetup, ForwardContext
from ..loading import PredictionHeadLoader
from ..methods.modeling import Activation_Function_Class
Expand Down Expand Up @@ -105,6 +111,21 @@ def get_output_embeddings(self):
def get_label_names(self):
return ["labels"]

def _get_cls_output(self, outputs, **kwargs):
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
elif kwargs.get("get_cls_from_eos_tokens", False):
x = outputs[0] # last hidden state
eos_mask = kwargs.get("eos_mask")
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_output = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_output = outputs[0][:, 0]

return cls_output


class ClassificationHead(PredictionHead):
def __init__(
Expand Down Expand Up @@ -134,10 +155,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -205,10 +223,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -271,10 +286,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=None, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
logits = logits.view(-1, self.config["num_choices"])
loss = None
Expand Down Expand Up @@ -476,10 +488,7 @@ def __init__(

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
Expand Down Expand Up @@ -800,6 +809,9 @@ def forward_head(
cls_output (torch.Tensor, optional): The classification output of the model.
attention_mask (torch.Tensor, optional): The attention mask of the model.
return_dict (bool): Whether or not to return a ``ModelOutput`` instead of a plain tuple.
get_cls_from_eos_tokens (bool):
If set to True, retrieve classifier token representations from the last <eos> token in the sequence.
Setting to True requires `eos_mask` to be passed as well.
**kwargs: Additional keyword arguments passed to the forward pass of the head.
"""
used_head_modules = self._get_used_heads(head_name)
Expand Down Expand Up @@ -846,10 +858,12 @@ def _get_head_input(outputs, cls_out, batch):
)
head_outputs = []
labels = kwargs.pop("labels", None)
eos_mask = kwargs.pop("eos_mask", None)
for i, head in enumerate(self.active_head):
head_module = self.heads[head]
batch_idx = range(sum(self.active_head.batch_sizes[:i]), sum(self.active_head.batch_sizes[: i + 1]))
kwargs["labels"] = labels[batch_idx] if labels is not None else None
kwargs["eos_mask"] = eos_mask[batch_idx] if eos_mask is not None else None
head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx)
# head_attention = attention_mask[batch_idx] if attention_mask is not None else None
head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs)
Expand Down
15 changes: 3 additions & 12 deletions src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...composition import adjust_tensors_for_parallel
from ...heads import (
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
Expand Down Expand Up @@ -102,23 +101,15 @@ def forward(
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_representation = x

head_outputs = self.forward_head(
outputs,
head_name=head,
cls_output=cls_representation,
attention_mask=attention_mask,
return_dict=return_dict,
get_cls_from_eos_tokens=True,
# `get_cls_from_eos_tokens` requires passing eos mask
eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None,
**kwargs,
)

Expand Down
Loading