Skip to content

Commit

Permalink
🚨 Llama: update rope scaling to match static cache changes (#29143)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and ArthurZucker committed Feb 21, 2024
1 parent 7a4bec6 commit 476957b
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, x, seq_len=None):
)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->OpenLlama
class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

Expand All @@ -120,7 +120,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->OpenLlama
class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def forward(self, x, seq_len=None):
)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

Expand All @@ -187,7 +188,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

Expand Down
59 changes: 26 additions & 33 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def forward(self, hidden_states):
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
Expand All @@ -118,6 +117,9 @@ def cos_cached(self):
return self._cos_cached

def forward(self, x, position_ids, seq_len=None):
if seq_len is not None:
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.")

# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
Expand All @@ -138,16 +140,11 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, position_ids, seq_len=None):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids, seq_len)
return cos, sin


class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
Expand All @@ -157,23 +154,20 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len

def forward(self, x, position_ids, seq_len=None):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
cos, sin = super().forward(x, position_ids, seq_len)
return cos, sin


def rotate_half(x):
Expand All @@ -183,17 +177,16 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down Expand Up @@ -360,8 +353,8 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
Expand Down Expand Up @@ -447,8 +440,8 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

past_key_value = getattr(self, "past_key_value", past_key_value)

Expand Down Expand Up @@ -645,8 +638,8 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

past_key_value = getattr(self, "past_key_value", past_key_value)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def forward(self, x, seq_len=None):
)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Persimmon
class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

Expand All @@ -97,7 +97,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Persimmon
class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def forward(self, x, seq_len=None):
)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

Expand All @@ -135,7 +135,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(self, x, seq_len=None):
)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

Expand All @@ -123,7 +123,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

Expand Down
1 change: 0 additions & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def test_save_load_fast_init_from_base(self):
pass

@parameterized.expand([("linear",), ("dynamic",)])
@unittest.skip("TODO @gante fix this for Llama")
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
Expand Down

0 comments on commit 476957b

Please sign in to comment.