Skip to content

Commit

Permalink
is_torchdynamo_compiling -- cast a wide exception net (#32476)
Browse files Browse the repository at this point in the history
* cast a wide net

* make fix-copies with a few manual changes

* add copied from
  • Loading branch information
gante authored and ArthurZucker committed Aug 6, 2024
1 parent 3e93524 commit af61272
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 23 deletions.
114 changes: 92 additions & 22 deletions src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,60 @@
_CONFIG_FOR_DOC = "NemotronConfig"


# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

return causal_mask


def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
Expand Down Expand Up @@ -902,27 +956,18 @@ def _update_causal_mask(
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
Expand Down Expand Up @@ -1086,11 +1131,36 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def is_torchdynamo_compiling():
import torch

return torch.compiler.is_compiling()
except AttributeError:
except Exception:
try:
import torch._dynamo as dynamo # noqa: F401

Expand Down

0 comments on commit af61272

Please sign in to comment.