Skip to content

Commit

Permalink
Gemma2 (#709)
Browse files Browse the repository at this point in the history
* Update mapper.py

* Update loader.py

* Update llama.py

* Update tokenizer_utils.py

* info

* edits

* Create chat template

* Fix tokenizer

* Update tokenizer_utils.py

* fix case where gguf saving fails due to first_conversion dtype (#630)

* Support revision parameter in FastLanguageModel.from_pretrained (#629)

* support `revision` parameter

* match unsloth formatting of named parameters

* clears any selected_adapters before calling internal_model.save_pretrained (#609)

* Update __init__.py (#602)

Check for incompatible modules before importing unsloth

* Fixed unsloth/tokenizer_utils.py for chat training (#604)

* Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345)

* Add save to llama.cpp GGML to save.py.

* Fix conversion command and path of convert to GGML function.

* Add autosaving lora to the GGML function

* Create lora save function for conversion to GGML

* Test fix #2 for saving lora

* Test fix #3 to save  the lora adapters to convert to GGML

* Remove unwated tokenizer saving for conversion to ggml and added a few print statements.

* Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages.

* Positional arguments didn't work out, so reverted to older version of the code, and added a few comments.

* Test fix 1 for arch

* Test fix 2 new Mistral error.

* Test fix 3

* Revert to old version for testing.

* Upload issue test fix 1

* Fix 2 uploading ggml

* Positional ags added.

* Temporray remove positional args

* Fix upload again!!!

* Add print statements and fix link

* Make the calling name better

* Create local saving for GGML

* Add choosing directory to save local GGML.

* Fix lil variable error in the save_to_custom_dir func

* docs: Add LoraConfig parameters documentation (#619)

* llama.cpp failing (#371)

llama.cpp is failing to generate quantize versions for the trained models.

Error:

```bash
You might have to compile llama.cpp yourself, then run this again.
You do not need to close this Python program. Run the following commands in a new terminal:
You must run this in the same folder as you're saving your model.
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j
Once that's done, redo the quantization.
```

But when i do clone this with recursive it works.

Co-authored-by: Daniel Han <[email protected]>

* fix libcuda_dirs import for triton 3.0 (#227)

* fix libcuda_dirs import for triton 3.0

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Update save.py

* Update __init__.py

* Update fast_lora.py

* Update save.py

* Update save.py

* Update save.py

* Update loader.py

* Update save.py

* Update save.py

* quantize now llama-quantize

* Update chat_templates.py

* Update loader.py

* Update mapper.py

* Update __init__.py

* embedding size

* Update qwen2.py

* docs

* Update README.md

* Update qwen2.py

* README: Fix minor typo. (#559)

* README: Fix minor typo.

One-character typo fix while reading.

* Update README.md

---------

Co-authored-by: Daniel Han <[email protected]>

* Update mistral.py

* Update qwen2.py

* Update qwen2.py

* Update qwen2.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update README.md

* FastMistralModel

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Auto check rope scaling

* Update llama.py

* Update llama.py

* Update llama.py

* GPU support

* Typo

* Update gemma.py

* gpu

* Multiple GGUF saving

* Update save.py

* Update save.py

* check PEFT and base

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update chat_templates.py

* Fix breaking bug in save.py with interpreting quantization_method as a string when saving to gguf (#651)

* Nightly (#649)

* Update llama.py

* offload

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* continued pretraining trainer

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* is_bfloat16_supported

* Update __init__.py

* Update README.md

* Update llama.py

* is_bfloat16_supported

* Update __init__.py

* Mistral v3

* Phi 3 medium

* Update chat_templates.py

* Update chat_templates.py

* Phi-3

* Update save.py

* Update README.md

Mistral v3 to Mistral v0.3

* Untrained tokens

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update save.py

* Update save.py

* Update save.py

* checkpoint

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* accelerate

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* train_dataloader

* Update llama.py

* Update llama.py

* Update llama.py

* use_fast_convert

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* remove_special_tokens

* Ollama

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update llama.py

* Update chat_templates.py

* Support bfloat16 GGUF

* Update save.py

* Update llama.py

* fast_forward_inference

* Update mapper.py

* Update loader.py

* Update llama.py

* Update tokenizer_utils.py

* info

* edits

* Create chat template

* Fix tokenizer

* Update tokenizer_utils.py

* fix case where gguf saving fails due to first_conversion dtype (#630)

* Support revision parameter in FastLanguageModel.from_pretrained (#629)

* support `revision` parameter

* match unsloth formatting of named parameters

* clears any selected_adapters before calling internal_model.save_pretrained (#609)

* Update __init__.py (#602)

Check for incompatible modules before importing unsloth

* Fixed unsloth/tokenizer_utils.py for chat training (#604)

* Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345)

* Add save to llama.cpp GGML to save.py.

* Fix conversion command and path of convert to GGML function.

* Add autosaving lora to the GGML function

* Create lora save function for conversion to GGML

* Test fix #2 for saving lora

* Test fix #3 to save  the lora adapters to convert to GGML

* Remove unwated tokenizer saving for conversion to ggml and added a few print statements.

* Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages.

* Positional arguments didn't work out, so reverted to older version of the code, and added a few comments.

* Test fix 1 for arch

* Test fix 2 new Mistral error.

* Test fix 3

* Revert to old version for testing.

* Upload issue test fix 1

* Fix 2 uploading ggml

* Positional ags added.

* Temporray remove positional args

* Fix upload again!!!

* Add print statements and fix link

* Make the calling name better

* Create local saving for GGML

* Add choosing directory to save local GGML.

* Fix lil variable error in the save_to_custom_dir func

* docs: Add LoraConfig parameters documentation (#619)

* llama.cpp failing (#371)

llama.cpp is failing to generate quantize versions for the trained models.

Error:

```bash
You might have to compile llama.cpp yourself, then run this again.
You do not need to close this Python program. Run the following commands in a new terminal:
You must run this in the same folder as you're saving your model.
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j
Once that's done, redo the quantization.
```

But when i do clone this with recursive it works.

Co-authored-by: Daniel Han <[email protected]>

* fix libcuda_dirs import for triton 3.0 (#227)

* fix libcuda_dirs import for triton 3.0

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Update save.py

* Update __init__.py

* Update fast_lora.py

* Update save.py

* Update save.py

* Update save.py

* Update loader.py

* Update save.py

* Update save.py

* quantize now llama-quantize

* Update chat_templates.py

* Update loader.py

* Update mapper.py

* Update __init__.py

* embedding size

* Update qwen2.py

* docs

* Update README.md

* Update qwen2.py

* README: Fix minor typo. (#559)

* README: Fix minor typo.

One-character typo fix while reading.

* Update README.md

---------

Co-authored-by: Daniel Han <[email protected]>

* Update mistral.py

* Update qwen2.py

* Update qwen2.py

* Update qwen2.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update README.md

* FastMistralModel

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Auto check rope scaling

* Update llama.py

* Update llama.py

* Update llama.py

* GPU support

* Typo

* Update gemma.py

* gpu

* Multiple GGUF saving

* Update save.py

* Update save.py

* check PEFT and base

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update chat_templates.py

---------

Co-authored-by: Michael Han <[email protected]>
Co-authored-by: Eliot Hall <[email protected]>
Co-authored-by: Rickard Edén <[email protected]>
Co-authored-by: XiaoYang <[email protected]>
Co-authored-by: Oseltamivir <[email protected]>
Co-authored-by: mahiatlinux <[email protected]>
Co-authored-by: Sébastien De Greef <[email protected]>
Co-authored-by: Alberto Ferrer <[email protected]>
Co-authored-by: Thomas Viehmann <[email protected]>
Co-authored-by: Walter Korman <[email protected]>

* Fix bug in save.py with interpreting quantization_method as a string that prevents GGUF from saving

* Implemented better list management and then forgot to actually call the new list variable, fixed

* Check type of given quantization method and return type error if not list or string

* Update save.py

---------

Co-authored-by: Daniel Han <[email protected]>
Co-authored-by: Michael Han <[email protected]>
Co-authored-by: Eliot Hall <[email protected]>
Co-authored-by: Rickard Edén <[email protected]>
Co-authored-by: XiaoYang <[email protected]>
Co-authored-by: Oseltamivir <[email protected]>
Co-authored-by: mahiatlinux <[email protected]>
Co-authored-by: Sébastien De Greef <[email protected]>
Co-authored-by: Alberto Ferrer <[email protected]>
Co-authored-by: Thomas Viehmann <[email protected]>
Co-authored-by: Walter Korman <[email protected]>

* Revert "Fix breaking bug in save.py with interpreting quantization_method as …" (#652)

This reverts commit 30605de.

* Revert "Revert "Fix breaking bug in save.py with interpreting quantization_me…" (#653)

This reverts commit e2b2083.

* Update llama.py

* peft

* patch

* Update loader.py

* retrain

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* offload

* Update llama.py

* Create a starter script for command-line training to integrate in ML ops pipelines. (#623)

* Update chat_templates.py

* Ollama

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Ollama

* Update chat_templates.py

* ollama

* Update mapper.py

* Update chat_templates.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update llama.py

* Fixes

* clearer messages

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* log

* Update __init__.py

* Update llama.py

* Update __init__.py

* Create Merge.png

* Create ollama.png

* Gemma2

* Update llama.py

* Update loader.py

* Update pyproject.toml

* Update pyproject.toml

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update _utils.py

* Revert Gemma2

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update rms_layernorm.py

* Update gemma2.py

* logit softcapping

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update gemma2.py

* Update gemma2.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update llama.py

* Update gemma2.py

* Update llama.py

* Update llama.py

* Update gemma2.py

* Update gemma2.py

* Update llama.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update _utils.py

* Update _utils.py

* Update gemma2.py

* compile flags

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update gemma2.py

* Update gemma2.py

* fixes

* Update _utils.py

* Fix generation

* Update llama.py

* Update llama.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* pad token

* Update gemma2.py

* pad token

* Update _utils.py

* Update llama.py

* Update gemma2.py

* edit warning

* Update tokenizer_utils.py

---------

Co-authored-by: Eliot Hall <[email protected]>
Co-authored-by: Rickard Edén <[email protected]>
Co-authored-by: XiaoYang <[email protected]>
Co-authored-by: Oseltamivir <[email protected]>
Co-authored-by: mahiatlinux <[email protected]>
Co-authored-by: Sébastien De Greef <[email protected]>
Co-authored-by: Alberto Ferrer <[email protected]>
Co-authored-by: Thomas Viehmann <[email protected]>
Co-authored-by: Walter Korman <[email protected]>
Co-authored-by: ArcadaLabs-Jason <[email protected]>
Co-authored-by: Michael Han <[email protected]>
  • Loading branch information
12 people committed Jul 3, 2024
1 parent 933d9fe commit 499635a
Show file tree
Hide file tree
Showing 17 changed files with 772 additions and 60 deletions.
Binary file added images/Merge.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/ollama.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ exclude = ["images*"]
[project.optional-dependencies]
huggingface = [
"tyro",
"transformers>=4.38.2",
"transformers>=4.42.3",
"datasets>=2.16.0",
"sentencepiece>=0.2.0",
"tqdm",
Expand Down Expand Up @@ -185,9 +185,9 @@ colab-ampere-torch220 = [
]
colab-new = [
"tyro",
"transformers>=4.38.2",
"transformers>=4.42.3",
"datasets>=2.16.0",
"sentencepiece",
"sentencepiece>=0.2.0",
"tqdm",
"psutil",
"wheel>=0.42.0",
Expand Down
102 changes: 73 additions & 29 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from transformers.models.llama.modeling_llama import logger


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Expand Down Expand Up @@ -58,29 +61,38 @@ def _cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = logsumexp - x
x = tl.load(logits_ptr + label_idx)
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
loss = logsumexp - x.to(tl.float32)
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
256K vocab divided in 4 chunks
Expand Down Expand Up @@ -117,7 +129,11 @@ def _chunked_cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

Expand All @@ -126,7 +142,9 @@ def _chunked_cross_entropy_forward(
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = -1.0 * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
loss = -1.0 * x.to(tl.float32)
else:
loss = 0.0
tl.store(loss_ptr, loss)
Expand All @@ -135,14 +153,17 @@ def _chunked_cross_entropy_forward(
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_backward(
logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
Expand Down Expand Up @@ -173,15 +194,27 @@ def _cross_entropy_backward(
else:
dloss = 0.0

x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
partial = tl.math.tanh(x / SOFTCAP)
x = SOFTCAP * partial
pass

logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x - logsumexp)
y = tl.exp(x.to(tl.float32) - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
y, # exp(x - logsumexp)
)

if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
y = y * (1.0 - partial*partial)
pass

# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
pass
Expand All @@ -191,40 +224,46 @@ def _cross_entropy_backward(

class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels):
def forward(ctx, logits, labels, logit_softcapping = 0):
n_rows, vocab_size = logits.shape

div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

DO_SOFTCAPPING = (logit_softcapping != 0)

if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda")
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")

_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
num_warps = 32,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
Expand All @@ -234,6 +273,8 @@ def forward(ctx, logits, labels):
pass

ctx.save_for_backward(logits, logsumexp, labels)
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
return losses
pass

Expand All @@ -251,16 +292,18 @@ def backward(ctx, dlosses):
dlosses, dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = 8,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
num_warps = 8,
)
return logits, None, None,
pass
pass


def fast_cross_entropy_loss(logits, labels):
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
Expand All @@ -274,6 +317,7 @@ def fast_cross_entropy_loss(logits, labels):
loss = Fast_CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
Expand Down
4 changes: 2 additions & 2 deletions unsloth/kernels/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
def geglu_exact_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
Expand Down Expand Up @@ -133,7 +133,7 @@ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
def geglu_approx_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
Expand Down
6 changes: 3 additions & 3 deletions unsloth/kernels/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _gemma_rms_layernorm_forward(
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)

row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
Expand All @@ -137,8 +137,8 @@ def forward(ctx, X, W, eps, gemma = False):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
fx[(n_rows,)](
Expand Down
2 changes: 1 addition & 1 deletion unsloth/kernels/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda")
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
Expand Down
8 changes: 4 additions & 4 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def fast_dequantize(W, quant_state = None, out = None):

# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda")
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
else:
assert(out.shape == shape)
assert(out.dtype == dtype)

# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda")
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")

# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
Expand Down Expand Up @@ -161,7 +161,7 @@ def fast_gemv(X, W, quant_state, out = None):
bout = shape[0]

if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda")
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
# else:
# assert(out.shape == (1, 1, bout,))
# pass
Expand All @@ -179,7 +179,7 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)

df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda")
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
Expand Down
Loading

0 comments on commit 499635a

Please sign in to comment.