diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 3dce5c6a6a..187ea98a10 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -361,7 +361,7 @@ pub struct ModelForCausalLM { impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let base_model = Model::new(cfg, vb.clone())?; - let lm_head = if vb.contains_tensor("lm_head") { + let lm_head = if vb.contains_tensor("lm_head.weight") { linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? } else { Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)