From 60ad9c5dd2840678688acb97cb643ff2963f600a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 13 May 2024 13:37:30 +0200 Subject: [PATCH] Shift attention mask from `1:` After discussion with @molbap --- src/transformers/models/paligemma/modeling_paligemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 46e4b3b3ed0415..a2534a572a6f3f 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -463,7 +463,7 @@ def forward( if attention_mask.dim() == 4: # take top or bottom row of the 4d mask. # this should only be used in the initial pass with full attention on prefix. - shift_attention_mask = attention_mask[:, 0, 0, :-1].squeeze(1) if not left_padding else attention_mask[:, 0, -1, :-1].squeeze(1) + shift_attention_mask = attention_mask[:, 0, 0, 1:].squeeze(1) if not left_padding else attention_mask[:, 0, -1, 1:].squeeze(1) elif attention_mask.dim() == 2: # take normal slice of the attn mask shift_attention_mask = attention_mask[..., 1:]