From e5452ddfd6e9a08d5e15bd81a010934550b9b507 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Nov 2023 20:03:58 -0800 Subject: [PATCH] Normalize head weights for Baichuan 2 (#1876) --- README.md | 2 +- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/baichuan.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ac3c79ec9e6a0..9cc325e924f77 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ vLLM is flexible and easy to use with: vLLM seamlessly supports many Hugging Face models, including the following architectures: - Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) -- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.) +- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index bebec8f9bfc6c..f56d6eaccfddc 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -19,7 +19,7 @@ Alongside each architecture, we include some popular models that use it. - :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc. * - :code:`BaiChuanForCausalLM` - Baichuan - - :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc. + - :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc. * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index d4a32e8e21a6d..3b56b9e137021 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -341,6 +341,17 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue + if name == "lm_head.weight": + # Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to: + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 + # Distinguish between Baichuan and Baichuan2 by checking the + # vocab size. This is suggested by + # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 + is_baichuan2 = self.config.vocab_size == 125696 + if is_baichuan2: + loaded_weight = torch.nn.functional.normalize( + loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue