diff --git a/guidance/_parser.py b/guidance/_parser.py index e6a22671d..c2c7baa46 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -1,6 +1,7 @@ import json +import logging import os -from typing import Any, Generator, Optional, Tuple, Union +from typing import Any, Generator, Optional, Sequence, Tuple, Union import llguidance # type: ignore[import-untyped] import numpy as np @@ -12,6 +13,9 @@ from .models._tokenizer import Tokenizer +logger = logging.getLogger(__name__) + + class TokenParserException(Exception): pass @@ -30,29 +34,11 @@ class TokenParser: def __init__( self, - grammar: Union[GrammarFunction, str], - tokenizer: Tokenizer, - prompt: bytes = b"", - ensure_bos_token: bool = True, + ll_interpreter: llguidance.LLInterpreter, + prompt_tokens: list[int] ): - if isinstance(grammar, GrammarFunction): - # we can't have a terminal as the root - if isinstance(grammar, Terminal): - grammar = Join([grammar]) - serialized_grammar = json.dumps(grammar.ll_serialize()) - else: - serialized_grammar = grammar - - self.tokenizer = tokenizer - self.ll_tokenizer = llguidance.LLTokenizer( - llguidance.TokenizerWrapper(tokenizer) - ) - self.ll_interpreter = llguidance.LLInterpreter( - self.ll_tokenizer, - serialized_grammar, - log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), - ) - self._generator = self._parse(prompt, ensure_bos_token) + self.ll_interpreter = ll_interpreter + self._generator = self._parse(prompt_tokens) self._done = False def is_accepting(self) -> bool: @@ -70,28 +56,10 @@ def advance( self._done = True return None, e.value - def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]: - prompt_tokens = self.ll_interpreter.process_prompt( - self.tokenizer.encode(prompt) - ) - if ( - ensure_bos_token - and self.tokenizer.bos_token is not None - and prompt_tokens[:1] != [self.tokenizer.bos_token_id] - ): - # add the beginning of sequence token if needed - prompt_tokens = [self.tokenizer.bos_token_id] + prompt_tokens - - return self.tokenizer.recode(prompt_tokens) - - def _parse( self, - prompt: bytes, - ensure_bos_token: bool, + tokens: list[int], ) -> Generator[Tuple[Optional[GenData], EngineCallResponse], Optional[int], EngineCallResponse]: - tokens = self._process_prompt(prompt=prompt, ensure_bos_token=ensure_bos_token) - while True: mask, resp = self.ll_interpreter.mid_process() r = LLInterpreterResponse.model_validate_json(resp) @@ -133,6 +101,57 @@ def _parse( return response +def process_prompt(prompt_tokens: Sequence[int], ll_interpreter: llguidance.LLInterpreter, bos_token_id: Optional[int]=None) -> list[int]: + # Allows ll_interpreter to make adjustments to prompt tokens, such as token healing + processed_tokens = ll_interpreter.process_prompt(prompt_tokens) + if ( + bos_token_id is not None + and prompt_tokens[:1] != [bos_token_id] + ): + # add the beginning of sequence token if needed + processed_tokens = [bos_token_id] + processed_tokens + + return processed_tokens + + +def serialize_grammar(grammar: Union[GrammarFunction, str]) -> str: + if isinstance(grammar, GrammarFunction): + # we can't have a terminal as the root + if isinstance(grammar, Terminal): + grammar = Join([grammar]) + return json.dumps(grammar.ll_serialize()) + else: + return grammar + + +def create_token_parser( + grammar: Union[GrammarFunction, str], + tokenizer: Tokenizer, + prompt: bytes = b"", + ensure_bos_token: bool = True, + trace: bool = False +) -> TokenParser: + serialized_grammar = serialize_grammar(grammar) + ll_tokenizer = llguidance.LLTokenizer( + llguidance.TokenizerWrapper(tokenizer) + ) + ll_interpreter = llguidance.LLInterpreter( + ll_tokenizer, + serialized_grammar, + log_level=2 if trace else int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + if ensure_bos_token: + if tokenizer.bos_token_id is None: + logger.warning("Tokenizer does not have a BOS token, but ensure_bos_token is True") + bos_token_id = tokenizer.bos_token_id + else: + bos_token_id = None + prompt_tokens = tokenizer.encode(prompt) + processed_tokens = process_prompt(prompt_tokens, ll_interpreter, bos_token_id) + processed_tokens = tokenizer.recode(processed_tokens) + return TokenParser(ll_interpreter, processed_tokens) + + class ByteParserException(Exception): def __init__(self, *args, **kwargs): self.current_byte = kwargs.pop("current_byte", None) @@ -149,7 +168,7 @@ def __init__( ensure_bos_token: bool = True, ): self.tokenizer = ByteTokenizer() - self.token_parser = TokenParser(grammar, self.tokenizer, prompt, ensure_bos_token) + self.token_parser = create_token_parser(grammar, self.tokenizer, prompt, ensure_bos_token) self.bytes = b"" self.gen_data: Optional[GenData] = None self.pos = 0 @@ -289,3 +308,4 @@ def _update_capture(self, response: EngineCallResponse): pass self._variables[k] = v self._variables_log_probs[k] = response.capture_group_log_probs[k] + diff --git a/guidance/chat.py b/guidance/chat.py index 7b3bd37f8..bee3f3939 100644 --- a/guidance/chat.py +++ b/guidance/chat.py @@ -214,6 +214,9 @@ def get_role_end(self, role_name=None): phi3_medium_template = "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}" +# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/tokenizer_config.json#L397 +phi3_vision_template = "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}" + # Although the templates are different, the roles are the same between medium and small (for now) class Phi3SmallMediumChatTemplate(ChatTemplate): # available_roles = ["user", "assistant"] @@ -230,9 +233,24 @@ def get_role_start(self, role_name): def get_role_end(self, role_name=None): return "<|end|>\n" +class Phi3VisionChatTemplate(ChatTemplate): + template_str = phi3_vision_template + + def get_role_start(self, role_name): + if role_name == "user": + return "<|user|>\n" + elif role_name == "assistant": + return "<|assistant|>\n" + else: + raise UnsupportedRoleException(role_name, self) + + def get_role_end(self, role_name=None): + return "<|end|>\n" CHAT_TEMPLATE_CACHE[phi3_small_template] = Phi3SmallMediumChatTemplate CHAT_TEMPLATE_CACHE[phi3_medium_template] = Phi3SmallMediumChatTemplate +CHAT_TEMPLATE_CACHE[phi3_vision_template] = Phi3VisionChatTemplate + # -------------------------------------------------- # @@@@ Mistral-7B-Instruct-v0.2 @@@@ diff --git a/guidance/library/_image.py b/guidance/library/_image.py index fc1de7d1c..960a8236f 100644 --- a/guidance/library/_image.py +++ b/guidance/library/_image.py @@ -4,6 +4,8 @@ import typing import urllib +from guidance.models._model import Modality + from .._guidance import guidance @@ -29,9 +31,5 @@ def image(lm, src: typing.Union[str, pathlib.Path, bytes], allow_local: bool = T else: raise Exception(f"Unable to load image bytes from {src}!") - bytes_id = str(id(bytes_data)) - - # set the image bytes - lm = lm.set(bytes_id, bytes_data) - lm += f"<|_image:{bytes_id}|>" + lm = lm.append_multimodal(bytes_data, Modality.IMAGE) return lm diff --git a/guidance/models/__init__.py b/guidance/models/__init__.py index 32ad16361..c51349071 100644 --- a/guidance/models/__init__.py +++ b/guidance/models/__init__.py @@ -2,6 +2,7 @@ # local models from .transformers._transformers import Transformers, TransformersTokenizer +from .transformers._transformers_phi3v import TransformersPhi3Vision from .llama_cpp import LlamaCpp from ._mock import Mock, MockChat diff --git a/guidance/models/_grammarless.py b/guidance/models/_grammarless.py index 9a622d426..e978186a4 100644 --- a/guidance/models/_grammarless.py +++ b/guidance/models/_grammarless.py @@ -258,7 +258,7 @@ def _reset_shared_data(self, new_data: bytes, temperature: float): self._last_stream_start = self._data def get_next_token( - self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int: + self, prompt: str, token_ids: list[int], mask: Optional[bytes], temperature: float, media: Optional[dict]=None) -> int: logger.debug( f"Start Grammarless.get_next_token({token_ids=}, {mask=}, {temperature=})" diff --git a/guidance/models/_mock.py b/guidance/models/_mock.py index 0f1e48b41..7790ed86a 100644 --- a/guidance/models/_mock.py +++ b/guidance/models/_mock.py @@ -80,9 +80,9 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs, force): # seed the random number generator self._rand_generator = np.random.default_rng(seed=42) - def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int: + def get_next_token(self, prompt: bytes, token_ids: list[int], mask: Optional[bytes], temperature: float, media: Optional[dict]=None) -> int: self.called_temperatures.append(temperature) - return super().get_next_token(token_ids, mask, temperature) + return super().get_next_token(prompt, token_ids, mask, temperature, media) def get_logits(self, token_ids: list[int]) -> np.ndarray: """Pretends to compute the logits for the given token state.""" diff --git a/guidance/models/_model.py b/guidance/models/_model.py index 74f3c8af0..7a531d1ac 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -1,5 +1,7 @@ import base64 import copy +from dataclasses import dataclass +from enum import Enum import html import logging import queue @@ -26,7 +28,7 @@ from .._schema import EngineCallResponse, GuidanceEngineMetrics from .._utils import softmax, CaptureEvents -from .._parser import TokenParser +from .._parser import TokenParser, create_token_parser from .._grammar import ( GrammarFunction, string, @@ -49,7 +51,21 @@ r"<\|\|_#NODISP_\|\|>.*?<\|\|_/NODISP_\|\|>", flags=re.DOTALL ) html_pattern = re.compile(r"<\|\|_html:(.*?)_\|\|>", flags=re.DOTALL) -image_pattern = re.compile(r"<\|_image:(.*?)\|>") +image_pattern = re.compile(r"<\|_IMAGE:(.*?)\|>") + + +class Modality(Enum): + TEXT = 1 + IMAGE = 2 + AUDIO = 3 + VIDEO = 4 + IMAGE_URL = 5 + AUDIO_URL = 6 + VIDEO_URL = 7 + +modality_pattern = re.compile( + r"<\|_(" + "|".join(modality.name for modality in Modality) + r"):(.*?)\|>" +) class Engine: @@ -72,7 +88,7 @@ def get_chat_template(self): # TODO [HN]: Add more logic here...should we instan def reset_metrics(self): self.metrics = GuidanceEngineMetrics() - def start(self, prompt, grammar, ensure_bos_token=True) -> TokenParser: + def start(self, prompt, grammar, media: Optional[dict]=None, ensure_bos_token=True) -> TokenParser: """Start processing parser state executed through the grammar. Parameters @@ -107,14 +123,14 @@ def start(self, prompt, grammar, ensure_bos_token=True) -> TokenParser: else: raise Exception("The passed prompt is of an unknown type!") - return TokenParser( + return create_token_parser( grammar=grammar, tokenizer=self.tokenizer, prompt=prompt, ensure_bos_token=ensure_bos_token ) - def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCallResponse]: + def __call__(self, prompt, grammar, media: Optional[dict]=None, ensure_bos_token=True) -> Iterator[EngineCallResponse]: """Main entry point for the inference-parser loop. Yields EngineCallResponse objects as the parser advances through the grammar. @@ -128,8 +144,10 @@ def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCal inferencing the model. (TODO: implement full parser extension support) grammar: Grammar This is the grammar we are extending the prompt with. + media: dict + An optional dictionary mapping placeholder IDs in the prompt to multimodal data bytes. """ - parser = self.start(prompt, grammar, ensure_bos_token) + parser = self.start(prompt, grammar, media, ensure_bos_token) token = None while not parser.done(): @@ -141,16 +159,20 @@ def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCal # but we will treat any "illegal" tokens as EOS, allowing the model to finish gracefully. assert gen_data.mask[self.tokenizer.eos_token_id] token = self.get_next_token( + prompt, token_ids=gen_data.tokens, mask=None, + media=media, temperature=gen_data.temperature ) if not gen_data.mask[token]: token = self.tokenizer.eos_token_id else: token = self.get_next_token( + prompt, token_ids=gen_data.tokens, mask=gen_data.mask, + media=media, temperature=gen_data.temperature ) else: @@ -158,15 +180,15 @@ def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCal yield response - def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int: + def get_next_token(self, prompt: bytes, token_ids: list[int], mask: Optional[bytes], temperature: float, media: Optional[dict]=None) -> int: """Base implementation for getting the next token from the model which calls get_logits and sample_with_temperature. Subclasses may override this method, e.g. if they use external APIs that do not support getting logits directly. """ - logits = self.get_logits(token_ids) + logits = self.get_logits(prompt, token_ids, media) token = self.sample_with_temperature(logits, mask, temperature) return token - def get_logits(self, token_ids: list[int]) -> np.ndarray: + def get_logits(self, prompt: bytes, token_ids: list[int], media: Optional[dict]=None) -> np.ndarray: raise NotImplementedError def sample_with_temperature(self, logits: np.ndarray, mask: Optional[bytes], temperature: float) -> int: @@ -187,7 +209,6 @@ def _report_failed_match(self, prompt): + str(prompt[-40:]) ) - class Model: """The base guidance model object, which represents a model in a given state. @@ -383,6 +404,25 @@ def _current_prompt(self): """The current prompt in bytes (which is the state without the context close tags).""" return format_pattern.sub("", self._state) + + def _create_media_dict(self) -> dict: + """ + Find multimodal placeholders in the prompt string and create a dictionary + containing the multimodal data. + """ + media_dict = {} + + prompt = self._current_prompt() + for match in modality_pattern.finditer(prompt): + # Add the current match + content_key = match.group(2) + content = self.get(content_key) + if content is None: + raise KeyError(f"Model does not contain the multimodal data with id '{content_key}'") + media_dict[content_key] = content + + return media_dict + def __str__(self): """A string representation of the current model object (that includes context closers).""" out = self._current_prompt() @@ -604,6 +644,14 @@ def remove(self, key): else: copy = self return copy + + def append_multimodal(self, data, modality: Modality): + """ + Appends multimodal data to the model's state. + """ + copy = self.set(str(id(data)), data) + copy._inplace_append(f"<|_{modality.name}:{str(id(data))}|>") + return copy def log_prob(self, key, default=None): """Return the log prob of a variable, or a default value if the variable is not present. @@ -676,8 +724,11 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1): # replace ModelVariables with their actual values (note we save what we replaced so we can restore it later) replacements = replace_model_variables(stateless_function, self) + # get the multimodal data + media = self._create_media_dict() + # start the generation stream - gen_obj = self.engine(self._current_prompt(), stateless_function) + gen_obj = self.engine(self._current_prompt(), stateless_function, media) # we will return a new extended version of ourselves, which we track as `lm` lm = self diff --git a/guidance/models/transformers/_transformers.py b/guidance/models/transformers/_transformers.py index d454bda2d..58846855d 100644 --- a/guidance/models/transformers/_transformers.py +++ b/guidance/models/transformers/_transformers.py @@ -3,7 +3,7 @@ import textwrap import warnings -from typing import Sequence, Union +from typing import Optional, Sequence, Union try: import torch @@ -35,6 +35,18 @@ "trust_remote_code", ] +def load_transformers_model(model, **kwargs): + # intantiate the model if needed + if isinstance(model, str): + + # make sure transformers is installed + if not has_transformers: + raise Exception( + "Please install transformers with `pip install transformers` in order to use guidance.models.Transformers!" + ) + model = transformers_package.AutoModelForCausalLM.from_pretrained(model, **kwargs) + return model + class ByteDecoderError(Exception): pass @@ -67,7 +79,7 @@ def __init__( transformers_tokenizer, transformers_package.PreTrainedTokenizerFast ) assert is_ptt or is_ptt_fast - byte_tokens = self._byte_tokens(transformers_tokenizer) + byte_tokens = self._byte_tokens(transformers_tokenizer, **kwargs) self._orig_tokenizer = transformers_tokenizer @@ -98,7 +110,7 @@ def _tokenizer(self, model: str, **kwargs) -> tuple[ tokenizer = transformers_package.AutoTokenizer.from_pretrained( model, use_fast=False, **kwargs ) - byte_tokens = self._byte_tokens(tokenizer) + byte_tokens = self._byte_tokens(tokenizer, **kwargs) except ImportError: # Raise on ImportError because it's likely a missing dependency that the user can install raise @@ -126,6 +138,7 @@ def _byte_tokens( "transformers_package.PreTrainedTokenizer", "transformers_package.PreTrainedTokenizerFast", ], + **kwargs, ) -> list[bytes]: if hasattr(transformers_tokenizer, "byte_decoder"): @@ -144,6 +157,9 @@ def _byte_tokens( if hasattr(transformers_tokenizer, "sp_model"): return self._byte_tokens_from_sp_model(transformers_tokenizer) + if kwargs.get("sp_whitespace", False): + return self._byte_tokens_from_sp_whitespace(transformers_tokenizer) + try: return self._byte_tokens_by_encoding_token_strings(transformers_tokenizer) except ValueError as e: @@ -185,7 +201,7 @@ def _byte_tokens_from_sp_model( transformers_tokenizer: Union[ "transformers_package.PreTrainedTokenizer", "transformers_package.PreTrainedTokenizerFast", - ], + ] ) -> list[bytes]: byte_tokens = [b""] * len(transformers_tokenizer) special_tokens_map = { @@ -204,6 +220,24 @@ def _byte_tokens_from_sp_model( byte_tokens[i] = byte_coded.replace(space_prefix, b" ") return byte_tokens + def _byte_tokens_from_sp_whitespace( + self, + transformers_tokenizer: Union[ + "transformers_package.PreTrainedTokenizer", + "transformers_package.PreTrainedTokenizerFast", + ] + ) -> list[bytes]: + byte_tokens = [b""] * len(transformers_tokenizer) + if hasattr(transformers_tokenizer, "get_vocab"): + space_prefix = "▁".encode() + vocab = transformers_tokenizer.get_vocab() + for token, tok_id in vocab.items(): + byte_coded = token.encode() + byte_tokens[tok_id] = byte_coded.replace(space_prefix, b" ") + else: + raise ValueError("Tokenizer does not have a get_vocab method") + return byte_tokens + def _byte_tokens_by_encoding_token_strings( self, transformers_tokenizer: Union[ @@ -382,7 +416,7 @@ def __init__(self, model, tokenizer, compute_log_probs: bool, chat_template=None except: pass - self.model_obj = self._model(model, **kwargs) + self.model_obj = load_transformers_model(model, **kwargs) if not isinstance(model, str): self.model = model.__class__.__name__ @@ -415,19 +449,7 @@ def __init__(self, model, tokenizer, compute_log_probs: bool, chat_template=None compute_log_probs=compute_log_probs, ) - def _model(self, model, **kwargs): - # intantiate the model if needed - if isinstance(model, str): - - # make sure transformers is installed - if not has_transformers: - raise Exception( - "Please install transformers with `pip install transformers` in order to use guidance.models.Transformers!" - ) - model = transformers_package.AutoModelForCausalLM.from_pretrained(model, **kwargs) - return model - - def get_logits(self, token_ids): + def get_logits(self, prompt: bytes, token_ids: list[int], media: Optional[dict] = None): """Computes the logits for the given token state. This overrides a method from the LocalEngine class that is used to get @@ -537,3 +559,4 @@ def __init__( ), echo=echo, ) + diff --git a/guidance/models/transformers/_transformers_phi3v.py b/guidance/models/transformers/_transformers_phi3v.py new file mode 100644 index 000000000..ba61aaffd --- /dev/null +++ b/guidance/models/transformers/_transformers_phi3v.py @@ -0,0 +1,219 @@ +import logging +import io +import os +from typing import Optional + +try: + import torch +except ModuleNotFoundError: + pass + +import llguidance +from transformers import AutoModelForCausalLM, AutoProcessor + +from guidance._parser import TokenParser, process_grammar, process_prompt +from guidance.models._model import ( + Engine, + Model, + modality_pattern, + Modality +) +from guidance.models.transformers._transformers import TransformersTokenizer + +try: + from PIL import Image + has_pillow = True +except ModuleNotFoundError: + has_pillow = False + +logger = logging.getLogger(__name__) + + +class TransformersPhi3VisionEngine(Engine): + def __init__( + self, + model="microsoft/Phi-3-vision-128k-instruct", + compute_log_probs=False, + **kwargs, + ): + if not has_pillow: + raise Exception("Please install pillow with `pip install pillow` to use Phi 3 Vision") + self.model_name = model + # Initialize the underlying Phi 3 Vision model + self.model_obj = AutoModelForCausalLM.from_pretrained(model, **kwargs) + self.device = self.model_obj.device + + # Processor handles tokenization and image processing + self.processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True) + self.tokenizer = TransformersTokenizer(model, self.processor.tokenizer, sp_whitespace=True) + super().__init__(self.tokenizer, compute_log_probs) + + # Cache for past key values + self._past_key_values = None + self._cached_token_ids: list[int] = [] + + + def start(self, prompt: bytes, grammar, media: dict, ensure_bos_token=True) -> TokenParser: + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + elif isinstance(prompt, TokenParser): + raise NotImplementedError( + "Still need to implement support for extending a full Parser state." + ) + elif not isinstance(prompt, str): + raise Exception("The passed prompt is of an unknown type!") + + # Map Guidance placeholders to Phi 3 Vision format + # and make list of images for processing + images = [] + processed_prompt = prompt + matches = {} + for match in modality_pattern.finditer(prompt): + match_str = match.group(0) + modality_type = match.group(1) + if modality_type != Modality.IMAGE.name: + logger.debug("Skipping non-image modality: %s", match_str) + continue + media_id = match.group(2) + if match_str not in matches: + matches[match_str] = media_id + + image_counter = 1 + for match in matches.keys(): + processed_prompt = processed_prompt.replace( + match, f"<|image_{image_counter}|>" + ) + media_key = matches[match] + images.append(Image.open(io.BytesIO(media[media_key]))) + image_counter += 1 + logger.debug("Transformed prompt: %s -> ", prompt, processed_prompt) + + model_inputs = self.processor( + text=processed_prompt, + images=images if len(images) > 0 else None, + return_tensors="pt", + ).to(self.device) + + # We will reuse everything except input_ids (attention_mask, pixel_values, image_sizes) + self.model_inputs = model_inputs + + tokens = model_inputs["input_ids"][0].tolist() + + serialized_grammar = process_grammar(grammar) + ll_tokenizer = llguidance.LLTokenizer( + llguidance.TokenizerWrapper(self.tokenizer) + ) + ll_interpreter = llguidance.LLInterpreter( + ll_tokenizer, + serialized_grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + if ensure_bos_token and self.tokenizer.bos_token_id is not None: + bos_token_id = self.tokenizer.bos_token_id + else: + bos_token_id = None + + # Find the last multimodal (negative) token in the sequence, if any + # Note: Phi 3 vision uses a convention of negative tokens for multimodal inputs + # Do not assume other models will use this convention + last_multimodal_index = -1 + for i, token in enumerate(reversed(tokens)): + if token < 0: + last_multimodal_index = len(tokens) - i - 1 + break + + # Process tokens and grammar state machine beginning from the last multimodal token + if last_multimodal_index != -1: + processed_tokens = process_prompt(tokens[last_multimodal_index+1:], ll_interpreter, bos_token_id) + prompt_tokens = tokens[:last_multimodal_index+1] + processed_tokens + else: + prompt_tokens = process_prompt(tokens, ll_interpreter, bos_token_id) + + return TokenParser(ll_interpreter, prompt_tokens) + + + def get_logits(self, prompt: bytes, token_ids: list[int], media: Optional[dict]=None): + """Computes the logits for the given token state. + + This overrides a method from the LocalEngine class that is used to get + inference results from the model. + """ + + # make sure we don't run off the end of the model + if len(token_ids) >= getattr(self.model_obj.config, "max_position_embeddings", 1e10): + raise Exception( + f"Attempted to run a transformers model past its maximum context window size of {self.model_obj.config.max_position_embeddings}!" + ) + + # get the number of cache positions we are using + cache_token_ids = self._cached_token_ids + num_cached = 0 + for id in cache_token_ids: + if ( + num_cached >= len(cache_token_ids) + or num_cached >= len(token_ids) + or token_ids[num_cached] != id + ): + break + num_cached += 1 + + # reset the cache length according to that number of positions + past_key_values = self._past_key_values + past_length = past_key_values[0][0].size(-2) if past_key_values is not None else 0 + if past_length > num_cached: + # note we recompute the last token because we don't bother to handle the special case of just computing logits + past_length = max(0, num_cached - 1) + self._past_key_values = tuple( + tuple(p[..., :past_length, :] for p in v) for v in past_key_values + ) + cache_token_ids[past_length:] = [] + + # call the model + new_token_ids = token_ids[past_length:] + if len(new_token_ids) > 0: + self.model_inputs["input_ids"] = torch.tensor(new_token_ids).unsqueeze(0).to(self.device) + self.model_inputs["attention_mask"]=torch.ones(1, past_length + len(new_token_ids)).to(self.device) + position_ids=torch.arange(past_length, past_length + len(new_token_ids)).unsqueeze(0).to(self.device) + with torch.no_grad(): + model_out = self.model_obj( + **self.model_inputs, + position_ids=position_ids, + past_key_values=self._past_key_values, + use_cache=True, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + # save the results + self._past_key_values = model_out.past_key_values + cache_token_ids.extend(new_token_ids) + # Need to add special truncating logic here for weird models that have a different output size than tokenizer vocab + self._cached_logits = ( + model_out.logits[0, -1, : len(self.tokenizer.tokens)].cpu().numpy() + ) + self.metrics.engine_input_tokens += len(token_ids) + self.metrics.engine_output_tokens += 1 + + return self._cached_logits + + +class TransformersPhi3Vision(Model): + def __init__( + self, + model=None, + echo=True, + compute_log_probs=False, + **kwargs, + ): + """Build a new TransformersPhi3Model object.""" + if model is None or len(model) == 0: + model = "microsoft/Phi-3-vision-128k-instruct" + super().__init__( + TransformersPhi3VisionEngine( + model, + compute_log_probs, + **kwargs, + ), + echo=echo, + ) \ No newline at end of file diff --git a/setup.py b/setup.py index 72c23d468..6a0965832 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ "openai": ["openai>=1.0"], "schemas": ["jsonschema"], "server": ["fastapi-slim", "uvicorn"], + "image": ["pillow"] } # Create the union of all our requirements diff --git a/tests/model_specific/test_transformers_phi3v.py b/tests/model_specific/test_transformers_phi3v.py new file mode 100644 index 000000000..38cffdced --- /dev/null +++ b/tests/model_specific/test_transformers_phi3v.py @@ -0,0 +1,173 @@ +import pytest +import json + +from guidance import models, gen, select, image +from guidance._grammar import string + +PHI_3_VISION_MODEL = "microsoft/Phi-3-vision-128k-instruct" + +# TODO - tests with regular guidance grammars, no images + +@pytest.fixture(scope="module") +def phi3_vision_model(): + """Load the TransformersPhi3Model with the specified model ID.""" + try: + model_kwargs = { + # "_attn_implementation": "eager", # Uncomment this line if flash attention is not working + "trust_remote_code": True, + } + model = models.TransformersPhi3Vision( + model=PHI_3_VISION_MODEL, **model_kwargs + ) + return model + except ImportError as e: + pytest.skip(f"Error importing Phi 3 vision model: {e}") + + +def test_image_loading(phi3_vision_model: models.TransformersPhi3Vision): + """Test basic image loading and placeholder replacement in the prompt.""" + image_url = "https://picsum.photos/200/300" + lm = phi3_vision_model + "This is a test with an image: " + image(image_url) + + # Verify that the image placeholder is correctly inserted + assert "<|image:id|>" in lm._state, f"Hidden state: {lm._state}" + + +def test_basic_generation_with_image( + phi3_vision_model: models.TransformersPhi3Vision, +): + """Test unconstrained generation with an image.""" + image_url = "https://picsum.photos/200/300" + lm = phi3_vision_model + "Describe this image: " + image(image_url) + lm += gen(name="description", max_tokens=10) + + # Verify that the model generated some text + assert len(lm["description"]) > 0 + + +def test_select_with_image(phi3_vision_model: models.TransformersPhi3Vision): + """Test constraint enforcement with select and an image.""" + image_url = "https://picsum.photos/200/300" + lm = phi3_vision_model + "Is this a photo of a cat or a dog: " + image(image_url) + lm += select(["cat", "dog"], name="answer") + + # Verify that the model selected one of the options + assert lm["answer"] in ["cat", "dog"] + + +def test_llguidance_interaction(phi3_vision_model: models.TransformersPhi3Vision): + """Test that llguidance correctly enforces a simple grammar with an image.""" + image_url = "https://picsum.photos/200/300" + lm = phi3_vision_model + "The color of the image is: " + image(image_url) + + # Define a grammar that only allows the words "red", "green", or "blue" + grammar = string("red") | string("green") | string("blue") + + lm += grammar + + # Verify that the generated text is one of the allowed colors + assert str(lm).endswith(("red", "green", "blue")) + + +def test_multiple_images(phi3_vision_model: models.TransformersPhi3Vision): + """Test cache invalidation with multiple images.""" + image_url_1 = "https://picsum.photos/200/300" + image_url_2 = "https://picsum.photos/300/200" + lm = phi3_vision_model + "Image 1: " + image(image_url_1) + ". Image 2: " + image(image_url_2) + lm += gen(name="description", max_tokens=10) + + # Add assertions to verify cache behavior and output (e.g., token count, presence of image tokens) + assert lm.engine._last_image_token_position > 0 + assert len(lm["description"]) > 0 + + +def test_empty_image_token(phi3_vision_model: models.TransformersPhi3Vision): + """Test handling of an image token without corresponding image data.""" + with pytest.raises(KeyError) as exc_info: + lm = ( + phi3_vision_model + + "This is a test with a missing image: " + + image("https://picsum.photos/200/300", id="missing_image") + ) + lm += gen(name="description", max_tokens=10) + # ... (Add assertions to check for expected behavior, e.g., error or default embedding generation) + assert "Model does not contain the multimodal data with id" in str(exc_info.value) + + +def test_invalid_image_url(phi3_vision_model: models.TransformersPhi3Vision): + """Test handling of an invalid image URL.""" + with pytest.raises(Exception) as exc_info: + lm = ( + phi3_vision_model + + "This is a test with an invalid image: " + + image("https://invalid-image-url.com/image.jpg") + ) + lm += gen(name="description", max_tokens=10) + # ... (Add assertions to check for expected error handling) + assert "Unable to load image bytes" in str(exc_info.value) + + +def test_complex_grammar(phi3_vision_model: models.TransformersPhi3Vision): + """Test constraint enforcement with a more complex grammar.""" + image_url = "https://picsum.photos/200/300" + lm = phi3_vision_model + "Describe this image: " + image(image_url) + + # Define a more complex grammar, potentially involving recursion or nested structures + grammar = string("This is") | string("There is") + (string(" a ") | string(" an ")) + gen( + name="object" + ) + string(" in the image.") + + lm += grammar + + # Add assertions to verify that llguidance correctly enforces the grammar + assert str(lm).startswith(("This is", "There is")) + assert str(lm).endswith(" in the image.") + + +def test_token_alignment(phi3_vision_model: models.TransformersPhi3Vision): + """Test that token alignment is maintained correctly.""" + image_url_1 = "https://picsum.photos/200/300" + image_url_2 = "https://picsum.photos/300/200" + lm = ( + phi3_vision_model + + "Describe these images: " + + image(image_url_1) + + " and " + + image(image_url_2) + ) + + # Force the model to generate a specific token sequence + grammar = ( + string("Image 1 is ") + + gen(name="color1", regex=r"[a-z]+") + + string(". ") + + string("Image 2 is ") + + gen(name="color2", regex=r"[a-z]+") + + string(".") + ) + + lm += grammar + + # Add assertions to specifically check the token sequence and verify alignment + # For example, check that "color1" and "color2" are generated in the correct positions + # relative to the image tokens. + + +def test_token_count_accuracy(phi3_vision_model: models.TransformersPhi3Vision): + """Test the accuracy of the token count.""" + image_url = "https://picsum.photos/200/300" + lm = phi3_vision_model + "Describe this image: " + image(image_url) + lm += gen(name="description", max_tokens=10) + + # Assert that the reported token count matches the expected number of text and image tokens + # You'll need to calculate the expected number of tokens based on the prompt, image size, and generated text. + + +def test_streaming_behavior(phi3_vision_model: models.TransformersPhi3Vision): + """Test the streaming functionality with images.""" + image_url = "https://picsum.photos/200/300" + lm = phi3_vision_model + "Describe this image: " + image(image_url) + lm += gen(name="description", max_tokens=10) + + # Iterate over the model stream and add assertions to check partial outputs and token counts + # For example, verify that the token count increases with each iteration and that the generated text is accumulated correctly.