Skip to content

Commit

Permalink
Improve _update_causal_mask performance (#29210)
Browse files Browse the repository at this point in the history
* Fix issue 29206

* Fix style
  • Loading branch information
alessandropalla authored and ArthurZucker committed Mar 1, 2024
1 parent 4f8689e commit 6d02350
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
25 changes: 15 additions & 10 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,22 +959,27 @@ def _update_causal_mask(self, attention_mask, input_tensor):
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

# We use the current dtype to avoid any overflows
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype

causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
padding_mask, torch.finfo(dtype).min
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

if self.config._attn_implementation == "sdpa":
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
dtype
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)

return causal_mask

Expand Down
35 changes: 16 additions & 19 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,31 +1058,28 @@ def _update_causal_mask(self, attention_mask, input_tensor):
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
causal_mask = (
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
)
else:
mask = torch.full(
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1)
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype

causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
padding_mask, torch.finfo(dtype).min
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

if self.config._attn_implementation == "sdpa":
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
dtype
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)

return causal_mask

Expand Down

0 comments on commit 6d02350

Please sign in to comment.