Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Nov 5, 2024
1 parent c155963 commit 8d8f8b9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
8 changes: 6 additions & 2 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ class ModelConfig:
kv_head_num: int = None
hidden_units: int = None
vocab_size: int = None
# In molmo, embedding.shape is [vocab_size + 128, hidden_units].
# Therefore, we add a new attr "embedding_size" to represent it
# Turbomind used to assume token_embedding and lm_head has the same size
# at vocab dim, i.e. `vocab_size`
# But in molmo, embedding.shape is [vocab_size + 128, hidden_units]
# while lm_head shape is [hidden_units, vocab_size].
# Therefore, we add a new attr "embedding_size" to represent the vocab dim
# of token_embedding
embedding_size: int = 0
num_layer: int = None
inter_size: int = None
Expand Down
1 change: 0 additions & 1 deletion lmdeploy/turbomind/deploy/source_model/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def model_info(self):
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
hidden_units = model_arg['hidden_size']
inter_size = model_arg['intermediate_size']
attn_head_num = model_arg['num_attention_heads']
vocab_size = model_arg['vocab_size']
inter_size = model_arg['intermediate_size']
Expand Down
23 changes: 12 additions & 11 deletions lmdeploy/vl/model/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ class MolmoVisionModel(VisonModel):

def build_model(self):
"""Load model."""
# import pdb; pdb.set_trace()
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
with init_empty_weights():
config = self.hf_config
model = AutoModelForCausalLM.from_config(config,
trust_remote_code=True)
if not self.with_llm:
# Remove nn modules other than embedding from the LLM model
for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']:
del model.model.transformer[key]
# get `wte.new_embedding` parameters, which will be
# used to perform image token embbeding later on
self.token_embedding = model.model.transformer.wte
else:
self.vl_model = model
Expand Down Expand Up @@ -59,13 +57,20 @@ def build_model(self):
@torch.no_grad()
def forward(self,
images: List[Image],
params: List[Dict] = None) -> List[torch.Tensor]:
params: List[Dict] = None) -> List[Dict]:
"""forward the model with given input.
Args:
images (List): [None]
messages (List):
"""
images (List): [None] it is not used
params (List): the inputs after precessing GPT4V messages in
`MolmoChatTemplateWrapper`. Its format is like the following:
[[
{'role': 'user', 'content': 'user prompt'},
{'role': 'asssistant', 'content': 'assistant prompt'},
{'role': 'user', 'content': 'user prompt', 'images': [PIL image list]},
...
]]
""" # noqa

messages = params[0]
assert isinstance(messages, List)
Expand Down Expand Up @@ -113,10 +118,6 @@ def forward(self,
batch_idx = torch.tile(batch_idx[:, None],
[1, image_features.shape[1]])
image_features = image_features.to(embeddings.device)
# print(f'>> molmo forward image ...')
# print(f'image_features.shape: {image_features.shape}')
# print(f'image_input_idx.shape: {image_input_idx.shape}')
# print(f'batch_idx[valid]: {batch_idx[valid]}')
embeddings[batch_idx[valid],
image_input_idx[valid]] += image_features[valid]
assert embeddings.shape[:2] == (batch_size, seq_len)
Expand Down

0 comments on commit 8d8f8b9

Please sign in to comment.