diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index dff19ddecb..c293cd71c8 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -102,10 +102,7 @@ async def _get_prompt_input(self, if isinstance(self.vl_prompt_template, MolmoChatTemplateWrapper): - results['input_ids'] = features[0] - results['input_embeddings'] = features[1] - results['input_embedding_range'] = features[2] - return results + return features[0] features = [x.cpu().numpy() for x in features] input_ids = [] diff --git a/lmdeploy/vl/model/molmo.py b/lmdeploy/vl/model/molmo.py index 0ca6c77943..91e69e6ada 100644 --- a/lmdeploy/vl/model/molmo.py +++ b/lmdeploy/vl/model/molmo.py @@ -74,6 +74,7 @@ def forward(self, prompts = '' for message in messages: if 'images' in message.keys(): + prompts += message['content'] # preprocess images. The output is a dict inputs = self.processor.process(images=message['images'], text=message['content']) @@ -118,8 +119,8 @@ def forward(self, # print(f'batch_idx[valid]: {batch_idx[valid]}') embeddings[batch_idx[valid], image_input_idx[valid]] += image_features[valid] - results.append(input_ids.flatten().tolist(), - embeddings.flatten()) + assert embeddings.shape[:2] == (batch_size, seq_len) + results.append((input_ids.flatten().tolist(), embeddings)) else: role = message['role'] content = message['content'] @@ -137,15 +138,20 @@ def forward(self, prompts += prompt # concat input_ids from results, calculate the range in the input_ids # where embeddings will be copied to - # import pdb; pdb.set_trace() input_ids = [] input_embeddings = [] input_embedding_ranges = [] - for result in results: - input_ids += result[0] - if results[1] is not None: - input_embeddings.append(results[1]) - start = len(input_ids) - end = start + result[1].shape[0] + start = 0 + for _input_ids, _embeddings in results: + if _embeddings is not None: + input_embeddings.append(_embeddings.cpu()) + end = start + len(_input_ids) input_embedding_ranges.append((start, end)) - return (prompts, input_ids, input_embeddings, input_embedding_ranges) + input_ids += _input_ids + start += len(_input_ids) + return [ + dict(prompt=prompts, + input_ids=input_ids, + input_embeddings=input_embeddings, + input_embedding_ranges=input_embedding_ranges) + ]