Skip to content

Commit

Permalink
Renamed bsz to bs for consistency; removed dead code
Browse files Browse the repository at this point in the history
ghstack-source-id: 0b273e8f81013c1c632f0c505b7229d51af3e488
Pull Request resolved: #299
  • Loading branch information
awgu committed May 3, 2024
1 parent 17cda29 commit fda5059
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ class Attention(nn.Module):
Attributes:
n_kv_heads (int): Number of key and value heads.
n_heads (int): Number of query heads.
n_local_kv_heads (int): Number of local key and value heads.
n_rep (int): Number of repetitions for local heads.
head_dim (int): Dimension size of each attention head.
wq (Linear): Linear transformation for queries.
Expand Down Expand Up @@ -183,12 +182,12 @@ def forward(
torch.Tensor: Output tensor after attention.
"""
bsz, seqlen, _ = x.shape
bs, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

Expand All @@ -205,7 +204,7 @@ def forward(
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bsz, seqlen, -1)
output = output.view(bs, seqlen, -1)
return self.wo(output)


Expand Down Expand Up @@ -421,7 +420,7 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Output logits after applying the Transformer model.
"""
_bsz, seqlen = tokens.shape
seqlen = tokens.shape[1]
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0:seqlen]
Expand Down

0 comments on commit fda5059

Please sign in to comment.