Skip to content

Commit

Permalink
memory efficient message passing (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachares authored and vahanhov committed Aug 11, 2023
1 parent 3c1ca82 commit 13affe7
Showing 1 changed file with 4 additions and 46 deletions.
50 changes: 4 additions & 46 deletions src/transformers/models/bloom/causal_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,6 @@ class GNNLayerFactory(enum.Enum):
gat = torch_geometric.nn.GATConv


def graph_cross_attention(
values: torch.Tensor,
key_representations: torch.Tensor,
query_representations: torch.Tensor,
edge_index: torch.Tensor
) -> torch.Tensor:
""" Performs graph attention on a set of prior probabilities uing the representation of each
node in the graph to calculate the attention weights. The implemented attention is dot
product attention as implemented in the transformer architecture
"""
scaling_constant = torch.Tensor(
np.sqrt([key_representations.size(1)])
).to(key_representations.device)
dot_products = (
query_representations[edge_index[1]]
* key_representations[edge_index[0]]
).sum(1) / scaling_constant
weights = scatter_softmax(src=dot_products, index=edge_index[1], dim=0)
weighted_probs = weights.unsqueeze(1) * values[edge_index[0]]
return scatter(src=weighted_probs, index=edge_index[1], dim=0)


class GatedGraphCrossAttentionLayer(torch.nn.Module):
""" A module for performing gated cross attention between elements in a graph that
have been serialized in a sequence of tokens and the token sequence
Expand All @@ -62,10 +40,6 @@ def __init__(self, gnn_type: str, embedding_size: int):
super().__init__()
self.gnn_layer = GNNLayerFactory[gnn_type].value(embedding_size, embedding_size)
self.gating_message_passing = torch.nn.Parameter(torch.zeros(1))
self.gating_linear = torch.nn.Parameter(torch.zeros(1))
self.key_embedder = torch.nn.Linear(embedding_size, 64)
self.query_embedder = torch.nn.Linear(embedding_size, 64)
self.linear_layer = torch.nn.Linear(embedding_size, embedding_size)

def forward(
self,
Expand All @@ -82,15 +56,8 @@ def forward(
element_embeddings,
message_passing_dict['edge_index']
)
start_idx, end_idx = message_passing_dict['slice_idxs']
new_t_embeddings[start_idx:end_idx] = graph_cross_attention(
values=element_embeddings,
key_representations=self.key_embedder(element_embeddings),
query_representations=self.query_embedder(t_embeddings),
edge_index=message_passing_dict['elements2tokens']
)[start_idx:]
new_t_embeddings[message_passing_dict['elements2tokens']] = element_embeddings
new_t_embeddings = t_embeddings + torch.tanh(self.gating_message_passing) * new_t_embeddings
new_t_embeddings = new_t_embeddings + torch.tanh(self.gating_linear) * self.linear_layer(new_t_embeddings)
new_token_embeddings.append(new_t_embeddings.unsqueeze(0))
return torch.cat(new_token_embeddings, dim=0)

Expand Down Expand Up @@ -151,7 +118,6 @@ def build_edge_information_passing(
)
for sequenced_edge in edge_sequence[:-1]:
add_edge(sequenced_edge)

# calculating adjacency matrix between edges (edges in this adjacency matrix always
# point from edges earlier in the serialized version of the graph to edges later in
# the graph)
Expand All @@ -166,7 +132,7 @@ def build_edge_information_passing(
@staticmethod
def get_sequence_end(
edge_sequence: List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]],
) -> Tuple[int, int]:
) -> int:
""" Returns last index + 1 of elements in the serialized graph sequence """
pred_node, edge, succ_node = edge_sequence[-1]
if isinstance(succ_node, SequenceElement):
Expand Down Expand Up @@ -242,11 +208,8 @@ def add_element_for_information_passing(
and back
"""
if start_idx != end_idx:
message_passing_dict[f"tokens2elements"].append(start_idx - 1)
for sequence_idx in range(start_idx, end_idx):
message_passing_dict[f"elements2tokens"].append(
[len(message_passing_dict[f"tokens2elements"]) - 1, sequence_idx]
)
message_passing_dict["tokens2elements"].append(start_idx - 1)
message_passing_dict["elements2tokens"].append(start_idx)

@staticmethod
def to_torch(
Expand All @@ -260,9 +223,4 @@ def to_torch(
array_dict[key] = torch.from_numpy(np.array(array)).long().to(device)
else:
array_dict[key] = torch.from_numpy(np.array(array).transpose(1, 0)).long().to(device)
if array_dict['elements2tokens'].numel() > 0:
array_dict['slice_idxs'] = torch.from_numpy(np.array([
array_dict['elements2tokens'][1].min().item(),
array_dict['elements2tokens'][1].max().item() + 1
])).long().to(device)
return array_dict

0 comments on commit 13affe7

Please sign in to comment.