Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rope dynamic linear scaling #7437

Merged
merged 15 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_continue_training.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.max_position_embeddings = cfg.model.max_position_embeddings
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

assert (
gpt_cfg.encoder_seq_length == gpt_cfg.max_position_embeddings * gpt_cfg.seq_len_interpolation_factor
), 'seq_length should be equal to max_position_embedding * seq_len_interpolation_factor'
# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ def __init__(
if rotary_percentage < 1:
rotary_dim = int(rotary_dim * rotary_percentage)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim, seq_len_interpolation_factor=seq_len_interpolation_factor
rotary_dim,
seq_len_interpolation_factor=seq_len_interpolation_factor,
pretrained_max_position_embeddings=max_position_embeddings,
)

elif position_embedding_type == 'alibi':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class RotaryEmbedding(nn.Module):
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""

def __init__(self, dim: int, seq_len_interpolation_factor: int = None):
def __init__(
self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
):
hsiehjackson marked this conversation as resolved.
Show resolved Hide resolved
"""
Args:

Expand All @@ -34,16 +36,24 @@ def __init__(self, dim: int, seq_len_interpolation_factor: int = None):
by this factor via the trick in https://arxiv.org/abs/2306.15595.
"""
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
self.seq_len_interpolation_factor = 1 if seq_len_interpolation_factor is None else seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
if self.seq_len_interpolation_factor is not None:
seq = seq.type_as(self.inv_freq)
seq *= 1 / self.seq_len_interpolation_factor
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
seq = seq.type_as(self.inv_freq)

if self.pretrained_max_position_embeddings is not None:
hsiehjackson marked this conversation as resolved.
Show resolved Hide resolved
if max_seq_len > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor:
# dynamic linear scaling
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor

freqs = einsum('i , j -> i j', seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
Expand Down
Loading