Skip to content

Commit

Permalink
rebase from HF
Browse files Browse the repository at this point in the history
  • Loading branch information
vahanhov committed Aug 11, 2023
1 parent 13affe7 commit 4eb3009
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 287 deletions.
118 changes: 78 additions & 40 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import math
import warnings
from typing import List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
Expand All @@ -35,9 +35,8 @@
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_bloom import BloomConfig
from .desequence_graph_ids import extract_edge_sequence, SequenceElement
from .permutation_invariant_positions import build_alibi_tensor
from .causal_message_passing import GatedGraphCrossAttentionLayer
from ..processing_graphs_within_model.desequence_graph_ids import extract_edge_sequence
from ..processing_graphs_within_model.causal_message_passing import GatedCausalMessagePassingLayer

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -85,6 +84,50 @@ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)


def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)

if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)


def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
"""
Dropout add function
Expand Down Expand Up @@ -578,8 +621,6 @@ def __init__(self, config: BloomConfig):

self.embed_dim = config.hidden_size
self.num_heads = config.n_head
self.graph_tokens = {}
self.position_type = 'normal'

# Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
Expand All @@ -598,21 +639,11 @@ def __init__(self, config: BloomConfig):

def build_alibi_tensor(
self,
token_ids: torch.Tensor,
edge_sequences: List[List[Tuple[SequenceElement, Optional[SequenceElement], Optional[SequenceElement]]]],
attention_mask: torch.Tensor,
num_heads: int,
dtype: torch.dtype
) -> torch.Tensor:
return build_alibi_tensor(
token_ids=token_ids,
edge_sequences=edge_sequences,
attention_mask=attention_mask,
num_heads=num_heads,
dtype=dtype,
graph_tokens=self.graph_tokens,
position_type=self.position_type
)
return build_alibi_tensor(attention_mask=attention_mask, num_heads=num_heads, dtype=dtype)

def get_input_embeddings(self):
return self.word_embeddings
Expand Down Expand Up @@ -642,12 +673,25 @@ def _prepare_attn_mask(
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings

def init_graph_information_passing(self, gnn_type: str, element_type: str):
def init_graph_information_passing(
self,
gnn_type: str,
element_type: str,
graph_token_ids: Dict[str, int]
):
""" Initializes a set of message passing layers to perform message passing of between
graph elements described in an input token id sequence
"""
assert element_type in ['nodes', 'edges'], 'unsupported message passing type'
self.message_passing_type = element_type
self.graph_token_ids = graph_token_ids
self.num_gnn_layers = (
self.config.num_layers - 1
if hasattr(self.config, 'num_layers') else self.config.n_layer - 1
)
self.graph_information_passing_layers = torch.nn.ModuleList([
GatedGraphCrossAttentionLayer(gnn_type, self.config.hidden_size)
for _ in range(self.config.n_layer - 1)
GatedCausalMessagePassingLayer(gnn_type, self.config.hidden_size)
for _ in range(self.num_gnn_layers)
])

@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
Expand All @@ -659,7 +703,6 @@ def init_graph_information_passing(self, gnn_type: str, element_type: str):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
full_input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -732,24 +775,19 @@ def forward(
else:
attention_mask = attention_mask.to(hidden_states.device)

token_ids: torch.Tensor = full_input_ids if full_input_ids is not None else input_ids
edge_sequences = [
extract_edge_sequence(t_ids.tolist(), self.graph_tokens) for t_ids in token_ids
]
alibi = self.build_alibi_tensor(
token_ids=token_ids,
edge_sequences=edge_sequences,
attention_mask=attention_mask,
num_heads=self.num_heads,
dtype=hidden_states.dtype
)
if hasattr(self, 'message_passing_type'):
if hasattr(self, "graph_tokens"):
assert input_ids.shape == attention_mask.shape
edge_sequences = [
extract_edge_sequence(t_ids.tolist(), self.graph_tokens) for t_ids in input_ids
]
if self.message_passing_type == 'nodes':
get_matrices = GatedGraphCrossAttentionLayer.build_node_information_passing
get_matrices = GatedCausalMessagePassingLayer.build_node_information_passing
else:
get_matrices = GatedGraphCrossAttentionLayer.build_edge_information_passing
get_matrices = GatedCausalMessagePassingLayer.build_edge_information_passing
message_passing_dicts = get_matrices(edge_sequences, self.device)

alibi = self.build_alibi_tensor(attention_mask, self.num_heads, hidden_states.dtype)

causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
Expand Down Expand Up @@ -788,7 +826,7 @@ def custom_forward(*inputs):
)

hidden_states = outputs[0]
if i != len(self.h) - 1 and hasattr(self, 'message_passing_type'):
if i <= self.num_gnn_layers and hasattr(self, 'graph_tokens'):
hidden_states = self.graph_information_passing_layers[i](
hidden_states,
message_passing_dicts
Expand Down Expand Up @@ -850,10 +888,10 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
return {
"input_ids": input_ids,
"past_key_values": None,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask
"input_ids": input_ids,
"past_key_values": None,
"use_cache": False,
"attention_mask": attention_mask
}

@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
Expand Down
Loading

0 comments on commit 4eb3009

Please sign in to comment.