Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Nov 5, 2024
1 parent e3c7e77 commit 2e1aea5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
5 changes: 1 addition & 4 deletions lmdeploy/serve/vl_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ async def _get_prompt_input(self,

if isinstance(self.vl_prompt_template,
MolmoChatTemplateWrapper):
results['input_ids'] = features[0]
results['input_embeddings'] = features[1]
results['input_embedding_range'] = features[2]
return results
return features[0]

features = [x.cpu().numpy() for x in features]
input_ids = []
Expand Down
26 changes: 16 additions & 10 deletions lmdeploy/vl/model/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def forward(self,
prompts = ''
for message in messages:
if 'images' in message.keys():
prompts += message['content']
# preprocess images. The output is a dict
inputs = self.processor.process(images=message['images'],
text=message['content'])
Expand Down Expand Up @@ -118,8 +119,8 @@ def forward(self,
# print(f'batch_idx[valid]: {batch_idx[valid]}')
embeddings[batch_idx[valid],
image_input_idx[valid]] += image_features[valid]
results.append(input_ids.flatten().tolist(),
embeddings.flatten())
assert embeddings.shape[:2] == (batch_size, seq_len)
results.append((input_ids.flatten().tolist(), embeddings))
else:
role = message['role']
content = message['content']
Expand All @@ -137,15 +138,20 @@ def forward(self,
prompts += prompt
# concat input_ids from results, calculate the range in the input_ids
# where embeddings will be copied to
# import pdb; pdb.set_trace()
input_ids = []
input_embeddings = []
input_embedding_ranges = []
for result in results:
input_ids += result[0]
if results[1] is not None:
input_embeddings.append(results[1])
start = len(input_ids)
end = start + result[1].shape[0]
start = 0
for _input_ids, _embeddings in results:
if _embeddings is not None:
input_embeddings.append(_embeddings.cpu())
end = start + len(_input_ids)
input_embedding_ranges.append((start, end))
return (prompts, input_ids, input_embeddings, input_embedding_ranges)
input_ids += _input_ids
start += len(_input_ids)
return [
dict(prompt=prompts,
input_ids=input_ids,
input_embeddings=input_embeddings,
input_embedding_ranges=input_embedding_ranges)
]

0 comments on commit 2e1aea5

Please sign in to comment.