Skip to content

Commit

Permalink
keep LlavaLlamaForCausalLM/LlavaMistralForCausalLM to llama
Browse files Browse the repository at this point in the history
  • Loading branch information
deepindeed2022 committed Nov 8, 2024
1 parent e5005f5 commit 3383c6e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions lmdeploy/turbomind/deploy/source_model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
super().__init__(model_path, tokenizer_path, **kwargs)
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
config = getattr(config, 'text_config', config)
arch = config.architectures[0]
_readers = dict(LlavaForConditionalGeneration=LlavaReader,
LlavaMistralForCausalLM=LlamaReader,
LlavaLlamaForCausalLM=LlamaReader)
_readers = dict(Qwen2ForCausalLM=LlavaReader,
LlamaForCausalL=LlavaReader)
self.Reader = _readers[arch]
self.arch = arch

def model_info(self):
if self.arch in ['LlavaMistralForCausalLM', 'LlavaLlamaForCausalLM']:
return super().model_info()
"""Read model info for LlavaForConditionalGeneration.
https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf"""
https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf
"""
params_path = osp.join(self.model_path, 'config.json')
with open(params_path) as f:
model_arg = json.load(f)['text_config']
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/turbomind/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# mistral
MistralForCausalLM='llama',
# llava
LlavaLlamaForCausalLM='llava',
LlavaMistralForCausalLM='llava',
LlavaLlamaForCausalLM='llama',
LlavaMistralForCausalLM='llama',
LlavaForConditionalGeneration='llava',
# xcomposer2
InternLMXComposer2ForCausalLM='xcomposer2',
Expand Down

0 comments on commit 3383c6e

Please sign in to comment.