Skip to content

Commit

Permalink
llama : add Mixtral support (ggerganov#4406)
Browse files Browse the repository at this point in the history
* convert : support Mixtral as LLAMA arch

* convert : fix n_ff typo

* llama : model loading

* ggml : sync latest ggml_mul_mat_id

* llama : update graph to support MoE

* llama : fix cur -> cur_expert

* llama : first working version

* llama : fix expert weighting in the FFN

* ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu only)

* ggml : add n_as argument to ggml_mul_mat_id

* ggml : fix ggml_get_rows to take into account ne02 / ne11

* metal : add more general support for ggml_get_rows + tests

* llama : add basic support for offloading moe with CUDA

* metal : add/mul/div use general kernel when src1 not cont

* metal : reduce the kernel launches for ggml_mul_mat_id

* ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D

* ggml : update get_rows f16 and q

* cuda : support non-contiguous src1 in get_rows

* llama : offload missing ffn_moe_silu

* metal : fix ggml_get_rows to work with non-cont src1

* metal : add indirect mat-vec kernels for all quantization types

* llama : do not quantize expert gating tensors

* llama : add n_expert and n_expert_used to hparams + change quants

* test-backend-ops : add moe test

* cuda : fix get_rows when ncols is odd

* convert : determine n_ctx correctly

* metal : fix ggml_mul_mat_id for F32

* test-backend-ops : make experts more evenly probable (test_moe)

* test-backend-ops : cleanup, add moe test for batches

* test-backend-ops : add cpy from f32 -> all types test

* test-backend-ops : fix dequantize block offset

* llama : fix hard-coded number of experts

* test-backend-ops : simplify and disable slow tests to avoid CI timeout

* test-backend-ops : disable MOE test with thread sanitizer

* cuda : fix mul_mat_id with multi gpu

* convert : use 1e6 rope_freq_base for mixtral

* convert : fix style

* convert : support safetensors format

* gguf-py : bump version

* metal : add cpy f16 -> f32 kernel

* metal : fix binary ops for ne10 % 4 != 0

* test-backend-ops : add one more sum_rows test

* ggml : do not use BLAS with ggml_mul_mat_id

* convert-hf : support for mixtral-instruct (ggerganov#4428)

* convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct

* convert : use sentencepiece tokenizer for Mixtral-instruct

* convert : make flake8 happy

* metal : fix soft_max kernels

ref: ggerganov/ggml@1914017

* metal : limit kernels to not use more than the allowed threads

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: Radek Pilar <[email protected]>
  • Loading branch information
3 people authored and teleprint-me committed Dec 21, 2023
1 parent 9cfbe94 commit 99db667
Show file tree
Hide file tree
Showing 14 changed files with 2,369 additions and 394 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,11 @@ ifdef LLAMA_CUBLAS
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
OBJS += ggml-cuda.o
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math

ifdef LLAMA_DEBUG
NVCCFLAGS += -lineinfo
endif

ifdef LLAMA_CUDA_NVCC
NVCC = $(LLAMA_CUDA_NVCC)
else
Expand Down
21 changes: 20 additions & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,18 @@ def set_gguf_parameters(self):
self.gguf_writer.add_embedding_length(n_embd)
if (n_ff := self.hparams.get("intermediate_size")) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
if (n_head := self.hparams.get("num_attention_head")) is not None:
if (n_head := self.hparams.get("num_attention_heads")) is not None:
self.gguf_writer.add_head_count(n_head)
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)

if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
if (n_experts := self.hparams.get("num_local_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)

self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))

def write_tensors(self):
Expand Down Expand Up @@ -170,6 +180,8 @@ def from_model_architecture(model_architecture):
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
if model_architecture == "MixtralForCausalLM":
return MixtralModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -207,6 +219,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -837,6 +851,11 @@ def set_gguf_parameters(self):
self.gguf_writer.add_layer_norm_eps(1e-5)


class MixtralModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()


class QwenModel(Model):
@staticmethod
def token_bytes_to_string(b):
Expand Down
74 changes: 57 additions & 17 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ARCH = gguf.MODEL_ARCH.LLAMA

DEFAULT_CONCURRENCY = 8

#
# data types
#
Expand All @@ -62,10 +63,10 @@ class UnquantizedDataType(DataType):
pass


DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])


@dataclass(frozen=True)
Expand Down Expand Up @@ -151,14 +152,16 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:

@dataclass
class Params:
n_vocab: int
n_embd: int
n_layer: int
n_ctx: int
n_ff: int
n_head: int
n_head_kv: int
f_norm_eps: float
n_vocab: int
n_embd: int
n_layer: int
n_ctx: int
n_ff: int
n_head: int
n_head_kv: int
n_experts: int | None = None
n_experts_used: int | None = None
f_norm_eps: float | None = None

rope_scaling_type: gguf.RopeScalingType | None = None
f_rope_freq_base: float | None = None
Expand Down Expand Up @@ -233,6 +236,13 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")

n_experts = None
n_experts_used = None

if "num_local_experts" in config:
n_experts = config["num_local_experts"]
n_experts_used = config["num_experts_per_tok"]

return Params(
n_vocab = config["vocab_size"],
n_embd = config["hidden_size"],
Expand All @@ -241,6 +251,8 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
n_ff = config["intermediate_size"],
n_head = (n_head := config["num_attention_heads"]),
n_head_kv = config.get("num_key_value_heads", n_head),
n_experts = n_experts,
n_experts_used = n_experts_used,
f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = config.get("rope_theta"),
rope_scaling_type = rope_scaling_type,
Expand All @@ -255,8 +267,15 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))

n_experts = None
n_experts_used = None
f_rope_freq_base = None

# hack to determine LLaMA v1 vs v2 vs CodeLlama
if config.get("rope_theta") == 1000000:
if config.get("moe"):
# Mixtral
n_ctx = 32768
elif config.get("rope_theta") == 1000000:
# CodeLlama
n_ctx = 16384
elif config["norm_eps"] == 1e-05:
Expand All @@ -266,16 +285,27 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
# LLaMA v1
n_ctx = 2048

if "layers.0.feed_forward.w1.weight" in model:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]

if config.get("moe"):
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
n_experts = config["moe"]["num_experts"]
n_experts_used = config["moe"]["num_experts_per_tok"]
f_rope_freq_base = 1e6

return Params(
n_vocab = model["tok_embeddings.weight"].shape[0],
n_embd = config["dim"],
n_layer = config["n_layers"],
n_ctx = n_ctx,
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
n_ff = n_ff,
n_head = (n_head := config["n_heads"]),
n_head_kv = config.get("n_kv_heads", n_head),
n_experts = n_experts,
n_experts_used = n_experts_used,
f_norm_eps = config["norm_eps"],
f_rope_freq_base = config.get("rope_theta"),
f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
)

@staticmethod
Expand Down Expand Up @@ -832,7 +862,17 @@ def add_meta_arch(self, params: Params) -> None:
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
self.gguf.add_head_count (params.n_head)
self.gguf.add_head_count_kv (params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)

if params.n_experts:
self.gguf.add_expert_count(params.n_experts)

if params.n_experts_used:
self.gguf.add_expert_used_count(params.n_experts_used)

if params.f_norm_eps:
self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
else:
raise ValueError('f_norm_eps is None')

if params.f_rope_freq_base is not None:
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
Expand Down Expand Up @@ -956,7 +996,7 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM


def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type

if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
return GGMLFileType.AllF32
Expand Down
Loading

0 comments on commit 99db667

Please sign in to comment.