Skip to content

Commit

Permalink
fix mini gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
dongchunyu committed May 10, 2024
1 parent e6468a6 commit f6c5fe4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
4 changes: 3 additions & 1 deletion lmdeploy/lite/apis/auto_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def save_vl_model(vl_model, model_path, dst_path):
tmp_path = osp.join(model_path, name)
if osp.exists(tmp_path):
shutil.copy(tmp_path, osp.join(dst_path, name))

safe_serialization = type(vl_model).__name__ == 'MGMLlamaForCausalLM'
vl_model.save_pretrained(dst_path,
max_shard_size='2GB',
safe_serialization=False)
safe_serialization=safe_serialization)


def auto_awq(model: str,
Expand Down
31 changes: 24 additions & 7 deletions lmdeploy/vl/model/mini_gemeni.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ def _openclip_vision_tower_load_model(self):
self.vision_stages.requires_grad_(False)


old_func = torch.nn.Module.load_state_dict


def _load_state_dict(self,
state_dict,
strict: bool = True,
assign: bool = False):
return old_func(self, state_dict, strict=False, assign=assign)


@contextmanager
def init_mini_gemini_model():
origin_func_path = [
Expand All @@ -145,7 +155,8 @@ def init_mini_gemini_model():
'mgm.model.multimodal_encoder.clip_encoder.CLIPVisionTower.__init__', # noqa: E501
'mgm.model.multimodal_encoder.clip_encoder.CLIPVisionTower.load_model', # noqa: E501
'mgm.model.multimodal_encoder.openclip_encoder.OpenCLIPVisionTower.__init__', # noqa: E501
'mgm.model.multimodal_encoder.openclip_encoder.OpenCLIPVisionTower.load_model' # noqa: E501
'mgm.model.multimodal_encoder.openclip_encoder.OpenCLIPVisionTower.load_model', # noqa: E501
'torch.nn.Module.load_state_dict',
]
rewrite_func = [
_build_vision_tower,
Expand All @@ -154,6 +165,7 @@ def init_mini_gemini_model():
_clip_vision_tower_load_model,
_openclip_vision_tower__init__,
_openclip_vision_tower_load_model,
_load_state_dict,
]
from lmdeploy.vl.model.utils import rewrite_ctx
with rewrite_ctx(origin_func_path, rewrite_func):
Expand All @@ -175,17 +187,25 @@ def build_model(self):
# empty init
from accelerate import init_empty_weights
from mgm.mm_utils import process_images
from mgm.model import MGMLlamaForCausalLM
from mgm.model import MGMLlamaForCausalLM # noqa
from mgm.model.language_model.mgm_llama import MGMConfig
from transformers import AutoModelForCausalLM
with init_empty_weights(), disable_transformers_logging(
), hack_import_with(['deepspeed']):
warnings.simplefilter('ignore')
model = MGMLlamaForCausalLM.from_pretrained(self.model_path)
config = MGMConfig.from_pretrained(self.model_path,
trust_remote_code=True)
setattr(config, 'quantization_config', {})
setattr(config, 'model_path', self.model_path)
model = AutoModelForCausalLM.from_config(config,
trust_remote_code=True)
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
else:
model.config.use_cache = False
self.vl_model = model

# # load weight
Expand All @@ -198,11 +218,8 @@ def build_model(self):
vision_tower_aux.is_loaded = False
vision_tower_aux.load_model()
load_model_from_weight_files(model, self.model_path)
model.to(self.device).eval()
model.model.vision_tower.half()
model.model.vision_tower_aux.half()
model.to(self.device).eval().half()

setattr(model.config, 'model_path', self.model_path)
model.get_model().initialize_uni_modules(model.config, for_eval=True)

self.model = model
Expand Down

0 comments on commit f6c5fe4

Please sign in to comment.