Skip to content

Commit

Permalink
Fix training T5 adapter models with Trainer (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Nov 15, 2023
1 parent 6ec40c8 commit a6055d0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
11 changes: 8 additions & 3 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin
from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin
from .llama.mixin_llama import LlamaModelAdapterMixin
from .t5.mixin_t5 import T5BlockAdaptersMixin, T5ModelAdaptersMixin, T5ModelAdaptersWithHeadsMixin
from .t5.mixin_t5 import (
T5BlockAdaptersMixin,
T5ForCondiditionalGenerationWithHeadsMixin,
T5ForQuestionAnsweringWithHeadsMixin,
T5ModelAdaptersMixin,
)
from .vit.mixin_vit import ViTIntermediateAdaptersMixin, ViTModelAdaptersMixin
from .xmod.mixin_xmod import XmodModelAdaptersMixin

Expand Down Expand Up @@ -57,8 +62,8 @@
"RobertaModel": BertModelAdaptersMixin,
"T5Block": T5BlockAdaptersMixin,
"T5Model": T5ModelAdaptersMixin,
"T5ForConditionalGeneration": T5ModelAdaptersWithHeadsMixin,
"T5ForQuestionAnswering": T5ModelAdaptersWithHeadsMixin,
"T5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin,
"T5ForQuestionAnswering": T5ForQuestionAnsweringWithHeadsMixin,
"T5EncoderModel": T5ModelAdaptersMixin,
"ViTIntermediate": ViTIntermediateAdaptersMixin,
"ViTModel": ViTModelAdaptersMixin,
Expand Down
30 changes: 27 additions & 3 deletions src/adapters/models/t5/mixin_t5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Iterable, Tuple
from typing import Iterable, Optional, Tuple

import torch
import torch.nn as nn

from ...methods.bottleneck import BottleneckLayer
Expand Down Expand Up @@ -99,5 +100,28 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
yield i, layer


class T5ModelAdaptersWithHeadsMixin(ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin):
pass
# Stating "labels" and "input_ids" explicitly is required for training using Trainer class
class T5ForCondiditionalGenerationWithHeadsMixin(ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin):
def forward(
self,
*args,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
):
return super().forward(*args, input_ids=input_ids, labels=labels, **kwargs)


# Stating "start_positions"/"end_positions" and "input_ids" explicitly is required for training using Trainer class
class T5ForQuestionAnsweringWithHeadsMixin(ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin):
def forward(
self,
*args,
input_ids: Optional[torch.LongTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
**kwargs,
):
return super().forward(
*args, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, **kwargs
)

0 comments on commit a6055d0

Please sign in to comment.