You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
DynamicNTKScalingRotaryEmbedding is originally designed to update base and inv_freq with dynamic_length (seq_len / self.max_position_embeddings) for every input sequence. However, current implementation only updates base and inv_freq only when seq_len > self.max_seq_len_cached.
classLlamaRotaryEmbedding(nn.Module):
def__init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim=dimself.max_position_embeddings=max_position_embeddingsself.base=baseinv_freq=1.0/ (self.base** (torch.arange(0, self.dim, 2).float().to(device) /self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def_set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached=seq_lent=torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs=torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculationemb=torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
defforward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]ifseq_len>self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
classLlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""def__init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor=scaling_factorsuper().__init__(dim, max_position_embeddings, base, device)
def_set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached=seq_lenifseq_len>self.max_position_embeddings:
base=self.base* (
(self.scaling_factor*seq_len/self.max_position_embeddings) - (self.scaling_factor-1)
) ** (self.dim/ (self.dim-2))
inv_freq=1.0/ (base** (torch.arange(0, self.dim, 2).float().to(device) /self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t=torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs=torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculationemb=torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
DynamicNTKScalingRotaryEmbedding
is originally designed to updatebase
andinv_freq
with dynamic_length (seq_len / self.max_position_embeddings) for every input sequence. However, current implementation only updatesbase
andinv_freq
only whenseq_len > self.max_seq_len_cached
.Thus, it needs to be fixed like the following:
to: @ArthurZucker @younesbelkada
Expected behavior
When using DynamicNTK for llama, cos and sin would be updated correctly.
The text was updated successfully, but these errors were encountered: