Skip to content

Commit

Permalink
model: support arch DbrxForCausalLM (ggerganov#6515)
Browse files Browse the repository at this point in the history
* model: dbrx convert to gguf
ggerganov#6344

* llama: support dbrx
ggerganov#6344

* doc: dbrx: add the model as supported

* scripts: get-wikitext-2 add unzip

* llama: increase maximum experts allowed

* llama: factorize moe graph implementation between grok, mixtral and dbrx


---------

Co-authored-by: Megha Agarwal <[email protected]>
  • Loading branch information
2 people authored and tybalex committed Apr 17, 2024
1 parent 81c3461 commit e4255c3
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 147 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Typically finetunes of the base models below are supported as well.
- [x] LLaMA 2 🦙🦙
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
- [X] Falcon
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
Expand Down
96 changes: 96 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,102 @@ def write_tensors(self):
self.gguf_writer.add_tensor(new_name, data)


@Model.register("DbrxForCausalLM")
class DbrxModel(Model):
model_arch = gguf.MODEL_ARCH.DBRX

def set_gguf_parameters(self):
ffn_config = self.hparams["ffn_config"]
attn_config = self.hparams["attn_config"]
self.gguf_writer.add_name(self.hparams["model_type"])
self.gguf_writer.add_block_count(self.hparams["n_layers"])

self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])

self.gguf_writer.add_head_count(self.hparams["n_heads"])
self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])

self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])

self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
self.gguf_writer.add_file_type(self.ftype)

self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])

self.gguf_writer.add_layer_norm_eps(1e-5)

self.gguf_writer.add_file_type(self.ftype)
print(f"gguf: file type = {self.ftype}")

def write_tensors(self):
block_count = self.hparams.get("n_layers")
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
n_expert = self.hparams["ffn_config"]["moe_num_experts"]
n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
n_embd = self.hparams["d_model"]

# Specific behavior for experts tensors: suffix .weight, view as 3D and transpose
# original implementation expects (n_expert, n_ff, n_embd) for all experts weights
# But llama.cpp moe graph works differently
# AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions
# so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor
exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
"ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert}
"ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
experts = False
for exp_tensor_name in exp_tensor_names.keys():
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
experts = True
data_torch = data_torch.view(n_expert, n_ff, n_embd)
if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None:
data_torch = data_torch.permute(*permute_tensor)
break

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

# map tensor names
# In MoE models the ffn tensors are typically most of the model weights,
# and need to be quantizable. Quantize expects tensor names to be suffixed by .weight.
# Every other model has the weight names ending in .weight,
# let's assume that is the convention which is not the case for dbrx:
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# Most of the codebase that takes in 1D tensors only handles F32 tensors
# and most of the outputs tensors are F32.
if data_dtype != np.float32 and n_dims == 1:
print(f"Can not map tensor {name!r}: all 1D tensors must be F32")
sys.exit()

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")

self.gguf_writer.add_tensor(new_name, data)


@Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM
Expand Down
26 changes: 18 additions & 8 deletions examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,27 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
}

static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
GGML_ASSERT(n > 0);
float sum = 0;
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
printf(" [\n");
for (int64_t i2 = 0; i2 < ne[2] && i2 < n; i2++) {
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
if (i2 == n && ne[2] > 2*n) {
printf(" ..., \n");
i2 = ne[2] - n;
}
printf(" [\n");
for (int64_t i1 = 0; i1 < ne[1] && i1 < n; i1++) {
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
if (i1 == n && ne[1] > 2*n) {
printf(" ..., \n");
i1 = ne[1] - n;
}
printf(" [");
for (int64_t i0 = 0; i0 < ne[0] && i0 < n; i0++) {
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
if (i0 == n && ne[0] > 2*n) {
printf("..., ");
i0 = ne[0] - n;
}
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
if (type == GGML_TYPE_F16) {
Expand All @@ -51,17 +64,14 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
} else {
GGML_ASSERT(false);
}
printf("%8.4f", v);
printf("%12.4f", v);
sum += v;
if (i0 < ne[0] - 1 && i0 < n - 1) printf(", ");
if (i0 < ne[0] - 1) printf(", ");
}
if (ne[0] > n) printf(", ...");
printf("],\n");
}
if (ne[1] > n) printf(" ...\n");
printf(" ],\n");
}
if (ne[2] > n) printf(" ...\n");
printf(" ]\n");
printf(" sum = %f\n", sum);
}
Expand Down
15 changes: 15 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class MODEL_ARCH(IntEnum):
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
DBRX = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -195,6 +196,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -642,6 +644,19 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
],
MODEL_ARCH.DBRX: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

Expand Down
58 changes: 33 additions & 25 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact qwen
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
Expand Down Expand Up @@ -48,7 +48,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -60,7 +60,7 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"transformer.norm_f", # mpt
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon
"model.final_layernorm", # persimmon
Expand Down Expand Up @@ -96,6 +96,7 @@ class TensorNameMap:
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
),

# Attention norm 2
Expand All @@ -108,6 +109,7 @@ class TensorNameMap:
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
Expand Down Expand Up @@ -152,30 +154,32 @@ class TensorNameMap:

# Attention output
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
),

# Attention output norm
MODEL_TENSOR.ATTN_OUT_NORM: (
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),

# Rotary embeddings
Expand All @@ -202,9 +206,10 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"transformer.decoder_layer.{bid}.router" # Grok
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
),

# Feed-forward up
Expand Down Expand Up @@ -233,6 +238,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
),

# AWQ-activation gate
Expand All @@ -251,8 +257,9 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear" # Grok (merged)
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
),

# Feed-forward down
Expand Down Expand Up @@ -280,6 +287,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
),

MODEL_TENSOR.ATTN_Q_NORM: (
Expand Down
Loading

0 comments on commit e4255c3

Please sign in to comment.