Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer label names for training for flex head models #367

Merged
merged 3 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def build(self, model):
def get_output_embeddings(self):
return None # override for heads with output embeddings

def get_label_names(self):
return ["labels"]


class ClassificationHead(PredictionHead):
def __init__(
Expand Down Expand Up @@ -405,6 +408,9 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
outputs = (total_loss,) + outputs
return outputs

def get_label_names(self):
return ["start_positions", "end_positions"]


class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
"""
Expand Down
36 changes: 6 additions & 30 deletions src/transformers/adapters/trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import inspect
import os
import re
from typing import Callable, Dict, List, Optional, Tuple, Union

import datasets
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataset import Dataset

Expand Down Expand Up @@ -86,6 +83,11 @@ def __init__(
"Expected a model with an active adapter setup."
"If you want to fully finetune the model use the Trainer class."
)
if (self.label_names is None or len(self.label_names) < 1) and model.active_head is not None:
all_label_names = set()
for head in model._active_heads:
all_label_names |= set(model.heads[head].get_label_names())
self.label_names = list(all_label_names)

def create_optimizer(self):
"""
Expand Down Expand Up @@ -215,33 +217,6 @@ def _load_heads(self, resume_from_checkpoint):
):
self.model.load_head(os.path.join(resume_from_checkpoint, file_name))

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return dataset
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += ["label", "label_ids"]
self._signature_columns += self.label_names
columns = [k for k in self._signature_columns if k in dataset.column_names]
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
if len(ignored_columns) > 0:
dset_description = "" if description is None else f"in the {description} set "
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
)

if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
)
return dataset
else:
return dataset.remove_columns(ignored_columns)


class AdapterTrainerCallback(TrainerCallback):
def __init__(self, trainer):
Expand Down Expand Up @@ -275,6 +250,7 @@ def on_train_end(self, args: TrainingArguments, state: TrainerState, control: Tr
fusion_dir = os.path.join(state.best_model_checkpoint, fusion)
if os.path.exists(fusion_dir):
model.load_adapter_fusion(fusion_dir)
model.to(args.device)

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# apply adapter fusion weight regularization on the value matrix
Expand Down