diff --git a/realhf/api/from_hf/qwen2.py b/realhf/api/from_hf/qwen2.py index ad4f5d40..7e0243e0 100644 --- a/realhf/api/from_hf/qwen2.py +++ b/realhf/api/from_hf/qwen2.py @@ -52,6 +52,7 @@ def convert_config_qwen2( apply_rotary=True, rotary_base=hf_config.rope_theta, rotary_interleaved=False, + tied_embedding=hf_config.tie_word_embeddings, ) @@ -70,6 +71,7 @@ def convert_config_back_qwen2( hidden_act=config.activation_function, attention_dropout=config.attn_pdrop, rope_theta=config.rotary_base, + tie_word_embeddings=config.tied_embedding, architectures=["Qwen2ForCausalLM"], )