Skip to content

Commit

Permalink
vision embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Nov 5, 2024
1 parent d548056 commit e3c7e77
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 79 deletions.
148 changes: 101 additions & 47 deletions lmdeploy/vl/model/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from typing import Dict, List

import numpy as np
import torch
from PIL.Image import Image
from transformers import AutoModel, AutoProcessor
from transformers import AutoModelForCausalLM, AutoProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
Expand All @@ -22,21 +21,21 @@ class MolmoVisionModel(VisonModel):

def build_model(self):
"""Load model."""
from accelerate import init_empty_weights
# import pdb; pdb.set_trace()
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
with init_empty_weights():
config = self.hf_config
model = AutoModel.from_config(config, trust_remote_code=True)
model = AutoModelForCausalLM.from_config(config,
trust_remote_code=True)
if not self.with_llm:
for key in ['emb_drop', 'ln_f', 'blocks']:
for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']:
del model.model.transformer[key]
# get `wte.new_embedding` parameters, which will be
# used to perform image token embbeding later on
self.token_embedding = model.model.transformer.wte
else:
self.vl_model = model
model.half()

from accelerate import load_checkpoint_and_dispatch
with disable_logging():
load_checkpoint_and_dispatch(
model=model,
Expand All @@ -45,53 +44,108 @@ def build_model(self):
max_memory=self.max_memory,
no_split_module_classes=[
'ResidualAttentionBlock', 'Embedding'
],
dtype=torch.half)
])

# We need eval mode to freeze the weights in model, thus,
# avoid randomness in inference.
self.model = model.eval()
self.config = config
# TODO: get embedding model

processor = AutoProcessor.from_pretrained(self.model_path,
trust_remote_code=True,
torch_dtype='auto',
device_map='auto')
self.image_processor = processor.image_processor

def preprocess(self, images: List[Image], params: List[Dict] = None):
images = [np.array(x.convert('RGB')) for x in images]
image_idx = [-1] * len(images)

DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
DEFAULT_IM_COL_TOKEN = '<im_col>'

image_patch_token_id = self.image_processor.special_token_ids[
DEFAULT_IMAGE_PATCH_TOKEN]
image_col_token_id = self.image_processor.special_token_ids[
DEFAULT_IM_COL_TOKEN]
image_start_token_id = self.image_processor.special_token_ids[
DEFAULT_IM_START_TOKEN]
image_end_token_id = self.image_processor.special_token_ids[
DEFAULT_IM_END_TOKEN]
out = self.image_processor.multimodal_preprocess(
images=images,
image_idx=image_idx,
tokens=np.asarray([]).astype(np.int32),
sequence_length=0, # unused parameter
image_patch_token_id=image_patch_token_id,
image_col_token_id=image_col_token_id,
image_start_token_id=image_start_token_id,
image_end_token_id=image_end_token_id,
)
return out

self.processor = AutoProcessor.from_pretrained(self.model_path,
trust_remote_code=True,
torch_dtype='auto',
device_map='auto')

@torch.no_grad()
def forward(self,
images: List[Image],
params: List[Dict] = None) -> List[torch.Tensor]:
self.preprocess(images)
# return self._forward_func(images, params)
"""forward the model with given input.
Args:
images (List): [None]
messages (List):
"""

messages = params[0]
assert isinstance(messages, List)

results = []
prompts = ''
for message in messages:
if 'images' in message.keys():
# preprocess images. The output is a dict
inputs = self.processor.process(images=message['images'],
text=message['content'])
inputs = {
k: v.to(self.model.device).unsqueeze(0)
for k, v in inputs.items()
}
input_ids = inputs['input_ids']
images = inputs[
'images'] # (batch_size, num_image, num_patch, d_model)
image_input_idx = inputs[
'image_input_idx'] # (batch_size, num_image, num_patch)
image_masks = inputs['image_masks']
batch_size, seq_len = input_ids.size()
assert batch_size == 1

# Get embeddings of input.
if input_ids is not None:
input_ids = input_ids * (input_ids != -1).to(
input_ids.dtype)
embeddings = self.model.model.transformer.wte(input_ids)
image_features, _ = self.model.model.vision_backbone(
images, image_masks)
num_image, num_patch = image_features.shape[1:3]
assert image_input_idx.shape == (batch_size, num_image,
num_patch)

# insert the image feature into the embedding.
image_features = image_features.view(batch_size,
num_image * num_patch, -1)
image_input_idx = image_input_idx.view(batch_size,
num_image * num_patch)

valid = image_input_idx >= 0
batch_idx = torch.arange(batch_size, device=embeddings.device)
batch_idx = torch.tile(batch_idx[:, None],
[1, image_features.shape[1]])
image_features = image_features.to(embeddings.device)
# print(f'>> molmo forward image ...')
# print(f'image_features.shape: {image_features.shape}')
# print(f'image_input_idx.shape: {image_input_idx.shape}')
# print(f'batch_idx[valid]: {batch_idx[valid]}')
embeddings[batch_idx[valid],
image_input_idx[valid]] += image_features[valid]
results.append(input_ids.flatten().tolist(),
embeddings.flatten())
else:
role = message['role']
content = message['content']
assert isinstance(content, str)
prompt = ''
if role == 'user':
prompt = f'User: {content} '
elif role == 'assistant':
prompt = f'Assistant:{content}'
else:
assert 0, f'molmo does not support role {role}, message is {message}' # noqa
input_ids = self.processor.tokenizer.encode(
prompt, add_special_tokens=False)
results.append((input_ids, None))
prompts += prompt
# concat input_ids from results, calculate the range in the input_ids
# where embeddings will be copied to
# import pdb; pdb.set_trace()
input_ids = []
input_embeddings = []
input_embedding_ranges = []
for result in results:
input_ids += result[0]
if results[1] is not None:
input_embeddings.append(results[1])
start = len(input_ids)
end = start + result[1].shape[0]
input_embedding_ranges.append((start, end))
return (prompts, input_ids, input_embeddings, input_embedding_ranges)
63 changes: 31 additions & 32 deletions lmdeploy/vl/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,58 +431,57 @@ class GLM4VChatTemplateWrapper(VLChatTemplateWrapper):
class MolmoChatTemplateWrapper(VLChatTemplateWrapper):

async def async_collect_pil_images(
self, messages: Dict) -> List[Tuple[PIL.Image.Image, Dict]]:
self, messages: List[Dict]) -> List[Tuple[PIL.Image.Image, Dict]]:
"""collect images from messages.
Args:
messages (Dict): a user request of GPT4V message format
messages (List[Dict]): a user request of GPT4V message format
"""
images_with_kwargs = []
for message in messages:
role = message['role']
content = message['content']
if isinstance(messages, Dict):
messages = [messages]
assert isinstance(messages, List)

out_messages = []

def _inner_call(i, in_messages, out_messages):
role = in_messages[i]['role']
content = in_messages[i]['content']
if role != 'user' or isinstance(content, str):
# means message is user's prompt input or assistant's prompt
images_with_kwargs.append([None, message])

# If the role is a user and the content is not a string, it
# indicates the message is composed of ONE user prompt and a list
# of images
image_prompt = [
item['content'] for item in content if item['type'] == 'text'
]
if len(image_prompt) != 0:
raise RuntimeError(f'invalid format {message}')
images_with_kwargs.append('image_prompt', image_prompt)
# means message is user's prompt input or assistant's prompt,
# returning it directory
out_messages.append(in_messages[i])
return
# the role is a user and the content is a list
assert isinstance(content, List)
message = dict(role=role, content='', images=[])
for item in content:
# 'image_url': means url or local path to image.
# 'image_data': means PIL.Image.Image object.
if item['type'] == 'image_url':
item_copy = item['image_url'].copy()
try:
url = item_copy.pop('url')
images_with_kwargs.append([url, item_copy])
image = load_image(item['image_url']['url'])
message['images'].append(image)
except KeyError:
logger.error(f'invalid format {message}')
elif item['type'] == 'image_data':
item_copy = item['image_data'].copy()
try:
data = item_copy.pop('data')
images_with_kwargs.append([data, item_copy])
image = load_image(item['image_data']['data'])
message['images'].append(image)
except KeyError:
logger.error(f'invalid format {message}')

def _inner_call(i, images):
url_or_data = images[i][0]
images[i][0] = load_image(url_or_data)
elif item['type'] == 'text':
message['content'] = item['text']
else:
logger.error(f'unexpected content type {message}')
out_messages.append(message)

await asyncio.gather(*[
asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
images_with_kwargs)
for i in range(len(images_with_kwargs))
messages, out_messages)
for i in range(len(messages))
])

return images_with_kwargs
#
return [(None, out_messages)]

def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str:
"""Return a placeholder "IMAGE_TOKEN" so that
Expand Down

0 comments on commit e3c7e77

Please sign in to comment.