Skip to content

Commit

Permalink
[multimodal] Support multi-image input for vision language models
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Sep 26, 2024
1 parent 8e8b75e commit 54b5c9d
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def get_tokenizer_inputs(self, image_token="<image>"):

prompt_text = '\n'.join(texts)
if len(images) > 0:
prompt_text = f"{image_token}\n{prompt_text}"
placeholders = "\n".join(f"<|image_{i}|>"
for i, _ in enumerate(images, start=1))
prompt_text = f"{placeholders}\n{prompt_text}"
return {
"role": self.role,
"content": prompt_text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from enum import Enum
from typing import Optional
from typing import Optional, Mapping

from pydantic import model_validator
from pydantic import model_validator, field_validator

from djl_python.properties_manager.properties import Properties

Expand Down Expand Up @@ -63,6 +63,8 @@ class LmiDistRbProperties(Properties):
cpu_offload_gb_per_gpu: Optional[int] = 0
enable_prefix_caching: Optional[bool] = False
disable_sliding_window: Optional[bool] = False
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[dict] = None

@model_validator(mode='after')
def validate_mpi(self):
Expand All @@ -78,3 +80,24 @@ def validate_speculative_and_lora(self):
f"Cannot enable lora and speculative decoding at the same time"
)
return self

@field_validator('limit_mm_per_prompt', mode="before")
def set_limit_mm_per_prompt(cls, val) -> Mapping[str, int]:
out_dict: Dict[str, int] = {}
for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
raise ValueError("Each item should be in the form key=value")
key, value = kv_parts

try:
parsed_value = int(value)
except ValueError as e:
raise ValueError(
f"Failed to parse value of item {key}={value}") from e

if key in out_dict and out_dict[key] != parsed_value:
raise ValueError(
f"Conflicting values specified for key: {key}")
out_dict[key] = parsed_value
return out_dict
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from enum import Enum
from typing import Optional, Any
from typing import Optional, Any, Mapping

from pydantic import field_validator

Expand Down Expand Up @@ -58,10 +58,33 @@ class VllmRbProperties(Properties):
cpu_offload_gb_per_gpu: Optional[int] = 0
enable_prefix_caching: Optional[bool] = False
disable_sliding_window: Optional[bool] = False
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[dict] = None

@field_validator('engine')
def validate_engine(cls, engine):
if engine != "Python":
raise AssertionError(
f"Need python engine to start vLLM RollingBatcher")
return engine
return engine

@field_validator('limit_mm_per_prompt', mode="before")
def set_limit_mm_per_prompt(cls, val) -> Mapping[str, int]:
out_dict: Dict[str, int] = {}
for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
raise ValueError("Each item should be in the form key=value")
key, value = kv_parts

try:
parsed_value = int(value)
except ValueError as e:
raise ValueError(
f"Failed to parse value of item {key}={value}") from e

if key in out_dict and out_dict[key] != parsed_value:
raise ValueError(
f"Conflicting values specified for key: {key}")
out_dict[key] = parsed_value
return out_dict
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
cpu_offload_gb=self.lmi_dist_config.cpu_offload_gb_per_gpu,
enable_prefix_caching=self.lmi_dist_config.enable_prefix_caching,
disable_sliding_window=self.lmi_dist_config.disable_sliding_window,
limit_mm_per_prompt=self.lmi_dist_config.limit_mm_per_prompt,
**engine_kwargs)

kwargs = {}
Expand Down Expand Up @@ -164,7 +165,7 @@ def inference(self, new_requests: List[Request]) -> List:
# step 0: register new requests to engine
for request in new_requests:
request_id = str(request.id)
llm_input = get_prompt_inputs(request)
llm_input = get_prompt_inputs(request, self.lmi_dist_config)
params = self.translate_lmi_dist_params(request.parameters)
request_params = RequestParams(**params)
lora_request_params = get_lora_request_params(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,21 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
)


def get_multi_modal_data(request: Request) -> dict:
def get_multi_modal_data(request: Request, config: VllmRbProperties) -> dict:
parameters = request.parameters
images = parameters.pop("images", None)
multi_modal_data = None
if images:
# vLLM only supports one image per request.
multi_modal_data = {"image": images[0]}
# The number of data instances allowed per modality is restricted by limit_mm_per_prompt
limit = config.limit_mm_per_prompt.get(
"image", 1) if config.limit_mm_per_prompt is not None else 1
multi_modal_data = {"image": images[:limit]}
return multi_modal_data


def get_prompt_inputs(request: Request):
def get_prompt_inputs(request: Request, config: VllmRbProperties):
prompt_inputs: PromptInputs = {"prompt": request.request_input.input_text}
multi_modal_data = get_multi_modal_data(request)
multi_modal_data = get_multi_modal_data(request, config)
if multi_modal_data:
prompt_inputs["multi_modal_data"] = multi_modal_data
return prompt_inputs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def inference(self, new_requests: List[Request]) -> List:
# step 0: register new requests to engine
for request in new_requests:
request_id = random_uuid()
prompt_inputs = get_prompt_inputs(request)
prompt_inputs = get_prompt_inputs(request, self.vllm_configs)
params = self.translate_vllm_params(request.parameters)
sampling_params = SamplingParams(**params)
request_params = get_lora_request_params(request, self.lora_ids)
Expand Down
14 changes: 11 additions & 3 deletions serving/docs/lmi/user_guides/vision_language_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ You can read more about the supported format in the [chat completions doc](chat_
## Deploying with LMI

Deploying Vision Language Models with LMI is very similar to deploying Text Generation Models.
There is an additional, optional config that is exposed, `option.image_placeholder_token` that we recommend you set.
This config specifies the image placeholder token, which is then used by the model's processor and tokenizer to determine where to place the image content in the prompt.
We recommend you set this value explicitly because it is challenging to determine from the model artifacts.

There are some additional, optional configs that are exposed:
* `option.image_placeholder_token`: Specifies the image placeholder token, which is then used by the model's processor and tokenizer to determine where to place the image content in the prompt. We recommend you set this value explicitly because it is challenging to determine from the model artifacts.
* `option.limit_mm_per_prompt`: For each multimodal plugin, limit how many input instances to allow for each prompt. Expects a comma-separated list of items, e.g.: `image=16,video=2` allows a maximum of 16 images and 2 videos per prompt. Defaults to 1 for each modality.

Example SageMaker deployment code:

Expand All @@ -34,6 +35,7 @@ model = DJLModel(
model_id="llava-hf/llava-v1.6-mistral-7b-hf",
env={
"OPTION_IMAGE_PLACEHOLDER_TOKEN": "<image>",
"OPTION_LIMIT_MM_PER_PROMPT": "image=2",
}
)

Expand All @@ -51,6 +53,12 @@ messages = {
"image_url": {
"url": "https://resources.djl.ai/images/dog_bike_car.jpg"
}
},
{
"type": "image_url",
"image_url": {
"url": "https://resources.djl.ai/images/kitten.jpg"
}
}
]
}
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,13 +1753,19 @@ def get_multimodal_prompt():
"image_url": {
"url": "https://resources.djl.ai/images/dog_bike_car.jpg",
}
}, {
"type": "image_url",
"image_url": {
"url": "https://resources.djl.ai/images/kitten.jpg",
}
}]
}]
return {
"messages": messages,
"temperature": 0.9,
"top_p": 0.6,
"max_new_tokens": 512,
"ignore_eos": True,
}


Expand Down
3 changes: 3 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,16 @@
},
"llava_v1.6-mistral": {
"option.model_id": "s3://djl-llm/llava-v1.6-mistral-7b-hf/",
"option.limit_mm_per_prompt": "image=2",
},
"paligemma-3b-mix-448": {
"option.model_id": "s3://djl-llm/paligemma-3b-mix-448/",
"option.tensor_parallel_degree": 1,
"option.limit_mm_per_prompt": "image=2",
},
"phi-3-vision-128k-instruct": {
"option.model_id": "s3://djl-llm/phi-3-vision-128k-instruct/",
"option.limit_mm_per_prompt": "image=2",
"option.trust_remote_code": True,
},
"llama-3.1-8b": {
Expand Down

0 comments on commit 54b5c9d

Please sign in to comment.