From 621f6a641975cdd1d82d0f3272e231047f92c90e Mon Sep 17 00:00:00 2001 From: Peter Zachares <32395644+zachares@users.noreply.github.com> Date: Wed, 24 May 2023 13:25:05 +0100 Subject: [PATCH] small bugs in LLM message passing (#6) * i thought this was alraedy fixed * fixed edge case * forgot to remove pdb * removed incorrect assert --- src/transformers/models/bloom/modeling_bloom.py | 6 +++--- src/transformers/models/t5/modeling_t5.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 9884b0bff7f11b..f2d43079148888 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -780,10 +780,10 @@ def forward( else: attention_mask = attention_mask.to(hidden_states.device) - if hasattr(self, "graph_tokens"): + if hasattr(self, "graph_token_ids"): assert input_ids.shape == attention_mask.shape edge_sequences = [ - extract_edge_sequence(t_ids.tolist(), self.graph_tokens) for t_ids in input_ids + extract_edge_sequence(t_ids.tolist(), self.graph_token_ids) for t_ids in input_ids ] if self.message_passing_type == 'nodes': get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing @@ -831,7 +831,7 @@ def custom_forward(*inputs): ) hidden_states = outputs[0] - if i <= self.num_gnn_layers and hasattr(self, 'graph_tokens'): + if hasattr(self, 'graph_token_ids') and i < self.num_gnn_layers: hidden_states = self.graph_information_passing_layers[i]( hidden_states, message_passing_dicts diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 1f308752dc121a..c34ead09369797 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1062,11 +1062,10 @@ def forward( hidden_states = self.dropout(inputs_embeds) - if hasattr(self, "graph_tokens"): + if hasattr(self, "graph_token_ids"): assert input_ids.shape == attention_mask.shape - assert input_ids[0, 0] == self.graph_tokens['gen_edge'], "Incorrect stating token" edge_sequences = [ - extract_edge_sequence(t_ids.tolist(), self.graph_tokens) for t_ids in input_ids + extract_edge_sequence(t_ids.tolist(), self.graph_token_ids) for t_ids in input_ids ] if self.message_passing_type == 'nodes': get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing @@ -1140,7 +1139,7 @@ def custom_forward(*inputs): hidden_states, present_key_value_state = layer_outputs[:2] - if i <= self.num_gnn_layers and hasattr(self, 'graph_tokens'): + if hasattr(self, 'graph_token_ids') and i < self.num_gnn_layers: hidden_states = self.graph_information_passing_layers[i]( hidden_states, message_passing_dicts