Skip to content

Commit

Permalink
Fix tensor device while attention_mask is not None (huggingface#23538)
Browse files Browse the repository at this point in the history
* Fix tensor device while attention_mask is not None

* Fix tensor device while attention_mask is not None
  • Loading branch information
zspo authored and sheonhan committed Jun 1, 2023
1 parent 5e397fd commit fd821a3
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def forward(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/open_llama/modeling_open_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def forward(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def forward(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/xglm/modeling_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def forward(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
Expand Down

0 comments on commit fd821a3

Please sign in to comment.