Skip to content

Commit

Permalink
Separate TransformerEmbedding layer
Browse files Browse the repository at this point in the history
Make it easier to chop Transformer into pieces for PP
  • Loading branch information
wconstab committed Feb 2, 2024
1 parent 705e3e0 commit d310f97
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,40 @@ def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerEmbedding(nn.Module):
def __init__(self, params: ModelArgs):
"""
Initialize the embedding module.
"""
super().__init__()
self.params = params
self.tok_embeddings = nn.Embedding(
params.vocab_size, params.dim
)

self.freqs_cis = precompute_freqs_cis(
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)

def forward(self, tokens: torch.Tensor):
"""
Perform a forward pass through the embedding module.
Args:
tokens (torch.Tensor): Input tensor.
Returns:
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0 : seqlen]
return h, freqs_cis


class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
"""
Expand Down Expand Up @@ -360,9 +394,7 @@ def __init__(self, params: ModelArgs):
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers

self.tok_embeddings = nn.Embedding(
params.vocab_size, params.dim
)
self.embeddings = TransformerEmbedding(params)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
Expand All @@ -373,12 +405,6 @@ def __init__(self, params: ModelArgs):
params.dim, params.vocab_size, bias=False
)

self.freqs_cis = precompute_freqs_cis(
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)

def forward(self, tokens: torch.Tensor):
"""
Perform a forward pass through the Transformer model.
Expand All @@ -390,10 +416,7 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Output logits after applying the Transformer model.
"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0 : seqlen]
h, freqs_cis = self.embeddings(tokens)

for layer in self.layers:
h = layer(h, freqs_cis)
Expand Down

0 comments on commit d310f97

Please sign in to comment.