From 8702c2608929e32d9a5ea748137ee5d97d2f595d Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 15 Jun 2022 10:24:12 +0200 Subject: [PATCH] Infer label names for training for flex head models (#367) --- src/transformers/adapters/heads/base.py | 6 +++++ src/transformers/adapters/trainer.py | 36 +++++-------------------- 2 files changed, 12 insertions(+), 30 deletions(-) diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index 87dcf9772d..3ad8f3e5ae 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -91,6 +91,9 @@ def build(self, model): def get_output_embeddings(self): return None # override for heads with output embeddings + def get_label_names(self): + return ["labels"] + class ClassificationHead(PredictionHead): def __init__( @@ -405,6 +408,9 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal outputs = (total_loss,) + outputs return outputs + def get_label_names(self): + return ["start_positions", "end_positions"] + class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin): """ diff --git a/src/transformers/adapters/trainer.py b/src/transformers/adapters/trainer.py index 39374c5517..f5e5a42049 100644 --- a/src/transformers/adapters/trainer.py +++ b/src/transformers/adapters/trainer.py @@ -1,11 +1,8 @@ -import inspect import os import re from typing import Callable, Dict, List, Optional, Tuple, Union -import datasets import torch -from packaging import version from torch import nn from torch.utils.data.dataset import Dataset @@ -86,6 +83,11 @@ def __init__( "Expected a model with an active adapter setup." "If you want to fully finetune the model use the Trainer class." ) + if (self.label_names is None or len(self.label_names) < 1) and model.active_head is not None: + all_label_names = set() + for head in model._active_heads: + all_label_names |= set(model.heads[head].get_label_names()) + self.label_names = list(all_label_names) def create_optimizer(self): """ @@ -215,33 +217,6 @@ def _load_heads(self, resume_from_checkpoint): ): self.model.load_head(os.path.join(resume_from_checkpoint, file_name)) - def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): - if not self.args.remove_unused_columns: - return dataset - if self._signature_columns is None: - # Inspect model forward signature to keep only the arguments it accepts. - signature = inspect.signature(self.model.forward) - self._signature_columns = list(signature.parameters.keys()) - # Labels may be named label or label_ids, the default data collator handles that. - self._signature_columns += ["label", "label_ids"] - self._signature_columns += self.label_names - columns = [k for k in self._signature_columns if k in dataset.column_names] - ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) - if len(ignored_columns) > 0: - dset_description = "" if description is None else f"in the {description} set " - logger.info( - f"The following columns {dset_description} don't have a corresponding argument in " - f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." - ) - - if version.parse(datasets.__version__) < version.parse("1.4.0"): - dataset.set_format( - type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] - ) - return dataset - else: - return dataset.remove_columns(ignored_columns) - class AdapterTrainerCallback(TrainerCallback): def __init__(self, trainer): @@ -275,6 +250,7 @@ def on_train_end(self, args: TrainingArguments, state: TrainerState, control: Tr fusion_dir = os.path.join(state.best_model_checkpoint, fusion) if os.path.exists(fusion_dir): model.load_adapter_fusion(fusion_dir) + model.to(args.device) def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): # apply adapter fusion weight regularization on the value matrix