diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 848216885a36ef..29aa95a9815cac 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -775,10 +775,10 @@ def forward( else: attention_mask = attention_mask.to(hidden_states.device) - if hasattr(self, "graph_tokens"): + if hasattr(self, "graph_token_ids"): assert input_ids.shape == attention_mask.shape edge_sequences = [ - extract_edge_sequence(t_ids.tolist(), self.graph_tokens) for t_ids in input_ids + extract_edge_sequence(t_ids.tolist(), self.graph_token_ids) for t_ids in input_ids ] if self.message_passing_type == 'nodes': get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing @@ -826,7 +826,7 @@ def custom_forward(*inputs): ) hidden_states = outputs[0] - if i <= self.num_gnn_layers and hasattr(self, 'graph_tokens'): + if hasattr(self, 'graph_token_ids') and i < self.num_gnn_layers: hidden_states = self.graph_information_passing_layers[i]( hidden_states, message_passing_dicts diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 1def99a3a09ecb..e3bc02421f64ca 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1099,11 +1099,10 @@ def forward( hidden_states = self.dropout(inputs_embeds) - if hasattr(self, "graph_tokens"): + if hasattr(self, "graph_token_ids"): assert input_ids.shape == attention_mask.shape - assert input_ids[0, 0] == self.graph_tokens['gen_edge'], "Incorrect stating token" edge_sequences = [ - extract_edge_sequence(t_ids.tolist(), self.graph_tokens) for t_ids in input_ids + extract_edge_sequence(t_ids.tolist(), self.graph_token_ids) for t_ids in input_ids ] if self.message_passing_type == 'nodes': get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing @@ -1177,7 +1176,7 @@ def custom_forward(*inputs): hidden_states, present_key_value_state = layer_outputs[:2] - if i <= self.num_gnn_layers and hasattr(self, 'graph_tokens'): + if hasattr(self, 'graph_token_ids') and i < self.num_gnn_layers: hidden_states = self.graph_information_passing_layers[i]( hidden_states, message_passing_dicts