Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RoPE extension #846

Merged
merged 14 commits into from
Jul 31, 2024
3 changes: 2 additions & 1 deletion unsloth/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .llama import *
from ._utils import __version__
import math

try:
from transformers.models.gemma.modeling_gemma import (
Expand Down Expand Up @@ -256,7 +257,7 @@ def forward(self, x, position_ids=None, seq_len=None):
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
Expand Down
7 changes: 4 additions & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import gc
import math
from typing import Optional, Tuple, List, Union
from ._utils import *
from ._utils import __version__
Expand Down Expand Up @@ -1036,7 +1037,7 @@ def forward(self, x, position_ids=None, seq_len=None):
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
Expand Down Expand Up @@ -1109,7 +1110,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len

t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float()

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down Expand Up @@ -1158,7 +1159,7 @@ def forward(self, x, position_ids=None, seq_len=None):
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
Expand Down