Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-gpu support for prefix-tuning #359

Merged
merged 7 commits into from
Jul 5, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/transformers/adapters/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from .context import AdapterSetup, ForwardContext
from .layer import AdapterLayerBase
from .modeling import Activation_Function_Class
from ..modeling_utils import ModuleUtilsMixin


class PrefixTuning(nn.Module):
class PrefixTuning(nn.Module, ModuleUtilsMixin):
def __init__(
self,
n_layers: int,
Expand All @@ -25,7 +26,6 @@ def __init__(
self.n_embd_per_head = self.input_size // self.n_heads
self.config = config

self.input_tokens = torch.arange(self.config.prefix_length).long()
self.wte = nn.Embedding(self.config.prefix_length, self.input_size)
self.control_trans = nn.Sequential(
nn.Linear(self.input_size, self.config.bottleneck_size),
Expand All @@ -35,8 +35,8 @@ def __init__(
self.dropout = nn.Dropout(self.config.dropout)

def eject(self):
device = next(self.parameters()).device
input_tokens = self.input_tokens.unsqueeze(0).expand(1, -1).to(device)
input_tokens = torch.arange(self.config.prefix_length).long()
input_tokens = input_tokens.unsqueeze(0).expand(1, -1).to(self.device)
embs = self.wte(input_tokens)
key_values = self.control_trans(embs) # batch_size x prefix_length x n_layers*2*input_size
key_values = key_values.view(
Expand All @@ -46,8 +46,8 @@ def eject(self):
return key_values

def forward(self, batch_size):
device = next(self.parameters()).device
input_tokens = self.input_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
input_tokens = torch.arange(self.config.prefix_length).long()
input_tokens = input_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
embs = self.wte(input_tokens)
key_values = self.control_trans(embs) # batch_size x prefix_length x n_layers*2*input_size
key_values = key_values.view(
Expand Down Expand Up @@ -80,12 +80,11 @@ def __init__(
self.dropout = nn.Dropout(self.config.dropout)

def forward(self, batch_size):
device = next(self.parameters()).device
key_values = (
self.control_trans.unsqueeze(0)
.expand(batch_size, -1)
.view(batch_size, self.config.prefix_length, self.n_layers * 2, self.n_heads, self.n_embd_per_head)
.to(device)
.to(self.device)
alexanderhanboli marked this conversation as resolved.
Show resolved Hide resolved
) # *2 for key and value
key_values = self.dropout(key_values)
# n_layers * (2 x batch_size x n_heads x prefix_length x n_embd_per_head)
Expand Down