Skip to content

Commit

Permalink
keep LlavaLlamaForCausalLM/LlavaMistralForCausalLM to llama
Browse files Browse the repository at this point in the history
  • Loading branch information
deepindeed2022 committed Nov 8, 2024
1 parent e5005f5 commit 48d1a5c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
12 changes: 6 additions & 6 deletions lmdeploy/turbomind/deploy/source_model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
super().__init__(model_path, tokenizer_path, **kwargs)
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
config = getattr(config, 'text_config', config)
arch = config.architectures[0]
_readers = dict(LlavaForConditionalGeneration=LlavaReader,
LlavaMistralForCausalLM=LlamaReader,
LlavaLlamaForCausalLM=LlamaReader)
_readers = dict(Qwen2ForCausalLM=LlavaReader,
LlamaForCausalL=LlavaReader)
self.Reader = _readers[arch]
self.arch = arch

def model_info(self):
if self.arch in ['LlavaMistralForCausalLM', 'LlavaLlamaForCausalLM']:
return super().model_info()
"""Read model info for LlavaForConditionalGeneration.
https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf"""
https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf
"""
params_path = osp.join(self.model_path, 'config.json')
with open(params_path) as f:
model_arg = json.load(f)['text_config']
Expand Down
5 changes: 2 additions & 3 deletions lmdeploy/turbomind/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# mistral
MistralForCausalLM='llama',
# llava
LlavaLlamaForCausalLM='llava',
LlavaMistralForCausalLM='llava',
LlavaLlamaForCausalLM='llama',
LlavaMistralForCausalLM='llama',
LlavaForConditionalGeneration='llava',
# xcomposer2
InternLMXComposer2ForCausalLM='xcomposer2',
Expand Down Expand Up @@ -95,7 +95,6 @@ def _is_head_dim_supported(cfg):
# glm-4v-9b not supported
support_by_turbomind = False
elif arch == 'InternVLChatModel':
# internvl2-4b,internlm2-1b are not working yet
support_by_turbomind = _is_head_dim_supported(cfg.llm_config)
elif arch == 'LlavaForConditionalGeneration':
sub_arch = cfg.text_config.architectures[0]
Expand Down

0 comments on commit 48d1a5c

Please sign in to comment.