From 4c55c8dc4d750ee7026b45ab6ff563e5138249e0 Mon Sep 17 00:00:00 2001 From: Willow Date: Fri, 8 Nov 2024 06:49:57 +0000 Subject: [PATCH] fix attn_bias default value --- lmdeploy/turbomind/deploy/source_model/llava.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lmdeploy/turbomind/deploy/source_model/llava.py b/lmdeploy/turbomind/deploy/source_model/llava.py index 7d0a1ff058..3b4d82c37b 100644 --- a/lmdeploy/turbomind/deploy/source_model/llava.py +++ b/lmdeploy/turbomind/deploy/source_model/llava.py @@ -33,7 +33,7 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs): config = getattr(config, 'text_config', config) arch = config.architectures[0] _readers = dict(Qwen2ForCausalLM=LlavaReader, - LlamaForCausalL=LlavaReader) + LlamaForCausalLM=LlavaReader) self.Reader = _readers[arch] self.arch = arch @@ -63,7 +63,9 @@ def model_info(self): hidden_units = model_arg.get('hidden_size', 4096) vocab_size = model_arg.get('vocab_size', 152000) intermediate_size = model_arg.get('intermediate_size', 11008) - attn_bias = int(model_arg.get('attn_bias', 1)) + attn_bias = 1 if model_arg['architectures'][0] \ + == 'Qwen2ForCausalLM' else 0 + attn_bias = int(model_arg.get('attn_bias', attn_bias)) use_logn_attn = int(model_arg.get('use_logn_attn', 0)) if isinstance(rope_scaling, dict):