Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with embedding training #386

Merged
merged 3 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/transformers/adapters/mixins/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)


class BartEncoderLayerAdaptersMixin:
Expand Down Expand Up @@ -48,3 +54,7 @@ def _init_adapter_modules(self):
self.enable_invertible_adapters = self.encoder.enable_invertible_adapters
self.invertible_adapters_forward = self.encoder.invertible_adapters_forward
super()._init_adapter_modules()


class BartModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
12 changes: 11 additions & 1 deletion src/transformers/adapters/mixins/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -32,3 +38,7 @@ class BertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, Mo
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.encoder.layer):
yield i, layer


class BertModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
12 changes: 11 additions & 1 deletion src/transformers/adapters/mixins/distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)


class DistilBertTransfomerBlockAdaptersMixin:
Expand All @@ -22,3 +28,7 @@ class DistilBertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMix
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.transformer.layer):
yield i, layer


class DistilBertModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
12 changes: 11 additions & 1 deletion src/transformers/adapters/mixins/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)


class GPT2DecoderBlockAdaptersMixin:
Expand All @@ -20,3 +26,7 @@ class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, Mod
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.base_model.h):
yield i, layer


class GPT2ModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
12 changes: 11 additions & 1 deletion src/transformers/adapters/mixins/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin
from ..model_mixin import (
EmbeddingAdaptersMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)


class T5SelfAttentionLayerAdaptersMixin(AdapterLayer):
Expand Down Expand Up @@ -45,3 +50,8 @@ def _init_adapter_modules(self):
self.invertible_adapters_forward = self.encoder.invertible_adapters_forward
self.delete_invertible_adapter = self.encoder.delete_invertible_adapter
super()._init_adapter_modules()


# EmbeddingAdaptersWrapperMixin not required here as base and heads model are identical
class T5ModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
pass
63 changes: 29 additions & 34 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def add_embeddings(self, name, tokenizer, reference_embedding=None, reference_to
raise ValueError("An embedding with the name {} already exists".format(name))
if embedding_dim is None:
embedding_dim = self.config.hidden_size
embedding = nn.Embedding(tokenizer.vocab_size, embedding_dim)
embedding = nn.Embedding(len(tokenizer), embedding_dim)
embedding.requires_grad_(False)
if (reference_embedding is not None and reference_tokenizer is None) or (
reference_tokenizer is not None and reference_embedding is None
Expand All @@ -167,7 +167,9 @@ def add_embeddings(self, name, tokenizer, reference_embedding=None, reference_to
for t in tokens:
idx_reference = reference_vocab[t]
idx = vocab[t]
embedding.weight[idx] = self.loaded_embeddings[reference_embedding].weight[idx_reference].clone()
embedding.weight[idx] = (
self.loaded_embeddings[reference_embedding].weight[idx_reference].detach().clone()
)
embedding.train(False)
self.loaded_embeddings[name] = embedding
self.set_active_embeddings(name)
Expand Down Expand Up @@ -224,6 +226,31 @@ def active_embeddings(self):
return self._active_embedding


class EmbeddingAdaptersWrapperMixin:
def load_embeddings(self, path: str, name: str):
return self.base_model.load_embeddings(path, name)

def add_embeddings(self, name, tokenizer, reference_embedding=None, reference_tokenizer=None, embedding_dim=None):
return self.base_model.add_embeddings(name, tokenizer, reference_embedding, reference_tokenizer, embedding_dim)

def delete_embeddings(self, name):
return self.base_model.delete_embeddings(name)

def save_embeddings(self, path, name, tokenizer=None):
return self.base_model.save_embeddings(path, name, tokenizer)

def set_active_embeddings(self, name):
return self.base_model.set_active_embeddings(name)

@property
def active_embeddings(self):
return self.base_model.active_embeddings

@property
def loaded_embeddings(self):
return self.base_model.loaded_embeddings


class ModelAdaptersMixin(PushAdapterToHubMixin, ABC):
"""Mixin for transformer models adding support for loading/ saving adapters."""

Expand Down Expand Up @@ -1128,35 +1155,3 @@ def get_adapter(self, name):
return super().get_adapter(name)
else:
return self.base_model.get_adapter(name)

def load_embeddings(self, path: str, name: str):
if self.base_model is self:
return super().load_embeddings(path, name)
else:
return self.base_model.load_embeddings(path, name)

def save_embeddings(self, path, name, tokenizer=None):
if self.base_model is self:
return super().save_embeddings(path, name, tokenizer)
else:
return self.base_model.save_embeddings(path, name, tokenizer)

def add_embeddings(self, name, tokenizer, reference_embedding=None, reference_tokenizer=None, embedding_dim=None):
if self.base_model is None:
return super().add_embeddings(name, tokenizer, reference_embedding, reference_tokenizer, embedding_dim)
else:
return self.base_model.add_embeddings(
name, tokenizer, reference_embedding, reference_tokenizer, embedding_dim
)

def set_active_embeddings(self, name):
if self.base_model is None:
return super().set_active_embeddings(name)
else:
return self.base_model.set_active_embeddings(name)

def delete_embeddings(self, name):
if self.base_model is None:
return super().delete_embeddings(name)
else:
return self.base_model.delete_embeddings(name)
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
QuestionAnsweringHead,
Seq2SeqLMHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING
)
class BartAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel):
class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/bert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
QuestionAnsweringHead,
TaggingHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"""Bert Model transformer with the option to add multiple flexible heads on top.""",
BERT_START_DOCSTRING,
)
class BertAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BertPreTrainedModel):
class BertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/deberta/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
QuestionAnsweringHead,
TaggingHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"""Deberta Model transformer with the option to add multiple flexible heads on top.""",
)
class DebertaAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, DebertaPreTrainedModel):
class DebertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"cls.predictions.bias"]

def __init__(self, config):
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/adapters/models/debertaV2/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
QuestionAnsweringHead,
TaggingHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"""Deberta v2 Model transformer with the option to add multiple flexible heads on top.""",
)
class DebertaV2AdapterModel(ModelWithFlexibleHeadsAdaptersMixin, DebertaV2PreTrainedModel):
class DebertaV2AdapterModel(
EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DebertaV2PreTrainedModel
):
_keys_to_ignore_on_load_unexpected = [r"cls.predictions.bias"]

def __init__(self, config):
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/adapters/models/distilbert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
QuestionAnsweringHead,
TaggingHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"""DistilBert Model transformer with the option to add multiple flexible heads on top.""",
DISTILBERT_START_DOCSTRING,
)
class DistilBertAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, DistilBertPreTrainedModel):
class DistilBertAdapterModel(
EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DistilBertPreTrainedModel
):
def __init__(self, config):
super().__init__(config)
self.distilbert = DistilBertModel(config)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/gpt2/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MultiLabelClassificationHead,
TaggingHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


logger = logging.getLogger(__name__)
Expand All @@ -29,7 +30,7 @@
""",
GPT2_START_DOCSTRING,
)
class GPT2AdapterModel(ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel):
class GPT2AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model(config)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/mbart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
QuestionAnsweringHead,
Seq2SeqLMHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"MBART Model with the option to add multiple flexible prediction heads on top.", MBART_START_DOCSTRING
)
class MBartAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel):
class MBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel):
def __init__(self, config: MBartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = MBartModel(config)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/roberta/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
QuestionAnsweringHead,
TaggingHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"""Roberta Model transformer with the option to add multiple flexible heads on top.""",
ROBERTA_START_DOCSTRING,
)
class RobertaAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, RobertaPreTrainedModel):
class RobertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/t5/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from ....models.t5.modeling_t5 import T5_INPUTS_DOCSTRING, T5_START_DOCSTRING, T5Model, T5PreTrainedModel
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...heads import ModelWithFlexibleHeadsAdaptersMixin, Seq2SeqLMHead
from ...model_mixin import EmbeddingAdaptersWrapperMixin


logger = logging.getLogger(__name__)


@add_start_docstrings("T5 Model with the option to add multiple flexible prediction heads on top.", T5_START_DOCSTRING)
class T5AdapterModel(ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel):
class T5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel):
def __init__(self, config):
super().__init__(config)

Expand Down
11 changes: 6 additions & 5 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
BartDecoderLayerAdaptersMixin,
BartEncoderLayerAdaptersMixin,
BartModelAdaptersMixin,
BartModelWithHeadsAdaptersMixin,
)
from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin
from ...adapters.model_mixin import InvertibleAdaptersMixin
from ...adapters.prefix_tuning import PrefixTuningShim
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -1302,7 +1303,7 @@ def forward(
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class BartForConditionalGeneration(ModelWithHeadsAdaptersMixin, BartPretrainedModel):
class BartForConditionalGeneration(BartModelWithHeadsAdaptersMixin, BartPretrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]

Expand Down Expand Up @@ -1472,7 +1473,7 @@ def _reorder_cache(past, beam_idx):
""",
BART_START_DOCSTRING,
)
class BartForSequenceClassification(ModelWithHeadsAdaptersMixin, BartPretrainedModel):
class BartForSequenceClassification(BartModelWithHeadsAdaptersMixin, BartPretrainedModel):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
Expand Down Expand Up @@ -1600,7 +1601,7 @@ def forward(
""",
BART_START_DOCSTRING,
)
class BartForQuestionAnswering(ModelWithHeadsAdaptersMixin, BartPretrainedModel):
class BartForQuestionAnswering(BartModelWithHeadsAdaptersMixin, BartPretrainedModel):
def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -1737,7 +1738,7 @@ def get_input_embeddings(self):
return self.decoder.get_input_embeddings()


class BartForCausalLM(ModelWithHeadsAdaptersMixin, BartPretrainedModel):
class BartForCausalLM(BartModelWithHeadsAdaptersMixin, BartPretrainedModel):
def __init__(self, config):
super().__init__(config)
decoder_config = copy.deepcopy(config)
Expand Down
Loading