diff --git a/.gitignore b/.gitignore index f3961fd..47826ac 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ __pycache__/ *.egg-info/ dist/ *.so +.hypothesis/ # cpp build/ diff --git a/README.md b/README.md index 8c0b4d3..e7bb168 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ![Python](https://img.shields.io/pypi/pyversions/chatglm-cpp) [![License: MIT](https://img.shields.io/badge/license-MIT-blue)](LICENSE) -C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and more LLMs for real-time chatting on your MacBook. +C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3) and more LLMs for real-time chatting on your MacBook. ![demo](docs/demo.gif) @@ -21,7 +21,7 @@ Highlights: Support Matrix: * Hardwares: x86/arm CPU, NVIDIA GPU, Apple Silicon GPU * Platforms: Linux, MacOS, Windows -* Models: [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan-7B](https://github.com/baichuan-inc/Baichuan-7B), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan2](https://github.com/baichuan-inc/Baichuan2), [InternLM](https://github.com/InternLM/InternLM) +* Models: [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3), [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan-7B](https://github.com/baichuan-inc/Baichuan-7B), [Baichuan-13B](https://github.com/baichuan-inc/Baichuan-13B), [Baichuan2](https://github.com/baichuan-inc/Baichuan2), [InternLM](https://github.com/InternLM/InternLM) ## Getting Started @@ -45,7 +45,7 @@ python3 -m pip install -U pip python3 -m pip install torch tabulate tqdm transformers accelerate sentencepiece ``` -Use `convert.py` to transform ChatGLM-6B or ChatGLM2-6B into quantized GGML format. For example, to convert the fp16 original model to q4_0 (quantized int4) GGML model, run: +Use `convert.py` to transform ChatGLM-6B into quantized GGML format. For example, to convert the fp16 original model to q4_0 (quantized int4) GGML model, run: ```sh python3 chatglm_cpp/convert.py -i THUDM/chatglm-6b -t q4_0 -o chatglm-ggml.bin ``` @@ -53,6 +53,7 @@ python3 chatglm_cpp/convert.py -i THUDM/chatglm-6b -t q4_0 -o chatglm-ggml.bin The original model (`-i `) can be a HuggingFace model name or a local path to your pre-downloaded model. Currently supported models are: * ChatGLM-6B: `THUDM/chatglm-6b`, `THUDM/chatglm-6b-int8`, `THUDM/chatglm-6b-int4` * ChatGLM2-6B: `THUDM/chatglm2-6b`, `THUDM/chatglm2-6b-int4` +* ChatGLM3-6B: `THUDM/chatglm3-6b` * CodeGeeX2: `THUDM/codegeex2-6b`, `THUDM/codegeex2-6b-int4` * Baichuan & Baichuan2: `baichuan-inc/Baichuan-13B-Chat`, `baichuan-inc/Baichuan2-7B-Chat`, `baichuan-inc/Baichuan2-13B-Chat` @@ -101,6 +102,16 @@ python3 chatglm_cpp/convert.py -i THUDM/chatglm2-6b -t q4_0 -o chatglm2-ggml.bin ``` +
+ChatGLM3-6B + +```sh +python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o chatglm3-ggml.bin +./build/bin/main -m chatglm3-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 +# 你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。 +``` +
+
CodeGeeX2 @@ -272,6 +283,15 @@ python3 web_demo.py -m ../chatglm2-ggml.bin --temp 0.8 --top_p 0.8 # web demo ```
+
+ChatGLM3-6B + +```sh +python3 cli_chat.py -m ../chatglm3-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo +python3 web_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 # web demo +``` +
+
CodeGeeX2 @@ -473,7 +493,7 @@ ChatGLM-6B: | file size | 3.3G | 3.7G | 4.0G | 4.4G | 6.2G | 12G | | mem usage | 4.0G | 4.4G | 4.7G | 5.1G | 6.9G | 13G | -ChatGLM2-6B / CodeGeeX2: +ChatGLM2-6B / ChatGLM3-6B / CodeGeeX2: | | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F16 | |--------------------------------|-------|-------|-------|-------|-------|-------| @@ -548,4 +568,4 @@ This will print timing for each graph operation when running the model. ## Acknowledgements * This project is greatly inspired by [@ggerganov](https://github.com/ggerganov)'s [llama.cpp](https://github.com/ggerganov/llama.cpp) and is based on his NN library [ggml](https://github.com/ggerganov/ggml). -* Thank [@THUDM](https://github.com/THUDM) for the amazing [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and for releasing the model sources and checkpoints. +* Thank [@THUDM](https://github.com/THUDM) for the amazing [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) and [ChatGLM3-6B](https://github.com/THUDM/ChatGLM3) and for releasing the model sources and checkpoints. diff --git a/chatglm.cpp b/chatglm.cpp index a0a3c9c..0fdcdd2 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -422,6 +422,8 @@ std::string to_string(ModelType model_type) { return "ChatGLM"; case MODEL_TYPE_CHATGLM2: return "ChatGLM2"; + case MODEL_TYPE_CHATGLM3: + return "ChatGLM3"; case MODEL_TYPE_BAICHUAN7B: return "Baichuan7B"; case MODEL_TYPE_BAICHUAN13B: @@ -433,9 +435,8 @@ std::string to_string(ModelType model_type) { } } -BaseModelForCausalLM::BaseModelForCausalLM(ModelType model_type, ModelConfig config, size_t mem_size, - size_t scratch_size, size_t num_weights) - : model_type_(model_type), config(config) { +BaseModelForCausalLM::BaseModelForCausalLM(ModelConfig config, size_t mem_size, size_t scratch_size, size_t num_weights) + : config(config) { ctx_.dtype = config.dtype; const size_t ctx_w_size = num_weights * ggml_tensor_overhead(); const size_t ctx_kv_size = 2 * config.num_hidden_layers * @@ -821,7 +822,7 @@ ggml_tensor *GLMBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, gg } ChatGLMForCausalLM::ChatGLMForCausalLM(const ModelConfig &config) - : BasicModelForCausalLM(MODEL_TYPE_CHATGLM, config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) { + : BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) { state_dict_ = state_dict(); } @@ -933,8 +934,7 @@ bool ChatGLM2Tokenizer::is_special_id(int id) const { } ChatGLM2ForCausalLM::ChatGLM2ForCausalLM(const ModelConfig &config) - : BasicModelForCausalLM(MODEL_TYPE_CHATGLM2, config, MEM_SIZE, SCRATCH_SIZE, - num_weights(config.num_hidden_layers)) { + : BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) { state_dict_ = state_dict(); } @@ -998,6 +998,79 @@ StateDict ChatGLM2ForCausalLM::state_dict() const { return sd; } +// ===== ChatGLM3-6B ===== + +ChatGLM3Tokenizer::ChatGLM3Tokenizer(std::string_view serialized_model_proto) { + const auto status = sp.LoadFromSerializedProto(serialized_model_proto); + CHATGLM_CHECK(status.ok()) << status.ToString(); + + int special_id = sp.GetPieceSize(); + mask_token_id = special_id++; + gmask_token_id = special_id++; + smask_token_id = special_id++; + sop_token_id = special_id++; + eop_token_id = special_id++; + system_token_id = special_id++; + user_token_id = special_id++; + assistant_token_id = special_id++; + observation_token_id = special_id++; +} + +std::vector ChatGLM3Tokenizer::encode(const std::string &text, int max_length) const { + std::vector ids; + sp.Encode(text, &ids); + ids.insert(ids.begin(), {gmask_token_id, sop_token_id}); // special prefix + truncate(ids, max_length); + return ids; +} + +std::string ChatGLM3Tokenizer::decode(const std::vector &ids) const { + // filter out special tokens + std::vector normal_ids(ids); + normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) { return is_special_id(id); }), + normal_ids.end()); + + std::string text; + sp.Decode(normal_ids, &text); + text = replace_punctuations(text); + return text; +} + +std::vector ChatGLM3Tokenizer::encode_history(const std::vector &history, int max_length) const { + // TODO: need a new api for system / tools / metadata prompt + std::vector newline_ids; + sp.Encode("\n", &newline_ids); + std::vector input_ids{gmask_token_id, sop_token_id}; + for (size_t i = 0; i < history.size(); i++) { + // TODO: support all roles + input_ids.emplace_back((i % 2 == 0) ? user_token_id : assistant_token_id); + // TODO: support metadata + input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end()); + std::vector content_ids; + sp.Encode(history[i], &content_ids); + input_ids.insert(input_ids.end(), content_ids.begin(), content_ids.end()); + } + input_ids.emplace_back(assistant_token_id); + // NOTE: push '\n' into input_ids to avoid model generating it, saving 2 tokens + input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end()); + truncate(input_ids, max_length); + return input_ids; +} + +bool ChatGLM3Tokenizer::is_special_id(int id) const { + return id == mask_token_id || id == gmask_token_id || id == smask_token_id || id == sop_token_id || + id == eop_token_id || id == system_token_id || id == user_token_id || id == assistant_token_id || + id == observation_token_id; +} + +void ChatGLM3Tokenizer::truncate(std::vector &ids, int max_length) { + if ((int)ids.size() > max_length) { + // sliding window: drop the least recent history while keeping the two special prefix tokens + int num_drop = (int)ids.size() - max_length; + ids.erase(ids.begin() + 2, ids.begin() + 2 + num_drop); + } +} + // ===== Baichuan ===== BaichuanTokenizer::BaichuanTokenizer(std::string_view serialized_model_proto) { @@ -1055,8 +1128,7 @@ void BaichuanTokenizer::truncate(std::vector &ids, int max_length) { // ===== Baichuan-7B ===== Baichuan7BForCausalLM::Baichuan7BForCausalLM(const ModelConfig &config) - : BasicModelForCausalLM(MODEL_TYPE_BAICHUAN7B, config, MEM_SIZE, SCRATCH_SIZE, - num_weights(config.num_hidden_layers)) { + : BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) { state_dict_ = state_dict(); } @@ -1097,8 +1169,7 @@ StateDict Baichuan7BForCausalLM::state_dict() const { // ===== Baichuan-13B ===== Baichuan13BForCausalLM::Baichuan13BForCausalLM(const ModelConfig &config) - : BasicModelForCausalLM(MODEL_TYPE_BAICHUAN13B, config, MEM_SIZE, SCRATCH_SIZE, - num_weights(config.num_hidden_layers)) { + : BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) { state_dict_ = state_dict(); } @@ -1192,8 +1263,7 @@ std::string InternLMTokenizer::build_prompt(const std::vector &hist template InternLMForCausalLM::InternLMForCausalLM(const ModelConfig &config) - : BasicModelForCausalLM(MODEL_TYPE_INTERNLM, config, MEM_SIZE, SCRATCH_SIZE, - num_weights(config.num_hidden_layers)) { + : BasicModelForCausalLM(config, MEM_SIZE, SCRATCH_SIZE, num_weights(config.num_hidden_layers)) { this->state_dict_ = state_dict(); } @@ -1258,7 +1328,7 @@ Pipeline::Pipeline(const std::string &path) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config - ModelConfig config(loader.read_basic()); + ModelConfig config(model_type, loader.read_basic()); // load tokenizer int proto_size = loader.read_basic(); @@ -1269,26 +1339,32 @@ Pipeline::Pipeline(const std::string &path) { // load model model = std::make_unique(config); model->load(loader); - } else if (model_type == MODEL_TYPE_CHATGLM2) { + } else if (model_type == MODEL_TYPE_CHATGLM2 || model_type == MODEL_TYPE_CHATGLM3) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config - ModelConfig config(loader.read_basic()); + ModelConfig config(model_type, loader.read_basic()); // load tokenizer int proto_size = loader.read_basic(); std::string_view serialized_model_proto((char *)mapped_file->data + loader.tell(), proto_size); loader.seek(proto_size, SEEK_CUR); - tokenizer = std::make_unique(serialized_model_proto); + + if (model_type == MODEL_TYPE_CHATGLM2) { + tokenizer = std::make_unique(serialized_model_proto); + model = std::make_unique(config); + } else { + tokenizer = std::make_unique(serialized_model_proto); + model = std::make_unique(config); + } // load model - model = std::make_unique(config); model->load(loader); } else if (model_type == MODEL_TYPE_BAICHUAN7B) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config - ModelConfig config(loader.read_basic()); + ModelConfig config(model_type, loader.read_basic()); config.norm_eps = 1e-6; // load tokenizer @@ -1304,7 +1380,7 @@ Pipeline::Pipeline(const std::string &path) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config - ModelConfig config(loader.read_basic()); + ModelConfig config(model_type, loader.read_basic()); config.norm_eps = 1e-6; // load tokenizer @@ -1320,7 +1396,7 @@ Pipeline::Pipeline(const std::string &path) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config - ModelConfig config(loader.read_basic()); + ModelConfig config(model_type, loader.read_basic()); config.norm_eps = 1e-6; // load tokenizer diff --git a/chatglm.h b/chatglm.h index 59ba9df..cf3562b 100644 --- a/chatglm.h +++ b/chatglm.h @@ -46,6 +46,17 @@ ggml_tensor *tensor_to_device(ggml_tensor *tensor); ggml_tensor *tensor_to_cpu(ggml_tensor *tensor); +enum ModelType { + MODEL_TYPE_CHATGLM = 1, + MODEL_TYPE_CHATGLM2 = 2, + MODEL_TYPE_CHATGLM3 = 3, + MODEL_TYPE_BAICHUAN7B = 1024, + MODEL_TYPE_BAICHUAN13B = 1025, + MODEL_TYPE_INTERNLM = 1280, +}; + +std::string to_string(ModelType model_type); + // For compatibility struct ConfigRecordV1 { // common attributes @@ -74,25 +85,28 @@ class ModelConfig { public: ModelConfig() = default; - ModelConfig(ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads, int num_kv_heads, - int num_hidden_layers, int intermediate_size, float norm_eps, int max_length, int bos_token_id, - int eos_token_id, int pad_token_id, int sep_token_id) - : dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size), num_attention_heads(num_attention_heads), - num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers), intermediate_size(intermediate_size), - norm_eps(norm_eps), max_length(max_length), bos_token_id(bos_token_id), eos_token_id(eos_token_id), - pad_token_id(pad_token_id), sep_token_id(sep_token_id) {} - - ModelConfig(const ConfigRecordV1 &rec) - : ModelConfig(rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_attention_heads, + ModelConfig(ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads, + int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps, int max_length, + int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id) + : model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size), + num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers), + intermediate_size(intermediate_size), norm_eps(norm_eps), max_length(max_length), bos_token_id(bos_token_id), + eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id) {} + + ModelConfig(ModelType model_type, const ConfigRecordV1 &rec) + : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, + rec.num_attention_heads, rec.num_hidden_layers, rec.intermediate_size, 1e-5, rec.max_length, + rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id) {} + + ModelConfig(ModelType model_type, const ConfigRecordV2 &rec) + : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads, rec.num_hidden_layers, rec.intermediate_size, 1e-5, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id) {} - ModelConfig(const ConfigRecordV2 &rec) - : ModelConfig(rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads, - rec.num_hidden_layers, rec.intermediate_size, 1e-5, rec.max_length, rec.bos_token_id, - rec.eos_token_id, rec.pad_token_id, rec.sep_token_id) {} + std::string model_type_name() const { return to_string(model_type); } public: + ModelType model_type; ggml_type dtype; int vocab_size; int hidden_size; @@ -734,19 +748,9 @@ struct GenerationConfig { top_p(top_p), temperature(temperature), repetition_penalty(repetition_penalty), num_threads(num_threads) {} }; -enum ModelType { - MODEL_TYPE_CHATGLM = 1, - MODEL_TYPE_CHATGLM2 = 2, - MODEL_TYPE_BAICHUAN7B = 1024, - MODEL_TYPE_BAICHUAN13B = 1025, - MODEL_TYPE_INTERNLM = 1280, -}; - int get_num_physical_cores(); int get_default_num_threads(); -std::string to_string(ModelType model_type); - struct TokenIdScore { int id; float score; @@ -764,16 +768,12 @@ struct TokenIdScore { class BaseModelForCausalLM { public: - BaseModelForCausalLM(ModelType model_type, ModelConfig config, size_t mem_size, size_t scratch_size, - size_t num_weights); + BaseModelForCausalLM(ModelConfig config, size_t mem_size, size_t scratch_size, size_t num_weights); virtual ~BaseModelForCausalLM() = default; virtual void load(ModelLoader &loader) = 0; virtual ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx) const = 0; - ModelType type() const { return model_type_; } - std::string type_name() const { return to_string(model_type_); } - std::vector generate(const std::vector &input_ids, const GenerationConfig &gen_config, BaseStreamer *streamer = nullptr); @@ -791,7 +791,6 @@ class BaseModelForCausalLM { static void sampling_softmax_inplace(TokenIdScore *first, TokenIdScore *last); protected: - ModelType model_type_; ModelContext ctx_; public: @@ -803,9 +802,8 @@ using StateDict = std::vector>; template class BasicModelForCausalLM : public BaseModelForCausalLM { protected: - BasicModelForCausalLM(ModelType model_type, const ModelConfig &config, size_t mem_size, size_t scratch_size, - size_t num_weights) - : BaseModelForCausalLM(model_type, config, mem_size, scratch_size, num_weights), transformer(&ctx_, config), + BasicModelForCausalLM(const ModelConfig &config, size_t mem_size, size_t scratch_size, size_t num_weights) + : BaseModelForCausalLM(config, mem_size, scratch_size, num_weights), transformer(&ctx_, config), lm_head(&ctx_, config.hidden_size, config.vocab_size, false) { CHATGLM_CHECK(ggml_used_mem(ctx_.ctx_w.get()) == ggml_get_mem_size(ctx_.ctx_w.get())) << "corrupted model weights"; @@ -983,6 +981,40 @@ class ChatGLM2ForCausalLM : public BasicModelForCausalLM { static constexpr size_t SCRATCH_SIZE = 1280 * MB; // 2k context }; +// ===== ChatGLM3-6B ===== + +class ChatGLM3Tokenizer : public BaseTokenizer { + public: + ChatGLM3Tokenizer(std::string_view serialized_model_proto); + + std::vector encode(const std::string &text, int max_length) const override; + + std::string decode(const std::vector &ids) const override; + + std::vector encode_history(const std::vector &history, int max_length) const override; + + bool is_special_id(int id) const; + + protected: + static void truncate(std::vector &ids, int max_length); + + public: + sentencepiece::SentencePieceProcessor sp; + int mask_token_id; + int gmask_token_id; + int smask_token_id; + int sop_token_id; + int eop_token_id; + int system_token_id; + int user_token_id; + int assistant_token_id; + int observation_token_id; +}; + +using ChatGLM3Model = ChatGLM2Model; + +using ChatGLM3ForCausalLM = ChatGLM2ForCausalLM; + // ===== Baichuan ===== class BaichuanTokenizer : public BaseTokenizer { diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index f0f7696..a1dc183 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -5,7 +5,7 @@ import chatglm_cpp._C as _C -__version__ = "0.2.9" +__version__ = "0.2.10" class Pipeline(_C.Pipeline): diff --git a/chatglm_cpp/convert.py b/chatglm_cpp/convert.py index d06d3e9..dd5bea2 100644 --- a/chatglm_cpp/convert.py +++ b/chatglm_cpp/convert.py @@ -41,6 +41,7 @@ class GGMLType(Enum): class ModelType(Enum): CHATGLM = 1 CHATGLM2 = 2 + CHATGLM3 = 3 BAICHUAN7B = 1024 BAICHUAN13B = 1025 INTERNLM = 1280 @@ -324,6 +325,10 @@ def dump_model(f, model, ggml_type): dump_state_dict(f, weight_names, model.state_dict(), model.config.quantization_bit, ggml_type) +class ChatGLM3Converter(ChatGLM2Converter): + MODEL_TYPE = ModelType.CHATGLM3 + + class BaichuanConverter(BaseConverter): @staticmethod def dump_config(f, config, ggml_type): @@ -481,7 +486,10 @@ def convert(f: BinaryIO, model_name_or_path: str, lora_model_name_or_path: Optio if model.config.model_type == "chatglm": if hasattr(model.config, "multi_query_attention"): - ChatGLM2Converter.convert(f, model, tokenizer, ggml_type) + if model.config.seq_length == 32768: + ChatGLM2Converter.convert(f, model, tokenizer, ggml_type) + else: + ChatGLM3Converter.convert(f, model, tokenizer, ggml_type) else: ChatGLMConverter.convert(f, model, tokenizer, ggml_type) elif model.config.model_type == "baichuan": diff --git a/chatglm_pybind.cpp b/chatglm_pybind.cpp index 8d56ccd..2fcd7c7 100644 --- a/chatglm_pybind.cpp +++ b/chatglm_pybind.cpp @@ -36,6 +36,7 @@ PYBIND11_MODULE(_C, m) { m.doc() = "ChatGLM.cpp python binding"; py::class_(m, "ModelConfig") + .def_readonly("model_type", &ModelConfig::model_type) .def_readonly("dtype", &ModelConfig::dtype) .def_readonly("vocab_size", &ModelConfig::vocab_size) .def_readonly("hidden_size", &ModelConfig::hidden_size) @@ -48,7 +49,8 @@ PYBIND11_MODULE(_C, m) { .def_readonly("bos_token_id", &ModelConfig::bos_token_id) .def_readonly("eos_token_id", &ModelConfig::eos_token_id) .def_readonly("pad_token_id", &ModelConfig::pad_token_id) - .def_readonly("sep_token_id", &ModelConfig::sep_token_id); + .def_readonly("sep_token_id", &ModelConfig::sep_token_id) + .def_property_readonly("model_type_name", &ModelConfig::model_type_name); py::class_(m, "BaseTokenizer") .def("encode", &BaseTokenizer::encode) @@ -56,7 +58,6 @@ PYBIND11_MODULE(_C, m) { .def("encode_history", &BaseTokenizer::encode_history); py::class_(m, "BaseModelForCausalLM") - .def_property_readonly("type_name", &BaseModelForCausalLM::type_name) .def("generate_next_token", &BaseModelForCausalLM::generate_next_token) .def_readonly("config", &BaseModelForCausalLM::config); @@ -85,6 +86,10 @@ PYBIND11_MODULE(_C, m) { py::class_(m, "ChatGLM2ForCausalLM"); + // ===== ChatGLM3 ===== + + py::class_(m, "ChatGLM3Tokenizer"); + // ===== Baichuan7B/13B ===== py::class_(m, "BaichuanTokenizer"); diff --git a/chatglm_test.cpp b/chatglm_test.cpp index cc54139..6df3028 100644 --- a/chatglm_test.cpp +++ b/chatglm_test.cpp @@ -751,6 +751,40 @@ TEST_F(ChatGLMTest, GLM2Model) { // } // } +TEST_F(ChatGLMTest, GLM3Model) { + fs::path data_path = fs::path(__FILE__).parent_path() / "tests/data/glm3_model.data"; + + ModelConfig config; + config.vocab_size = 5; + config.hidden_size = 32; + config.num_attention_heads = 8; + config.num_kv_heads = 2; + config.num_hidden_layers = 1; + config.intermediate_size = 48; + config.norm_eps = 1e-5; + config.max_length = 8; + + constexpr int seq_len = 3; + + ChatGLM3Model model(&ctx, config); + + tensor_to_device(model.layers[0].attention.k_cache); + tensor_to_device(model.layers[0].attention.v_cache); + + std::vector all_weights{model.word_embeddings.weight, + model.layers[0].input_layernorm.weight, + model.layers[0].attention.query_key_value.weight, + model.layers[0].attention.query_key_value.bias, + model.layers[0].attention.dense.weight, + model.layers[0].post_attention_layernorm.weight, + model.layers[0].mlp.gate_proj.weight, + model.layers[0].mlp.up_proj.weight, + model.layers[0].mlp.down_proj.weight, + model.final_layernorm.weight}; + + test_model(model, config, data_path, seq_len, all_weights); +} + TEST_F(ChatGLMTest, Baichuan7BModel) { fs::path data_path = fs::path(__FILE__).parent_path() / "tests/data/baichuan7b_model.data"; @@ -1082,6 +1116,64 @@ TEST(Pipeline, ChatGLM2) { } } +TEST(Pipeline, ChatGLM3) { + fs::path model_path = fs::path(__FILE__).parent_path() / "chatglm3-ggml.bin"; + if (!fs::exists(model_path)) { + GTEST_SKIP() << "Skipping ChatGLM3 e2e test (ggml model not found)"; + } + Pipeline pipeline(model_path.string()); + EXPECT_TRUE(dynamic_cast(pipeline.model.get())); + + // tokenizer + { + std::vector cases{{"你好", {64790, 64792, 36474, 54591}}}; + check_tokenizer(pipeline.tokenizer.get(), cases); + + { + std::vector history{"你好"}; + std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048); + std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796, 30910, 13}; + EXPECT_EQ(input_ids, target_ids); + } + { + std::vector history{"你好", + "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。", + "晚上睡不着应该怎么办"}; + std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048); + std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796, 30910, 13, + 36474, 54591, 243, 162, 148, 142, 31404, 33030, 34797, 42481, + 22011, 10461, 30944, 30966, 30941, 30978, 30949, 31123, 48895, 35214, + 54622, 31123, 32616, 39905, 31901, 31639, 31155, 64795, 30910, 13, + 30910, 32820, 54266, 31876, 35153, 64796, 30910, 13}; + EXPECT_EQ(input_ids, target_ids); + } + } + + // memory test + { + GenerationConfig gen_config; + gen_config.max_length = 2048; + gen_config.max_context_length = gen_config.max_length - 1; + gen_config.do_sample = false; + + std::ostringstream oss; + for (int i = 0; i < gen_config.max_context_length; i++) { + oss << "你好"; + } + std::vector history{oss.str()}; + pipeline.chat(history, gen_config); + } + + // chat + { + GenerationConfig gen_config; + gen_config.do_sample = false; + std::vector history{"你好"}; + std::string output = pipeline.chat(history, gen_config); + EXPECT_EQ(output, "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。"); + } +} + TEST(Pipeline, CodeGeeX2) { fs::path model_path = fs::path(__FILE__).parent_path() / "codegeex2-ggml.bin"; if (!fs::exists(model_path)) { diff --git a/examples/cli_chat.py b/examples/cli_chat.py index 0fcf8ee..c5cda13 100644 --- a/examples/cli_chat.py +++ b/examples/cli_chat.py @@ -70,7 +70,7 @@ def main(): history = [] while True: try: - prompt = input(f"{'Prompt':{len(pipeline.model.type_name)}} > ") + prompt = input(f"{'Prompt':{len(pipeline.model.config.model_type_name)}} > ") except EOFError: break if not prompt: @@ -81,7 +81,7 @@ def main(): history = [] continue history.append(prompt) - print(f"{pipeline.model.type_name} > ", sep="", end="") + print(f"{pipeline.model.config.model_type_name} > ", sep="", end="") output = "" for piece in pipeline.chat(history, **generation_kwargs): print(piece, sep="", end="", flush=True) diff --git a/main.cpp b/main.cpp index 2e9a1d9..cf305a7 100644 --- a/main.cpp +++ b/main.cpp @@ -139,7 +139,7 @@ static void chat(Args &args) { chatglm::Pipeline pipeline(args.model_path); int64_t end_load_us = ggml_time_us(); - std::string model_name = pipeline.model->type_name(); + std::string model_name = pipeline.model->config.model_type_name(); auto text_streamer = std::make_shared(std::cout, pipeline.tokenizer.get()); auto perf_streamer = std::make_shared(); @@ -174,7 +174,7 @@ static void chat(Args &args) { << "temperature = " << args.temp << " | " << "num_threads = " << args.num_threads << " |\n"; - std::cout << "loaded " << pipeline.model->type_name() << " model from " << args.model_path + std::cout << "loaded " << pipeline.model->config.model_type_name() << " model from " << args.model_path << " within: " << (end_load_us - start_load_us) / 1000.f << " ms\n"; std::cout << std::endl; diff --git a/tests/data/glm3_model.data b/tests/data/glm3_model.data new file mode 100644 index 0000000..d5dc2c6 Binary files /dev/null and b/tests/data/glm3_model.data differ diff --git a/tests/test_chatglm_cpp.py b/tests/test_chatglm_cpp.py index 8f63f1e..29b2829 100644 --- a/tests/test_chatglm_cpp.py +++ b/tests/test_chatglm_cpp.py @@ -7,6 +7,7 @@ CHATGLM_MODEL_PATH = PROJECT_ROOT / "chatglm-ggml.bin" CHATGLM2_MODEL_PATH = PROJECT_ROOT / "chatglm2-ggml.bin" +CHATGLM3_MODEL_PATH = PROJECT_ROOT / "chatglm3-ggml.bin" CODEGEEX2_MODEL_PATH = PROJECT_ROOT / "codegeex2-ggml.bin" BAICHUAN13B_MODEL_PATH = PROJECT_ROOT / "baichuan-13b-chat-ggml.bin" BAICHUAN2_7B_MODEL_PATH = PROJECT_ROOT / "baichuan2-7b-chat-ggml.bin" @@ -55,6 +56,24 @@ def test_chatglm2_pipeline(): assert stream_output == target +@pytest.mark.skipif(not CHATGLM3_MODEL_PATH.exists(), reason="model file not found") +def test_chatglm3_pipeline(): + history = ["你好"] + target = "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。" + + pipeline = chatglm_cpp.Pipeline(CHATGLM3_MODEL_PATH) + output = pipeline.chat(history, do_sample=False) + assert output == target + + stream_output = pipeline.stream_chat(history, do_sample=False) + stream_output = "".join(stream_output) + assert stream_output == target + + stream_output = pipeline.chat(history, do_sample=False, stream=True) + stream_output = "".join(stream_output) + assert stream_output == target + + @pytest.mark.skipif(not CODEGEEX2_MODEL_PATH.exists(), reason="model file not found") def test_codegeex2_pipeline(): prompt = "# language: Python\n# write a bubble sort function\n" diff --git a/tests/test_convert.py b/tests/test_convert.py index ff4ff1d..816b163 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -384,6 +384,77 @@ def make_data_glm2_model(): y3.numpy().tofile(f) +def make_data_glm3_model(): + CHATGLM3_MODEL_PATH = Path("./chatglm3-6b").expanduser() + + sys.path.append(str(CHATGLM3_MODEL_PATH)) + from modeling_chatglm import ChatGLMModel + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(CHATGLM3_MODEL_PATH, trust_remote_code=True) + config.hidden_size = 32 + config.num_attention_heads = 8 + config.num_layers = 1 + config.padded_vocab_size = 5 + config.multi_query_group_num = 2 + config.ffn_hidden_size = 48 + config.kv_channels = config.hidden_size // config.num_attention_heads + config.torch_dtype = torch.float32 + + m = ChatGLMModel(config).float().eval() + for param in m.parameters(): + param.data.uniform_(-0.5, 0.5) + + seq_len = 3 + + # self attention + x1 = torch.arange(seq_len, dtype=torch.int64)[None, :] + position_ids = torch.arange(seq_len, dtype=torch.int64)[None, :] + attn_mask = torch.ones(1, seq_len, dtype=torch.int64) + with torch.no_grad(): + out = m(x1, position_ids=position_ids, attention_mask=attn_mask, use_cache=True) + y1 = out.last_hidden_state + kv_cache = out.past_key_values + + # cross attention + x2 = torch.tensor([[seq_len]], dtype=torch.int64) + position_ids = torch.tensor([[seq_len]], dtype=torch.int64) + attn_mask = torch.ones(1, seq_len + 1, dtype=torch.int64) + with torch.no_grad(): + out = m(x2, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True) + y2 = out.last_hidden_state + kv_cache = out.past_key_values + + # cross attention + x3 = torch.tensor([[seq_len + 1]], dtype=torch.int64) + position_ids = torch.tensor([[seq_len + 1]], dtype=torch.int64) + attn_mask = torch.ones(1, seq_len + 2, dtype=torch.int64) + with torch.no_grad(): + out = m(x3, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True) + y3 = out.last_hidden_state + kv_cache = out.past_key_values + + print(m) + + with open(HERE / "data/glm3_model.data", "wb") as f: + m.embedding.word_embeddings.weight.data.numpy().tofile(f) + m.encoder.layers[0].input_layernorm.weight.data.numpy().tofile(f) + m.encoder.layers[0].self_attention.query_key_value.weight.data.numpy().tofile(f) + m.encoder.layers[0].self_attention.query_key_value.bias.data.numpy().tofile(f) + m.encoder.layers[0].self_attention.dense.weight.data.numpy().tofile(f) + m.encoder.layers[0].post_attention_layernorm.weight.data.numpy().tofile(f) + m.encoder.layers[0].mlp.dense_h_to_4h.weight.data.numpy().tofile(f) + m.encoder.layers[0].mlp.dense_4h_to_h.weight.data.numpy().tofile(f) + m.encoder.final_layernorm.weight.data.numpy().tofile(f) + + x1.int().numpy().tofile(f) + y1.numpy().tofile(f) + x2.int().numpy().tofile(f) + y2.numpy().tofile(f) + x3.int().numpy().tofile(f) + y3.numpy().tofile(f) + + def _make_data_baichuan_model(model_path, out_name): sys.path.append(str(model_path)) from modeling_baichuan import BaichuanModel @@ -549,9 +620,10 @@ def main(): # make_data_rms_norm() # make_data_glm_model() # make_data_glm2_model() + make_data_glm3_model() # make_data_baichuan7b_model() # make_data_baichuan13b_model() - make_internlm_model() + # make_internlm_model() if __name__ == "__main__":