Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Sep 27, 2024
1 parent 478968c commit df0625f
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from collections import defaultdict
from typing import Optional, Union, List, Dict, Any, Tuple
from pydantic import BaseModel, Field, field_validator, ValidationInfo, ConfigDict
from PIL.Image import Image
Expand Down Expand Up @@ -70,7 +71,7 @@ def validate_content(
transformed_content.append(ImageInput(url))
return transformed_content

def get_tokenizer_inputs(self, image_token="<image>"):
def get_tokenizer_inputs(self, image_token=lambda: "<image>"):
texts = []
images = []
for content in self.content:
Expand All @@ -81,9 +82,8 @@ def get_tokenizer_inputs(self, image_token="<image>"):

prompt_text = '\n'.join(texts)
if len(images) > 0:
placeholders = "\n".join(f"<|image_{i}|>"
for i, _ in enumerate(images, start=1))
prompt_text = f"{placeholders}\n{prompt_text}"
prompt_text = self._get_full_multimodal_text_prompt(
images, prompt_text, image_token)
return {
"role": self.role,
"content": prompt_text,
Expand All @@ -92,6 +92,37 @@ def get_tokenizer_inputs(self, image_token="<image>"):
def get_images(self) -> List[Image]:
return [i.image for i in self.content if isinstance(i, ImageInput)]

def _get_full_multimodal_text_prompt(self, parts: List[str],
text_prompt: str,
image_token) -> str:
"""
Combine multimodal prompts for a multimodal language model.
Inspired by https://github.com/vllm-project/vllm/blob/v0.6.2/vllm/entrypoints/chat_utils.py#L334
"""
placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
for i in range(1, len(parts) + 1):
placeholder = image_token(i)
if placeholder:
placeholder_counts[placeholder] += 1

# Look through the text prompt to check for missing placeholders
missing_placeholders: List[str] = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder)

if placeholder_counts[placeholder] < 0:
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")

missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])

# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])


class ChatProperties(BaseModel):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def parse_chat_completions_request(
tokenizer,
image_token: Optional[str] = None,
configs: Properties = None,
model_config=None,
):
# Chat completions can either be a rolling batch or no-batching .
if not (is_rolling_batch or configs.batch_size == 1):
Expand All @@ -47,7 +48,8 @@ def parse_chat_completions_request(
tokenizer_inputs = []
for message in messages:
tokenizer_inputs.append(
message.get_tokenizer_inputs(image_token=image_token))
message.get_tokenizer_inputs(image_token=image_token,
model_config=model_config))
images.extend(message.get_images())
inputs = apply_chat_template(tokenizer, tokenizer_inputs)
param[
Expand Down
25 changes: 19 additions & 6 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _read_model_config(self, model_config_path: str):

def get_image_token(self):
if self.hf_configs.image_placeholder_token:
return self.hf_configs.image_placeholder_token
return lambda: self.hf_configs.image_placeholder_token

logging.warning(
"image_placeholder_token is not explicitly set. It is highly recommended to explicitly"
Expand All @@ -474,15 +474,28 @@ def get_image_token(self):
# This is less than ideal, but until there is a good way to obtain this from the tokenizer/model, it's the best way to do so
model_type = self.model_config.model_type
if model_type == "phi3_v":
# phi3_v does support multiple images, but vllm/lmi-dist can only support 1 per request
return "<|image_1|>"
if model_type in {"llava", "llava_next", "paligemma"}:
return "<image>"
return lambda i: f"<|image_{i}|>"
if model_type == "minicpmv":
return lambda: "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"):
# These models do not use image tokens in the prompt
return lambda: None
if model_type == "qwen":
return lambda i: f"Picture {i}: <img></img>"
if model_type.startswith("llava"):
return lambda: "<image>"
if model_type in ("chameleon", "internvl_chat"):
return lambda: "<image>"
if model_type == "mllama":
return lambda: "<|image|>"
if model_type == "qwen2_vl":
return lambda: "<|vision_start|><|image_pad|><|vision_end|>"

logging.warning(
"could not infer image token from the model artifacts. Using <image> as default."
)
return "<image>"
return lambda: "<image>"


_service = HuggingFaceService()
Expand Down
4 changes: 3 additions & 1 deletion engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
tokenizer = kwargs.get("tokenizer")
image_token = kwargs.get("image_placeholder_token")
configs = kwargs.get("configs")
model_config = kwargs.get("model_config")
is_bedrock = False
if configs is not None:
is_bedrock = configs.bedrock_compat
Expand All @@ -138,7 +139,8 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs)
configs=configs,
model_config=model_config)
elif is_bedrock:
inputs, param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class LmiDistRbProperties(Properties):
enable_prefix_caching: Optional[bool] = False
disable_sliding_window: Optional[bool] = False
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[dict] = None

@model_validator(mode='after')
def validate_mpi(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class VllmRbProperties(Properties):
enable_prefix_caching: Optional[bool] = False
disable_sliding_window: Optional[bool] = False
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[dict] = None

@field_validator('engine')
def validate_engine(cls, engine):
Expand Down
1 change: 0 additions & 1 deletion tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,7 +1765,6 @@ def get_multimodal_prompt():
"temperature": 0.9,
"top_p": 0.6,
"max_new_tokens": 512,
"ignore_eos": True,
}


Expand Down
1 change: 0 additions & 1 deletion tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@
"paligemma-3b-mix-448": {
"option.model_id": "s3://djl-llm/paligemma-3b-mix-448/",
"option.tensor_parallel_degree": 1,
"option.limit_mm_per_prompt": "image=2",
},
"phi-3-vision-128k-instruct": {
"option.model_id": "s3://djl-llm/phi-3-vision-128k-instruct/",
Expand Down

0 comments on commit df0625f

Please sign in to comment.