From 8c91f15ae576a5f33559aa243218199697195279 Mon Sep 17 00:00:00 2001 From: Folco Bertini Baldassini <46280006+folbaeni@users.noreply.github.com> Date: Tue, 7 Nov 2023 17:26:15 +0100 Subject: [PATCH] Resolve AttributeError by utilizing device calculation at the start of the forward function (#27347) This commit addresses the 'NoneType' object AttributeError within the IdeficsModel forward function. Previously, the 'device' attribute was accessed directly from input_ids, resulting in a potential 'NoneType' error. Now, the device is properly calculated at the beginning of the forward function and utilized consistently throughout, ensuring the 'image_hidden_states' are derived from the correct device. This modification enables smoother processing and compatibility, ensuring the correct device attribution for 'image_encoder_embeddings' in the IdeficsModel forward pass. --- src/transformers/models/idefics/modeling_idefics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index aba0b43f695b3e..f7881ddd39ed77 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1161,7 +1161,6 @@ def forward( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) elif position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) @@ -1186,7 +1185,7 @@ def forward( elif image_encoder_embeddings is not None: batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size() - image_hidden_states = image_encoder_embeddings.to(dtype=self.dtype, device=input_ids.device) + image_hidden_states = image_encoder_embeddings.to(dtype=self.dtype, device=device) image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size) if self.config.use_resampler: