Skip to content

Commit

Permalink
Add rope dynamic linear scaling (NVIDIA#7437)
Browse files Browse the repository at this point in the history
* Add dynamic linear scaling

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix bug

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

---------

Signed-off-by: Cheng-Ping Hsieh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yang Zhang <[email protected]>
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
3 people authored and ssh-meister committed Oct 5, 2023
1 parent 14ba7f8 commit 32dc1d0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
4 changes: 3 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_continue_training.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.max_position_embeddings = cfg.model.max_position_embeddings
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

assert (
gpt_cfg.encoder_seq_length == gpt_cfg.max_position_embeddings * gpt_cfg.seq_len_interpolation_factor
), 'seq_length should be equal to max_position_embedding * seq_len_interpolation_factor'
# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ def __init__(
if rotary_percentage < 1:
rotary_dim = int(rotary_dim * rotary_percentage)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim, seq_len_interpolation_factor=seq_len_interpolation_factor
rotary_dim,
seq_len_interpolation_factor=seq_len_interpolation_factor,
pretrained_max_position_embeddings=max_position_embeddings,
)

elif position_embedding_type == 'alibi':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,36 @@ class RotaryEmbedding(nn.Module):
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""

def __init__(self, dim: int, seq_len_interpolation_factor: int = None):
def __init__(
self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
):
"""
Args:
dim (int): rotary embedding dimension
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
by this factor via the trick in https://arxiv.org/abs/2306.15595.
pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
"""
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
if self.seq_len_interpolation_factor is not None:
seq = seq.type_as(self.inv_freq)
seq *= 1 / self.seq_len_interpolation_factor
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
seq = seq.type_as(self.inv_freq)

if self.pretrained_max_position_embeddings is not None and self.seq_len_interpolation_factor is not None:
if max_seq_len > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor:
# dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor

freqs = einsum('i , j -> i j', seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
Expand Down

0 comments on commit 32dc1d0

Please sign in to comment.