Skip to content

Commit

Permalink
Multi-gpu support for prefix-tuning (#359)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Li <[email protected]>
  • Loading branch information
alexanderhanboli and Alexander Li committed Jul 5, 2022
1 parent d56e9a5 commit 74dd021
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/transformers/adapters/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import torch
from torch import nn

from ..modeling_utils import ModuleUtilsMixin
from .composition import AdapterCompositionBlock
from .configuration import PrefixTuningConfig
from .context import AdapterSetup, ForwardContext
from .layer import AdapterLayerBase
from .modeling import Activation_Function_Class


class PrefixTuning(nn.Module):
class PrefixTuning(nn.Module, ModuleUtilsMixin):
def __init__(
self,
n_layers: int,
Expand All @@ -25,7 +26,6 @@ def __init__(
self.n_embd_per_head = self.input_size // self.n_heads
self.config = config

self.input_tokens = torch.arange(self.config.prefix_length).long()
self.wte = nn.Embedding(self.config.prefix_length, self.input_size)
self.control_trans = nn.Sequential(
nn.Linear(self.input_size, self.config.bottleneck_size),
Expand All @@ -35,8 +35,8 @@ def __init__(
self.dropout = nn.Dropout(self.config.dropout)

def eject(self):
device = next(self.parameters()).device
input_tokens = self.input_tokens.unsqueeze(0).expand(1, -1).to(device)
input_tokens = torch.arange(self.config.prefix_length).long()
input_tokens = input_tokens.unsqueeze(0).expand(1, -1).to(self.device)
embs = self.wte(input_tokens)
key_values = self.control_trans(embs) # batch_size x prefix_length x n_layers*2*input_size
key_values = key_values.view(
Expand All @@ -46,8 +46,8 @@ def eject(self):
return key_values

def forward(self, batch_size):
device = next(self.parameters()).device
input_tokens = self.input_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
input_tokens = torch.arange(self.config.prefix_length).long()
input_tokens = input_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
embs = self.wte(input_tokens)
key_values = self.control_trans(embs) # batch_size x prefix_length x n_layers*2*input_size
key_values = key_values.view(
Expand All @@ -60,7 +60,7 @@ def forward(self, batch_size):
return key_values


class FlatPrefixTuning(nn.Module):
class FlatPrefixTuning(nn.Module, ModuleUtilsMixin):
def __init__(
self,
n_layers: int,
Expand All @@ -80,12 +80,11 @@ def __init__(
self.dropout = nn.Dropout(self.config.dropout)

def forward(self, batch_size):
device = next(self.parameters()).device
key_values = (
self.control_trans.unsqueeze(0)
.expand(batch_size, -1)
.view(batch_size, self.config.prefix_length, self.n_layers * 2, self.n_heads, self.n_embd_per_head)
.to(device)
.to(self.device)
) # *2 for key and value
key_values = self.dropout(key_values)
# n_layers * (2 x batch_size x n_heads x prefix_length x n_embd_per_head)
Expand Down

0 comments on commit 74dd021

Please sign in to comment.