Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Structured model state #929

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 0 additions & 4 deletions guidance/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,10 +966,6 @@ def _re_with_temperature(grammar, temperature, visited_set):
# return ModelVariable(name)


def active_role_end() -> ModelVariable:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of active_role_end will actually produce the closing tag twice since it is essentially a stateless grammar and can't pop ContextBlocks off of a model. Furthermore, generating active_role_end requires the model to produce special tags which will be hidden from future prompts... A better implementation of this may exist, but I'm not quite sure yet. Deleting it for now

return ModelVariable("active_role_end")


def eos_token() -> ModelVariable:
return ModelVariable("eos_token")

Expand Down
4 changes: 2 additions & 2 deletions guidance/library/_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._any_char import any_char
from .._grammar import capture
from ._regex import regex as regex_grammar
from .._grammar import token_limit, eos_token, active_role_end, with_temperature
from .._grammar import token_limit, eos_token, with_temperature
from ._tool import Tool
from ._block import block

Expand Down Expand Up @@ -129,7 +129,7 @@ def gen(
if isinstance(stop, str):
stop = [stop]
if regex is None:
stop = stop + [select([eos_token(), active_role_end()])]
stop = stop + [eos_token()]

if stop_regex is None:
stop_regex = []
Expand Down
54 changes: 10 additions & 44 deletions guidance/library/_role.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
from .._guidance import guidance
from ._block import block
from ._block import ContextBlock
from ._set_attribute import set_attribute

nodisp_start = "<||_#NODISP_||>"
nodisp_end = "<||_/NODISP_||>"
span_start = "<||_html:<span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>_||>"
span_end = "<||_html:</span>_||>"

@guidance
def role_opener(lm, role_name, **kwargs):
indent = getattr(lm, "indent_roles", True)
class RoleBlock(ContextBlock):
def __init__(self, role_name, opener, closer, name=None):
super().__init__(opener, closer, name=name)
self.role_name = role_name


# Block start container (centers elements)
if indent:
lm += f"<||_html:<div style='display: flex; border-bottom: 1px solid rgba(127, 127, 127, 0.2); justify-content: center; align-items: center;'><div style='flex: 0 0 80px; opacity: 0.5;'>{role_name.lower()}</div><div style='flex-grow: 1; padding: 5px; padding-top: 10px; padding-bottom: 10px; margin-top: 0px; white-space: pre-wrap; margin-bottom: 0px;'>_||>"

# Start of either debug or HTML no disp block
if indent:
lm += nodisp_start
else:
lm += span_start

@guidance
def role_opener(lm, role_name, **kwargs):
# TODO [HN]: Temporary change while I instrument chat_template in transformers only.
# Eventually have all models use chat_template.
if hasattr(lm, "get_role_start"):
Expand All @@ -32,48 +21,25 @@ def role_opener(lm, role_name, **kwargs):
raise Exception(
f"You need to use a chat model in order the use role blocks like `with {role_name}():`! Perhaps you meant to use the {type(lm).__name__}Chat class?"
)

# End of either debug or HTML no disp block
if indent:
lm += nodisp_end
else:
lm += span_end

return lm


@guidance
def role_closer(lm, role_name, **kwargs):
indent = getattr(lm, "indent_roles", True)
# Start of either debug or HTML no disp block
if indent:
lm += nodisp_start
else:
lm += span_start

# TODO [HN]: Temporary change while I instrument chat_template in transformers only.
# Eventually have all models use chat_template.
if hasattr(lm, "get_role_end"):
lm += lm.get_role_end(role_name)
elif hasattr(lm, "chat_template"):
lm += lm.chat_template.get_role_end(role_name)

# End of either debug or HTML no disp block
if indent:
lm += nodisp_end
else:
lm += span_end

# End of top container
if indent:
lm += "<||_html:</div></div>_||>"

return lm


# TODO HN: Add a docstring to better describe arbitrary role functions
def role(role_name, text=None, **kwargs):
if text is None:
return block(
return RoleBlock(
role_name=role_name,
opener=role_opener(role_name, **kwargs),
closer=role_closer(role_name, **kwargs),
)
Expand Down
151 changes: 77 additions & 74 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


from pprint import pprint
from typing import Dict, TYPE_CHECKING
from typing import Dict, Optional, TYPE_CHECKING


import numpy as np
Expand All @@ -38,6 +38,8 @@
"Failed to load guidance.cpp, falling back to Python mirror implementations..."
)
from .. import _cpp as cpp

from ._model_state import ModelState, Text, Object, RoleOpener, RoleCloser
from ._guidance_engine_metrics import GuidanceEngineMetrics
from .._utils import softmax, CaptureEvents
from .._parser import EarleyCommitParser, Parser
Expand All @@ -62,14 +64,6 @@

# define some constants we will reuse many times
_null_grammar = string("")
format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL)
nodisp_pattern = re.compile(
r"&lt;\|\|_#NODISP_\|\|&gt;.*?&lt;\|\|_/NODISP_\|\|&gt;", flags=re.DOTALL
)
html_pattern = re.compile(r"&lt;\|\|_html:(.*?)_\|\|&gt;", flags=re.DOTALL)
image_pattern = re.compile(r"&lt;\|_image:(.*?)\|&gt;")




class EngineCallResponse:
Expand Down Expand Up @@ -850,57 +844,19 @@ def __init__(self, engine, echo=True, **kwargs):
self.echo = echo
self.token_count = 0 # tracks how many tokens our byte state represents
self.max_display_rate = 0.2 # this controls how frequently we are allowed to redraw the display (in seconds)
self.opened_blocks = {} # what context blocks have been opened but not closed
self.opened_blocks: dict["ContextBlock", tuple[int, Optional(Object)]] = {} # what context blocks have been opened but not closed
# self.compute_log_probs = compute_log_probs

# private attributes
self._variables = {} # these are the state variables stored with the model
self._variables_log_probs = {} # these are the state variables stored with the model
self._cache_state = {} # mutable caching state used to save computation
self._state = "" # the current bytes that represent the state of the model
self._state = ModelState() # the current bytes that represent the state of the model
self._event_queue = None # TODO: these are for streaming results in code, but that needs implemented
self._event_parent = None
self._last_display = 0 # used to track the last display call to enable throttling
self._last_event_stream = 0 # used to track the last event streaming call to enable throttling

@property
def active_role_end(self):
"""The default end patterns we should use for `gen` calls.
TODO: move this logic into the gen call...we can do with if we allow model_variables to run functions.

These patterns are computed dynamically by the model object because they can depend on
what the current open roles are, which is something
"""

# add any active non-empty role ends. Ignore role ends that are spaces
parts = []
for _, role_end_str in self.opened_blocks.values():
role_end_str = format_pattern.sub("", role_end_str)
if len(role_end_str) > 0 and not re.fullmatch(r"\s+", role_end_str):
parts.append(role_end_str)

return select(parts)

def _html(self):
"""Generate HTML that displays the model object."""
display_out = self._state
for context in reversed(self.opened_blocks):
display_out += self.opened_blocks[context][1]
display_out = html.escape(display_out)
display_out = nodisp_pattern.sub("", display_out)
display_out = html_pattern.sub(lambda x: html.unescape(x.group(1)), display_out)
display_out = image_pattern.sub(
lambda x: '<img src="data:image/png;base64,'
+ base64.b64encode(self[x.groups(1)[0]]).decode()
+ '" style="max-width: 400px; vertical-align: middle; margin: 4px;">',
display_out,
)
display_out = (
"<pre style='margin: 0px; padding: 0px; vertical-align: middle; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>"
+ display_out
+ "</pre>"
)
return display_out

def _send_to_event_queue(self, value):
"""For streaming in code.
Expand All @@ -924,6 +880,7 @@ def copy(self):
new_lm._variables = self._variables.copy()
new_lm._variables_log_probs = self._variables_log_probs.copy()
new_lm.opened_blocks = self.opened_blocks.copy()
new_lm._state = self._state.copy()

# create a new clean event queue
new_lm._event_queue = None # we start with no event queue because nobody is listening to us yet
Expand All @@ -938,7 +895,7 @@ def copy(self):

return new_lm

def _inplace_append(self, value, force_silent=False):
def _inplace_append(self, obj: Object, force_silent: bool = False):
"""This is the base way to add content to the current LM object that is being constructed.

All updates to the model state should eventually use this function.
Expand All @@ -951,7 +908,7 @@ def _inplace_append(self, value, force_silent=False):
"""

# update the byte state
self._state += str(value) # TODO: make _state to be bytes not a string
self._state.append(obj)

# see if we should update the display
if not force_silent:
Expand Down Expand Up @@ -995,20 +952,30 @@ def reset(self, clear_variables=True):
self._variables_log_probs = {}
return self

def _html(self):
out = self._state._html()
for context in reversed(self.opened_blocks):
_, closer = self.opened_blocks[context]
if closer is not None:
out += closer._html()
return out

def _repr_html_(self):
if ipython_is_imported:
clear_output(wait=True)
return self._html()

def _current_prompt(self):
def _current_prompt(self) -> str:
"""The current prompt in bytes (which is the state without the context close tags)."""
return format_pattern.sub("", self._state)
return str(self._state)

def __str__(self):
"""A string representation of the current model object (that includes context closers)."""
out = self._current_prompt()
out = str(self._state)
for context in reversed(self.opened_blocks):
out += format_pattern.sub("", self.opened_blocks[context][1])
_, closer = self.opened_blocks[context]
if closer is not None:
out += str(closer)
return out

def __add__(self, value):
Expand All @@ -1019,6 +986,8 @@ def __add__(self, value):
value : guidance grammar
The grammar used to extend the current model.
"""
# Import in function to guard against circular import
from ..library._role import RoleBlock
Comment on lines +989 to +990
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling that some may object to the import inside of a function... Is anyone else better at resolving issues with circular imports?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we haven't tried turning on ruff or flake8 yet.....


# create the new lm object we will return
# (we need to do this since Model objects are immutable)
Expand All @@ -1034,7 +1003,7 @@ def __add__(self, value):
new_blocks.append(context)

# mark this so we don't re-add when computing the opener or closer (even though we don't know the close text yet)
lm.opened_blocks[context] = (0, "")
lm.opened_blocks[context] = (0, None)

# find what old blocks need to be removed
old_blocks = []
Expand All @@ -1046,20 +1015,44 @@ def __add__(self, value):
del lm.opened_blocks[context]

# close any newly closed contexts
for (pos, close_text), context in old_blocks:
for (pos, closer), context in old_blocks:
assert closer is not None
if context.name is not None:
lm._variables[context.name] = format_pattern.sub(
"", lm._state[pos:]
)
lm += context.closer
# Capture
lm._variables[context.name] = str(lm._state[pos:])
lm._inplace_append(closer)

# apply any newly opened contexts (new from this object's perspective)
for context in new_blocks:
lm += context.opener
if isinstance(context, RoleBlock):
# Apply the opener (a grammar)
with grammar_only():
# TODO: be careful about the temp lm's display? (e.g. with silent())
tmp = lm + context.opener
open_text = str(tmp._state[len(lm._state):]) # get the new state added by calling the opener
# Add that new state as text in a RoleOpener
lm._inplace_append(
RoleOpener(
role_name=context.role_name,
text=open_text,
indent=getattr(lm, "indent_roles", True)
)
)
else:
lm += context.opener
with grammar_only():
# TODO: be careful about the temp lm's display? (e.g. with silent())
tmp = lm + context.closer
close_text = tmp._state[len(lm._state):] # get the new state added by calling the closer
lm.opened_blocks[context] = (len(lm._state), close_text)
close_text = str(tmp._state[len(lm._state):]) # get the new state added by calling the closer
if isinstance(context, RoleBlock):
closer = RoleCloser(
role_name=context.role_name,
text=close_text,
indent=getattr(lm, "indent_roles", True)
)
else:
closer = Text(text=close_text)
lm.opened_blocks[context] = (len(lm._state), closer)

# clear out names that we override
if context.name is not None:
Expand All @@ -1075,7 +1068,9 @@ def __add__(self, value):

# we have no embedded objects
if len(parts) == 1:
lm._inplace_append(value)
lm._inplace_append(
Text(text=value)
)
out = lm

# if we have embedded objects we have to convert the string to a grammar tree
Expand Down Expand Up @@ -1119,7 +1114,8 @@ def __add__(self, value):
)

# this flushes the display
out._inplace_append("")
# TODO: directly call _update_display?
out._inplace_append(Text(text=""))

return out

Expand Down Expand Up @@ -1147,9 +1143,8 @@ def __getitem__(self, key):
else:
for context in list(reversed(self.opened_blocks)):
if context.name == key:
return format_pattern.sub(
"", self._state[self.opened_blocks[context][0] :]
)
pos, _ = self.opened_blocks[context]
return str(self._state[pos:])

raise KeyError(f"Model does not contain the variable '{key}'")

Expand Down Expand Up @@ -1328,11 +1323,19 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
if len(chunk.new_bytes) > 0:
generated_value += new_text
if chunk.is_generated:
lm += f"<||_html:<span style='background-color: rgba({165*(1-chunk.new_bytes_prob) + 0}, {165*chunk.new_bytes_prob + 0}, 0, {0.15}); border-radius: 3px;' title='{chunk.new_bytes_prob}'>_||>"
lm += new_text
if chunk.is_generated:
lm += "<||_html:</span>_||>"

self._inplace_append(
Text(
text = new_text,
# TODO: this will be slightly wrong if we have a delayed byte string
probability = chunk.new_bytes_prob
)
)
else:
self._inplace_append(
Text(
text = new_text,
)
)
# last_is_generated = chunk.is_generated

if len(chunk.capture_groups) > 0:
Expand Down
Loading
Loading