From 3383c6eb1eabbdac8d1a9972e489e1e06209e52e Mon Sep 17 00:00:00 2001 From: Willow Date: Fri, 8 Nov 2024 03:13:40 +0000 Subject: [PATCH] keep LlavaLlamaForCausalLM/LlavaMistralForCausalLM to llama --- lmdeploy/turbomind/deploy/source_model/llava.py | 12 ++++++------ lmdeploy/turbomind/supported_models.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lmdeploy/turbomind/deploy/source_model/llava.py b/lmdeploy/turbomind/deploy/source_model/llava.py index 0902468a77..7d0a1ff058 100644 --- a/lmdeploy/turbomind/deploy/source_model/llava.py +++ b/lmdeploy/turbomind/deploy/source_model/llava.py @@ -30,18 +30,18 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs): super().__init__(model_path, tokenizer_path, **kwargs) from transformers import AutoConfig config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + config = getattr(config, 'text_config', config) arch = config.architectures[0] - _readers = dict(LlavaForConditionalGeneration=LlavaReader, - LlavaMistralForCausalLM=LlamaReader, - LlavaLlamaForCausalLM=LlamaReader) + _readers = dict(Qwen2ForCausalLM=LlavaReader, + LlamaForCausalL=LlavaReader) self.Reader = _readers[arch] self.arch = arch def model_info(self): - if self.arch in ['LlavaMistralForCausalLM', 'LlavaLlamaForCausalLM']: - return super().model_info() """Read model info for LlavaForConditionalGeneration. - https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf""" + + https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf + """ params_path = osp.join(self.model_path, 'config.json') with open(params_path) as f: model_arg = json.load(f)['text_config'] diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index 88ca22717d..402f340cc3 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -23,8 +23,8 @@ # mistral MistralForCausalLM='llama', # llava - LlavaLlamaForCausalLM='llava', - LlavaMistralForCausalLM='llava', + LlavaLlamaForCausalLM='llama', + LlavaMistralForCausalLM='llama', LlavaForConditionalGeneration='llava', # xcomposer2 InternLMXComposer2ForCausalLM='xcomposer2',