Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ChatGLM3 #158

Merged
merged 4 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__/
*.egg-info/
dist/
*.so
.hypothesis/

# cpp
build/
Expand Down
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
![Python](https://img.shields.io/pypi/pyversions/chatglm-cpp)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue)](LICENSE)

C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and more LLMs for real-time chatting on your MacBook.
C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3) and more LLMs for real-time chatting on your MacBook.

![demo](docs/demo.gif)

Expand All @@ -21,7 +21,7 @@ Highlights:
Support Matrix:
* Hardwares: x86/arm CPU, NVIDIA GPU, Apple Silicon GPU
* Platforms: Linux, MacOS, Windows
* Models: [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan-7B](https://github.com/baichuan-inc/Baichuan-7B), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan2](https://github.com/baichuan-inc/Baichuan2), [InternLM](https://github.com/InternLM/InternLM)
* Models: [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3), [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan-7B](https://github.com/baichuan-inc/Baichuan-7B), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan2](https://github.com/baichuan-inc/Baichuan2), [InternLM](https://github.com/InternLM/InternLM)

## Getting Started

Expand All @@ -45,14 +45,15 @@ python3 -m pip install -U pip
python3 -m pip install torch tabulate tqdm transformers accelerate sentencepiece
```

Use `convert.py` to transform ChatGLM-6B or ChatGLM2-6B into quantized GGML format. For example, to convert the fp16 original model to q4_0 (quantized int4) GGML model, run:
Use `convert.py` to transform ChatGLM-6B into quantized GGML format. For example, to convert the fp16 original model to q4_0 (quantized int4) GGML model, run:
```sh
python3 chatglm_cpp/convert.py -i THUDM/chatglm-6b -t q4_0 -o chatglm-ggml.bin
```

The original model (`-i <model_name_or_path>`) can be a HuggingFace model name or a local path to your pre-downloaded model. Currently supported models are:
* ChatGLM-6B: `THUDM/chatglm-6b`, `THUDM/chatglm-6b-int8`, `THUDM/chatglm-6b-int4`
* ChatGLM2-6B: `THUDM/chatglm2-6b`, `THUDM/chatglm2-6b-int4`
* ChatGLM3-6B: `THUDM/chatglm3-6b`
* CodeGeeX2: `THUDM/codegeex2-6b`, `THUDM/codegeex2-6b-int4`
* Baichuan & Baichuan2: `baichuan-inc/Baichuan-13B-Chat`, `baichuan-inc/Baichuan2-7B-Chat`, `baichuan-inc/Baichuan2-13B-Chat`

Expand Down Expand Up @@ -101,6 +102,16 @@ python3 chatglm_cpp/convert.py -i THUDM/chatglm2-6b -t q4_0 -o chatglm2-ggml.bin
```
</details>

<details open>
<summary>ChatGLM3-6B</summary>

```sh
python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o chatglm3-ggml.bin
./build/bin/main -m chatglm3-ggml.bin -p 你好 --top_p 0.8 --temp 0.8
# 你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。
```
</details>

<details>
<summary>CodeGeeX2</summary>

Expand Down Expand Up @@ -272,6 +283,15 @@ python3 web_demo.py -m ../chatglm2-ggml.bin --temp 0.8 --top_p 0.8 # web demo
```
</details>

<details open>
<summary>ChatGLM3-6B</summary>

```sh
python3 cli_chat.py -m ../chatglm3-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo
python3 web_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 # web demo
```
</details>

<details>
<summary>CodeGeeX2</summary>

Expand Down Expand Up @@ -473,7 +493,7 @@ ChatGLM-6B:
| file size | 3.3G | 3.7G | 4.0G | 4.4G | 6.2G | 12G |
| mem usage | 4.0G | 4.4G | 4.7G | 5.1G | 6.9G | 13G |

ChatGLM2-6B / CodeGeeX2:
ChatGLM2-6B / ChatGLM3-6B / CodeGeeX2:

| | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F16 |
|--------------------------------|-------|-------|-------|-------|-------|-------|
Expand Down Expand Up @@ -548,4 +568,4 @@ This will print timing for each graph operation when running the model.
## Acknowledgements

* This project is greatly inspired by [@ggerganov](https://github.com/ggerganov)'s [llama.cpp](https://github.com/ggerganov/llama.cpp) and is based on his NN library [ggml](https://github.com/ggerganov/ggml).
* Thank [@THUDM](https://github.com/THUDM) for the amazing [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and for releasing the model sources and checkpoints.
* Thank [@THUDM](https://github.com/THUDM) for the amazing [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3) and for releasing the model sources and checkpoints.
116 changes: 96 additions & 20 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ std::string to_string(ModelType model_type) {
return "ChatGLM";
case MODEL_TYPE_CHATGLM2:
return "ChatGLM2";
case MODEL_TYPE_CHATGLM3:
return "ChatGLM3";
case MODEL_TYPE_BAICHUAN7B:
return "Baichuan7B";
case MODEL_TYPE_BAICHUAN13B:
Expand All @@ -433,9 +435,8 @@ std::string to_string(ModelType model_type) {
}
}

BaseModelForCausalLM::BaseModelForCausalLM(ModelType model_type, ModelConfig config, size_t mem_size,
size_t scratch_size, size_t num_weights)
: model_type_(model_type), config(config) {
BaseModelForCausalLM::BaseModelForCausalLM(ModelConfig config, size_t mem_size, size_t scratch_size, size_t num_weights)
: config(config) {
ctx_.dtype = config.dtype;
const size_t ctx_w_size = num_weights * ggml_tensor_overhead();
const size_t ctx_kv_size = 2 * config.num_hidden_layers *
Expand Down Expand Up @@ -821,7 +822,7 @@ ggml_tensor *GLMBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, gg
}

ChatGLMForCausalLM::ChatGLMForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_CHATGLM, config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
state_dict_ = state_dict();
}

Expand Down Expand Up @@ -933,8 +934,7 @@ bool ChatGLM2Tokenizer::is_special_id(int id) const {
}

ChatGLM2ForCausalLM::ChatGLM2ForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_CHATGLM2, config, MEM_SIZE, SCRATCH_SIZE,
num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
state_dict_ = state_dict();
}

Expand Down Expand Up @@ -998,6 +998,79 @@ StateDict ChatGLM2ForCausalLM::state_dict() const {
return sd;
}

// ===== ChatGLM3-6B =====

ChatGLM3Tokenizer::ChatGLM3Tokenizer(std::string_view serialized_model_proto) {
const auto status = sp.LoadFromSerializedProto(serialized_model_proto);
CHATGLM_CHECK(status.ok()) << status.ToString();

int special_id = sp.GetPieceSize();
mask_token_id = special_id++;
gmask_token_id = special_id++;
smask_token_id = special_id++;
sop_token_id = special_id++;
eop_token_id = special_id++;
system_token_id = special_id++;
user_token_id = special_id++;
assistant_token_id = special_id++;
observation_token_id = special_id++;
}

std::vector<int> ChatGLM3Tokenizer::encode(const std::string &text, int max_length) const {
std::vector<int> ids;
sp.Encode(text, &ids);
ids.insert(ids.begin(), {gmask_token_id, sop_token_id}); // special prefix
truncate(ids, max_length);
return ids;
}

std::string ChatGLM3Tokenizer::decode(const std::vector<int> &ids) const {
// filter out special tokens
std::vector<int> normal_ids(ids);
normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) { return is_special_id(id); }),
normal_ids.end());

std::string text;
sp.Decode(normal_ids, &text);
text = replace_punctuations(text);
return text;
}

std::vector<int> ChatGLM3Tokenizer::encode_history(const std::vector<std::string> &history, int max_length) const {
// TODO: need a new api for system / tools / metadata prompt
std::vector<int> newline_ids;
sp.Encode("\n", &newline_ids);
std::vector<int> input_ids{gmask_token_id, sop_token_id};
for (size_t i = 0; i < history.size(); i++) {
// TODO: support all roles
input_ids.emplace_back((i % 2 == 0) ? user_token_id : assistant_token_id);
// TODO: support metadata
input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end());
std::vector<int> content_ids;
sp.Encode(history[i], &content_ids);
input_ids.insert(input_ids.end(), content_ids.begin(), content_ids.end());
}
input_ids.emplace_back(assistant_token_id);
// NOTE: push '\n' into input_ids to avoid model generating it, saving 2 tokens
input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end());
truncate(input_ids, max_length);
return input_ids;
}

bool ChatGLM3Tokenizer::is_special_id(int id) const {
return id == mask_token_id || id == gmask_token_id || id == smask_token_id || id == sop_token_id ||
id == eop_token_id || id == system_token_id || id == user_token_id || id == assistant_token_id ||
id == observation_token_id;
}

void ChatGLM3Tokenizer::truncate(std::vector<int> &ids, int max_length) {
if ((int)ids.size() > max_length) {
// sliding window: drop the least recent history while keeping the two special prefix tokens
int num_drop = (int)ids.size() - max_length;
ids.erase(ids.begin() + 2, ids.begin() + 2 + num_drop);
}
}

// ===== Baichuan =====

BaichuanTokenizer::BaichuanTokenizer(std::string_view serialized_model_proto) {
Expand Down Expand Up @@ -1055,8 +1128,7 @@ void BaichuanTokenizer::truncate(std::vector<int> &ids, int max_length) {
// ===== Baichuan-7B =====

Baichuan7BForCausalLM::Baichuan7BForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_BAICHUAN7B, config, MEM_SIZE, SCRATCH_SIZE,
num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
state_dict_ = state_dict();
}

Expand Down Expand Up @@ -1097,8 +1169,7 @@ StateDict Baichuan7BForCausalLM::state_dict() const {
// ===== Baichuan-13B =====

Baichuan13BForCausalLM::Baichuan13BForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_BAICHUAN13B, config, MEM_SIZE, SCRATCH_SIZE,
num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
state_dict_ = state_dict();
}

Expand Down Expand Up @@ -1192,8 +1263,7 @@ std::string InternLMTokenizer::build_prompt(const std::vector<std::string> &hist

template <typename InternLMModel>
InternLMForCausalLM<InternLMModel>::InternLMForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM<InternLMModel>(MODEL_TYPE_INTERNLM, config, MEM_SIZE, SCRATCH_SIZE,
num_weights(config.num_hidden_layers)) {
: BasicModelForCausalLM<InternLMModel>(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) {
this->state_dict_ = state_dict();
}

Expand Down Expand Up @@ -1258,7 +1328,7 @@ Pipeline::Pipeline(const std::string &path) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV1>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>());

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1269,26 +1339,32 @@ Pipeline::Pipeline(const std::string &path) {
// load model
model = std::make_unique<ChatGLMForCausalLM>(config);
model->load(loader);
} else if (model_type == MODEL_TYPE_CHATGLM2) {
} else if (model_type == MODEL_TYPE_CHATGLM2 || model_type == MODEL_TYPE_CHATGLM3) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV2>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV2>());

// load tokenizer
int proto_size = loader.read_basic<int>();
std::string_view serialized_model_proto((char *)mapped_file->data + loader.tell(), proto_size);
loader.seek(proto_size, SEEK_CUR);
tokenizer = std::make_unique<ChatGLM2Tokenizer>(serialized_model_proto);

if (model_type == MODEL_TYPE_CHATGLM2) {
tokenizer = std::make_unique<ChatGLM2Tokenizer>(serialized_model_proto);
model = std::make_unique<ChatGLM2ForCausalLM>(config);
} else {
tokenizer = std::make_unique<ChatGLM3Tokenizer>(serialized_model_proto);
model = std::make_unique<ChatGLM3ForCausalLM>(config);
}

// load model
model = std::make_unique<ChatGLM2ForCausalLM>(config);
model->load(loader);
} else if (model_type == MODEL_TYPE_BAICHUAN7B) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV1>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>());
config.norm_eps = 1e-6;

// load tokenizer
Expand All @@ -1304,7 +1380,7 @@ Pipeline::Pipeline(const std::string &path) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV1>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>());
config.norm_eps = 1e-6;

// load tokenizer
Expand All @@ -1320,7 +1396,7 @@ Pipeline::Pipeline(const std::string &path) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;

// load config
ModelConfig config(loader.read_basic<ConfigRecordV1>());
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>());
config.norm_eps = 1e-6;

// load tokenizer
Expand Down
Loading