Skip to content

Commit

Permalink
Support ChatGLM3 (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Oct 29, 2023
1 parent f89f6fd commit dc6a8ba
Show file tree
Hide file tree
Showing 13 changed files with 392 additions and 67 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__/
*.egg-info/
dist/
*.so
.hypothesis/

# cpp
build/
Expand Down
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
![Python](https://img.shields.io/pypi/pyversions/chatglm-cpp)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue)](LICENSE)

C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and more LLMs for real-time chatting on your MacBook.
C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3) and more LLMs for real-time chatting on your MacBook.

![demo](docs/demo.gif)

Expand All @@ -21,7 +21,7 @@ Highlights:
Support Matrix:
* Hardwares: x86/arm CPU, NVIDIA GPU, Apple Silicon GPU
* Platforms: Linux, MacOS, Windows
* Models: [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan-7B](https://github.com/baichuan-inc/Baichuan-7B), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan2](https://github.com/baichuan-inc/Baichuan2), [InternLM](https://github.com/InternLM/InternLM)
* Models: [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3), [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan-7B](https://github.com/baichuan-inc/Baichuan-7B), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan2](https://github.com/baichuan-inc/Baichuan2), [InternLM](https://github.com/InternLM/InternLM)

## Getting Started

Expand All @@ -45,14 +45,15 @@ python3 -m pip install -U pip
python3 -m pip install torch tabulate tqdm transformers accelerate sentencepiece
```

Use `convert.py` to transform ChatGLM-6B or ChatGLM2-6B into quantized GGML format. For example, to convert the fp16 original model to q4_0 (quantized int4) GGML model, run:
Use `convert.py` to transform ChatGLM-6B into quantized GGML format. For example, to convert the fp16 original model to q4_0 (quantized int4) GGML model, run:
```sh
python3 chatglm_cpp/convert.py -i THUDM/chatglm-6b -t q4_0 -o chatglm-ggml.bin
```

The original model (`-i <model_name_or_path>`) can be a HuggingFace model name or a local path to your pre-downloaded model. Currently supported models are:
* ChatGLM-6B: `THUDM/chatglm-6b`, `THUDM/chatglm-6b-int8`, `THUDM/chatglm-6b-int4`
* ChatGLM2-6B: `THUDM/chatglm2-6b`, `THUDM/chatglm2-6b-int4`
* ChatGLM3-6B: `THUDM/chatglm3-6b`
* CodeGeeX2: `THUDM/codegeex2-6b`, `THUDM/codegeex2-6b-int4`
* Baichuan & Baichuan2: `baichuan-inc/Baichuan-13B-Chat`, `baichuan-inc/Baichuan2-7B-Chat`, `baichuan-inc/Baichuan2-13B-Chat`

Expand Down Expand Up @@ -101,6 +102,16 @@ python3 chatglm_cpp/convert.py -i THUDM/chatglm2-6b -t q4_0 -o chatglm2-ggml.bin
```
</details>

<details open>
<summary>ChatGLM3-6B</summary>

```sh
python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o chatglm3-ggml.bin
./build/bin/main -m chatglm3-ggml.bin -p 你好 --top_p 0.8 --temp 0.8
# 你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。
```
</details>

<details>
<summary>CodeGeeX2</summary>

Expand Down Expand Up @@ -272,6 +283,15 @@ python3 web_demo.py -m ../chatglm2-ggml.bin --temp 0.8 --top_p 0.8 # web demo
```
</details>
<details open>
<summary>ChatGLM3-6B</summary>
```sh
python3 cli_chat.py -m ../chatglm3-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo
python3 web_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 # web demo
```
</details>
<details>
<summary>CodeGeeX2</summary>
Expand Down Expand Up @@ -473,7 +493,7 @@ ChatGLM-6B:
| file size | 3.3G | 3.7G | 4.0G | 4.4G | 6.2G | 12G |
| mem usage | 4.0G | 4.4G | 4.7G | 5.1G | 6.9G | 13G |
ChatGLM2-6B / CodeGeeX2:
ChatGLM2-6B / ChatGLM3-6B / CodeGeeX2:
| | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F16 |
|--------------------------------|-------|-------|-------|-------|-------|-------|
Expand Down Expand Up @@ -548,4 +568,4 @@ This will print timing for each graph operation when running the model.
## Acknowledgements
* This project is greatly inspired by [@ggerganov](https://github.com/ggerganov)'s [llama.cpp](https://github.com/ggerganov/llama.cpp) and is based on his NN library [ggml](https://github.com/ggerganov/ggml).
* Thank [@THUDM](https://github.com/THUDM) for the amazing [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and for releasing the model sources and checkpoints.
* Thank [@THUDM](https://github.com/THUDM) for the amazing [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3) and for releasing the model sources and checkpoints.
116 changes: 96 additions & 20 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ std::string to_string(ModelType model_type) {
return "ChatGLM";
case MODEL_TYPE_CHATGLM2:
return "ChatGLM2";
case MODEL_TYPE_CHATGLM3:
return "ChatGLM3";
case MODEL_TYPE_BAICHUAN7B:
return "Baichuan7B";
case MODEL_TYPE_BAICHUAN13B:
Expand All @@ -433,9 +435,8 @@ std::string to_string(ModelType model_type) {
}
}

BaseModelForCausalLM::BaseModelForCausalLM(ModelType model_type, ModelConfig config, size_t mem_size,
size_t scratch_size, size_t num_weights)
: model_type_(model_type), config(config) {
BaseModelForCausalLM::BaseModelForCausalLM(ModelConfig config, size_t mem_size, size_t scratch_size, size_t num_weights)
: config(config) {
ctx_.dtype = config.dtype;
const size_t ctx_w_size = num_weights * ggml_tensor_overhead();
const size_t ctx_kv_size = 2 * config.num_hidden_layers *
Expand Down Expand Up @@ -821,7 +822,7 @@ ggml_tensor *GLMBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, gg
}

ChatGLMForCausalLM::ChatGLMForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_CHATGLM, config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
state_dict_ = state_dict();
}

Expand Down Expand Up @@ -933,8 +934,7 @@ bool ChatGLM2Tokenizer::is_special_id(int id) const {
}

ChatGLM2ForCausalLM::ChatGLM2ForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_CHATGLM2, config, MEM_SIZE, SCRATCH_SIZE,
num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
state_dict_ = state_dict();
}

Expand Down Expand Up @@ -998,6 +998,79 @@ StateDict ChatGLM2ForCausalLM::state_dict() const {
return sd;
}

// ===== ChatGLM3-6B =====

ChatGLM3Tokenizer::ChatGLM3Tokenizer(std::string_view serialized_model_proto) {
const auto status = sp.LoadFromSerializedProto(serialized_model_proto);
CHATGLM_CHECK(status.ok()) << status.ToString();

int special_id = sp.GetPieceSize();
mask_token_id = special_id++;
gmask_token_id = special_id++;
smask_token_id = special_id++;
sop_token_id = special_id++;
eop_token_id = special_id++;
system_token_id = special_id++;
user_token_id = special_id++;
assistant_token_id = special_id++;
observation_token_id = special_id++;
}

std::vector<int> ChatGLM3Tokenizer::encode(const std::string &text, int max_length) const {
std::vector<int> ids;
sp.Encode(text, &ids);
ids.insert(ids.begin(), {gmask_token_id, sop_token_id}); // special prefix
truncate(ids, max_length);
return ids;
}

std::string ChatGLM3Tokenizer::decode(const std::vector<int> &ids) const {
// filter out special tokens
std::vector<int> normal_ids(ids);
normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) { return is_special_id(id); }),
normal_ids.end());

std::string text;
sp.Decode(normal_ids, &text);
text = replace_punctuations(text);
return text;
}

std::vector<int> ChatGLM3Tokenizer::encode_history(const std::vector<std::string> &history, int max_length) const {
// TODO: need a new api for system / tools / metadata prompt
std::vector<int> newline_ids;
sp.Encode("\n", &newline_ids);
std::vector<int> input_ids{gmask_token_id, sop_token_id};
for (size_t i = 0; i < history.size(); i++) {
// TODO: support all roles
input_ids.emplace_back((i % 2 == 0) ? user_token_id : assistant_token_id);
// TODO: support metadata
input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end());
std::vector<int> content_ids;
sp.Encode(history[i], &content_ids);
input_ids.insert(input_ids.end(), content_ids.begin(), content_ids.end());
}
input_ids.emplace_back(assistant_token_id);
// NOTE: push '\n' into input_ids to avoid model generating it, saving 2 tokens
input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end());
truncate(input_ids, max_length);
return input_ids;
}

bool ChatGLM3Tokenizer::is_special_id(int id) const {
return id == mask_token_id || id == gmask_token_id || id == smask_token_id || id == sop_token_id ||
id == eop_token_id || id == system_token_id || id == user_token_id || id == assistant_token_id ||
id == observation_token_id;
}

void ChatGLM3Tokenizer::truncate(std::vector<int> &ids, int max_length) {
if ((int)ids.size() > max_length) {
// sliding window: drop the least recent history while keeping the two special prefix tokens
int num_drop = (int)ids.size() - max_length;
ids.erase(ids.begin() + 2, ids.begin() + 2 + num_drop);
}
}

// ===== Baichuan =====

BaichuanTokenizer::BaichuanTokenizer(std::string_view serialized_model_proto) {
Expand Down Expand Up @@ -1055,8 +1128,7 @@ void BaichuanTokenizer::truncate(std::vector<int> &ids, int max_length) {
// ===== Baichuan-7B =====

Baichuan7BForCausalLM::Baichuan7BForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_BAICHUAN7B, config, MEM_SIZE, SCRATCH_SIZE,
num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
state_dict_ = state_dict();
}

Expand Down Expand Up @@ -1097,8 +1169,7 @@ StateDict Baichuan7BForCausalLM::state_dict() const {
// ===== Baichuan-13B =====

Baichuan13BForCausalLM::Baichuan13BForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_BAICHUAN13B, config, MEM_SIZE, SCRATCH_SIZE,
num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
state_dict_ = state_dict();
}

Expand Down Expand Up @@ -1192,8 +1263,7 @@ std::string InternLMTokenizer::build_prompt(const std::vector<std::string> &hist

template <typename InternLMModel>
InternLMForCausalLM<InternLMModel>::InternLMForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM<InternLMModel>(MODEL_TYPE_INTERNLM, config, MEM_SIZE, SCRATCH_SIZE,
num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM<InternLMModel>(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
this->state_dict_ = state_dict();
}

Expand Down Expand Up @@ -1258,7 +1328,7 @@ Pipeline::Pipeline(const std::string &path) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV1>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>());

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1269,26 +1339,32 @@ Pipeline::Pipeline(const std::string &path) {
// load model
model = std::make_unique<ChatGLMForCausalLM>(config);
model->load(loader);
} else if (model_type == MODEL_TYPE_CHATGLM2) {
} else if (model_type == MODEL_TYPE_CHATGLM2 || model_type == MODEL_TYPE_CHATGLM3) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV2>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV2>());

// load tokenizer
int proto_size = loader.read_basic<int>();
std::string_view serialized_model_proto((char *)mapped_file->data + loader.tell(), proto_size);
loader.seek(proto_size, SEEK_CUR);
tokenizer = std::make_unique<ChatGLM2Tokenizer>(serialized_model_proto);

if (model_type == MODEL_TYPE_CHATGLM2) {
tokenizer = std::make_unique<ChatGLM2Tokenizer>(serialized_model_proto);
model = std::make_unique<ChatGLM2ForCausalLM>(config);
} else {
tokenizer = std::make_unique<ChatGLM3Tokenizer>(serialized_model_proto);
model = std::make_unique<ChatGLM3ForCausalLM>(config);
}

// load model
model = std::make_unique<ChatGLM2ForCausalLM>(config);
model->load(loader);
} else if (model_type == MODEL_TYPE_BAICHUAN7B) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV1>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>());
config.norm_eps = 1e-6;

// load tokenizer
Expand All @@ -1304,7 +1380,7 @@ Pipeline::Pipeline(const std::string &path) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV1>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>());
config.norm_eps = 1e-6;

// load tokenizer
Expand All @@ -1320,7 +1396,7 @@ Pipeline::Pipeline(const std::string &path) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV1>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>());
config.norm_eps = 1e-6;

// load tokenizer
Expand Down
Loading

0 comments on commit dc6a8ba

Please sign in to comment.