Skip to content

Commit

Permalink
support yi
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed May 8, 2024
1 parent a67ddfe commit 07013a4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
5 changes: 4 additions & 1 deletion lmdeploy/vl/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ def vl_model_with_tokenizer(model_path: str, device: str):
if arch == 'MultiModalityCausalLM':
return DeepSeekVisionModel.model_with_tokenizer(model_path, device)
if arch == 'LlavaLlamaForCausalLM':
projector_type = config.get('mm_projector_type', 'linear')
mm_vision_tower = config.get('mm_vision_tower', '')
if 'OpenGVLab' in mm_vision_tower:
if '_Norm' in projector_type:
return YiVisionModel.model_with_tokenizer(model_path)
elif 'OpenGVLab' in mm_vision_tower:
return InternVLLlavaVisionModel.model_with_tokenizer(model_path)
else:
return LlavaVisionModel.model_with_tokenizer(model_path, device)
Expand Down
6 changes: 2 additions & 4 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,10 @@ def build_model(self):
# load weight
load_model_from_weight_files(model, self.model_path)
model.to(self.device).eval()
if model.dtype != torch.float16:
model.half()

self.model = model.model
self.vision_tower = model.model.vision_tower
self.mm_projector = model.model.mm_projector
self.vision_tower = model.model.vision_tower.half()
self.mm_projector = model.model.mm_projector.half()

@staticmethod
def model_with_tokenizer(model_path: str, device='cpu'):
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/vl/model/yi.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,12 @@ def build_model(self):

with init_yi_model(), disable_transformers_logging():
super().build_model()

@staticmethod
def model_with_tokenizer(model_path: str, device='cpu'):
check_llava_install()
global _model_path
_model_path = model_path
with init_yi_model(), disable_transformers_logging():
outs = LlavaVisionModel.model_with_tokenizer(model_path, device)
return outs

0 comments on commit 07013a4

Please sign in to comment.