Skip to content

Commit

Permalink
small bugs in LLM message passing (#6)
Browse files Browse the repository at this point in the history
* i thought this was alraedy fixed

* fixed edge case

* forgot to remove pdb

* removed incorrect assert
  • Loading branch information
zachares authored May 24, 2023
1 parent b252932 commit 621f6a6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,10 @@ def forward(
else:
attention_mask = attention_mask.to(hidden_states.device)

if hasattr(self, "graph_tokens"):
if hasattr(self, "graph_token_ids"):
assert input_ids.shape == attention_mask.shape
edge_sequences = [
extract_edge_sequence(t_ids.tolist(), self.graph_tokens) for t_ids in input_ids
extract_edge_sequence(t_ids.tolist(), self.graph_token_ids) for t_ids in input_ids
]
if self.message_passing_type == 'nodes':
get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing
Expand Down Expand Up @@ -831,7 +831,7 @@ def custom_forward(*inputs):
)

hidden_states = outputs[0]
if i <= self.num_gnn_layers and hasattr(self, 'graph_tokens'):
if hasattr(self, 'graph_token_ids') and i < self.num_gnn_layers:
hidden_states = self.graph_information_passing_layers[i](
hidden_states,
message_passing_dicts
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,11 +1062,10 @@ def forward(

hidden_states = self.dropout(inputs_embeds)

if hasattr(self, "graph_tokens"):
if hasattr(self, "graph_token_ids"):
assert input_ids.shape == attention_mask.shape
assert input_ids[0, 0] == self.graph_tokens['gen_edge'], "Incorrect stating token"
edge_sequences = [
extract_edge_sequence(t_ids.tolist(), self.graph_tokens) for t_ids in input_ids
extract_edge_sequence(t_ids.tolist(), self.graph_token_ids) for t_ids in input_ids
]
if self.message_passing_type == 'nodes':
get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing
Expand Down Expand Up @@ -1140,7 +1139,7 @@ def custom_forward(*inputs):

hidden_states, present_key_value_state = layer_outputs[:2]

if i <= self.num_gnn_layers and hasattr(self, 'graph_tokens'):
if hasattr(self, 'graph_token_ids') and i < self.num_gnn_layers:
hidden_states = self.graph_information_passing_layers[i](
hidden_states,
message_passing_dicts
Expand Down

0 comments on commit 621f6a6

Please sign in to comment.