Skip to content

Commit

Permalink
gguf-py: Refactor and allow reading/modifying existing GGUF files (#3981
Browse files Browse the repository at this point in the history
)

* gguf-py: Refactor and add file reading support

* Replay changes from #3871

Credit to @cebtenzzre for that pull

* Various type annotation fixes.

* sort imports with isort (again)

* Fix missing return statement in add_tensor

* style cleanup with flake8

* fix NamedTuple and Enum usage

* Fix an issue with state init in GGUFReader

Move examples to an examples/ directory

Clean up examples

Add an example of modifying keys in a GGUF file

Update documentation with info on examples

Try to support people importing gguf/gguf.py directly

* Damagage is not a word.

* Clean up gguf-py/examples/modify_gguf.py whitespace

Co-authored-by: Jared Van Bortel <[email protected]>

* Update gguf-py/examples/modify_gguf.py formatting

Co-authored-by: Jared Van Bortel <[email protected]>

* Update gguf-py/gguf/gguf_reader.py type hint

Co-authored-by: Jared Van Bortel <[email protected]>

* Make examples executable, formatting changes

* Add more information to GGUFReader and examples comments

* Include a gguf Python package version bump

* Add convert-gguf-endian.py script

* cleanup

* gguf-py : bump minor version

* Reorganize scripts

* Make GGUFReader endian detection less arbitrary

* Add JSON dumping support to gguf-dump.py

Which I kind of regret now

* A few for gguf-dump.py cleanups

* Murder accidental tuple in gguf-py/scripts/gguf-dump.py

Co-authored-by: Jared Van Bortel <[email protected]>

* cleanup

* constants : remove unneeded type annotations

* fix python 3.8 compat

* Set up gguf- scripts in pyproject.toml

* And include scripts/__init__.py, derp

* convert.py: We can't currently support Q8_0 on big endian.

* gguf-py: SpecialVocab: Always try available sources for special token ids

gguf-py: SpecialVocab: Try to load merges from merges.txt if not in tokenizer.json

gguf-py: SpecialVocab: Add 'add_bos_token' type bools to GGUF metadata
u

* cleanup

* Promote add_X_token to GGUF metadata for BOS and EOS

---------

Co-authored-by: Jared Van Bortel <[email protected]>
Co-authored-by: Jared Van Bortel <[email protected]>
  • Loading branch information
3 people authored Nov 11, 2023
1 parent 4a4fd3e commit 34b0a08
Show file tree
Hide file tree
Showing 20 changed files with 1,982 additions and 1,176 deletions.
2 changes: 1 addition & 1 deletion convert-baichuan-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sentencepiece import SentencePieceProcessor # type: ignore[import]

if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf


Expand Down
24 changes: 2 additions & 22 deletions convert-llama-ggml-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,9 @@

import os
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf

# Note: Does not support GGML_QKK_64
QK_K = 256
# Items here are (block size, type size)
GGML_QUANT_SIZES = {
gguf.GGMLQuantizationType.F32 : (1, 4),
gguf.GGMLQuantizationType.F16 : (1, 2),
gguf.GGMLQuantizationType.Q4_0 : (32, 2 + 16),
gguf.GGMLQuantizationType.Q4_1 : (32, 2 + 2 + 16),
gguf.GGMLQuantizationType.Q5_0 : (32, 2 + 4 + 16),
gguf.GGMLQuantizationType.Q5_1 : (32, 2 + 2 + 4 + 16),
gguf.GGMLQuantizationType.Q8_0 : (32, 2 + 32),
gguf.GGMLQuantizationType.Q8_1 : (32, 4 + 4 + 32),
gguf.GGMLQuantizationType.Q2_K : (256, 2 + 2 + QK_K // 16 + QK_K // 4),
gguf.GGMLQuantizationType.Q3_K : (256, 2 + QK_K // 4 + QK_K // 8 + 12),
gguf.GGMLQuantizationType.Q4_K : (256, 2 + 2 + QK_K // 2 + 12),
gguf.GGMLQuantizationType.Q5_K : (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
gguf.GGMLQuantizationType.Q6_K : (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8),
}

class GGMLFormat(IntEnum):
GGML = 0
GGMF = 1
Expand Down Expand Up @@ -125,7 +105,7 @@ def load(self, data, offset):
(n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
assert name_len < 4096, 'Absurd tensor name length'
quant = GGML_QUANT_SIZES.get(dtype)
quant = gguf.GGML_QUANT_SIZES.get(dtype)
assert quant is not None, 'Unknown tensor type'
(blksize, tysize) = quant
offset += 12
Expand Down
2 changes: 1 addition & 1 deletion convert-persimmon-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from sentencepiece import SentencePieceProcessor
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf

def _flatten_dict(dct, tensors, prefix=None):
Expand Down
16 changes: 9 additions & 7 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

import argparse
import concurrent.futures
import copy
import enum
import faulthandler
import functools
import io
import itertools
import json
import math
Expand All @@ -23,14 +21,14 @@
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Generator, Iterable, Literal, Sequence, TypeVar
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar

import numpy as np
from sentencepiece import SentencePieceProcessor

import os
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf

if TYPE_CHECKING:
Expand Down Expand Up @@ -851,7 +849,7 @@ def add_meta_vocab(self, vocab: Vocab) -> None:
elif isinstance(vocab, BpeVocab):
self.gguf.add_tokenizer_model("gpt2")
else:
raise ValueError(f'Unknown vocab type: Not BpeVocab or SentencePieceVocab')
raise ValueError('Unknown vocab type: Not BpeVocab or SentencePieceVocab')
self.gguf.add_token_list(tokens)
self.gguf.add_token_scores(scores)
self.gguf.add_token_types(toktypes)
Expand Down Expand Up @@ -905,7 +903,7 @@ def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray:
return dt.quantize(arr)

@staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess=gguf.GGUFEndian.LITTLE) -> None:
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
check_vocab_size(params, vocab)

of = OutputFile(fname_out, endianess=endianess)
Expand Down Expand Up @@ -1114,11 +1112,15 @@ def do_dump_model(model_plus: ModelPlus) -> None:


def main(args_in: list[str] | None = None) -> None:
output_choices = ["f32", "f16"]
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
# We currently only support Q8_0 output on little endian systems.
output_choices.append("q8_0")
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path

if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / '..' / '..' / 'gguf-py' / 'gguf'))
sys.path.insert(1, str(Path(__file__).parent / '..' / '..' / 'gguf-py'))
import gguf

# gguf constants
Expand Down
10 changes: 10 additions & 0 deletions gguf-py/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ as an example for its usage.
pip install gguf
```

## API Examples/Simple Tools

[examples/writer.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model.

[scripts/gguf-dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-dump.py) — Dumps a GGUF file's metadata to the console.

[scripts/gguf-set-metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-set-metadata.py) — Allows changing simple metadata values in a GGUF file by key.

[scripts/gguf-convert-endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-convert-endian.py) — Allows converting the endianness of GGUF files.

## Development
Maintainers who participate in development of this package are advised to install it in editable mode:

Expand Down
40 changes: 40 additions & 0 deletions gguf-py/examples/writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3
import sys
from pathlib import Path

import numpy as np

# Necessary to load the local gguf package
sys.path.insert(0, str(Path(__file__).parent.parent))

from gguf import GGUFWriter # noqa: E402


# Example usage:
def writer_example() -> None:
# Example usage with a file
gguf_writer = GGUFWriter("example.gguf", "llama")

gguf_writer.add_architecture()
gguf_writer.add_block_count(12)
gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float
gguf_writer.add_custom_alignment(64)

tensor1 = np.ones((32,), dtype=np.float32) * 100.0
tensor2 = np.ones((64,), dtype=np.float32) * 101.0
tensor3 = np.ones((96,), dtype=np.float32) * 102.0

gguf_writer.add_tensor("tensor1", tensor1)
gguf_writer.add_tensor("tensor2", tensor2)
gguf_writer.add_tensor("tensor3", tensor3)

gguf_writer.write_header_to_file()
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file()

gguf_writer.close()


if __name__ == '__main__':
writer_example()
6 changes: 5 additions & 1 deletion gguf-py/gguf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .gguf import *
from .constants import *
from .gguf_reader import *
from .gguf_writer import *
from .tensor_mapping import *
from .vocab import *
Loading

0 comments on commit 34b0a08

Please sign in to comment.